summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/checker/checker.go11
-rw-r--r--pkg/tcpip/network/arp/BUILD2
-rw-r--r--pkg/tcpip/network/arp/stats_test.go2
-rw-r--r--pkg/tcpip/network/internal/ip/errors.go10
-rw-r--r--pkg/tcpip/network/internal/ip/stats.go30
-rw-r--r--pkg/tcpip/network/internal/testutil/BUILD5
-rw-r--r--pkg/tcpip/network/internal/testutil/testutil.go68
-rw-r--r--pkg/tcpip/network/ipv4/BUILD2
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go16
-rw-r--r--pkg/tcpip/network/ipv4/stats_test.go2
-rw-r--r--pkg/tcpip/network/ipv6/BUILD1
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go8
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go125
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go195
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go52
-rw-r--r--pkg/tcpip/tcpip.go12
-rw-r--r--pkg/tcpip/testutil/BUILD5
-rw-r--r--pkg/tcpip/testutil/testutil.go68
-rw-r--r--pkg/tcpip/testutil/testutil_unsafe.go (renamed from pkg/tcpip/network/internal/testutil/testutil_unsafe.go)0
19 files changed, 445 insertions, 169 deletions
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 12c39dfa3..18e6cc3cd 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -1607,6 +1607,17 @@ func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker {
}
}
+// IPv6UnknownOption validates that an extension header option is the
+// unknown header option.
+func IPv6UnknownOption() IPv6ExtHdrOptionChecker {
+ return func(t *testing.T, opt header.IPv6ExtHdrOption) {
+ _, ok := opt.(*header.IPv6UnknownExtHdrOption)
+ if !ok {
+ t.Errorf("got = %T, want = header.IPv6UnknownExtHdrOption", opt)
+ }
+ }
+}
+
// IgnoreCmpPath returns a cmp.Option that ignores listed field paths.
func IgnoreCmpPath(paths ...string) cmp.Option {
ignores := map[string]struct{}{}
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index 6905b9ccb..a72eb1aad 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -47,7 +47,7 @@ go_test(
library = ":arp",
deps = [
"//pkg/tcpip",
- "//pkg/tcpip/network/internal/testutil",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
],
)
diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go
index e867b3c3f..0df39ae81 100644
--- a/pkg/tcpip/network/arp/stats_test.go
+++ b/pkg/tcpip/network/arp/stats_test.go
@@ -19,8 +19,8 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
var _ stack.NetworkInterface = (*testInterface)(nil)
diff --git a/pkg/tcpip/network/internal/ip/errors.go b/pkg/tcpip/network/internal/ip/errors.go
index 50fabfd79..d3577b377 100644
--- a/pkg/tcpip/network/internal/ip/errors.go
+++ b/pkg/tcpip/network/internal/ip/errors.go
@@ -34,13 +34,13 @@ func (*ErrTTLExceeded) isForwardingError() {}
func (*ErrTTLExceeded) String() string { return "ttl exceeded" }
-// ErrIPOptProblem indicates the received packet had a problem with an IP
-// option.
-type ErrIPOptProblem struct{}
+// ErrParameterProblem indicates the received packet had a problem with an IP
+// parameter.
+type ErrParameterProblem struct{}
-func (*ErrIPOptProblem) isForwardingError() {}
+func (*ErrParameterProblem) isForwardingError() {}
-func (*ErrIPOptProblem) String() string { return "ip option problem" }
+func (*ErrParameterProblem) String() string { return "parameter problem" }
// ErrLinkLocalSourceAddress indicates the received packet had a link-local
// source address.
diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go
index 392f0b0c7..68b8b550e 100644
--- a/pkg/tcpip/network/internal/ip/stats.go
+++ b/pkg/tcpip/network/internal/ip/stats.go
@@ -16,7 +16,7 @@ package ip
import "gvisor.dev/gvisor/pkg/tcpip"
-// LINT.IfChange(MultiCounterIPStats)
+// LINT.IfChange(MultiCounterIPForwardingStats)
// MultiCounterIPForwardingStats holds IP forwarding statistics. Each counter
// may have several versions.
@@ -38,11 +38,30 @@ type MultiCounterIPForwardingStats struct {
// because they contained a link-local destination address.
LinkLocalDestination tcpip.MultiCounterStat
+ // ExtensionHeaderProblem is the number of IP packets which were dropped
+ // because of a problem encountered when processing an IPv6 extension
+ // header.
+ ExtensionHeaderProblem tcpip.MultiCounterStat
+
// Errors is the number of IP packets received which could not be
// successfully forwarded.
Errors tcpip.MultiCounterStat
}
+// Init sets internal counters to track a and b counters.
+func (m *MultiCounterIPForwardingStats) Init(a, b *tcpip.IPForwardingStats) {
+ m.Unrouteable.Init(a.Unrouteable, b.Unrouteable)
+ m.Errors.Init(a.Errors, b.Errors)
+ m.LinkLocalSource.Init(a.LinkLocalSource, b.LinkLocalSource)
+ m.LinkLocalDestination.Init(a.LinkLocalDestination, b.LinkLocalDestination)
+ m.ExtensionHeaderProblem.Init(a.ExtensionHeaderProblem, b.ExtensionHeaderProblem)
+ m.ExhaustedTTL.Init(a.ExhaustedTTL, b.ExhaustedTTL)
+}
+
+// LINT.ThenChange(:MultiCounterIPForwardingStats, ../../../tcpip.go:IPForwardingStats)
+
+// LINT.IfChange(MultiCounterIPStats)
+
// MultiCounterIPStats holds IP statistics, each counter may have several
// versions.
type MultiCounterIPStats struct {
@@ -120,15 +139,6 @@ type MultiCounterIPStats struct {
}
// Init sets internal counters to track a and b counters.
-func (m *MultiCounterIPForwardingStats) Init(a, b *tcpip.IPForwardingStats) {
- m.Unrouteable.Init(a.Unrouteable, b.Unrouteable)
- m.Errors.Init(a.Errors, b.Errors)
- m.LinkLocalSource.Init(a.LinkLocalSource, b.LinkLocalSource)
- m.LinkLocalDestination.Init(a.LinkLocalDestination, b.LinkLocalDestination)
- m.ExhaustedTTL.Init(a.ExhaustedTTL, b.ExhaustedTTL)
-}
-
-// Init sets internal counters to track a and b counters.
func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) {
m.PacketsReceived.Init(a.PacketsReceived, b.PacketsReceived)
m.DisabledPacketsReceived.Init(a.DisabledPacketsReceived, b.DisabledPacketsReceived)
diff --git a/pkg/tcpip/network/internal/testutil/BUILD b/pkg/tcpip/network/internal/testutil/BUILD
index 1c4f583c7..cec3e62c4 100644
--- a/pkg/tcpip/network/internal/testutil/BUILD
+++ b/pkg/tcpip/network/internal/testutil/BUILD
@@ -4,10 +4,7 @@ package(licenses = ["notice"])
go_library(
name = "testutil",
- srcs = [
- "testutil.go",
- "testutil_unsafe.go",
- ],
+ srcs = ["testutil.go"],
visibility = [
"//pkg/tcpip/network/arp:__pkg__",
"//pkg/tcpip/network/internal/fragmentation:__pkg__",
diff --git a/pkg/tcpip/network/internal/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go
index e2cf24b67..605e9ef8d 100644
--- a/pkg/tcpip/network/internal/testutil/testutil.go
+++ b/pkg/tcpip/network/internal/testutil/testutil.go
@@ -19,8 +19,6 @@ package testutil
import (
"fmt"
"math/rand"
- "reflect"
- "strings"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -129,69 +127,3 @@ func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSi
}
return pkt
}
-
-func checkFieldCounts(ref, multi reflect.Value) error {
- refTypeName := ref.Type().Name()
- multiTypeName := multi.Type().Name()
- refNumField := ref.NumField()
- multiNumField := multi.NumField()
-
- if refNumField != multiNumField {
- return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName)
- }
-
- return nil
-}
-
-func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error {
- s, ok := ref.Addr().Interface().(**tcpip.StatCounter)
- if !ok {
- return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name())
- }
-
- // The field names are expected to match (case insensitive).
- if !strings.EqualFold(refName, multiName) {
- return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName)
- }
-
- base := (*s).Value()
- m.Increment()
- if (*s).Value() != base+1 {
- return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName)
- }
-
- return nil
-}
-
-// ValidateMultiCounterStats verifies that every counter stored in multi is
-// correctly tracking its counterpart in the given counters.
-func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error {
- for _, c := range counters {
- if err := checkFieldCounts(c, multi); err != nil {
- return err
- }
- }
-
- for i := 0; i < multi.NumField(); i++ {
- multiName := multi.Type().Field(i).Name
- multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i))
-
- if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok {
- for _, c := range counters {
- if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil {
- return err
- }
- }
- } else {
- var countersNextField []reflect.Value
- for _, c := range counters {
- countersNextField = append(countersNextField, c.Field(i))
- }
- if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil {
- return err
- }
- }
- }
-
- return nil
-}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 7ee0495d9..c90974693 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -62,7 +62,7 @@ go_test(
library = ":ipv4",
deps = [
"//pkg/tcpip",
- "//pkg/tcpip/network/internal/testutil",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
],
)
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index b11e56c6a..4031032d0 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -645,7 +645,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
forwarding: true,
}, pkt)
}
- return &ip.ErrIPOptProblem{}
+ return &ip.ErrParameterProblem{}
}
copied := copy(opts, newOpts)
if copied != len(newOpts) {
@@ -827,7 +827,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer)
stats.ip.Forwarding.ExhaustedTTL.Increment()
case *ip.ErrNoRoute:
stats.ip.Forwarding.Unrouteable.Increment()
- case *ip.ErrIPOptProblem:
+ case *ip.ErrParameterProblem:
e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
stats.ip.MalformedPacketsReceived.Increment()
default:
@@ -990,8 +990,8 @@ func (e *endpoint) Close() {
// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.mu.RLock()
+ defer e.mu.RUnlock()
ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
if err == nil {
@@ -1002,8 +1002,8 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p
// RemovePermanentAddress implements stack.AddressableEndpoint.
func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.mu.RLock()
+ defer e.mu.RUnlock()
return e.mu.addressableEndpointState.RemovePermanentAddress(addr)
}
@@ -1016,8 +1016,8 @@ func (e *endpoint) MainAddress() tcpip.AddressWithPrefix {
// AcquireAssignedAddress implements stack.AddressableEndpoint.
func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint {
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.mu.RLock()
+ defer e.mu.RUnlock()
loopback := e.nic.IsLoopback()
return e.mu.addressableEndpointState.AcquireAssignedAddressOrMatching(localAddr, func(addressEndpoint stack.AddressEndpoint) bool {
diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go
index a637f9d50..d1f9e3cf5 100644
--- a/pkg/tcpip/network/ipv4/stats_test.go
+++ b/pkg/tcpip/network/ipv4/stats_test.go
@@ -19,8 +19,8 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
var _ stack.NetworkInterface = (*testInterface)(nil)
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index db998e83e..f99cbf8f3 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -45,6 +45,7 @@ go_test(
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/internal/testutil",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index ebb0b73df..247a07dc2 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -984,11 +984,15 @@ type icmpReasonParameterProblem struct {
// packet if the field in error is beyond what can fit
// in the maximum size of an ICMPv6 error message.
pointer uint32
+
+ // forwarding indicates that the problem arose while we were trying to forward
+ // a packet.
+ forwarding bool
}
func (*icmpReasonParameterProblem) isICMPReason() {}
-func (*icmpReasonParameterProblem) isForwarding() bool {
- return false
+func (p *icmpReasonParameterProblem) isForwarding() bool {
+ return p.forwarding
}
// icmpReasonPortUnreachable is an error where the transport protocol has no
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 659057fa7..029d5f51b 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -941,6 +941,11 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
return nil
}
+ // Check extension headers for any errors requiring action during forwarding.
+ if err := e.processExtensionHeaders(h, pkt, true /* forwarding */); err != nil {
+ return &ip.ErrParameterProblem{}
+ }
+
r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
switch err.(type) {
case nil:
@@ -1084,6 +1089,8 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
e.stats.ip.Forwarding.ExhaustedTTL.Increment()
case *ip.ErrNoRoute:
e.stats.ip.Forwarding.Unrouteable.Increment()
+ case *ip.ErrParameterProblem:
+ e.stats.ip.Forwarding.ExtensionHeaderProblem.Increment()
default:
panic(fmt.Sprintf("unexpected error %s while trying to forward packet: %#v", err, pkt))
}
@@ -1091,6 +1098,28 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
return
}
+ // iptables filtering. All packets that reach here are intended for
+ // this machine and need not be forwarded.
+ inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok {
+ // iptables is telling us to drop the packet.
+ stats.IPTablesInputDropped.Increment()
+ return
+ }
+
+ // Any returned error is only useful for terminating execution early, but
+ // we have nothing left to do, so we can drop it.
+ _ = e.processExtensionHeaders(h, pkt, false /* forwarding */)
+}
+
+// processExtensionHeaders processes the extension headers in the given packet.
+// Returns an error if the processing of a header failed or if the packet should
+// be discarded.
+func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffer, forwarding bool) error {
+ stats := e.stats.ip
+ srcAddr := h.SourceAddress()
+ dstAddr := h.DestinationAddress()
+
// Create a VV to parse the packet. We don't plan to modify anything here.
// vv consists of:
// - Any IPv6 header bytes after the first 40 (i.e. extensions).
@@ -1101,15 +1130,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
vv.AppendViews(pkt.Data().Views())
it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv)
- // iptables filtering. All packets that reach here are intended for
- // this machine and need not be forwarded.
- inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok {
- // iptables is telling us to drop the packet.
- stats.IPTablesInputDropped.Increment()
- return
- }
-
var (
hasFragmentHeader bool
routerAlert *header.IPv6RouterAlertOption
@@ -1122,22 +1142,41 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
extHdr, done, err := it.Next()
if err != nil {
stats.MalformedPacketsReceived.Increment()
- return
+ return err
}
if done {
break
}
+ // As per RFC 8200, section 4:
+ //
+ // Extension headers (except for the Hop-by-Hop Options header) are
+ // not processed, inserted, or deleted by any node along a packet's
+ // delivery path until the packet reaches the node identified in the
+ // Destination Address field of the IPv6 header.
+ //
+ // Furthermore, as per RFC 8200 section 4.1, the Hop By Hop extension
+ // header is restricted to appear first in the list of extension headers.
+ //
+ // Therefore, we can immediately return once we hit any header other
+ // than the Hop-by-Hop header while forwarding a packet.
+ if forwarding {
+ if _, ok := extHdr.(header.IPv6HopByHopOptionsExtHdr); !ok {
+ return nil
+ }
+ }
+
switch extHdr := extHdr.(type) {
case header.IPv6HopByHopOptionsExtHdr:
// As per RFC 8200 section 4.1, the Hop By Hop extension header is
// restricted to appear immediately after an IPv6 fixed header.
if previousHeaderStart != 0 {
_ = e.protocol.returnError(&icmpReasonParameterProblem{
- code: header.ICMPv6UnknownHeader,
- pointer: previousHeaderStart,
+ code: header.ICMPv6UnknownHeader,
+ pointer: previousHeaderStart,
+ forwarding: forwarding,
}, pkt)
- return
+ return fmt.Errorf("found Hop-by-Hop header = %#v with non-zero previous header offset = %d", extHdr, previousHeaderStart)
}
optsIt := extHdr.Iter()
@@ -1146,7 +1185,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
opt, done, err := optsIt.Next()
if err != nil {
stats.MalformedPacketsReceived.Increment()
- return
+ return err
}
if done {
break
@@ -1161,7 +1200,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
// There MUST only be one option of this type, regardless of
// value, per Hop-by-Hop header.
stats.MalformedPacketsReceived.Increment()
- return
+ return fmt.Errorf("found multiple Router Alert options (%#v, %#v)", opt, routerAlert)
}
routerAlert = opt
stats.OptionRouterAlertReceived.Increment()
@@ -1169,10 +1208,10 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
switch opt.UnknownAction() {
case header.IPv6OptionUnknownActionSkip:
case header.IPv6OptionUnknownActionDiscard:
- return
+ return fmt.Errorf("found unknown Hop-by-Hop header option = %#v with discard action", opt)
case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
if header.IsV6MulticastAddress(dstAddr) {
- return
+ return fmt.Errorf("found unknown hop-by-hop header option = %#v with discard action", opt)
}
fallthrough
case header.IPv6OptionUnknownActionDiscardSendICMP:
@@ -1187,10 +1226,11 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
code: header.ICMPv6UnknownOption,
pointer: it.ParseOffset() + optsIt.OptionOffset(),
respondToMulticast: true,
+ forwarding: forwarding,
}, pkt)
- return
+ return fmt.Errorf("found unknown hop-by-hop header option = %#v with discard action", opt)
default:
- panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt))
+ panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %#v", opt))
}
}
}
@@ -1212,8 +1252,13 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
_ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6ErroneousHeader,
pointer: it.ParseOffset(),
+ // For the sake of consistency, we're using the value of `forwarding`
+ // here, even though it should always be false if we've reached this
+ // point. If `forwarding` is true here, we're executing undefined
+ // behavior no matter what.
+ forwarding: forwarding,
}, pkt)
- return
+ return fmt.Errorf("found unrecognized routing type with non-zero segments left in header = %#v", extHdr)
}
case header.IPv6FragmentExtHdr:
@@ -1248,7 +1293,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
if err != nil {
stats.MalformedPacketsReceived.Increment()
stats.MalformedFragmentsReceived.Increment()
- return
+ return err
}
if done {
break
@@ -1276,7 +1321,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
default:
stats.MalformedPacketsReceived.Increment()
stats.MalformedFragmentsReceived.Increment()
- return
+ return fmt.Errorf("known extension header = %#v present after fragment header in a non-initial fragment", lastHdr)
}
}
@@ -1285,7 +1330,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
// Drop the packet as it's marked as a fragment but has no payload.
stats.MalformedPacketsReceived.Increment()
stats.MalformedFragmentsReceived.Increment()
- return
+ return fmt.Errorf("fragment has no payload")
}
// As per RFC 2460 Section 4.5:
@@ -1303,7 +1348,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
code: header.ICMPv6ErroneousHeader,
pointer: header.IPv6PayloadLenOffset,
}, pkt)
- return
+ return fmt.Errorf("found fragment length = %d that is not a multiple of 8 octets", fragmentPayloadLen)
}
// The packet is a fragment, let's try to reassemble it.
@@ -1317,14 +1362,15 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
// Parameter Problem, Code 0, message should be sent to the source of
// the fragment, pointing to the Fragment Offset field of the fragment
// packet.
- if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize {
+ lengthAfterReassembly := int(start) + fragmentPayloadLen
+ if lengthAfterReassembly > header.IPv6MaximumPayloadSize {
stats.MalformedPacketsReceived.Increment()
stats.MalformedFragmentsReceived.Increment()
_ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6ErroneousHeader,
pointer: fragmentFieldOffset,
}, pkt)
- return
+ return fmt.Errorf("determined that reassembled packet length = %d would exceed allowed length = %d", lengthAfterReassembly, header.IPv6MaximumPayloadSize)
}
// Note that pkt doesn't have its transport header set after reassembly,
@@ -1346,7 +1392,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
if err != nil {
stats.MalformedPacketsReceived.Increment()
stats.MalformedFragmentsReceived.Increment()
- return
+ return err
}
if ready {
@@ -1368,7 +1414,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
opt, done, err := optsIt.Next()
if err != nil {
stats.MalformedPacketsReceived.Increment()
- return
+ return err
}
if done {
break
@@ -1379,10 +1425,10 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
switch opt.UnknownAction() {
case header.IPv6OptionUnknownActionSkip:
case header.IPv6OptionUnknownActionDiscard:
- return
+ return fmt.Errorf("found unknown destination header option = %#v with discard action", opt)
case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
if header.IsV6MulticastAddress(dstAddr) {
- return
+ return fmt.Errorf("found unknown destination header option %#v with discard action", opt)
}
fallthrough
case header.IPv6OptionUnknownActionDiscardSendICMP:
@@ -1399,9 +1445,9 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
pointer: it.ParseOffset() + optsIt.OptionOffset(),
respondToMulticast: true,
}, pkt)
- return
+ return fmt.Errorf("found unknown destination header option %#v with discard action", opt)
default:
- panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt))
+ panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %#v", opt))
}
}
@@ -1432,6 +1478,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
// transport protocol (e.g., UDP) has no listener, if that transport
// protocol has no alternative means to inform the sender.
_ = e.protocol.returnError(&icmpReasonPortUnreachable{}, pkt)
+ return fmt.Errorf("destination port unreachable")
case stack.TransportPacketProtocolUnreachable:
// As per RFC 8200 section 4. (page 7):
// Extension headers are numbered from IANA IP Protocol Numbers
@@ -1463,6 +1510,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
code: header.ICMPv6UnknownHeader,
pointer: prevHdrIDOffset,
}, pkt)
+ return fmt.Errorf("transport protocol unreachable")
default:
panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res))
}
@@ -1476,6 +1524,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
}
}
+ return nil
}
// Close cleans up resources associated with the endpoint.
@@ -1497,8 +1546,8 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
// TODO(b/169350103): add checks here after making sure we no longer receive
// an empty address.
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.mu.RLock()
+ defer e.mu.RUnlock()
return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated)
}
@@ -1539,8 +1588,8 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre
// RemovePermanentAddress implements stack.AddressableEndpoint.
func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.mu.RLock()
+ defer e.mu.RUnlock()
addressEndpoint := e.getAddressRLocked(addr)
if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() {
@@ -1617,8 +1666,8 @@ func (e *endpoint) MainAddress() tcpip.AddressWithPrefix {
// AcquireAssignedAddress implements stack.AddressableEndpoint.
func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint {
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.mu.RLock()
+ defer e.mu.RUnlock()
return e.acquireAddressOrCreateTempLocked(localAddr, allowTemp, tempPEB)
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 4fbe39528..8ebca735b 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -31,8 +31,9 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
+ iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -2603,7 +2604,7 @@ func TestWriteStats(t *testing.T) {
t.Run(writer.name, func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets)
+ ep := iptestutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets)
rt := buildRoute(t, ep)
var pkts stack.PacketBufferList
for i := 0; i < nPackets; i++ {
@@ -2802,9 +2803,9 @@ func TestFragmentationWritePacket(t *testing.T) {
for _, ft := range fragmentationTests {
t.Run(ft.description, func(t *testing.T) {
- pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
source := pkt.Clone()
- ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
r := buildRoute(t, ep)
err := r.WritePacket(stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
@@ -2858,7 +2859,7 @@ func TestFragmentationWritePackets(t *testing.T) {
insertAfter: 1,
},
}
- tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber)
+ tinyPacket := iptestutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber)
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
@@ -2868,14 +2869,14 @@ func TestFragmentationWritePackets(t *testing.T) {
for i := 0; i < test.insertBefore; i++ {
pkts.PushBack(tinyPacket.Clone())
}
- pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
source := pkt
pkts.PushBack(pkt.Clone())
for i := 0; i < test.insertAfter; i++ {
pkts.PushBack(tinyPacket.Clone())
}
- ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
r := buildRoute(t, ep)
wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter
@@ -2980,8 +2981,8 @@ func TestFragmentationErrors(t *testing.T) {
for _, ft := range tests {
t.Run(ft.description, func(t *testing.T) {
- pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
- ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
+ pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ ep := iptestutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
r := buildRoute(t, ep)
err := r.WritePacket(stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
@@ -3025,6 +3026,7 @@ func TestForwarding(t *testing.T) {
tests := []struct {
name string
+ extHdr func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker)
TTL uint8
expectErrorICMP bool
expectPacketForwarded bool
@@ -3036,6 +3038,7 @@ func TestForwarding(t *testing.T) {
expectPacketUnrouteableError bool
expectLinkLocalSourceError bool
expectLinkLocalDestError bool
+ expectExtensionHeaderError bool
}{
{
name: "TTL of zero",
@@ -3108,6 +3111,158 @@ func TestForwarding(t *testing.T) {
destAddr: remoteIPv6Addr2,
expectLinkLocalSourceError: true,
},
+ {
+ name: "Hopbyhop with unknown option skippable action",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Skippable unknown.
+ 62, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID, checker.IPv6ExtHdr(checker.IPv6HopByHopExtensionHeader(checker.IPv6UnknownOption(), checker.IPv6UnknownOption()))
+ },
+ expectPacketForwarded: true,
+ },
+ {
+ name: "Hopbyhop with unknown option discard action",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard unknown.
+ 127, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID, nil
+ },
+ expectExtensionHeaderError: true,
+ },
+ {
+ name: "Hopbyhop with unknown option discard and send icmp action (unicast)",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID, nil
+ },
+ expectErrorICMP: true,
+ icmpType: header.ICMPv6ParamProblem,
+ icmpCode: header.ICMPv6UnknownOption,
+ expectExtensionHeaderError: true,
+ },
+ {
+ name: "Hopbyhop with unknown option discard and send icmp action (multicast)",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: multicastIPv6Addr,
+ extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID, nil
+ },
+ expectErrorICMP: true,
+ icmpType: header.ICMPv6ParamProblem,
+ icmpCode: header.ICMPv6UnknownOption,
+ expectExtensionHeaderError: true,
+ },
+ {
+ name: "Hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID, nil
+ },
+ expectErrorICMP: true,
+ icmpType: header.ICMPv6ParamProblem,
+ icmpCode: header.ICMPv6UnknownOption,
+ expectExtensionHeaderError: true,
+ },
+ {
+ name: "Hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: multicastIPv6Addr,
+ extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID, nil
+ },
+ expectExtensionHeaderError: true,
+ },
+ {
+ name: "Hopbyhop with router alert option",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
+ return []byte{
+ nextHdr, 0,
+
+ // Router Alert option.
+ 5, 2, 0, 0, 0, 0,
+ }, hopByHopExtHdrID, checker.IPv6ExtHdr(checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)))
+ },
+ expectPacketForwarded: true,
+ },
+ {
+ name: "Hopbyhop with two router alert options",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
+ return []byte{
+ nextHdr, 1,
+
+ // Router Alert option.
+ 5, 2, 0, 0, 0, 0,
+
+ // Router Alert option.
+ 5, 2, 0, 0, 0, 0,
+ }, hopByHopExtHdrID, nil
+ },
+ expectExtensionHeaderError: true,
+ },
}
for _, test := range tests {
@@ -3150,7 +3305,17 @@ func TestForwarding(t *testing.T) {
t.Fatalf("SetForwarding(%d, true): %s", ProtocolNumber, err)
}
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6MinimumSize)
+ transportProtocol := header.ICMPv6ProtocolNumber
+ extHdrBytes := []byte{}
+ extHdrChecker := checker.IPv6ExtHdr()
+ if test.extHdr != nil {
+ nextHdrID := hopByHopExtHdrID
+ extHdrBytes, nextHdrID, extHdrChecker = test.extHdr(uint8(header.ICMPv6ProtocolNumber))
+ transportProtocol = tcpip.TransportProtocolNumber(nextHdrID)
+ }
+ extHdrLen := len(extHdrBytes)
+
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6MinimumSize + extHdrLen)
icmp := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
icmp.SetIdent(randomIdent)
icmp.SetSequence(randomSequence)
@@ -3162,10 +3327,11 @@ func TestForwarding(t *testing.T) {
Src: test.sourceAddr,
Dst: test.destAddr,
}))
+ copy(hdr.Prepend(extHdrLen), extHdrBytes)
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: header.ICMPv6MinimumSize,
- TransportProtocol: header.ICMPv6ProtocolNumber,
+ TransportProtocol: transportProtocol,
HopLimit: test.TTL,
SrcAddr: test.sourceAddr,
DstAddr: test.destAddr,
@@ -3205,10 +3371,11 @@ func TestForwarding(t *testing.T) {
t.Fatal("expected ICMP Echo Request packet through outgoing NIC")
}
- checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.IPv6WithExtHdr(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())),
checker.SrcAddr(test.sourceAddr),
checker.DstAddr(test.destAddr),
checker.TTL(test.TTL-1),
+ extHdrChecker,
checker.ICMPv6(
checker.ICMPv6Type(header.ICMPv6EchoRequest),
checker.ICMPv6Code(header.ICMPv6UnusedCode),
@@ -3249,6 +3416,10 @@ func TestForwarding(t *testing.T) {
if got, want := s.Stats().IP.Forwarding.Errors.Value(), boolToInt(!test.expectPacketForwarded); got != want {
t.Errorf("got s.Stats().IP.Forwarding.Errors.Value() = %d, want = %d", got, want)
}
+
+ if got, want := s.Stats().IP.Forwarding.ExtensionHeaderProblem.Value(), boolToInt(test.expectExtensionHeaderError); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.ExtensionHeaderProblem.Value() = %d, want = %d", got, want)
+ }
})
}
}
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index e5590ecc0..ce9cebdaa 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -440,33 +440,54 @@ func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*ad
// Regardless how the address was obtained, it will be acquired before it is
// returned.
func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tcpip.Address, f func(AddressEndpoint) bool, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
- a.mu.Lock()
- defer a.mu.Unlock()
+ lookup := func() *addressState {
+ if addrState, ok := a.mu.endpoints[localAddr]; ok {
+ if !addrState.IsAssigned(allowTemp) {
+ return nil
+ }
- if addrState, ok := a.mu.endpoints[localAddr]; ok {
- if !addrState.IsAssigned(allowTemp) {
- return nil
- }
+ if !addrState.IncRef() {
+ panic(fmt.Sprintf("failed to increase the reference count for address = %s", addrState.addr))
+ }
- if !addrState.IncRef() {
- panic(fmt.Sprintf("failed to increase the reference count for address = %s", addrState.addr))
+ return addrState
}
- return addrState
- }
-
- if f != nil {
- for _, addrState := range a.mu.endpoints {
- if addrState.IsAssigned(allowTemp) && f(addrState) && addrState.IncRef() {
- return addrState
+ if f != nil {
+ for _, addrState := range a.mu.endpoints {
+ if addrState.IsAssigned(allowTemp) && f(addrState) && addrState.IncRef() {
+ return addrState
+ }
}
}
+ return nil
+ }
+ // Avoid exclusive lock on mu unless we need to add a new address.
+ a.mu.RLock()
+ ep := lookup()
+ a.mu.RUnlock()
+
+ if ep != nil {
+ return ep
}
if !allowTemp {
return nil
}
+ // Acquire state lock in exclusive mode as we need to add a new temporary
+ // endpoint.
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ // Do the lookup again in case another goroutine added the address in the time
+ // we released and acquired the lock.
+ ep = lookup()
+ if ep != nil {
+ return ep
+ }
+
+ // Proceed to add a new temporary endpoint.
addr := localAddr.WithPrefix()
ep, err := a.addAndAcquireAddressLocked(addr, tempPEB, AddressConfigStatic, false /* deprecated */, false /* permanent */)
if err != nil {
@@ -475,6 +496,7 @@ func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tc
// expect no error.
panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, %d, %d, false, false): %s", addr, tempPEB, AddressConfigStatic, err))
}
+
// From https://golang.org/doc/faq#nil_error:
//
// Under the covers, interfaces are implemented as two elements, a type T and
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 7e7415df4..f9acd4bb8 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -1530,9 +1530,10 @@ type IGMPStats struct {
// IPForwardingStats collects stats related to IP forwarding (both v4 and v6).
type IPForwardingStats struct {
+ // LINT.IfChange(IPForwardingStats)
+
// Unrouteable is the number of IP packets received which were dropped
- // because the netstack could not construct a route to their
- // destination.
+ // because a route to their destination could not be constructed.
Unrouteable *StatCounter
// ExhaustedTTL is the number of IP packets received which were dropped
@@ -1547,9 +1548,16 @@ type IPForwardingStats struct {
// because they contained a link-local destination address.
LinkLocalDestination *StatCounter
+ // ExtensionHeaderProblem is the number of IP packets which were dropped
+ // because of a problem encountered when processing an IPv6 extension
+ // header.
+ ExtensionHeaderProblem *StatCounter
+
// Errors is the number of IP packets received which could not be
// successfully forwarded.
Errors *StatCounter
+
+ // LINT.ThenChange(network/internal/ip/stats.go:multiCounterIPForwardingStats)
}
// IPStats collects IP-specific stats (both v4 and v6).
diff --git a/pkg/tcpip/testutil/BUILD b/pkg/tcpip/testutil/BUILD
index 472545a5d..02ee86ff1 100644
--- a/pkg/tcpip/testutil/BUILD
+++ b/pkg/tcpip/testutil/BUILD
@@ -5,7 +5,10 @@ package(licenses = ["notice"])
go_library(
name = "testutil",
testonly = True,
- srcs = ["testutil.go"],
+ srcs = [
+ "testutil.go",
+ "testutil_unsafe.go",
+ ],
visibility = ["//visibility:public"],
deps = ["//pkg/tcpip"],
)
diff --git a/pkg/tcpip/testutil/testutil.go b/pkg/tcpip/testutil/testutil.go
index 1aaed590f..f84d399fb 100644
--- a/pkg/tcpip/testutil/testutil.go
+++ b/pkg/tcpip/testutil/testutil.go
@@ -18,6 +18,8 @@ package testutil
import (
"fmt"
"net"
+ "reflect"
+ "strings"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -41,3 +43,69 @@ func MustParse6(addr string) tcpip.Address {
}
return tcpip.Address(ip)
}
+
+func checkFieldCounts(ref, multi reflect.Value) error {
+ refTypeName := ref.Type().Name()
+ multiTypeName := multi.Type().Name()
+ refNumField := ref.NumField()
+ multiNumField := multi.NumField()
+
+ if refNumField != multiNumField {
+ return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName)
+ }
+
+ return nil
+}
+
+func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error {
+ s, ok := ref.Addr().Interface().(**tcpip.StatCounter)
+ if !ok {
+ return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name())
+ }
+
+ // The field names are expected to match (case insensitive).
+ if !strings.EqualFold(refName, multiName) {
+ return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName)
+ }
+
+ base := (*s).Value()
+ m.Increment()
+ if (*s).Value() != base+1 {
+ return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName)
+ }
+
+ return nil
+}
+
+// ValidateMultiCounterStats verifies that every counter stored in multi is
+// correctly tracking its counterpart in the given counters.
+func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error {
+ for _, c := range counters {
+ if err := checkFieldCounts(c, multi); err != nil {
+ return err
+ }
+ }
+
+ for i := 0; i < multi.NumField(); i++ {
+ multiName := multi.Type().Field(i).Name
+ multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i))
+
+ if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok {
+ for _, c := range counters {
+ if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil {
+ return err
+ }
+ }
+ } else {
+ var countersNextField []reflect.Value
+ for _, c := range counters {
+ countersNextField = append(countersNextField, c.Field(i))
+ }
+ if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
diff --git a/pkg/tcpip/network/internal/testutil/testutil_unsafe.go b/pkg/tcpip/testutil/testutil_unsafe.go
index 5ff764800..5ff764800 100644
--- a/pkg/tcpip/network/internal/testutil/testutil_unsafe.go
+++ b/pkg/tcpip/testutil/testutil_unsafe.go