diff options
Diffstat (limited to 'pkg/tcpip/network')
25 files changed, 1562 insertions, 600 deletions
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index 6905b9ccb..a72eb1aad 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -47,7 +47,7 @@ go_test( library = ":arp", deps = [ "//pkg/tcpip", - "//pkg/tcpip/network/internal/testutil", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", ], ) diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go index e867b3c3f..0df39ae81 100644 --- a/pkg/tcpip/network/arp/stats_test.go +++ b/pkg/tcpip/network/arp/stats_test.go @@ -19,8 +19,8 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) var _ stack.NetworkInterface = (*testInterface)(nil) diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler.go b/pkg/tcpip/network/internal/fragmentation/reassembler.go index 90075a70c..56b76a284 100644 --- a/pkg/tcpip/network/internal/fragmentation/reassembler.go +++ b/pkg/tcpip/network/internal/fragmentation/reassembler.go @@ -167,8 +167,7 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s resPkt := r.holes[0].pkt for i := 1; i < len(r.holes); i++ { - fragData := r.holes[i].pkt.Data() - resPkt.Data().ReadFromData(fragData, fragData.Size()) + stack.MergeFragment(resPkt, r.holes[i].pkt) } return resPkt, r.proto, true, memConsumed, nil } diff --git a/pkg/tcpip/network/internal/ip/BUILD b/pkg/tcpip/network/internal/ip/BUILD index d21b4c7ef..fd944ce99 100644 --- a/pkg/tcpip/network/internal/ip/BUILD +++ b/pkg/tcpip/network/internal/ip/BUILD @@ -6,6 +6,7 @@ go_library( name = "ip", srcs = [ "duplicate_address_detection.go", + "errors.go", "generic_multicast_protocol.go", "stats.go", ], diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go index eed49f5d2..5123b7d6a 100644 --- a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go +++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go @@ -83,6 +83,8 @@ func (d *DAD) Init(protocolMU sync.Locker, configs stack.DADConfigurations, opts panic(fmt.Sprintf("given a non-zero value for NonceSize (%d) but zero for ExtendDADTransmits", opts.NonceSize)) } + configs.Validate() + *d = DAD{ opts: opts, configs: configs, diff --git a/pkg/tcpip/network/internal/ip/errors.go b/pkg/tcpip/network/internal/ip/errors.go new file mode 100644 index 000000000..94f1cd1cb --- /dev/null +++ b/pkg/tcpip/network/internal/ip/errors.go @@ -0,0 +1,85 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ip + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +// ForwardingError represents an error that occured while trying to forward +// a packet. +type ForwardingError interface { + isForwardingError() + fmt.Stringer +} + +// ErrTTLExceeded indicates that the received packet's TTL has been exceeded. +type ErrTTLExceeded struct{} + +func (*ErrTTLExceeded) isForwardingError() {} + +func (*ErrTTLExceeded) String() string { return "ttl exceeded" } + +// ErrParameterProblem indicates the received packet had a problem with an IP +// parameter. +type ErrParameterProblem struct{} + +func (*ErrParameterProblem) isForwardingError() {} + +func (*ErrParameterProblem) String() string { return "parameter problem" } + +// ErrLinkLocalSourceAddress indicates the received packet had a link-local +// source address. +type ErrLinkLocalSourceAddress struct{} + +func (*ErrLinkLocalSourceAddress) isForwardingError() {} + +func (*ErrLinkLocalSourceAddress) String() string { return "link local destination address" } + +// ErrLinkLocalDestinationAddress indicates the received packet had a link-local +// destination address. +type ErrLinkLocalDestinationAddress struct{} + +func (*ErrLinkLocalDestinationAddress) isForwardingError() {} + +func (*ErrLinkLocalDestinationAddress) String() string { return "link local destination address" } + +// ErrNoRoute indicates that a route for the received packet couldn't be found. +type ErrNoRoute struct{} + +func (*ErrNoRoute) isForwardingError() {} + +func (*ErrNoRoute) String() string { return "no route" } + +// ErrMessageTooLong indicates the packet was too big for the outgoing MTU. +// +// +stateify savable +type ErrMessageTooLong struct{} + +func (*ErrMessageTooLong) isForwardingError() {} + +func (*ErrMessageTooLong) String() string { return "message too long" } + +// ErrOther indicates the packet coould not be forwarded for a reason +// captured by the contained error. +type ErrOther struct { + Err tcpip.Error +} + +func (*ErrOther) isForwardingError() {} + +func (e *ErrOther) String() string { return fmt.Sprintf("other tcpip error: %s", e.Err) } diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go index ac35d81e7..d22974b12 100644 --- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package ip holds IPv4/IPv6 common utilities. package ip import ( diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go index d06b26309..0c2b62127 100644 --- a/pkg/tcpip/network/internal/ip/stats.go +++ b/pkg/tcpip/network/internal/ip/stats.go @@ -16,80 +16,145 @@ package ip import "gvisor.dev/gvisor/pkg/tcpip" +// LINT.IfChange(MultiCounterIPForwardingStats) + +// MultiCounterIPForwardingStats holds IP forwarding statistics. Each counter +// may have several versions. +type MultiCounterIPForwardingStats struct { + // Unrouteable is the number of IP packets received which were dropped + // because the netstack could not construct a route to their + // destination. + Unrouteable tcpip.MultiCounterStat + + // ExhaustedTTL is the number of IP packets received which were dropped + // because their TTL was exhausted. + ExhaustedTTL tcpip.MultiCounterStat + + // LinkLocalSource is the number of IP packets which were dropped + // because they contained a link-local source address. + LinkLocalSource tcpip.MultiCounterStat + + // LinkLocalDestination is the number of IP packets which were dropped + // because they contained a link-local destination address. + LinkLocalDestination tcpip.MultiCounterStat + + // PacketTooBig is the number of IP packets which were dropped because they + // were too big for the outgoing MTU. + PacketTooBig tcpip.MultiCounterStat + + // ExtensionHeaderProblem is the number of IP packets which were dropped + // because of a problem encountered when processing an IPv6 extension + // header. + ExtensionHeaderProblem tcpip.MultiCounterStat + + // Errors is the number of IP packets received which could not be + // successfully forwarded. + Errors tcpip.MultiCounterStat +} + +// Init sets internal counters to track a and b counters. +func (m *MultiCounterIPForwardingStats) Init(a, b *tcpip.IPForwardingStats) { + m.Unrouteable.Init(a.Unrouteable, b.Unrouteable) + m.Errors.Init(a.Errors, b.Errors) + m.LinkLocalSource.Init(a.LinkLocalSource, b.LinkLocalSource) + m.LinkLocalDestination.Init(a.LinkLocalDestination, b.LinkLocalDestination) + m.ExtensionHeaderProblem.Init(a.ExtensionHeaderProblem, b.ExtensionHeaderProblem) + m.PacketTooBig.Init(a.PacketTooBig, b.PacketTooBig) + m.ExhaustedTTL.Init(a.ExhaustedTTL, b.ExhaustedTTL) +} + +// LINT.ThenChange(:MultiCounterIPForwardingStats, ../../../tcpip.go:IPForwardingStats) + // LINT.IfChange(MultiCounterIPStats) // MultiCounterIPStats holds IP statistics, each counter may have several // versions. type MultiCounterIPStats struct { - // PacketsReceived is the number of IP packets received from the link layer. + // PacketsReceived is the number of IP packets received from the link + // layer. PacketsReceived tcpip.MultiCounterStat - // DisabledPacketsReceived is the number of IP packets received from the link - // layer when the IP layer is disabled. + // ValidPacketsReceived is the number of valid IP packets that reached the IP + // layer. + ValidPacketsReceived tcpip.MultiCounterStat + + // DisabledPacketsReceived is the number of IP packets received from + // the link layer when the IP layer is disabled. DisabledPacketsReceived tcpip.MultiCounterStat - // InvalidDestinationAddressesReceived is the number of IP packets received - // with an unknown or invalid destination address. + // InvalidDestinationAddressesReceived is the number of IP packets + // received with an unknown or invalid destination address. InvalidDestinationAddressesReceived tcpip.MultiCounterStat - // InvalidSourceAddressesReceived is the number of IP packets received with a - // source address that should never have been received on the wire. + // InvalidSourceAddressesReceived is the number of IP packets received + // with a source address that should never have been received on the + // wire. InvalidSourceAddressesReceived tcpip.MultiCounterStat - // PacketsDelivered is the number of incoming IP packets that are successfully + // PacketsDelivered is the number of incoming IP packets successfully // delivered to the transport layer. PacketsDelivered tcpip.MultiCounterStat // PacketsSent is the number of IP packets sent via WritePacket. PacketsSent tcpip.MultiCounterStat - // OutgoingPacketErrors is the number of IP packets which failed to write to a - // link-layer endpoint. + // OutgoingPacketErrors is the number of IP packets which failed to + // write to a link-layer endpoint. OutgoingPacketErrors tcpip.MultiCounterStat - // MalformedPacketsReceived is the number of IP Packets that were dropped due - // to the IP packet header failing validation checks. + // MalformedPacketsReceived is the number of IP Packets that were + // dropped due to the IP packet header failing validation checks. MalformedPacketsReceived tcpip.MultiCounterStat - // MalformedFragmentsReceived is the number of IP Fragments that were dropped - // due to the fragment failing validation checks. + // MalformedFragmentsReceived is the number of IP Fragments that were + // dropped due to the fragment failing validation checks. MalformedFragmentsReceived tcpip.MultiCounterStat // IPTablesPreroutingDropped is the number of IP packets dropped in the // Prerouting chain. IPTablesPreroutingDropped tcpip.MultiCounterStat - // IPTablesInputDropped is the number of IP packets dropped in the Input - // chain. + // IPTablesInputDropped is the number of IP packets dropped in the + // Input chain. IPTablesInputDropped tcpip.MultiCounterStat - // IPTablesOutputDropped is the number of IP packets dropped in the Output - // chain. + // IPTablesForwardDropped is the number of IP packets dropped in the + // Forward chain. + IPTablesForwardDropped tcpip.MultiCounterStat + + // IPTablesOutputDropped is the number of IP packets dropped in the + // Output chain. IPTablesOutputDropped tcpip.MultiCounterStat - // IPTablesPostroutingDropped is the number of IP packets dropped in the - // Postrouting chain. + // IPTablesPostroutingDropped is the number of IP packets dropped in + // the Postrouting chain. IPTablesPostroutingDropped tcpip.MultiCounterStat - // TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option stats out - // of IPStats. + // TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option + // stats out of IPStats. // OptionTimestampReceived is the number of Timestamp options seen. OptionTimestampReceived tcpip.MultiCounterStat - // OptionRecordRouteReceived is the number of Record Route options seen. + // OptionRecordRouteReceived is the number of Record Route options + // seen. OptionRecordRouteReceived tcpip.MultiCounterStat - // OptionRouterAlertReceived is the number of Router Alert options seen. + // OptionRouterAlertReceived is the number of Router Alert options + // seen. OptionRouterAlertReceived tcpip.MultiCounterStat // OptionUnknownReceived is the number of unknown IP options seen. OptionUnknownReceived tcpip.MultiCounterStat + + // Forwarding collects stats related to IP forwarding. + Forwarding MultiCounterIPForwardingStats } // Init sets internal counters to track a and b counters. func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) { m.PacketsReceived.Init(a.PacketsReceived, b.PacketsReceived) + m.ValidPacketsReceived.Init(a.ValidPacketsReceived, b.ValidPacketsReceived) m.DisabledPacketsReceived.Init(a.DisabledPacketsReceived, b.DisabledPacketsReceived) m.InvalidDestinationAddressesReceived.Init(a.InvalidDestinationAddressesReceived, b.InvalidDestinationAddressesReceived) m.InvalidSourceAddressesReceived.Init(a.InvalidSourceAddressesReceived, b.InvalidSourceAddressesReceived) @@ -100,12 +165,14 @@ func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) { m.MalformedFragmentsReceived.Init(a.MalformedFragmentsReceived, b.MalformedFragmentsReceived) m.IPTablesPreroutingDropped.Init(a.IPTablesPreroutingDropped, b.IPTablesPreroutingDropped) m.IPTablesInputDropped.Init(a.IPTablesInputDropped, b.IPTablesInputDropped) + m.IPTablesForwardDropped.Init(a.IPTablesForwardDropped, b.IPTablesForwardDropped) m.IPTablesOutputDropped.Init(a.IPTablesOutputDropped, b.IPTablesOutputDropped) m.IPTablesPostroutingDropped.Init(a.IPTablesPostroutingDropped, b.IPTablesPostroutingDropped) m.OptionTimestampReceived.Init(a.OptionTimestampReceived, b.OptionTimestampReceived) m.OptionRecordRouteReceived.Init(a.OptionRecordRouteReceived, b.OptionRecordRouteReceived) m.OptionRouterAlertReceived.Init(a.OptionRouterAlertReceived, b.OptionRouterAlertReceived) m.OptionUnknownReceived.Init(a.OptionUnknownReceived, b.OptionUnknownReceived) + m.Forwarding.Init(&a.Forwarding, &b.Forwarding) } // LINT.ThenChange(:MultiCounterIPStats, ../../../tcpip.go:IPStats) diff --git a/pkg/tcpip/network/internal/testutil/BUILD b/pkg/tcpip/network/internal/testutil/BUILD index 1c4f583c7..cec3e62c4 100644 --- a/pkg/tcpip/network/internal/testutil/BUILD +++ b/pkg/tcpip/network/internal/testutil/BUILD @@ -4,10 +4,7 @@ package(licenses = ["notice"]) go_library( name = "testutil", - srcs = [ - "testutil.go", - "testutil_unsafe.go", - ], + srcs = ["testutil.go"], visibility = [ "//pkg/tcpip/network/arp:__pkg__", "//pkg/tcpip/network/internal/fragmentation:__pkg__", diff --git a/pkg/tcpip/network/internal/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go index e2cf24b67..605e9ef8d 100644 --- a/pkg/tcpip/network/internal/testutil/testutil.go +++ b/pkg/tcpip/network/internal/testutil/testutil.go @@ -19,8 +19,6 @@ package testutil import ( "fmt" "math/rand" - "reflect" - "strings" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -129,69 +127,3 @@ func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSi } return pkt } - -func checkFieldCounts(ref, multi reflect.Value) error { - refTypeName := ref.Type().Name() - multiTypeName := multi.Type().Name() - refNumField := ref.NumField() - multiNumField := multi.NumField() - - if refNumField != multiNumField { - return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName) - } - - return nil -} - -func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error { - s, ok := ref.Addr().Interface().(**tcpip.StatCounter) - if !ok { - return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name()) - } - - // The field names are expected to match (case insensitive). - if !strings.EqualFold(refName, multiName) { - return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName) - } - - base := (*s).Value() - m.Increment() - if (*s).Value() != base+1 { - return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName) - } - - return nil -} - -// ValidateMultiCounterStats verifies that every counter stored in multi is -// correctly tracking its counterpart in the given counters. -func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error { - for _, c := range counters { - if err := checkFieldCounts(c, multi); err != nil { - return err - } - } - - for i := 0; i < multi.NumField(); i++ { - multiName := multi.Type().Field(i).Name - multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i)) - - if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok { - for _, c := range counters { - if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil { - return err - } - } - } else { - var countersNextField []reflect.Value - for _, c := range counters { - countersNextField = append(countersNextField, c.Field(i)) - } - if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil { - return err - } - } - } - - return nil -} diff --git a/pkg/tcpip/network/internal/testutil/testutil_unsafe.go b/pkg/tcpip/network/internal/testutil/testutil_unsafe.go deleted file mode 100644 index 5ff764800..000000000 --- a/pkg/tcpip/network/internal/testutil/testutil_unsafe.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testutil - -import ( - "reflect" - "unsafe" -) - -// unsafeExposeUnexportedFields takes a Value and returns a version of it in -// which even unexported fields can be read and written. -func unsafeExposeUnexportedFields(a reflect.Value) reflect.Value { - return reflect.NewAt(a.Type(), unsafe.Pointer(a.UnsafeAddr())).Elem() -} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 74aad126c..bd63e0289 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -1996,8 +1996,8 @@ func TestJoinLeaveAllRoutersGroup(t *testing.T) { t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr) } - if err := s.SetForwarding(test.netProto, true); err != nil { - t.Fatalf("s.SetForwarding(%d, true): %s", test.netProto, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netProto, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", test.netProto, err) } if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) @@ -2005,8 +2005,8 @@ func TestJoinLeaveAllRoutersGroup(t *testing.T) { t.Fatalf("got s.IsInGroup(%d, %s) = false, want = true", nicID, test.allRoutersAddr) } - if err := s.SetForwarding(test.netProto, false); err != nil { - t.Fatalf("s.SetForwarding(%d, false): %s", test.netProto, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netProto, false); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, false): %s", test.netProto, err) } if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 7ee0495d9..c90974693 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -62,7 +62,7 @@ go_test( library = ":ipv4", deps = [ "//pkg/tcpip", - "//pkg/tcpip/network/internal/testutil", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", ], ) diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index f663fdc0b..d1a82b584 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -163,10 +163,12 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet return } - // Skip the ip header, then deliver the error. - pkt.Data().TrimFront(hlen) + // Keep needed information before trimming header. p := hdr.TransportProtocol() - e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, errInfo, pkt) + dstAddr := hdr.DestinationAddress() + // Skip the ip header, then deliver the error. + pkt.Data().DeleteFront(hlen) + e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, errInfo, pkt) } func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { @@ -336,14 +338,16 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { case header.ICMPv4DstUnreachable: received.dstUnreachable.Increment() - pkt.Data().TrimFront(header.ICMPv4MinimumSize) - switch h.Code() { + mtu := h.MTU() + code := h.Code() + pkt.Data().DeleteFront(header.ICMPv4MinimumSize) + switch code { case header.ICMPv4HostUnreachable: e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt) case header.ICMPv4PortUnreachable: e.handleControl(&icmpv4DestinationPortUnreachableSockError{}, pkt) case header.ICMPv4FragmentationNeeded: - networkMTU, err := calculateNetworkMTU(uint32(h.MTU()), header.IPv4MinimumSize) + networkMTU, err := calculateNetworkMTU(uint32(mtu), header.IPv4MinimumSize) if err != nil { networkMTU = 0 } @@ -383,6 +387,8 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { // icmpReason is a marker interface for IPv4 specific ICMP errors. type icmpReason interface { isICMPReason() + // isForwarding indicates whether or not the error arose while attempting to + // forward a packet. isForwarding() bool } @@ -442,6 +448,39 @@ func (r *icmpReasonParamProblem) isForwarding() bool { return r.forwarding } +// icmpReasonNetworkUnreachable is an error in which the network specified in +// the internet destination field of the datagram is unreachable. +type icmpReasonNetworkUnreachable struct{} + +func (*icmpReasonNetworkUnreachable) isICMPReason() {} +func (*icmpReasonNetworkUnreachable) isForwarding() bool { + // If we hit a Net Unreachable error, then we know we are operating as + // a router. As per RFC 792 page 5, Destination Unreachable Message, + // + // If, according to the information in the gateway's routing tables, + // the network specified in the internet destination field of a + // datagram is unreachable, e.g., the distance to the network is + // infinity, the gateway may send a destination unreachable message to + // the internet source host of the datagram. + return true +} + +// icmpReasonFragmentationNeeded is an error where a packet requires +// fragmentation while also having the Don't Fragment flag set, as per RFC 792 +// page 3, Destination Unreachable Message. +type icmpReasonFragmentationNeeded struct{} + +func (*icmpReasonFragmentationNeeded) isICMPReason() {} +func (*icmpReasonFragmentationNeeded) isForwarding() bool { + // If we hit a Don't Fragment error, then we know we are operating as a router. + // As per RFC 792 page 4, Destination Unreachable Message, + // + // Another case is when a datagram must be fragmented to be forwarded by a + // gateway yet the Don't Fragment flag is on. In this case the gateway must + // discard the datagram and may return a destination unreachable message. + return true +} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv4 and sends it back to the remote device that sent // the problematic packet. It incorporates as much of that packet as @@ -610,6 +649,14 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) counter = sent.dstUnreachable + case *icmpReasonNetworkUnreachable: + icmpHdr.SetType(header.ICMPv4DstUnreachable) + icmpHdr.SetCode(header.ICMPv4NetUnreachable) + counter = sent.dstUnreachable + case *icmpReasonFragmentationNeeded: + icmpHdr.SetType(header.ICMPv4DstUnreachable) + icmpHdr.SetCode(header.ICMPv4FragmentationNeeded) + counter = sent.dstUnreachable case *icmpReasonTTLExceeded: icmpHdr.SetType(header.ICMPv4TimeExceeded) icmpHdr.SetCode(header.ICMPv4TTLExceeded) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index a0bc06465..23178277a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -29,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/network/hash" "gvisor.dev/gvisor/pkg/tcpip/network/internal/fragmentation" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -62,9 +63,15 @@ const ( fragmentblockSize = 8 ) +const ( + forwardingDisabled = 0 + forwardingEnabled = 1 +) + var ipv4BroadcastAddr = header.IPv4Broadcast.WithPrefix() var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil) +var _ stack.ForwardingNetworkEndpoint = (*endpoint)(nil) var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) var _ stack.AddressableEndpoint = (*endpoint)(nil) var _ stack.NetworkEndpoint = (*endpoint)(nil) @@ -81,6 +88,12 @@ type endpoint struct { // Must be accessed using atomic operations. enabled uint32 + // forwarding is set to forwardingEnabled when the endpoint has forwarding + // enabled and forwardingDisabled when it is disabled. + // + // Must be accessed using atomic operations. + forwarding uint32 + mu struct { sync.RWMutex @@ -150,14 +163,32 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { delete(p.mu.eps, nicID) } -// transitionForwarding transitions the endpoint's forwarding status to -// forwarding. +// Forwarding implements stack.ForwardingNetworkEndpoint. +func (e *endpoint) Forwarding() bool { + return atomic.LoadUint32(&e.forwarding) == forwardingEnabled +} + +// setForwarding sets the forwarding status for the endpoint. // -// Must only be called when the forwarding status changes. -func (e *endpoint) transitionForwarding(forwarding bool) { +// Returns true if the forwarding status was updated. +func (e *endpoint) setForwarding(v bool) bool { + forwarding := uint32(forwardingDisabled) + if v { + forwarding = forwardingEnabled + } + + return atomic.SwapUint32(&e.forwarding, forwarding) != forwarding +} + +// SetForwarding implements stack.ForwardingNetworkEndpoint. +func (e *endpoint) SetForwarding(forwarding bool) { e.mu.Lock() defer e.mu.Unlock() + if !e.setForwarding(forwarding) { + return + } + if forwarding { // There does not seem to be an RFC requirement for a node to join the all // routers multicast address but @@ -433,6 +464,12 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIn } if packetMustBeFragmented(pkt, networkMTU) { + h := header.IPv4(pkt.NetworkHeader().View()) + if h.Flags()&header.IPv4FlagDontFragment != 0 && pkt.NetworkPacketInfo.IsForwardedPacket { + // TODO(gvisor.dev/issue/5919): Handle error condition in which DontFragment + // is set but the packet must be fragmented for the non-forwarding case. + return &tcpip.ErrMessageTooLong{} + } sent, remain, err := e.handleFragments(r, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we @@ -599,22 +636,25 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu } // forwardPacket attempts to forward a packet to its final destination. -func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { h := header.IPv4(pkt.NetworkHeader().View()) dstAddr := h.DestinationAddress() - if header.IsV4LinkLocalUnicastAddress(h.SourceAddress()) || header.IsV4LinkLocalUnicastAddress(dstAddr) || header.IsV4LinkLocalMulticastAddress(dstAddr) { - // As per RFC 3927 section 7, - // - // A router MUST NOT forward a packet with an IPv4 Link-Local source or - // destination address, irrespective of the router's default route - // configuration or routes obtained from dynamic routing protocols. - // - // A router which receives a packet with an IPv4 Link-Local source or - // destination address MUST NOT forward the packet. This prevents - // forwarding of packets back onto the network segment from which they - // originated, or to any other segment. - return nil + // As per RFC 3927 section 7, + // + // A router MUST NOT forward a packet with an IPv4 Link-Local source or + // destination address, irrespective of the router's default route + // configuration or routes obtained from dynamic routing protocols. + // + // A router which receives a packet with an IPv4 Link-Local source or + // destination address MUST NOT forward the packet. This prevents + // forwarding of packets back onto the network segment from which they + // originated, or to any other segment. + if header.IsV4LinkLocalUnicastAddress(h.SourceAddress()) { + return &ip.ErrLinkLocalSourceAddress{} + } + if header.IsV4LinkLocalUnicastAddress(dstAddr) || header.IsV4LinkLocalMulticastAddress(dstAddr) { + return &ip.ErrLinkLocalDestinationAddress{} } ttl := h.TTL() @@ -624,7 +664,12 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // If the gateway processing a datagram finds the time to live field // is zero it must discard the datagram. The gateway may also notify // the source host via the time exceeded message. - return e.protocol.returnError(&icmpReasonTTLExceeded{}, pkt) + // + // We return the original error rather than the result of returning + // the ICMP packet because the original error is more relevant to + // the caller. + _ = e.protocol.returnError(&icmpReasonTTLExceeded{}, pkt) + return &ip.ErrTTLExceeded{} } if opts := h.Options(); len(opts) != 0 { @@ -635,10 +680,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { pointer: optProblem.Pointer, forwarding: true, }, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() - e.stats.ip.MalformedPacketsReceived.Increment() } - return nil // option problems are not reported locally. + return &ip.ErrParameterProblem{} } copied := copy(opts, newOpts) if copied != len(newOpts) { @@ -655,18 +698,44 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { } } + stk := e.protocol.stack + // Check if the destination is owned by the stack. if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { + inNicName := stk.FindNICNameFromID(e.nic.ID()) + outNicName := stk.FindNICNameFromID(ep.nic.ID()) + if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + // iptables is telling us to drop the packet. + e.stats.ip.IPTablesForwardDropped.Increment() + return nil + } + ep.handleValidatedPacket(h, pkt) return nil } - r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) - if err != nil { - return err + r, err := stk.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) + switch err.(type) { + case nil: + case *tcpip.ErrNoRoute, *tcpip.ErrNetworkUnreachable: + // We return the original error rather than the result of returning + // the ICMP packet because the original error is more relevant to + // the caller. + _ = e.protocol.returnError(&icmpReasonNetworkUnreachable{}, pkt) + return &ip.ErrNoRoute{} + default: + return &ip.ErrOther{Err: err} } defer r.Release() + inNicName := stk.FindNICNameFromID(e.nic.ID()) + outNicName := stk.FindNICNameFromID(r.NICID()) + if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + // iptables is telling us to drop the packet. + e.stats.ip.IPTablesForwardDropped.Increment() + return nil + } + // We need to do a deep copy of the IP packet because // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do // not own it. @@ -680,10 +749,28 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // spent, the field must be decremented by 1. newHdr.SetTTL(ttl - 1) - return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: buffer.View(newHdr).ToVectorisedView(), - })) + IsForwardedPacket: true, + })); err.(type) { + case nil: + return nil + case *tcpip.ErrMessageTooLong: + // As per RFC 792, page 4, Destination Unreachable: + // + // Another case is when a datagram must be fragmented to be forwarded by a + // gateway yet the Don't Fragment flag is on. In this case the gateway must + // discard the datagram and may return a destination unreachable message. + // + // WriteHeaderIncludedPacket checks for the presence of the Don't Fragment bit + // while sending the packet and returns this error iff fragmentation is + // necessary and the bit is also set. + _ = e.protocol.returnError(&icmpReasonFragmentationNeeded{}, pkt) + return &ip.ErrMessageTooLong{} + default: + return &ip.ErrOther{Err: err} + } } // HandlePacket is called by the link layer when new ipv4 packets arrive for @@ -764,6 +851,7 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() stats := e.stats + stats.ip.ValidPacketsReceived.Increment() srcAddr := h.SourceAddress() dstAddr := h.DestinationAddress() @@ -794,11 +882,30 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) addressEndpoint.DecRef() pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast } else if !e.IsInGroup(dstAddr) { - if !e.protocol.Forwarding() { + if !e.Forwarding() { stats.ip.InvalidDestinationAddressesReceived.Increment() return } - _ = e.forwardPacket(pkt) + switch err := e.forwardPacket(pkt); err.(type) { + case nil: + return + case *ip.ErrLinkLocalSourceAddress: + stats.ip.Forwarding.LinkLocalSource.Increment() + case *ip.ErrLinkLocalDestinationAddress: + stats.ip.Forwarding.LinkLocalDestination.Increment() + case *ip.ErrTTLExceeded: + stats.ip.Forwarding.ExhaustedTTL.Increment() + case *ip.ErrNoRoute: + stats.ip.Forwarding.Unrouteable.Increment() + case *ip.ErrParameterProblem: + e.protocol.stack.Stats().MalformedRcvdPackets.Increment() + stats.ip.MalformedPacketsReceived.Increment() + case *ip.ErrMessageTooLong: + stats.ip.Forwarding.PacketTooBig.Increment() + default: + panic(fmt.Sprintf("unexpected error %s while trying to forward packet: %#v", err, pkt)) + } + stats.ip.Forwarding.Errors.Increment() return } @@ -955,8 +1062,8 @@ func (e *endpoint) Close() { // AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { - e.mu.Lock() - defer e.mu.Unlock() + e.mu.RLock() + defer e.mu.RUnlock() ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) if err == nil { @@ -967,8 +1074,8 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p // RemovePermanentAddress implements stack.AddressableEndpoint. func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() + e.mu.RLock() + defer e.mu.RUnlock() return e.mu.addressableEndpointState.RemovePermanentAddress(addr) } @@ -981,8 +1088,8 @@ func (e *endpoint) MainAddress() tcpip.AddressWithPrefix { // AcquireAssignedAddress implements stack.AddressableEndpoint. func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint { - e.mu.Lock() - defer e.mu.Unlock() + e.mu.RLock() + defer e.mu.RUnlock() loopback := e.nic.IsLoopback() return e.mu.addressableEndpointState.AcquireAssignedAddressOrMatching(localAddr, func(addressEndpoint stack.AddressEndpoint) bool { @@ -1067,7 +1174,6 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats { return &e.stats.localStats } -var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) var _ fragmentation.TimeoutHandler = (*protocol)(nil) @@ -1088,12 +1194,6 @@ type protocol struct { // Must be accessed using atomic operations. defaultTTL uint32 - // forwarding is set to 1 when the protocol has forwarding enabled and 0 - // when it is disabled. - // - // Must be accessed using atomic operations. - forwarding uint32 - ids []uint32 hashIV uint32 @@ -1206,35 +1306,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true } -// Forwarding implements stack.ForwardingNetworkProtocol. -func (p *protocol) Forwarding() bool { - return uint8(atomic.LoadUint32(&p.forwarding)) == 1 -} - -// setForwarding sets the forwarding status for the protocol. -// -// Returns true if the forwarding status was updated. -func (p *protocol) setForwarding(v bool) bool { - if v { - return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */) - } - return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */) -} - -// SetForwarding implements stack.ForwardingNetworkProtocol. -func (p *protocol) SetForwarding(v bool) { - p.mu.Lock() - defer p.mu.Unlock() - - if !p.setForwarding(v) { - return - } - - for _, ep := range p.mu.eps { - ep.transitionForwarding(v) - } -} - // calculateNetworkMTU calculates the network-layer payload MTU based on the // link-layer payload mtu. func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) { diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 7d413c455..da9cc0ae8 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -112,67 +112,103 @@ func TestExcludeBroadcast(t *testing.T) { }) } +type forwardedPacket struct { + fragments []fragmentInfo +} + func TestForwarding(t *testing.T) { const ( - nicID1 = 1 - nicID2 = 2 + incomingNICID = 1 + outgoingNICID = 2 randomSequence = 123 randomIdent = 42 randomTimeOffset = 0x10203040 ) - ipv4Addr1 := tcpip.AddressWithPrefix{ + incomingIPv4Addr := tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), PrefixLen: 8, } - ipv4Addr2 := tcpip.AddressWithPrefix{ + outgoingIPv4Addr := tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()), PrefixLen: 8, } - remoteIPv4Addr1 := tcpip.Address(net.ParseIP("10.0.0.2").To4()) - remoteIPv4Addr2 := tcpip.Address(net.ParseIP("11.0.0.2").To4()) + outgoingLinkAddr := tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") + remoteIPv4Addr1 := tcptestutil.MustParse4("10.0.0.2") + remoteIPv4Addr2 := tcptestutil.MustParse4("11.0.0.2") + unreachableIPv4Addr := tcptestutil.MustParse4("12.0.0.2") + multicastIPv4Addr := tcptestutil.MustParse4("225.0.0.0") + linkLocalIPv4Addr := tcptestutil.MustParse4("169.254.0.0") tests := []struct { - name string - TTL uint8 - expectErrorICMP bool - options header.IPv4Options - forwardedOptions header.IPv4Options - icmpType header.ICMPv4Type - icmpCode header.ICMPv4Code + name string + TTL uint8 + sourceAddr tcpip.Address + destAddr tcpip.Address + expectErrorICMP bool + ipFlags uint8 + mtu uint32 + payloadLength int + options header.IPv4Options + forwardedOptions header.IPv4Options + icmpType header.ICMPv4Type + icmpCode header.ICMPv4Code + expectPacketUnrouteableError bool + expectLinkLocalSourceError bool + expectLinkLocalDestError bool + expectPacketForwarded bool + expectedFragmentsForwarded []fragmentInfo }{ { name: "TTL of zero", TTL: 0, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, expectErrorICMP: true, icmpType: header.ICMPv4TimeExceeded, icmpCode: header.ICMPv4TTLExceeded, + mtu: ipv4.MaxTotalSize, }, { - name: "TTL of one", - TTL: 1, - expectErrorICMP: false, + name: "TTL of one", + TTL: 1, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + expectPacketForwarded: true, + mtu: ipv4.MaxTotalSize, }, { - name: "TTL of two", - TTL: 2, - expectErrorICMP: false, + name: "TTL of two", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + expectPacketForwarded: true, + mtu: ipv4.MaxTotalSize, }, { - name: "Max TTL", - TTL: math.MaxUint8, - expectErrorICMP: false, + name: "Max TTL", + TTL: math.MaxUint8, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + expectPacketForwarded: true, + mtu: ipv4.MaxTotalSize, }, { - name: "four EOL options", - TTL: 2, - expectErrorICMP: false, - options: header.IPv4Options{0, 0, 0, 0}, - forwardedOptions: header.IPv4Options{0, 0, 0, 0}, + name: "four EOL options", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + expectPacketForwarded: true, + mtu: ipv4.MaxTotalSize, + options: header.IPv4Options{0, 0, 0, 0}, + forwardedOptions: header.IPv4Options{0, 0, 0, 0}, }, { - name: "TS type 1 full", - TTL: 2, + name: "TS type 1 full", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + mtu: ipv4.MaxTotalSize, options: header.IPv4Options{ 68, 12, 13, 0xF1, 192, 168, 1, 12, @@ -183,8 +219,11 @@ func TestForwarding(t *testing.T) { icmpCode: header.ICMPv4UnusedCode, }, { - name: "TS type 0", - TTL: 2, + name: "TS type 0", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + mtu: ipv4.MaxTotalSize, options: header.IPv4Options{ 68, 24, 21, 0x00, 1, 2, 3, 4, @@ -201,10 +240,14 @@ func TestForwarding(t *testing.T) { 13, 14, 15, 16, 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock }, + expectPacketForwarded: true, }, { - name: "end of options list", - TTL: 2, + name: "end of options list", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + mtu: ipv4.MaxTotalSize, options: header.IPv4Options{ 68, 12, 13, 0x11, 192, 168, 1, 12, @@ -220,11 +263,89 @@ func TestForwarding(t *testing.T) { 0, 0, 0, // 7 bytes unknown option removed. 0, 0, 0, 0, }, + expectPacketForwarded: true, + }, + { + name: "Network unreachable", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: unreachableIPv4Addr, + expectErrorICMP: true, + mtu: ipv4.MaxTotalSize, + icmpType: header.ICMPv4DstUnreachable, + icmpCode: header.ICMPv4NetUnreachable, + expectPacketUnrouteableError: true, + }, + { + name: "Multicast destination", + TTL: 2, + destAddr: multicastIPv4Addr, + expectPacketUnrouteableError: true, + }, + { + name: "Link local destination", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: linkLocalIPv4Addr, + expectLinkLocalDestError: true, + }, + { + name: "Link local source", + TTL: 2, + sourceAddr: linkLocalIPv4Addr, + destAddr: remoteIPv4Addr2, + expectLinkLocalSourceError: true, + }, + { + name: "Fragmentation needed and DF set", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + ipFlags: header.IPv4FlagDontFragment, + // We've picked this MTU because it is: + // + // 1) Greater than the minimum MTU that IPv4 hosts are required to process + // (576 bytes). As per RFC 1812, Section 4.3.2.3: + // + // The ICMP datagram SHOULD contain as much of the original datagram as + // possible without the length of the ICMP datagram exceeding 576 bytes. + // + // Therefore, setting an MTU greater than 576 bytes ensures that we can fit a + // complete ICMP packet on the incoming endpoint (and make assertions about + // it). + // + // 2) Less than `ipv4.MaxTotalSize`, which lets us build an IPv4 packet whose + // size exceeds the MTU. + mtu: 1000, + payloadLength: 1004, + expectErrorICMP: true, + icmpType: header.ICMPv4DstUnreachable, + icmpCode: header.ICMPv4FragmentationNeeded, + }, + { + name: "Fragmentation needed and DF not set", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + mtu: 1000, + payloadLength: 1004, + expectPacketForwarded: true, + // Combined, these fragments have length of 1012 octets, which is equal to + // the length of the payload (1004 octets), plus the length of the ICMP + // header (8 octets). + expectedFragmentsForwarded: []fragmentInfo{ + // The first fragment has a length of the greatest multiple of 8 which is + // less than or equal to to `mtu - header.IPv4MinimumSize`. + {offset: 0, payloadSize: uint16(976), more: true}, + // The next fragment holds the rest of the packet. + {offset: uint16(976), payloadSize: 36, more: false}, + }, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { clock := faketime.NewManualClock() + s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, @@ -236,46 +357,52 @@ func TestForwarding(t *testing.T) { clock.Advance(time.Millisecond * randomTimeOffset) // We expect at most a single packet in response to our ICMP Echo Request. - e1 := channel.New(1, ipv4.MaxTotalSize, "") - if err := s.CreateNIC(nicID1, e1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + incomingEndpoint := channel.New(1, test.mtu, "") + if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) } - ipv4ProtoAddr1 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr1} - if err := s.AddProtocolAddress(nicID1, ipv4ProtoAddr1); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv4ProtoAddr1, err) + incomingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: incomingIPv4Addr} + if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv4ProtoAddr, err) } - e2 := channel.New(1, ipv4.MaxTotalSize, "") - if err := s.CreateNIC(nicID2, e2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + expectedEmittedPacketCount := 1 + if len(test.expectedFragmentsForwarded) > expectedEmittedPacketCount { + expectedEmittedPacketCount = len(test.expectedFragmentsForwarded) + } + outgoingEndpoint := channel.New(expectedEmittedPacketCount, test.mtu, outgoingLinkAddr) + if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) } - ipv4ProtoAddr2 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr2} - if err := s.AddProtocolAddress(nicID2, ipv4ProtoAddr2); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv4ProtoAddr2, err) + outgoingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: outgoingIPv4Addr} + if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv4ProtoAddr, err) } s.SetRouteTable([]tcpip.Route{ { - Destination: ipv4Addr1.Subnet(), - NIC: nicID1, + Destination: incomingIPv4Addr.Subnet(), + NIC: incomingNICID, }, { - Destination: ipv4Addr2.Subnet(), - NIC: nicID2, + Destination: outgoingIPv4Addr.Subnet(), + NIC: outgoingNICID, }, }) - if err := s.SetForwarding(header.IPv4ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", header.IPv4ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(header.IPv4ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", header.IPv4ProtocolNumber, err) } ipHeaderLength := header.IPv4MinimumSize + len(test.options) if ipHeaderLength > header.IPv4MaximumHeaderSize { t.Fatalf("got ipHeaderLength = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) } - totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize) - hdr := buffer.NewPrependable(int(totalLen)) - icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmpHeaderLength := header.ICMPv4MinimumSize + totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength + hdr := buffer.NewPrependable(totalLength) + hdr.Prepend(test.payloadLength) + icmp := header.ICMPv4(hdr.Prepend(icmpHeaderLength)) icmp.SetIdent(randomIdent) icmp.SetSequence(randomSequence) icmp.SetType(header.ICMPv4Echo) @@ -284,11 +411,12 @@ func TestForwarding(t *testing.T) { icmp.SetChecksum(^header.Checksum(icmp, 0)) ip := header.IPv4(hdr.Prepend(ipHeaderLength)) ip.Encode(&header.IPv4Fields{ - TotalLength: totalLen, + TotalLength: uint16(totalLength), Protocol: uint8(header.ICMPv4ProtocolNumber), TTL: test.TTL, - SrcAddr: remoteIPv4Addr1, - DstAddr: remoteIPv4Addr2, + SrcAddr: test.sourceAddr, + DstAddr: test.destAddr, + Flags: test.ipFlags, }) if len(test.options) != 0 { ip.SetHeaderLength(uint8(ipHeaderLength)) @@ -303,51 +431,122 @@ func TestForwarding(t *testing.T) { requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), }) - e1.InjectInbound(header.IPv4ProtocolNumber, requestPkt) + requestPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + incomingEndpoint.InjectInbound(header.IPv4ProtocolNumber, requestPkt) + + reply, ok := incomingEndpoint.Read() if test.expectErrorICMP { - reply, ok := e1.Read() if !ok { t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType) } + // We expect the ICMP packet to contain as much of the original packet as + // possible up to a limit of 576 bytes, split between payload, IP header, + // and ICMP header. + expectedICMPPayloadLength := func() int { + maxICMPPacketLength := header.IPv4MinimumProcessableDatagramSize + maxICMPPayloadLength := maxICMPPacketLength - icmpHeaderLength - ipHeaderLength + if len(hdr.View()) > maxICMPPayloadLength { + return maxICMPPayloadLength + } + return len(hdr.View()) + } + checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), - checker.SrcAddr(ipv4Addr1.Address), - checker.DstAddr(remoteIPv4Addr1), + checker.SrcAddr(incomingIPv4Addr.Address), + checker.DstAddr(test.sourceAddr), checker.TTL(ipv4.DefaultTTL), checker.ICMPv4( checker.ICMPv4Checksum(), checker.ICMPv4Type(test.icmpType), checker.ICMPv4Code(test.icmpCode), - checker.ICMPv4Payload([]byte(hdr.View())), + checker.ICMPv4Payload([]byte(hdr.View()[0:expectedICMPPayloadLength()])), ), ) + } else if ok { + t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", reply) + } - if n := e2.Drain(); n != 0 { - t.Fatalf("got e2.Drain() = %d, want = 0", n) + if test.expectPacketForwarded { + if len(test.expectedFragmentsForwarded) != 0 { + fragmentedPackets := []*stack.PacketBuffer{} + for i := 0; i < len(test.expectedFragmentsForwarded); i++ { + reply, ok = outgoingEndpoint.Read() + if !ok { + t.Fatal("expected ICMP Echo fragment through outgoing NIC") + } + fragmentedPackets = append(fragmentedPackets, reply.Pkt) + } + + // The forwarded packet's TTL will have been decremented. + ipHeader := header.IPv4(requestPkt.NetworkHeader().View()) + ipHeader.SetTTL(ipHeader.TTL() - 1) + + // Forwarded packets have available header bytes equalling the sum of the + // maximum IP header size and the maximum size allocated for link layer + // headers. In this case, no size is allocated for link layer headers. + expectedAvailableHeaderBytes := header.IPv4MaximumHeaderSize + if err := compareFragments(fragmentedPackets, requestPkt, uint32(test.mtu), test.expectedFragmentsForwarded, header.ICMPv4ProtocolNumber, true /* withIPHeader */, expectedAvailableHeaderBytes); err != nil { + t.Error(err) + } + } else { + reply, ok = outgoingEndpoint.Read() + if !ok { + t.Fatal("expected ICMP Echo packet through outgoing NIC") + } + + checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.SrcAddr(test.sourceAddr), + checker.DstAddr(test.destAddr), + checker.TTL(test.TTL-1), + checker.IPv4Options(test.forwardedOptions), + checker.ICMPv4( + checker.ICMPv4Checksum(), + checker.ICMPv4Type(header.ICMPv4Echo), + checker.ICMPv4Code(header.ICMPv4UnusedCode), + checker.ICMPv4Payload(nil), + ), + ) } } else { - reply, ok := e2.Read() - if !ok { - t.Fatal("expected ICMP Echo packet through outgoing NIC") + if reply, ok = outgoingEndpoint.Read(); ok { + t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply) + } + } + boolToInt := func(val bool) uint64 { + if val { + return 1 } + return 0 + } - checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), - checker.SrcAddr(remoteIPv4Addr1), - checker.DstAddr(remoteIPv4Addr2), - checker.TTL(test.TTL-1), - checker.IPv4Options(test.forwardedOptions), - checker.ICMPv4( - checker.ICMPv4Checksum(), - checker.ICMPv4Type(header.ICMPv4Echo), - checker.ICMPv4Code(header.ICMPv4UnusedCode), - checker.ICMPv4Payload(nil), - ), - ) + if got, want := s.Stats().IP.Forwarding.LinkLocalSource.Value(), boolToInt(test.expectLinkLocalSourceError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.LinkLocalSource.Value() = %d, want = %d", got, want) + } - if n := e1.Drain(); n != 0 { - t.Fatalf("got e1.Drain() = %d, want = 0", n) - } + if got, want := s.Stats().IP.Forwarding.LinkLocalDestination.Value(), boolToInt(test.expectLinkLocalDestError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.LinkLocalDestination.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), boolToInt(test.icmpType == header.ICMPv4ParamProblem); got != want { + t.Errorf("got s.Stats().IP.MalformedPacketsReceived.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.ExhaustedTTL.Value(), boolToInt(test.TTL <= 0); got != want { + t.Errorf("got s.Stats().IP.Forwarding.ExhaustedTTL.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.Unrouteable.Value(), boolToInt(test.expectPacketUnrouteableError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.Unrouteable.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.Errors.Value(), boolToInt(!test.expectPacketForwarded); got != want { + t.Errorf("got s.Stats().IP.Forwarding.Errors.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.PacketTooBig.Value(), boolToInt(test.icmpCode == header.ICMPv4FragmentationNeeded); got != want { + t.Errorf("got s.Stats().IP.Forwarding.PacketTooBig.Value() = %d, want = %d", got, want) } }) } @@ -1170,13 +1369,25 @@ func TestIPv4Sanity(t *testing.T) { } } -// comparePayloads compared the contents of all the packets against the contents -// of the source packet. -func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error { +// compareFragments compares the contents of a set of fragmented packets against +// the contents of a source packet. +// +// If withIPHeader is set to true, we will validate the fragmented packets' IP +// headers against the source packet's IP header. If set to false, we validate +// the fragmented packets' IP headers against each other. +func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber, withIPHeader bool, expectedAvailableHeaderBytes int) error { // Make a complete array of the sourcePacket packet. - source := header.IPv4(packets[0].NetworkHeader().View()) + var source header.IPv4 vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views()) - source = append(source, vv.ToView()...) + + // If the packet to be fragmented contains an IPv4 header, use that header for + // validating fragment headers. Else, use the header of the first fragment. + if withIPHeader { + source = header.IPv4(vv.ToView()) + } else { + source = header.IPv4(packets[0].NetworkHeader().View()) + source = append(source, vv.ToView()...) + } // Make a copy of the IP header, which will be modified in some fields to make // an expected header. @@ -1199,12 +1410,12 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB if got := fragmentIPHeader.TransportProtocol(); got != proto { return fmt.Errorf("fragment #%d: got fragmentIPHeader.TransportProtocol() = %d, want = %d", i, got, uint8(proto)) } - if got := packet.AvailableHeaderBytes(); got != extraHeaderReserve { - return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve) - } if got, want := packet.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber; got != want { return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, got, want) } + if got := packet.AvailableHeaderBytes(); got != expectedAvailableHeaderBytes { + return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, expectedAvailableHeaderBytes) + } if got, want := fragmentIPHeader.CalculateChecksum(), uint16(0xffff); got != want { return fmt.Errorf("fragment #%d: got ip.CalculateChecksum() = %#x, want = %#x", i, got, want) } @@ -1220,6 +1431,14 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB sourceCopy.SetTotalLength(wantFragments[i].payloadSize + header.IPv4MinimumSize) sourceCopy.SetChecksum(0) sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum()) + + // If we are validating against the original IP header, we should exclude the + // ID field, which will only be set fo fragmented packets. + if withIPHeader { + fragmentIPHeader.SetID(0) + fragmentIPHeader.SetChecksum(0) + fragmentIPHeader.SetChecksum(^fragmentIPHeader.CalculateChecksum()) + } if diff := cmp.Diff(fragmentIPHeader[:fragmentIPHeader.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]); diff != "" { return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff) } @@ -1348,7 +1567,7 @@ func TestFragmentationWritePacket(t *testing.T) { if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) } - if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { + if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber, false /* withIPHeader */, extraHeaderReserve); err != nil { t.Error(err) } }) @@ -1429,7 +1648,7 @@ func TestFragmentationWritePackets(t *testing.T) { } fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore] - if err := compareFragments(fragments, pkt, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { + if err := compareFragments(fragments, pkt, ft.mtu, ft.wantFragments, tcp.ProtocolNumber, false /* withIPHeader */, extraHeaderReserve); err != nil { t.Error(err) } }) diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go index a637f9d50..d1f9e3cf5 100644 --- a/pkg/tcpip/network/ipv4/stats_test.go +++ b/pkg/tcpip/network/ipv4/stats_test.go @@ -19,8 +19,8 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) var _ stack.NetworkInterface = (*testInterface)(nil) diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index db998e83e..f99cbf8f3 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -45,6 +45,7 @@ go_test( "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/internal/testutil", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 1319db32b..307e1972d 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -181,10 +181,13 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe return } + // Keep needed information before trimming header. + p := hdr.TransportProtocol() + dstAddr := hdr.DestinationAddress() + // Skip the IP header, then handle the fragmentation header if there // is one. - pkt.Data().TrimFront(header.IPv6MinimumSize) - p := hdr.TransportProtocol() + pkt.Data().DeleteFront(header.IPv6MinimumSize) if p == header.IPv6FragmentHeader { f, ok := pkt.Data().PullUp(header.IPv6FragmentHeaderSize) if !ok { @@ -196,14 +199,14 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe // because they don't have the transport headers. return } + p = fragHdr.TransportProtocol() // Skip fragmentation header and find out the actual protocol // number. - pkt.Data().TrimFront(header.IPv6FragmentHeaderSize) - p = fragHdr.TransportProtocol() + pkt.Data().DeleteFront(header.IPv6FragmentHeaderSize) } - e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, transErr, pkt) + e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, transErr, pkt) } // getLinkAddrOption searches NDP options for a given link address option using @@ -327,11 +330,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r received.invalid.Increment() return } - pkt.Data().TrimFront(header.ICMPv6PacketTooBigMinimumSize) networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize) if err != nil { networkMTU = 0 } + pkt.Data().DeleteFront(header.ICMPv6PacketTooBigMinimumSize) e.handleControl(&icmpv6PacketTooBigSockError{mtu: networkMTU}, pkt) case header.ICMPv6DstUnreachable: @@ -341,8 +344,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r received.invalid.Increment() return } - pkt.Data().TrimFront(header.ICMPv6DstUnreachableMinimumSize) - switch header.ICMPv6(hdr).Code() { + code := header.ICMPv6(hdr).Code() + pkt.Data().DeleteFront(header.ICMPv6DstUnreachableMinimumSize) + switch code { case header.ICMPv6NetworkUnreachable: e.handleControl(&icmpv6DestinationNetworkUnreachableSockError{}, pkt) case header.ICMPv6PortUnreachable: @@ -741,11 +745,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r return } - stack := e.protocol.stack - - // Is the networking stack operating as a router? - if !stack.Forwarding(ProtocolNumber) { - // ... No, silently drop the packet. + if !e.Forwarding() { received.routerOnlyPacketsDroppedByHost.Increment() return } @@ -951,6 +951,19 @@ func (*endpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo // icmpReason is a marker interface for IPv6 specific ICMP errors. type icmpReason interface { isICMPReason() + // isForwarding indicates whether or not the error arose while attempting to + // forward a packet. + isForwarding() bool + // respondToMulticast indicates whether this error falls under the exception + // outlined by RFC 4443 section 2.4 point e.3 exception 2: + // + // (e.3) A packet destined to an IPv6 multicast address. (There are two + // exceptions to this rule: (1) the Packet Too Big Message (Section 3.2) to + // allow Path MTU discovery to work for IPv6 multicast, and (2) the Parameter + // Problem Message, Code 2 (Section 3.4) reporting an unrecognized IPv6 + // option (see Section 4.2 of [IPv6]) that has the Option Type highest- + // order two bits set to 10). + respondsToMulticast() bool } // icmpReasonParameterProblem is an error during processing of extension headers @@ -958,18 +971,6 @@ type icmpReason interface { type icmpReasonParameterProblem struct { code header.ICMPv6Code - // respondToMulticast indicates that we are sending a packet that falls under - // the exception outlined by RFC 4443 section 2.4 point e.3 exception 2: - // - // (e.3) A packet destined to an IPv6 multicast address. (There are - // two exceptions to this rule: (1) the Packet Too Big Message - // (Section 3.2) to allow Path MTU discovery to work for IPv6 - // multicast, and (2) the Parameter Problem Message, Code 2 - // (Section 3.4) reporting an unrecognized IPv6 option (see - // Section 4.2 of [IPv6]) that has the Option Type highest- - // order two bits set to 10). - respondToMulticast bool - // pointer is defined in the RFC 4443 setion 3.4 which reads: // // Pointer Identifies the octet offset within the invoking packet @@ -979,9 +980,20 @@ type icmpReasonParameterProblem struct { // packet if the field in error is beyond what can fit // in the maximum size of an ICMPv6 error message. pointer uint32 + + forwarding bool + + respondToMulticast bool } func (*icmpReasonParameterProblem) isICMPReason() {} +func (p *icmpReasonParameterProblem) isForwarding() bool { + return p.forwarding +} + +func (p *icmpReasonParameterProblem) respondsToMulticast() bool { + return p.respondToMulticast +} // icmpReasonPortUnreachable is an error where the transport protocol has no // listener and no alternative means to inform the sender. @@ -989,12 +1001,76 @@ type icmpReasonPortUnreachable struct{} func (*icmpReasonPortUnreachable) isICMPReason() {} +func (*icmpReasonPortUnreachable) isForwarding() bool { + return false +} + +func (*icmpReasonPortUnreachable) respondsToMulticast() bool { + return false +} + +// icmpReasonNetUnreachable is an error where no route can be found to the +// network of the final destination. +type icmpReasonNetUnreachable struct{} + +func (*icmpReasonNetUnreachable) isICMPReason() {} + +func (*icmpReasonNetUnreachable) isForwarding() bool { + // If we hit a Network Unreachable error, then we also know we are + // operating as a router. As per RFC 4443 section 3.1: + // + // If the reason for the failure to deliver is lack of a matching + // entry in the forwarding node's routing table, the Code field is + // set to 0 (Network Unreachable). + return true +} + +func (*icmpReasonNetUnreachable) respondsToMulticast() bool { + return false +} + +// icmpReasonFragmentationNeeded is an error where a packet is to big to be sent +// out through the outgoing MTU, as per RFC 4443 page 9, Packet Too Big Message. +type icmpReasonPacketTooBig struct{} + +func (*icmpReasonPacketTooBig) isICMPReason() {} + +func (*icmpReasonPacketTooBig) isForwarding() bool { + // If we hit a Packet Too Big error, then we know we are operating as a router. + // As per RFC 4443 section 3.2: + // + // A Packet Too Big MUST be sent by a router in response to a packet that it + // cannot forward because the packet is larger than the MTU of the outgoing + // link. + return true +} + +func (*icmpReasonPacketTooBig) respondsToMulticast() bool { + return true +} + // icmpReasonHopLimitExceeded is an error where a packet's hop limit exceeded in // transit to its final destination, as per RFC 4443 section 3.3. type icmpReasonHopLimitExceeded struct{} func (*icmpReasonHopLimitExceeded) isICMPReason() {} +func (*icmpReasonHopLimitExceeded) isForwarding() bool { + // If we hit a Hop Limit Exceeded error, then we know we are operating + // as a router. As per RFC 4443 section 3.3: + // + // If a router receives a packet with a Hop Limit of zero, or if a + // router decrements a packet's Hop Limit to zero, it MUST discard + // the packet and originate an ICMPv6 Time Exceeded message with Code + // 0 to the source of the packet. This indicates either a routing + // loop or too small an initial Hop Limit value. + return true +} + +func (*icmpReasonHopLimitExceeded) respondsToMulticast() bool { + return false +} + // icmpReasonReassemblyTimeout is an error where insufficient fragments are // received to complete reassembly of a packet within a configured time after // the reception of the first-arriving fragment of that packet. @@ -1002,6 +1078,14 @@ type icmpReasonReassemblyTimeout struct{} func (*icmpReasonReassemblyTimeout) isICMPReason() {} +func (*icmpReasonReassemblyTimeout) isForwarding() bool { + return false +} + +func (*icmpReasonReassemblyTimeout) respondsToMulticast() bool { + return false +} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv6 and sends it. func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip.Error { @@ -1030,25 +1114,12 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip // Section 4.2 of [IPv6]) that has the Option Type highest- // order two bits set to 10). // - var allowResponseToMulticast bool - if reason, ok := reason.(*icmpReasonParameterProblem); ok { - allowResponseToMulticast = reason.respondToMulticast - } - + allowResponseToMulticast := reason.respondsToMulticast() isOrigDstMulticast := header.IsV6MulticastAddress(origIPHdrDst) if (!allowResponseToMulticast && isOrigDstMulticast) || origIPHdrSrc == header.IPv6Any { return nil } - // If we hit a Hop Limit Exceeded error, then we know we are operating as a - // router. As per RFC 4443 section 3.3: - // - // If a router receives a packet with a Hop Limit of zero, or if a - // router decrements a packet's Hop Limit to zero, it MUST discard the - // packet and originate an ICMPv6 Time Exceeded message with Code 0 to - // the source of the packet. This indicates either a routing loop or - // too small an initial Hop Limit value. - // // If we are operating as a router, do not use the packet's destination // address as the response's source address as we should not own the // destination address of a packet we are forwarding. @@ -1058,7 +1129,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip // packet as "multicast addresses must not be used as source addresses in IPv6 // packets", as per RFC 4291 section 2.7. localAddr := origIPHdrDst - if _, ok := reason.(*icmpReasonHopLimitExceeded); ok || isOrigDstMulticast { + if reason.isForwarding() || isOrigDstMulticast { localAddr = "" } // Even if we were able to receive a packet from some remote, we may not have @@ -1147,6 +1218,14 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpHdr.SetType(header.ICMPv6DstUnreachable) icmpHdr.SetCode(header.ICMPv6PortUnreachable) counter = sent.dstUnreachable + case *icmpReasonNetUnreachable: + icmpHdr.SetType(header.ICMPv6DstUnreachable) + icmpHdr.SetCode(header.ICMPv6NetworkUnreachable) + counter = sent.dstUnreachable + case *icmpReasonPacketTooBig: + icmpHdr.SetType(header.ICMPv6PacketTooBig) + icmpHdr.SetCode(header.ICMPv6UnusedCode) + counter = sent.packetTooBig case *icmpReasonHopLimitExceeded: icmpHdr.SetType(header.ICMPv6TimeExceeded) icmpHdr.SetCode(header.ICMPv6HopLimitExceeded) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index e457be3cf..040cd4bc8 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -673,8 +673,9 @@ func TestICMPChecksumValidationSimple(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) if isRouter { - // Enabling forwarding makes the stack act as a router. - s.SetForwarding(ProtocolNumber, true) + if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err) + } } if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(_, _) = %s", err) diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index f7510c243..95e11ac51 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -63,6 +63,11 @@ const ( buckets = 2048 ) +const ( + forwardingDisabled = 0 + forwardingEnabled = 1 +) + // policyTable is the default policy table defined in RFC 6724 section 2.1. // // A more human-readable version: @@ -168,6 +173,7 @@ func getLabel(addr tcpip.Address) uint8 { var _ stack.DuplicateAddressDetector = (*endpoint)(nil) var _ stack.LinkAddressResolver = (*endpoint)(nil) var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil) +var _ stack.ForwardingNetworkEndpoint = (*endpoint)(nil) var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) var _ stack.AddressableEndpoint = (*endpoint)(nil) var _ stack.NetworkEndpoint = (*endpoint)(nil) @@ -187,6 +193,12 @@ type endpoint struct { // Must be accessed using atomic operations. enabled uint32 + // forwarding is set to forwardingEnabled when the endpoint has forwarding + // enabled and forwardingDisabled when it is disabled. + // + // Must be accessed using atomic operations. + forwarding uint32 + mu struct { sync.RWMutex @@ -405,27 +417,39 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr t } } -// transitionForwarding transitions the endpoint's forwarding status to -// forwarding. +// Forwarding implements stack.ForwardingNetworkEndpoint. +func (e *endpoint) Forwarding() bool { + return atomic.LoadUint32(&e.forwarding) == forwardingEnabled +} + +// setForwarding sets the forwarding status for the endpoint. // -// Must only be called when the forwarding status changes. -func (e *endpoint) transitionForwarding(forwarding bool) { +// Returns true if the forwarding status was updated. +func (e *endpoint) setForwarding(v bool) bool { + forwarding := uint32(forwardingDisabled) + if v { + forwarding = forwardingEnabled + } + + return atomic.SwapUint32(&e.forwarding, forwarding) != forwarding +} + +// SetForwarding implements stack.ForwardingNetworkEndpoint. +func (e *endpoint) SetForwarding(forwarding bool) { + e.mu.Lock() + defer e.mu.Unlock() + + if !e.setForwarding(forwarding) { + return + } + allRoutersGroups := [...]tcpip.Address{ header.IPv6AllRoutersInterfaceLocalMulticastAddress, header.IPv6AllRoutersLinkLocalMulticastAddress, header.IPv6AllRoutersSiteLocalMulticastAddress, } - e.mu.Lock() - defer e.mu.Unlock() - if forwarding { - // When transitioning into an IPv6 router, host-only state (NDP discovered - // routers, discovered on-link prefixes, and auto-generated addresses) is - // cleaned up/invalidated and NDP router solicitations are stopped. - e.mu.ndp.stopSolicitingRouters() - e.mu.ndp.cleanupState(true /* hostOnly */) - // As per RFC 4291 section 2.8: // // A router is required to recognize all addresses that a host is @@ -449,28 +473,19 @@ func (e *endpoint) transitionForwarding(forwarding bool) { panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", g, err)) } } - - return - } - - for _, g := range allRoutersGroups { - switch err := e.leaveGroupLocked(g).(type) { - case nil: - case *tcpip.ErrBadLocalAddress: - // The endpoint may have already left the multicast group. - default: - panic(fmt.Sprintf("e.leaveGroupLocked(%s): %s", g, err)) + } else { + for _, g := range allRoutersGroups { + switch err := e.leaveGroupLocked(g).(type) { + case nil: + case *tcpip.ErrBadLocalAddress: + // The endpoint may have already left the multicast group. + default: + panic(fmt.Sprintf("e.leaveGroupLocked(%s): %s", g, err)) + } } } - // When transitioning into an IPv6 host, NDP router solicitations are - // started if the endpoint is enabled. - // - // If the endpoint is not currently enabled, routers will be solicited when - // the endpoint becomes enabled (if it is still a host). - if e.Enabled() { - e.mu.ndp.startSolicitingRouters() - } + e.mu.ndp.forwardingChanged(forwarding) } // Enable implements stack.NetworkEndpoint. @@ -552,17 +567,7 @@ func (e *endpoint) Enable() tcpip.Error { e.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime) } - // If we are operating as a router, then do not solicit routers since we - // won't process the RAs anyway. - // - // Routers do not process Router Advertisements (RA) the same way a host - // does. That is, routers do not learn from RAs (e.g. on-link prefixes - // and default routers). Therefore, soliciting RAs from other routers on - // a link is unnecessary for routers. - if !e.protocol.Forwarding() { - e.mu.ndp.startSolicitingRouters() - } - + e.mu.ndp.startSolicitingRouters() return nil } @@ -613,7 +618,7 @@ func (e *endpoint) disableLocked() { return true }) - e.mu.ndp.cleanupState(false /* hostOnly */) + e.mu.ndp.cleanupState() // The endpoint may have already left the multicast group. switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress).(type) { @@ -786,6 +791,12 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol } if packetMustBeFragmented(pkt, networkMTU) { + if pkt.NetworkPacketInfo.IsForwardedPacket { + // As per RFC 2460, section 4.5: + // Unlike IPv4, fragmentation in IPv6 is performed only by source nodes, + // not by routers along a packet's delivery path. + return &tcpip.ErrMessageTooLong{} + } sent, remain, err := e.handleFragments(r, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we @@ -928,16 +939,19 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu } // forwardPacket attempts to forward a packet to its final destination. -func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { h := header.IPv6(pkt.NetworkHeader().View()) dstAddr := h.DestinationAddress() - if header.IsV6LinkLocalUnicastAddress(h.SourceAddress()) || header.IsV6LinkLocalUnicastAddress(dstAddr) || header.IsV6LinkLocalMulticastAddress(dstAddr) { - // As per RFC 4291 section 2.5.6, - // - // Routers must not forward any packets with Link-Local source or - // destination addresses to other links. - return nil + // As per RFC 4291 section 2.5.6, + // + // Routers must not forward any packets with Link-Local source or + // destination addresses to other links. + if header.IsV6LinkLocalUnicastAddress(h.SourceAddress()) { + return &ip.ErrLinkLocalSourceAddress{} + } + if header.IsV6LinkLocalUnicastAddress(dstAddr) || header.IsV6LinkLocalMulticastAddress(dstAddr) { + return &ip.ErrLinkLocalDestinationAddress{} } hopLimit := h.HopLimit() @@ -949,21 +963,56 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // packet and originate an ICMPv6 Time Exceeded message with Code 0 to // the source of the packet. This indicates either a routing loop or // too small an initial Hop Limit value. - return e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt) + // + // We return the original error rather than the result of returning + // the ICMP packet because the original error is more relevant to + // the caller. + _ = e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt) + return &ip.ErrTTLExceeded{} } + stk := e.protocol.stack + // Check if the destination is owned by the stack. if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { + inNicName := stk.FindNICNameFromID(e.nic.ID()) + outNicName := stk.FindNICNameFromID(ep.nic.ID()) + if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + // iptables is telling us to drop the packet. + e.stats.ip.IPTablesForwardDropped.Increment() + return nil + } + ep.handleValidatedPacket(h, pkt) return nil } - r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) - if err != nil { - return err + // Check extension headers for any errors requiring action during forwarding. + if err := e.processExtensionHeaders(h, pkt, true /* forwarding */); err != nil { + return &ip.ErrParameterProblem{} + } + + r, err := stk.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) + switch err.(type) { + case nil: + case *tcpip.ErrNoRoute, *tcpip.ErrNetworkUnreachable: + // We return the original error rather than the result of returning the + // ICMP packet because the original error is more relevant to the caller. + _ = e.protocol.returnError(&icmpReasonNetUnreachable{}, pkt) + return &ip.ErrNoRoute{} + default: + return &ip.ErrOther{Err: err} } defer r.Release() + inNicName := stk.FindNICNameFromID(e.nic.ID()) + outNicName := stk.FindNICNameFromID(r.NICID()) + if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + // iptables is telling us to drop the packet. + e.stats.ip.IPTablesForwardDropped.Increment() + return nil + } + // We need to do a deep copy of the IP packet because // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do // not own it. @@ -975,10 +1024,23 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // each node that forwards the packet. newHdr.SetHopLimit(hopLimit - 1) - return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: buffer.View(newHdr).ToVectorisedView(), - })) + IsForwardedPacket: true, + })); err.(type) { + case nil: + return nil + case *tcpip.ErrMessageTooLong: + // As per RFC 4443, section 3.2: + // A Packet Too Big MUST be sent by a router in response to a packet that + // it cannot forward because the packet is larger than the MTU of the + // outgoing link. + _ = e.protocol.returnError(&icmpReasonPacketTooBig{}, pkt) + return &ip.ErrMessageTooLong{} + default: + return &ip.ErrOther{Err: err} + } } // HandlePacket is called by the link layer when new ipv6 packets arrive for @@ -1059,6 +1121,8 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() stats := e.stats.ip + stats.ValidPacketsReceived.Increment() + srcAddr := h.SourceAddress() dstAddr := h.DestinationAddress() @@ -1075,15 +1139,54 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil { addressEndpoint.DecRef() } else if !e.IsInGroup(dstAddr) { - if !e.protocol.Forwarding() { + if !e.Forwarding() { stats.InvalidDestinationAddressesReceived.Increment() return } + switch err := e.forwardPacket(pkt); err.(type) { + case nil: + return + case *ip.ErrLinkLocalSourceAddress: + e.stats.ip.Forwarding.LinkLocalSource.Increment() + case *ip.ErrLinkLocalDestinationAddress: + e.stats.ip.Forwarding.LinkLocalDestination.Increment() + case *ip.ErrTTLExceeded: + e.stats.ip.Forwarding.ExhaustedTTL.Increment() + case *ip.ErrNoRoute: + e.stats.ip.Forwarding.Unrouteable.Increment() + case *ip.ErrParameterProblem: + e.stats.ip.Forwarding.ExtensionHeaderProblem.Increment() + case *ip.ErrMessageTooLong: + e.stats.ip.Forwarding.PacketTooBig.Increment() + default: + panic(fmt.Sprintf("unexpected error %s while trying to forward packet: %#v", err, pkt)) + } + e.stats.ip.Forwarding.Errors.Increment() + return + } - _ = e.forwardPacket(pkt) + // iptables filtering. All packets that reach here are intended for + // this machine and need not be forwarded. + inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { + // iptables is telling us to drop the packet. + stats.IPTablesInputDropped.Increment() return } + // Any returned error is only useful for terminating execution early, but + // we have nothing left to do, so we can drop it. + _ = e.processExtensionHeaders(h, pkt, false /* forwarding */) +} + +// processExtensionHeaders processes the extension headers in the given packet. +// Returns an error if the processing of a header failed or if the packet should +// be discarded. +func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffer, forwarding bool) error { + stats := e.stats.ip + srcAddr := h.SourceAddress() + dstAddr := h.DestinationAddress() + // Create a VV to parse the packet. We don't plan to modify anything here. // vv consists of: // - Any IPv6 header bytes after the first 40 (i.e. extensions). @@ -1094,15 +1197,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) vv.AppendViews(pkt.Data().Views()) it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv) - // iptables filtering. All packets that reach here are intended for - // this machine and need not be forwarded. - inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { - // iptables is telling us to drop the packet. - stats.IPTablesInputDropped.Increment() - return - } - var ( hasFragmentHeader bool routerAlert *header.IPv6RouterAlertOption @@ -1115,22 +1209,41 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) extHdr, done, err := it.Next() if err != nil { stats.MalformedPacketsReceived.Increment() - return + return err } if done { break } + // As per RFC 8200, section 4: + // + // Extension headers (except for the Hop-by-Hop Options header) are + // not processed, inserted, or deleted by any node along a packet's + // delivery path until the packet reaches the node identified in the + // Destination Address field of the IPv6 header. + // + // Furthermore, as per RFC 8200 section 4.1, the Hop By Hop extension + // header is restricted to appear first in the list of extension headers. + // + // Therefore, we can immediately return once we hit any header other + // than the Hop-by-Hop header while forwarding a packet. + if forwarding { + if _, ok := extHdr.(header.IPv6HopByHopOptionsExtHdr); !ok { + return nil + } + } + switch extHdr := extHdr.(type) { case header.IPv6HopByHopOptionsExtHdr: // As per RFC 8200 section 4.1, the Hop By Hop extension header is // restricted to appear immediately after an IPv6 fixed header. if previousHeaderStart != 0 { _ = e.protocol.returnError(&icmpReasonParameterProblem{ - code: header.ICMPv6UnknownHeader, - pointer: previousHeaderStart, + code: header.ICMPv6UnknownHeader, + pointer: previousHeaderStart, + forwarding: forwarding, }, pkt) - return + return fmt.Errorf("found Hop-by-Hop header = %#v with non-zero previous header offset = %d", extHdr, previousHeaderStart) } optsIt := extHdr.Iter() @@ -1139,7 +1252,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) opt, done, err := optsIt.Next() if err != nil { stats.MalformedPacketsReceived.Increment() - return + return err } if done { break @@ -1154,7 +1267,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // There MUST only be one option of this type, regardless of // value, per Hop-by-Hop header. stats.MalformedPacketsReceived.Increment() - return + return fmt.Errorf("found multiple Router Alert options (%#v, %#v)", opt, routerAlert) } routerAlert = opt stats.OptionRouterAlertReceived.Increment() @@ -1162,10 +1275,10 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) switch opt.UnknownAction() { case header.IPv6OptionUnknownActionSkip: case header.IPv6OptionUnknownActionDiscard: - return + return fmt.Errorf("found unknown Hop-by-Hop header option = %#v with discard action", opt) case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: if header.IsV6MulticastAddress(dstAddr) { - return + return fmt.Errorf("found unknown hop-by-hop header option = %#v with discard action", opt) } fallthrough case header.IPv6OptionUnknownActionDiscardSendICMP: @@ -1180,10 +1293,11 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) code: header.ICMPv6UnknownOption, pointer: it.ParseOffset() + optsIt.OptionOffset(), respondToMulticast: true, + forwarding: forwarding, }, pkt) - return + return fmt.Errorf("found unknown hop-by-hop header option = %#v with discard action", opt) default: - panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt)) + panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %#v", opt)) } } } @@ -1205,8 +1319,13 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6ErroneousHeader, pointer: it.ParseOffset(), + // For the sake of consistency, we're using the value of `forwarding` + // here, even though it should always be false if we've reached this + // point. If `forwarding` is true here, we're executing undefined + // behavior no matter what. + forwarding: forwarding, }, pkt) - return + return fmt.Errorf("found unrecognized routing type with non-zero segments left in header = %#v", extHdr) } case header.IPv6FragmentExtHdr: @@ -1241,7 +1360,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) if err != nil { stats.MalformedPacketsReceived.Increment() stats.MalformedFragmentsReceived.Increment() - return + return err } if done { break @@ -1269,7 +1388,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) default: stats.MalformedPacketsReceived.Increment() stats.MalformedFragmentsReceived.Increment() - return + return fmt.Errorf("known extension header = %#v present after fragment header in a non-initial fragment", lastHdr) } } @@ -1278,7 +1397,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // Drop the packet as it's marked as a fragment but has no payload. stats.MalformedPacketsReceived.Increment() stats.MalformedFragmentsReceived.Increment() - return + return fmt.Errorf("fragment has no payload") } // As per RFC 2460 Section 4.5: @@ -1296,7 +1415,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) code: header.ICMPv6ErroneousHeader, pointer: header.IPv6PayloadLenOffset, }, pkt) - return + return fmt.Errorf("found fragment length = %d that is not a multiple of 8 octets", fragmentPayloadLen) } // The packet is a fragment, let's try to reassemble it. @@ -1310,14 +1429,15 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // Parameter Problem, Code 0, message should be sent to the source of // the fragment, pointing to the Fragment Offset field of the fragment // packet. - if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize { + lengthAfterReassembly := int(start) + fragmentPayloadLen + if lengthAfterReassembly > header.IPv6MaximumPayloadSize { stats.MalformedPacketsReceived.Increment() stats.MalformedFragmentsReceived.Increment() _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6ErroneousHeader, pointer: fragmentFieldOffset, }, pkt) - return + return fmt.Errorf("determined that reassembled packet length = %d would exceed allowed length = %d", lengthAfterReassembly, header.IPv6MaximumPayloadSize) } // Note that pkt doesn't have its transport header set after reassembly, @@ -1339,7 +1459,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) if err != nil { stats.MalformedPacketsReceived.Increment() stats.MalformedFragmentsReceived.Increment() - return + return err } if ready { @@ -1361,7 +1481,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) opt, done, err := optsIt.Next() if err != nil { stats.MalformedPacketsReceived.Increment() - return + return err } if done { break @@ -1372,10 +1492,10 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) switch opt.UnknownAction() { case header.IPv6OptionUnknownActionSkip: case header.IPv6OptionUnknownActionDiscard: - return + return fmt.Errorf("found unknown destination header option = %#v with discard action", opt) case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: if header.IsV6MulticastAddress(dstAddr) { - return + return fmt.Errorf("found unknown destination header option %#v with discard action", opt) } fallthrough case header.IPv6OptionUnknownActionDiscardSendICMP: @@ -1392,9 +1512,9 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) pointer: it.ParseOffset() + optsIt.OptionOffset(), respondToMulticast: true, }, pkt) - return + return fmt.Errorf("found unknown destination header option %#v with discard action", opt) default: - panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt)) + panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %#v", opt)) } } @@ -1402,13 +1522,19 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // If the last header in the payload isn't a known IPv6 extension header, // handle it as if it is transport layer data. + // Calculate the number of octets parsed from data. We want to remove all + // the data except the unparsed portion located at the end, which its size + // is extHdr.Buf.Size(). + trim := pkt.Data().Size() - extHdr.Buf.Size() + // For unfragmented packets, extHdr still contains the transport header. // Get rid of it. // // For reassembled fragments, pkt.TransportHeader is unset, so this is a // no-op and pkt.Data begins with the transport header. - extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size()) - pkt.Data().Replace(extHdr.Buf) + trim += pkt.TransportHeader().View().Size() + + pkt.Data().DeleteFront(trim) stats.PacketsDelivered.Increment() if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { @@ -1425,6 +1551,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // transport protocol (e.g., UDP) has no listener, if that transport // protocol has no alternative means to inform the sender. _ = e.protocol.returnError(&icmpReasonPortUnreachable{}, pkt) + return fmt.Errorf("destination port unreachable") case stack.TransportPacketProtocolUnreachable: // As per RFC 8200 section 4. (page 7): // Extension headers are numbered from IANA IP Protocol Numbers @@ -1456,6 +1583,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) code: header.ICMPv6UnknownHeader, pointer: prevHdrIDOffset, }, pkt) + return fmt.Errorf("transport protocol unreachable") default: panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) } @@ -1469,6 +1597,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) } } + return nil } // Close cleans up resources associated with the endpoint. @@ -1490,8 +1619,8 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { // TODO(b/169350103): add checks here after making sure we no longer receive // an empty address. - e.mu.Lock() - defer e.mu.Unlock() + e.mu.RLock() + defer e.mu.RUnlock() return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated) } @@ -1532,8 +1661,8 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre // RemovePermanentAddress implements stack.AddressableEndpoint. func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() + e.mu.RLock() + defer e.mu.RUnlock() addressEndpoint := e.getAddressRLocked(addr) if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() { @@ -1610,8 +1739,8 @@ func (e *endpoint) MainAddress() tcpip.AddressWithPrefix { // AcquireAssignedAddress implements stack.AddressableEndpoint. func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint { - e.mu.Lock() - defer e.mu.Unlock() + e.mu.RLock() + defer e.mu.RUnlock() return e.acquireAddressOrCreateTempLocked(localAddr, allowTemp, tempPEB) } @@ -1833,7 +1962,6 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats { return &e.stats.localStats } -var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) var _ fragmentation.TimeoutHandler = (*protocol)(nil) @@ -1858,12 +1986,6 @@ type protocol struct { // Must be accessed using atomic operations. defaultTTL uint32 - // forwarding is set to 1 when the protocol has forwarding enabled and 0 - // when it is disabled. - // - // Must be accessed using atomic operations. - forwarding uint32 - fragmentation *fragmentation.Fragmentation } @@ -2038,35 +2160,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu return proto, !fragMore && fragOffset == 0, true } -// Forwarding implements stack.ForwardingNetworkProtocol. -func (p *protocol) Forwarding() bool { - return uint8(atomic.LoadUint32(&p.forwarding)) == 1 -} - -// setForwarding sets the forwarding status for the protocol. -// -// Returns true if the forwarding status was updated. -func (p *protocol) setForwarding(v bool) bool { - if v { - return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */) - } - return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */) -} - -// SetForwarding implements stack.ForwardingNetworkProtocol. -func (p *protocol) SetForwarding(v bool) { - p.mu.Lock() - defer p.mu.Unlock() - - if !p.setForwarding(v) { - return - } - - for _, ep := range p.mu.eps { - ep.transitionForwarding(v) - } -} - // calculateNetworkMTU calculates the network-layer payload MTU based on the // link-layer payload MTU and the length of every IPv6 header. // Note that this is different than the Payload Length field of the IPv6 header, diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 40a793d6b..afc6c3547 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -31,8 +31,9 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" + iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -2603,7 +2604,7 @@ func TestWriteStats(t *testing.T) { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) + ep := iptestutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) rt := buildRoute(t, ep) var pkts stack.PacketBufferList for i := 0; i < nPackets; i++ { @@ -2802,9 +2803,9 @@ func TestFragmentationWritePacket(t *testing.T) { for _, ft := range fragmentationTests { t.Run(ft.description, func(t *testing.T) { - pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) source := pkt.Clone() - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, @@ -2858,7 +2859,7 @@ func TestFragmentationWritePackets(t *testing.T) { insertAfter: 1, }, } - tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber) + tinyPacket := iptestutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber) for _, test := range tests { t.Run(test.description, func(t *testing.T) { @@ -2868,14 +2869,14 @@ func TestFragmentationWritePackets(t *testing.T) { for i := 0; i < test.insertBefore; i++ { pkts.PushBack(tinyPacket.Clone()) } - pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) source := pkt pkts.PushBack(pkt.Clone()) for i := 0; i < test.insertAfter; i++ { pkts.PushBack(tinyPacket.Clone()) } - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter @@ -2980,8 +2981,8 @@ func TestFragmentationErrors(t *testing.T) { for _, ft := range tests { t.Run(ft.description, func(t *testing.T) { - pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) - ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) + pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) r := buildRoute(t, ep) err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, @@ -3003,52 +3004,289 @@ func TestFragmentationErrors(t *testing.T) { func TestForwarding(t *testing.T) { const ( - nicID1 = 1 - nicID2 = 2 + incomingNICID = 1 + outgoingNICID = 2 randomSequence = 123 randomIdent = 42 ) - ipv6Addr1 := tcpip.AddressWithPrefix{ + incomingIPv6Addr := tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("10::1").To16()), PrefixLen: 64, } - ipv6Addr2 := tcpip.AddressWithPrefix{ + outgoingIPv6Addr := tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("11::1").To16()), PrefixLen: 64, } + multicastIPv6Addr := tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("ff00::").To16()), + PrefixLen: 64, + } + remoteIPv6Addr1 := tcpip.Address(net.ParseIP("10::2").To16()) remoteIPv6Addr2 := tcpip.Address(net.ParseIP("11::2").To16()) + unreachableIPv6Addr := tcpip.Address(net.ParseIP("12::2").To16()) + linkLocalIPv6Addr := tcpip.Address(net.ParseIP("fe80::").To16()) tests := []struct { - name string - TTL uint8 - expectErrorICMP bool + name string + extHdr func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) + TTL uint8 + expectErrorICMP bool + expectPacketForwarded bool + payloadLength int + countUnrouteablePackets uint64 + sourceAddr tcpip.Address + destAddr tcpip.Address + icmpType header.ICMPv6Type + icmpCode header.ICMPv6Code + expectPacketUnrouteableError bool + expectLinkLocalSourceError bool + expectLinkLocalDestError bool + expectExtensionHeaderError bool }{ { name: "TTL of zero", TTL: 0, expectErrorICMP: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + icmpType: header.ICMPv6TimeExceeded, + icmpCode: header.ICMPv6HopLimitExceeded, }, { name: "TTL of one", TTL: 1, expectErrorICMP: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + icmpType: header.ICMPv6TimeExceeded, + icmpCode: header.ICMPv6HopLimitExceeded, }, { - name: "TTL of two", - TTL: 2, - expectErrorICMP: false, + name: "TTL of two", + TTL: 2, + expectPacketForwarded: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + }, + { + name: "TTL of three", + TTL: 3, + expectPacketForwarded: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + }, + { + name: "Max TTL", + TTL: math.MaxUint8, + expectPacketForwarded: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + }, + { + name: "Network unreachable", + TTL: 2, + expectErrorICMP: true, + sourceAddr: remoteIPv6Addr1, + destAddr: unreachableIPv6Addr, + icmpType: header.ICMPv6DstUnreachable, + icmpCode: header.ICMPv6NetworkUnreachable, + expectPacketUnrouteableError: true, + }, + { + name: "Multicast destination", + TTL: 2, + countUnrouteablePackets: 1, + sourceAddr: remoteIPv6Addr1, + destAddr: multicastIPv6Addr.Address, + expectPacketForwarded: true, + }, + { + name: "Link local destination", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: linkLocalIPv6Addr, + expectLinkLocalDestError: true, + }, + { + name: "Link local source", + TTL: 2, + sourceAddr: linkLocalIPv6Addr, + destAddr: remoteIPv6Addr2, + expectLinkLocalSourceError: true, + }, + { + name: "Hopbyhop with unknown option skippable action", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Skippable unknown. + 62, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID, checker.IPv6ExtHdr(checker.IPv6HopByHopExtensionHeader(checker.IPv6UnknownOption(), checker.IPv6UnknownOption())) + }, + expectPacketForwarded: true, + }, + { + name: "Hopbyhop with unknown option discard action", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard unknown. + 127, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID, nil + }, + expectExtensionHeaderError: true, + }, + { + name: "Hopbyhop with unknown option discard and send icmp action (unicast)", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP if option is unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID, nil + }, + expectErrorICMP: true, + icmpType: header.ICMPv6ParamProblem, + icmpCode: header.ICMPv6UnknownOption, + expectExtensionHeaderError: true, + }, + { + name: "Hopbyhop with unknown option discard and send icmp action (multicast)", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: multicastIPv6Addr.Address, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP if option is unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID, nil + }, + expectErrorICMP: true, + icmpType: header.ICMPv6ParamProblem, + icmpCode: header.ICMPv6UnknownOption, + expectExtensionHeaderError: true, + }, + { + name: "Hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP unless packet is for multicast destination if + // option is unknown. + 255, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID, nil + }, + expectErrorICMP: true, + icmpType: header.ICMPv6ParamProblem, + icmpCode: header.ICMPv6UnknownOption, + expectExtensionHeaderError: true, + }, + { + name: "Hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: multicastIPv6Addr.Address, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP unless packet is for multicast destination if + // option is unknown. + 255, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID, nil + }, + expectExtensionHeaderError: true, }, { - name: "TTL of three", - TTL: 3, - expectErrorICMP: false, + name: "Hopbyhop with router alert option", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 0, + + // Router Alert option. + 5, 2, 0, 0, 0, 0, + }, hopByHopExtHdrID, checker.IPv6ExtHdr(checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD))) + }, + expectPacketForwarded: true, + }, + { + name: "Hopbyhop with two router alert options", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Router Alert option. + 5, 2, 0, 0, 0, 0, + + // Router Alert option. + 5, 2, 0, 0, 0, 0, + }, hopByHopExtHdrID, nil + }, + expectExtensionHeaderError: true, + }, + { + name: "Can't fragment", + TTL: 2, + payloadLength: header.IPv6MinimumMTU + 1, + expectErrorICMP: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + icmpType: header.ICMPv6PacketTooBig, + icmpCode: header.ICMPv6UnusedCode, }, { - name: "Max TTL", - TTL: math.MaxUint8, - expectErrorICMP: false, + name: "Can't fragment multicast", + TTL: 2, + payloadLength: header.IPv6MinimumMTU + 1, + sourceAddr: remoteIPv6Addr1, + destAddr: multicastIPv6Addr.Address, + expectErrorICMP: true, + icmpType: header.ICMPv6PacketTooBig, + icmpCode: header.ICMPv6UnusedCode, }, } @@ -3059,41 +3297,60 @@ func TestForwarding(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, }) // We expect at most a single packet in response to our ICMP Echo Request. - e1 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID1, e1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + incomingEndpoint := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) } - ipv6ProtoAddr1 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr1} - if err := s.AddProtocolAddress(nicID1, ipv6ProtoAddr1); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv6ProtoAddr1, err) + incomingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: incomingIPv6Addr} + if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv6ProtoAddr, err) } - e2 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID2, e2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + outgoingEndpoint := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) } - ipv6ProtoAddr2 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr2} - if err := s.AddProtocolAddress(nicID2, ipv6ProtoAddr2); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv6ProtoAddr2, err) + outgoingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: outgoingIPv6Addr} + if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv6ProtoAddr, err) } s.SetRouteTable([]tcpip.Route{ { - Destination: ipv6Addr1.Subnet(), - NIC: nicID1, + Destination: incomingIPv6Addr.Subnet(), + NIC: incomingNICID, + }, + { + Destination: outgoingIPv6Addr.Subnet(), + NIC: outgoingNICID, }, { - Destination: ipv6Addr2.Subnet(), - NIC: nicID2, + Destination: multicastIPv6Addr.Subnet(), + NIC: outgoingNICID, }, }) - if err := s.SetForwarding(ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err) } - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6MinimumSize) - icmp := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + transportProtocol := header.ICMPv6ProtocolNumber + extHdrBytes := []byte{} + extHdrChecker := checker.IPv6ExtHdr() + if test.extHdr != nil { + nextHdrID := hopByHopExtHdrID + extHdrBytes, nextHdrID, extHdrChecker = test.extHdr(uint8(header.ICMPv6ProtocolNumber)) + transportProtocol = tcpip.TransportProtocolNumber(nextHdrID) + } + extHdrLen := len(extHdrBytes) + + ipHeaderLength := header.IPv6MinimumSize + icmpHeaderLength := header.ICMPv6MinimumSize + totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength + extHdrLen + hdr := buffer.NewPrependable(totalLength) + hdr.Prepend(test.payloadLength) + icmp := header.ICMPv6(hdr.Prepend(icmpHeaderLength)) + icmp.SetIdent(randomIdent) icmp.SetSequence(randomSequence) icmp.SetType(header.ICMPv6EchoRequest) @@ -3101,52 +3358,72 @@ func TestForwarding(t *testing.T) { icmp.SetChecksum(0) icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmp, - Src: remoteIPv6Addr1, - Dst: remoteIPv6Addr2, + Src: test.sourceAddr, + Dst: test.destAddr, })) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + copy(hdr.Prepend(extHdrLen), extHdrBytes) + ip := header.IPv6(hdr.Prepend(ipHeaderLength)) ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - TransportProtocol: header.ICMPv6ProtocolNumber, + PayloadLength: uint16(header.ICMPv6MinimumSize + test.payloadLength), + TransportProtocol: transportProtocol, HopLimit: test.TTL, - SrcAddr: remoteIPv6Addr1, - DstAddr: remoteIPv6Addr2, + SrcAddr: test.sourceAddr, + DstAddr: test.destAddr, }) requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), }) - e1.InjectInbound(ProtocolNumber, requestPkt) + incomingEndpoint.InjectInbound(ProtocolNumber, requestPkt) + + reply, ok := incomingEndpoint.Read() if test.expectErrorICMP { - reply, ok := e1.Read() if !ok { - t.Fatal("expected ICMP Hop Limit Exceeded packet through incoming NIC") + t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType) + } + + // As per RFC 4443, page 9: + // + // The returned ICMP packet will contain as much of invoking packet + // as possible without the ICMPv6 packet exceeding the minimum IPv6 + // MTU. + expectedICMPPayloadLength := func() int { + maxICMPPayloadLength := header.IPv6MinimumMTU - ipHeaderLength - icmpHeaderLength + if len(hdr.View()) > maxICMPPayloadLength { + return maxICMPPayloadLength + } + return len(hdr.View()) } checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), - checker.SrcAddr(ipv6Addr1.Address), - checker.DstAddr(remoteIPv6Addr1), + checker.SrcAddr(incomingIPv6Addr.Address), + checker.DstAddr(test.sourceAddr), checker.TTL(DefaultTTL), checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6TimeExceeded), - checker.ICMPv6Code(header.ICMPv6HopLimitExceeded), - checker.ICMPv6Payload([]byte(hdr.View())), + checker.ICMPv6Type(test.icmpType), + checker.ICMPv6Code(test.icmpCode), + checker.ICMPv6Payload([]byte(hdr.View()[0:expectedICMPPayloadLength()])), ), ) - if n := e2.Drain(); n != 0 { + if n := outgoingEndpoint.Drain(); n != 0 { t.Fatalf("got e2.Drain() = %d, want = 0", n) } - } else { - reply, ok := e2.Read() + } else if ok { + t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", reply) + } + + reply, ok = outgoingEndpoint.Read() + if test.expectPacketForwarded { if !ok { t.Fatal("expected ICMP Echo Request packet through outgoing NIC") } - checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), - checker.SrcAddr(remoteIPv6Addr1), - checker.DstAddr(remoteIPv6Addr2), + checker.IPv6WithExtHdr(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.SrcAddr(test.sourceAddr), + checker.DstAddr(test.destAddr), checker.TTL(test.TTL-1), + extHdrChecker, checker.ICMPv6( checker.ICMPv6Type(header.ICMPv6EchoRequest), checker.ICMPv6Code(header.ICMPv6UnusedCode), @@ -3154,9 +3431,46 @@ func TestForwarding(t *testing.T) { ), ) - if n := e1.Drain(); n != 0 { + if n := incomingEndpoint.Drain(); n != 0 { t.Fatalf("got e1.Drain() = %d, want = 0", n) } + } else if ok { + t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply) + } + + boolToInt := func(val bool) uint64 { + if val { + return 1 + } + return 0 + } + + if got, want := s.Stats().IP.Forwarding.LinkLocalSource.Value(), boolToInt(test.expectLinkLocalSourceError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.LinkLocalSource.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.LinkLocalDestination.Value(), boolToInt(test.expectLinkLocalDestError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.LinkLocalDestination.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.ExhaustedTTL.Value(), boolToInt(test.TTL <= 1); got != want { + t.Errorf("got rt.Stats().IP.Forwarding.ExhaustedTTL.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.Unrouteable.Value(), boolToInt(test.expectPacketUnrouteableError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.Unrouteable.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.Errors.Value(), boolToInt(!test.expectPacketForwarded); got != want { + t.Errorf("got s.Stats().IP.Forwarding.Errors.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.ExtensionHeaderProblem.Value(), boolToInt(test.expectExtensionHeaderError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.ExtensionHeaderProblem.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.PacketTooBig.Value(), boolToInt(test.icmpType == header.ICMPv6PacketTooBig); got != want { + t.Errorf("got s.Stats().IP.Forwarding.PacketTooBig.Value() = %d, want = %d", got, want) } }) } diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index d6e0a81a6..f0ff111c5 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -48,7 +48,7 @@ const ( // defaultHandleRAs is the default configuration for whether or not to // handle incoming Router Advertisements as a host. - defaultHandleRAs = true + defaultHandleRAs = HandlingRAsEnabledWhenForwardingDisabled // defaultDiscoverDefaultRouters is the default configuration for // whether or not to discover default routers from incoming Router @@ -301,10 +301,60 @@ type NDPDispatcher interface { OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) } +var _ fmt.Stringer = HandleRAsConfiguration(0) + +// HandleRAsConfiguration enumerates when RAs may be handled. +type HandleRAsConfiguration int + +const ( + // HandlingRAsDisabled indicates that Router Advertisements will not be + // handled. + HandlingRAsDisabled HandleRAsConfiguration = iota + + // HandlingRAsEnabledWhenForwardingDisabled indicates that router + // advertisements will only be handled when forwarding is disabled. + HandlingRAsEnabledWhenForwardingDisabled + + // HandlingRAsAlwaysEnabled indicates that Router Advertisements will always + // be handled, even when forwarding is enabled. + HandlingRAsAlwaysEnabled +) + +// String implements fmt.Stringer. +func (c HandleRAsConfiguration) String() string { + switch c { + case HandlingRAsDisabled: + return "HandlingRAsDisabled" + case HandlingRAsEnabledWhenForwardingDisabled: + return "HandlingRAsEnabledWhenForwardingDisabled" + case HandlingRAsAlwaysEnabled: + return "HandlingRAsAlwaysEnabled" + default: + return fmt.Sprintf("HandleRAsConfiguration(%d)", c) + } +} + +// enabled returns true iff Router Advertisements may be handled given the +// specified forwarding status. +func (c HandleRAsConfiguration) enabled(forwarding bool) bool { + switch c { + case HandlingRAsDisabled: + return false + case HandlingRAsEnabledWhenForwardingDisabled: + return !forwarding + case HandlingRAsAlwaysEnabled: + return true + default: + panic(fmt.Sprintf("unhandled HandleRAsConfiguration = %d", c)) + } +} + // NDPConfigurations is the NDP configurations for the netstack. type NDPConfigurations struct { // The number of Router Solicitation messages to send when the IPv6 endpoint // becomes enabled. + // + // Ignored unless configured to handle Router Advertisements. MaxRtrSolicitations uint8 // The amount of time between transmitting Router Solicitation messages. @@ -318,8 +368,9 @@ type NDPConfigurations struct { // Must be greater than or equal to 0s. MaxRtrSolicitationDelay time.Duration - // HandleRAs determines whether or not Router Advertisements are processed. - HandleRAs bool + // HandleRAs is the configuration for when Router Advertisements should be + // handled. + HandleRAs HandleRAsConfiguration // DiscoverDefaultRouters determines whether or not default routers are // discovered from Router Advertisements, as per RFC 4861 section 6. This @@ -654,7 +705,8 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // per-interface basis; it is a protocol-wide configuration, so we check the // protocol's forwarding flag to determine if the IPv6 endpoint is forwarding // packets. - if !ndp.configs.HandleRAs || ndp.ep.protocol.Forwarding() { + if !ndp.configs.HandleRAs.enabled(ndp.ep.Forwarding()) { + ndp.ep.stats.localStats.UnhandledRouterAdvertisements.Increment() return } @@ -1609,44 +1661,16 @@ func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotifyInner(tempAddrs map[t delete(tempAddrs, tempAddr) } -// removeSLAACAddresses removes all SLAAC addresses. -// -// If keepLinkLocal is false, the SLAAC generated link-local address is removed. -// -// The IPv6 endpoint that ndp belongs to MUST be locked. -func (ndp *ndpState) removeSLAACAddresses(keepLinkLocal bool) { - linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet() - var linkLocalPrefixes int - for prefix, state := range ndp.slaacPrefixes { - // RFC 4862 section 5 states that routers are also expected to generate a - // link-local address so we do not invalidate them if we are cleaning up - // host-only state. - if keepLinkLocal && prefix == linkLocalSubnet { - linkLocalPrefixes++ - continue - } - - ndp.invalidateSLAACPrefix(prefix, state) - } - - if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes { - panic(fmt.Sprintf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes)) - } -} - // cleanupState cleans up ndp's state. // -// If hostOnly is true, then only host-specific state is cleaned up. -// // This function invalidates all discovered on-link prefixes, discovered // routers, and auto-generated addresses. // -// If hostOnly is true, then the link-local auto-generated address aren't -// invalidated as routers are also expected to generate a link-local address. -// // The IPv6 endpoint that ndp belongs to MUST be locked. -func (ndp *ndpState) cleanupState(hostOnly bool) { - ndp.removeSLAACAddresses(hostOnly /* keepLinkLocal */) +func (ndp *ndpState) cleanupState() { + for prefix, state := range ndp.slaacPrefixes { + ndp.invalidateSLAACPrefix(prefix, state) + } for prefix := range ndp.onLinkPrefixes { ndp.invalidateOnLinkPrefix(prefix) @@ -1670,6 +1694,10 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { // startSolicitingRouters starts soliciting routers, as per RFC 4861 section // 6.3.7. If routers are already being solicited, this function does nothing. // +// If ndp is not configured to handle Router Advertisements, routers will not +// be solicited as there is no point soliciting routers if we don't handle their +// advertisements. +// // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) startSolicitingRouters() { if ndp.rtrSolicitTimer.timer != nil { @@ -1682,6 +1710,10 @@ func (ndp *ndpState) startSolicitingRouters() { return } + if !ndp.configs.HandleRAs.enabled(ndp.ep.Forwarding()) { + return + } + // Calculate the random delay before sending our first RS, as per RFC // 4861 section 6.3.7. var delay time.Duration @@ -1774,6 +1806,32 @@ func (ndp *ndpState) startSolicitingRouters() { } } +// forwardingChanged handles a change in forwarding configuration. +// +// If transitioning to a host, router solicitation will be started. Otherwise, +// router solicitation will be stopped if NDP is not configured to handle RAs +// as a router. +// +// Precondition: ndp.ep.mu must be locked. +func (ndp *ndpState) forwardingChanged(forwarding bool) { + if forwarding { + if ndp.configs.HandleRAs.enabled(forwarding) { + return + } + + ndp.stopSolicitingRouters() + return + } + + // Solicit routers when transitioning to a host. + // + // If the endpoint is not currently enabled, routers will be solicited when + // the endpoint becomes enabled (if it is still a host). + if ndp.ep.Enabled() { + ndp.startSolicitingRouters() + } +} + // stopSolicitingRouters stops soliciting routers. If routers are not currently // being solicited, this function does nothing. // diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 52b9a200c..234e34952 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -732,15 +732,7 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { } func TestNDPValidation(t *testing.T) { - setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) { - t.Helper() - - // Create a stack with the assigned link-local address lladdr0 - // and an endpoint to lladdr1. - s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1) - - return s, ep - } + const nicID = 1 handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) { var extHdrs header.IPv6ExtHdrSerializer @@ -865,6 +857,11 @@ func TestNDPValidation(t *testing.T) { }, } + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0)))) + if err != nil { + t.Fatal(err) + } + for _, typ := range types { for _, isRouter := range []bool{false, true} { name := typ.name @@ -875,13 +872,35 @@ func TestNDPValidation(t *testing.T) { t.Run(name, func(t *testing.T) { for _, test := range subTests { t.Run(test.name, func(t *testing.T) { - s, ep := setup(t) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + }) if isRouter { - // Enabling forwarding makes the stack act as a router. - s.SetForwarding(ProtocolNumber, true) + if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err) + } } + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, lladdr0, err) + } + + ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber) + if err != nil { + t.Fatal("cannot find network endpoint instance for IPv6") + } + + s.SetRouteTable([]tcpip.Route{{ + Destination: subnet, + NIC: nicID, + }}) + stats := s.Stats().ICMP.V6.PacketsReceived invalid := stats.Invalid routerOnly := stats.RouterOnlyPacketsDroppedByHost @@ -906,12 +925,12 @@ func TestNDPValidation(t *testing.T) { // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) + t.Errorf("got invalid.Value() = %d, want = 0", got) } - // RouterOnlyPacketsReceivedByHost count should initially be 0. + // Should initially not have dropped any packets. if got := routerOnly.Value(); got != 0 { - t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) + t.Errorf("got routerOnly.Value() = %d, want = 0", got) } if t.Failed() { @@ -931,18 +950,18 @@ func TestNDPValidation(t *testing.T) { want = 1 } if got := invalid.Value(); got != want { - t.Errorf("got invalid = %d, want = %d", got, want) + t.Errorf("got invalid.Value() = %d, want = %d", got, want) } want = 0 if test.valid && !isRouter && typ.routerOnly { - // RouterOnlyPacketsReceivedByHost count should have increased. + // Router only packets are expected to be dropped when operating + // as a host. want = 1 } if got := routerOnly.Value(); got != want { - t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want) + t.Errorf("got routerOnly.Value() = %d, want = %d", got, want) } - }) } }) diff --git a/pkg/tcpip/network/ipv6/stats.go b/pkg/tcpip/network/ipv6/stats.go index c2758352f..2f18f60e8 100644 --- a/pkg/tcpip/network/ipv6/stats.go +++ b/pkg/tcpip/network/ipv6/stats.go @@ -29,6 +29,10 @@ type Stats struct { // ICMP holds ICMPv6 statistics. ICMP tcpip.ICMPv6Stats + + // UnhandledRouterAdvertisements is the number of Router Advertisements that + // were observed but not handled. + UnhandledRouterAdvertisements *tcpip.StatCounter } // IsNetworkEndpointStats implements stack.NetworkEndpointStats. |