summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/BUILD4
-rw-r--r--pkg/tcpip/checker/checker.go11
-rw-r--r--pkg/tcpip/header/parse/parse.go6
-rw-r--r--pkg/tcpip/link/channel/channel.go11
-rw-r--r--pkg/tcpip/link/fdbased/BUILD1
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go55
-rw-r--r--pkg/tcpip/link/fdbased/packet_dispatchers.go4
-rw-r--r--pkg/tcpip/link/nested/nested.go8
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go13
-rw-r--r--pkg/tcpip/network/arp/BUILD2
-rw-r--r--pkg/tcpip/network/arp/stats_test.go2
-rw-r--r--pkg/tcpip/network/internal/fragmentation/reassembler.go3
-rw-r--r--pkg/tcpip/network/internal/ip/BUILD1
-rw-r--r--pkg/tcpip/network/internal/ip/duplicate_address_detection.go2
-rw-r--r--pkg/tcpip/network/internal/ip/errors.go85
-rw-r--r--pkg/tcpip/network/internal/ip/generic_multicast_protocol.go1
-rw-r--r--pkg/tcpip/network/internal/ip/stats.go115
-rw-r--r--pkg/tcpip/network/internal/testutil/BUILD5
-rw-r--r--pkg/tcpip/network/internal/testutil/testutil.go68
-rw-r--r--pkg/tcpip/network/ip_test.go8
-rw-r--r--pkg/tcpip/network/ipv4/BUILD2
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go59
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go211
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go403
-rw-r--r--pkg/tcpip/network/ipv4/stats_test.go2
-rw-r--r--pkg/tcpip/network/ipv6/BUILD1
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go159
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go5
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go363
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go448
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go130
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go57
-rw-r--r--pkg/tcpip/network/ipv6/stats.go4
-rw-r--r--pkg/tcpip/socketops.go9
-rw-r--r--pkg/tcpip/stack/BUILD2
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go52
-rw-r--r--pkg/tcpip/stack/forwarding_test.go33
-rw-r--r--pkg/tcpip/stack/iptables.go1
-rw-r--r--pkg/tcpip/stack/iptables_types.go15
-rw-r--r--pkg/tcpip/stack/ndp_test.go1173
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go2
-rw-r--r--pkg/tcpip/stack/nic.go29
-rw-r--r--pkg/tcpip/stack/packet_buffer.go383
-rw-r--r--pkg/tcpip/stack/packet_buffer_test.go200
-rw-r--r--pkg/tcpip/stack/registration.go33
-rw-r--r--pkg/tcpip/stack/route.go12
-rw-r--r--pkg/tcpip/stack/stack.go152
-rw-r--r--pkg/tcpip/stack/stack_test.go32
-rw-r--r--pkg/tcpip/stdclock.go130
-rw-r--r--pkg/tcpip/stdclock_state.go26
-rw-r--r--pkg/tcpip/tcpip.go58
-rw-r--r--pkg/tcpip/tests/integration/BUILD2
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go256
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go296
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go8
-rw-r--r--pkg/tcpip/tests/utils/utils.go42
-rw-r--r--pkg/tcpip/testutil/BUILD5
-rw-r--r--pkg/tcpip/testutil/testutil.go68
-rw-r--r--pkg/tcpip/testutil/testutil_unsafe.go (renamed from pkg/tcpip/network/internal/testutil/testutil_unsafe.go)0
-rw-r--r--pkg/tcpip/time_unsafe.go75
-rw-r--r--pkg/tcpip/timer_test.go32
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go8
-rw-r--r--pkg/tcpip/transport/tcp/BUILD12
-rw-r--r--pkg/tcpip/transport/tcp/connect.go28
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go19
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go8
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go12
-rw-r--r--pkg/tcpip/transport/tcp/segment.go9
-rw-r--r--pkg/tcpip/transport/tcp/segment_test.go67
-rw-r--r--pkg/tcpip/transport/tcp/snd.go11
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go184
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go4
72 files changed, 4029 insertions, 1708 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index aa30cfc85..ea46c30da 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -22,12 +22,14 @@ go_library(
"errors.go",
"sock_err_list.go",
"socketops.go",
+ "stdclock.go",
+ "stdclock_state.go",
"tcpip.go",
- "time_unsafe.go",
"timer.go",
],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/atomicbitops",
"//pkg/sync",
"//pkg/tcpip/buffer",
"//pkg/waiter",
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 12c39dfa3..18e6cc3cd 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -1607,6 +1607,17 @@ func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker {
}
}
+// IPv6UnknownOption validates that an extension header option is the
+// unknown header option.
+func IPv6UnknownOption() IPv6ExtHdrOptionChecker {
+ return func(t *testing.T, opt header.IPv6ExtHdrOption) {
+ _, ok := opt.(*header.IPv6UnknownExtHdrOption)
+ if !ok {
+ t.Errorf("got = %T, want = header.IPv6UnknownExtHdrOption", opt)
+ }
+ }
+}
+
// IgnoreCmpPath returns a cmp.Option that ignores listed field paths.
func IgnoreCmpPath(paths ...string) cmp.Option {
ignores := map[string]struct{}{}
diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go
index ebb4b2c1d..1c913b5e1 100644
--- a/pkg/tcpip/header/parse/parse.go
+++ b/pkg/tcpip/header/parse/parse.go
@@ -60,9 +60,13 @@ func IPv4(pkt *stack.PacketBuffer) bool {
return false
}
ipHdr = header.IPv4(hdr)
+ length := int(ipHdr.TotalLength()) - len(hdr)
+ if length < 0 {
+ return false
+ }
pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
- pkt.Data().CapLength(int(ipHdr.TotalLength()) - len(hdr))
+ pkt.Data().CapLength(length)
return true
}
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index f75ee34ab..ef9126deb 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -123,6 +123,9 @@ func (q *queue) RemoveNotify(handle *NotificationHandle) {
q.notify = notify
}
+var _ stack.LinkEndpoint = (*Endpoint)(nil)
+var _ stack.GSOEndpoint = (*Endpoint)(nil)
+
// Endpoint is link layer endpoint that stores outbound packets in a channel
// and allows injection of inbound packets.
type Endpoint struct {
@@ -130,6 +133,7 @@ type Endpoint struct {
mtu uint32
linkAddr tcpip.LinkAddress
LinkEPCapabilities stack.LinkEndpointCapabilities
+ SupportedGSOKind stack.SupportedGSO
// Outbound packet queue.
q *queue
@@ -211,11 +215,16 @@ func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
return e.LinkEPCapabilities
}
-// GSOMaxSize returns the maximum GSO packet size.
+// GSOMaxSize implements stack.GSOEndpoint.
func (*Endpoint) GSOMaxSize() uint32 {
return 1 << 15
}
+// SupportedGSO implements stack.GSOEndpoint.
+func (e *Endpoint) SupportedGSO() stack.SupportedGSO {
+ return e.SupportedGSOKind
+}
+
// MaxHeaderLength returns the maximum size of the link layer header. Given it
// doesn't have a header, it just returns 0.
func (*Endpoint) MaxHeaderLength() uint16 {
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
index f042df82e..d971194e6 100644
--- a/pkg/tcpip/link/fdbased/BUILD
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -14,7 +14,6 @@ go_library(
],
visibility = ["//visibility:public"],
deps = [
- "//pkg/binary",
"//pkg/iovec",
"//pkg/sync",
"//pkg/tcpip",
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index feb79fe0e..bddb1d0a2 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -45,7 +45,6 @@ import (
"sync/atomic"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/iovec"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -98,6 +97,9 @@ func (p PacketDispatchMode) String() string {
}
}
+var _ stack.LinkEndpoint = (*endpoint)(nil)
+var _ stack.GSOEndpoint = (*endpoint)(nil)
+
type endpoint struct {
// fds is the set of file descriptors each identifying one inbound/outbound
// channel. The endpoint will dispatch from all inbound channels as well as
@@ -134,6 +136,9 @@ type endpoint struct {
// wg keeps track of running goroutines.
wg sync.WaitGroup
+
+ // gsoKind is the supported kind of GSO.
+ gsoKind stack.SupportedGSO
}
// Options specify the details about the fd-based endpoint to be created.
@@ -255,9 +260,9 @@ func New(opts *Options) (stack.LinkEndpoint, error) {
if isSocket {
if opts.GSOMaxSize != 0 {
if opts.SoftwareGSOEnabled {
- e.caps |= stack.CapabilitySoftwareGSO
+ e.gsoKind = stack.SWGSOSupported
} else {
- e.caps |= stack.CapabilityHardwareGSO
+ e.gsoKind = stack.HWGSOSupported
}
e.gsoMaxSize = opts.GSOMaxSize
}
@@ -403,6 +408,35 @@ type virtioNetHdr struct {
csumOffset uint16
}
+// marshal serializes h to a newly-allocated byte slice, in little-endian byte
+// order.
+//
+// Note: Virtio v1.0 onwards specifies little-endian as the byte ordering used
+// for general serialization. This makes it difficult to use go-marshal for
+// virtio types, as go-marshal implicitly uses the native byte ordering.
+func (h *virtioNetHdr) marshal() []byte {
+ buf := [virtioNetHdrSize]byte{
+ 0: byte(h.flags),
+ 1: byte(h.gsoType),
+
+ // Manually lay out the fields in little-endian byte order. Little endian =>
+ // least significant bit goes to the lower address.
+
+ 2: byte(h.hdrLen),
+ 3: byte(h.hdrLen >> 8),
+
+ 4: byte(h.gsoSize),
+ 5: byte(h.gsoSize >> 8),
+
+ 6: byte(h.csumStart),
+ 7: byte(h.csumStart >> 8),
+
+ 8: byte(h.csumOffset),
+ 9: byte(h.csumOffset >> 8),
+ }
+ return buf[:]
+}
+
// These constants are declared in linux/virtio_net.h.
const (
_VIRTIO_NET_HDR_F_NEEDS_CSUM = 1
@@ -441,7 +475,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol
var builder iovec.Builder
fd := e.fds[pkt.Hash%uint32(len(e.fds))]
- if e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ if e.gsoKind == stack.HWGSOSupported {
vnetHdr := virtioNetHdr{}
if pkt.GSOOptions.Type != stack.GSONone {
vnetHdr.hdrLen = uint16(pkt.HeaderSize())
@@ -463,7 +497,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol
}
}
- vnetHdrBuf := binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr)
+ vnetHdrBuf := vnetHdr.marshal()
builder.Add(vnetHdrBuf)
}
@@ -482,7 +516,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcp
}
var vnetHdrBuf []byte
- if e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ if e.gsoKind == stack.HWGSOSupported {
vnetHdr := virtioNetHdr{}
if pkt.GSOOptions.Type != stack.GSONone {
vnetHdr.hdrLen = uint16(pkt.HeaderSize())
@@ -503,7 +537,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcp
vnetHdr.gsoSize = pkt.GSOOptions.MSS
}
}
- vnetHdrBuf = binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr)
+ vnetHdrBuf = vnetHdr.marshal()
}
var builder iovec.Builder
@@ -602,11 +636,16 @@ func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) tcpip.Error {
}
}
-// GSOMaxSize returns the maximum GSO packet size.
+// GSOMaxSize implements stack.GSOEndpoint.
func (e *endpoint) GSOMaxSize() uint32 {
return e.gsoMaxSize
}
+// SupportsHWGSO implements stack.GSOEndpoint.
+func (e *endpoint) SupportedGSO() stack.SupportedGSO {
+ return e.gsoKind
+}
+
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
if e.hdrSize > 0 {
diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go
index a7adf822b..4b7ef3aac 100644
--- a/pkg/tcpip/link/fdbased/packet_dispatchers.go
+++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go
@@ -128,7 +128,7 @@ type readVDispatcher struct {
func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
d := &readVDispatcher{fd: fd, e: e}
- skipsVnetHdr := d.e.Capabilities()&stack.CapabilityHardwareGSO != 0
+ skipsVnetHdr := d.e.gsoKind == stack.HWGSOSupported
d.buf = newIovecBuffer(BufConfig, skipsVnetHdr)
return d, nil
}
@@ -212,7 +212,7 @@ func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
bufs: make([]*iovecBuffer, MaxMsgsPerRecv),
msgHdrs: make([]rawfile.MMsgHdr, MaxMsgsPerRecv),
}
- skipsVnetHdr := d.e.Capabilities()&stack.CapabilityHardwareGSO != 0
+ skipsVnetHdr := d.e.gsoKind == stack.HWGSOSupported
for i := range d.bufs {
d.bufs[i] = newIovecBuffer(BufConfig, skipsVnetHdr)
}
diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go
index 89df35822..3e816b0c7 100644
--- a/pkg/tcpip/link/nested/nested.go
+++ b/pkg/tcpip/link/nested/nested.go
@@ -135,6 +135,14 @@ func (e *Endpoint) GSOMaxSize() uint32 {
return 0
}
+// SupportedGSO implements stack.GSOEndpoint.
+func (e *Endpoint) SupportedGSO() stack.SupportedGSO {
+ if e, ok := e.child.(stack.GSOEndpoint); ok {
+ return e.SupportedGSO()
+ }
+ return stack.GSONotSupported
+}
+
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType
func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
return e.child.ARPHardwareType()
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index bba6a6973..b1a28491d 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -25,6 +25,9 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+var _ stack.LinkEndpoint = (*endpoint)(nil)
+var _ stack.GSOEndpoint = (*endpoint)(nil)
+
// endpoint represents a LinkEndpoint which implements a FIFO queue for all
// outgoing packets. endpoint can have 1 or more underlying queueDispatchers.
// All outgoing packets are consistenly hashed to a single underlying queue
@@ -141,7 +144,7 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return e.lower.LinkAddress()
}
-// GSOMaxSize returns the maximum GSO packet size.
+// GSOMaxSize implements stack.GSOEndpoint.
func (e *endpoint) GSOMaxSize() uint32 {
if gso, ok := e.lower.(stack.GSOEndpoint); ok {
return gso.GSOMaxSize()
@@ -149,6 +152,14 @@ func (e *endpoint) GSOMaxSize() uint32 {
return 0
}
+// SupportedGSO implements stack.GSOEndpoint.
+func (e *endpoint) SupportedGSO() stack.SupportedGSO {
+ if gso, ok := e.lower.(stack.GSOEndpoint); ok {
+ return gso.SupportedGSO()
+ }
+ return stack.GSONotSupported
+}
+
// WritePacket implements stack.LinkEndpoint.WritePacket.
//
// The packet must have the following fields populated:
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index 6905b9ccb..a72eb1aad 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -47,7 +47,7 @@ go_test(
library = ":arp",
deps = [
"//pkg/tcpip",
- "//pkg/tcpip/network/internal/testutil",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
],
)
diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go
index e867b3c3f..0df39ae81 100644
--- a/pkg/tcpip/network/arp/stats_test.go
+++ b/pkg/tcpip/network/arp/stats_test.go
@@ -19,8 +19,8 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
var _ stack.NetworkInterface = (*testInterface)(nil)
diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler.go b/pkg/tcpip/network/internal/fragmentation/reassembler.go
index 90075a70c..56b76a284 100644
--- a/pkg/tcpip/network/internal/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/internal/fragmentation/reassembler.go
@@ -167,8 +167,7 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s
resPkt := r.holes[0].pkt
for i := 1; i < len(r.holes); i++ {
- fragData := r.holes[i].pkt.Data()
- resPkt.Data().ReadFromData(fragData, fragData.Size())
+ stack.MergeFragment(resPkt, r.holes[i].pkt)
}
return resPkt, r.proto, true, memConsumed, nil
}
diff --git a/pkg/tcpip/network/internal/ip/BUILD b/pkg/tcpip/network/internal/ip/BUILD
index d21b4c7ef..fd944ce99 100644
--- a/pkg/tcpip/network/internal/ip/BUILD
+++ b/pkg/tcpip/network/internal/ip/BUILD
@@ -6,6 +6,7 @@ go_library(
name = "ip",
srcs = [
"duplicate_address_detection.go",
+ "errors.go",
"generic_multicast_protocol.go",
"stats.go",
],
diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go
index eed49f5d2..5123b7d6a 100644
--- a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go
+++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go
@@ -83,6 +83,8 @@ func (d *DAD) Init(protocolMU sync.Locker, configs stack.DADConfigurations, opts
panic(fmt.Sprintf("given a non-zero value for NonceSize (%d) but zero for ExtendDADTransmits", opts.NonceSize))
}
+ configs.Validate()
+
*d = DAD{
opts: opts,
configs: configs,
diff --git a/pkg/tcpip/network/internal/ip/errors.go b/pkg/tcpip/network/internal/ip/errors.go
new file mode 100644
index 000000000..94f1cd1cb
--- /dev/null
+++ b/pkg/tcpip/network/internal/ip/errors.go
@@ -0,0 +1,85 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ip
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// ForwardingError represents an error that occured while trying to forward
+// a packet.
+type ForwardingError interface {
+ isForwardingError()
+ fmt.Stringer
+}
+
+// ErrTTLExceeded indicates that the received packet's TTL has been exceeded.
+type ErrTTLExceeded struct{}
+
+func (*ErrTTLExceeded) isForwardingError() {}
+
+func (*ErrTTLExceeded) String() string { return "ttl exceeded" }
+
+// ErrParameterProblem indicates the received packet had a problem with an IP
+// parameter.
+type ErrParameterProblem struct{}
+
+func (*ErrParameterProblem) isForwardingError() {}
+
+func (*ErrParameterProblem) String() string { return "parameter problem" }
+
+// ErrLinkLocalSourceAddress indicates the received packet had a link-local
+// source address.
+type ErrLinkLocalSourceAddress struct{}
+
+func (*ErrLinkLocalSourceAddress) isForwardingError() {}
+
+func (*ErrLinkLocalSourceAddress) String() string { return "link local destination address" }
+
+// ErrLinkLocalDestinationAddress indicates the received packet had a link-local
+// destination address.
+type ErrLinkLocalDestinationAddress struct{}
+
+func (*ErrLinkLocalDestinationAddress) isForwardingError() {}
+
+func (*ErrLinkLocalDestinationAddress) String() string { return "link local destination address" }
+
+// ErrNoRoute indicates that a route for the received packet couldn't be found.
+type ErrNoRoute struct{}
+
+func (*ErrNoRoute) isForwardingError() {}
+
+func (*ErrNoRoute) String() string { return "no route" }
+
+// ErrMessageTooLong indicates the packet was too big for the outgoing MTU.
+//
+// +stateify savable
+type ErrMessageTooLong struct{}
+
+func (*ErrMessageTooLong) isForwardingError() {}
+
+func (*ErrMessageTooLong) String() string { return "message too long" }
+
+// ErrOther indicates the packet coould not be forwarded for a reason
+// captured by the contained error.
+type ErrOther struct {
+ Err tcpip.Error
+}
+
+func (*ErrOther) isForwardingError() {}
+
+func (e *ErrOther) String() string { return fmt.Sprintf("other tcpip error: %s", e.Err) }
diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
index ac35d81e7..d22974b12 100644
--- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
+++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package ip holds IPv4/IPv6 common utilities.
package ip
import (
diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go
index d06b26309..0c2b62127 100644
--- a/pkg/tcpip/network/internal/ip/stats.go
+++ b/pkg/tcpip/network/internal/ip/stats.go
@@ -16,80 +16,145 @@ package ip
import "gvisor.dev/gvisor/pkg/tcpip"
+// LINT.IfChange(MultiCounterIPForwardingStats)
+
+// MultiCounterIPForwardingStats holds IP forwarding statistics. Each counter
+// may have several versions.
+type MultiCounterIPForwardingStats struct {
+ // Unrouteable is the number of IP packets received which were dropped
+ // because the netstack could not construct a route to their
+ // destination.
+ Unrouteable tcpip.MultiCounterStat
+
+ // ExhaustedTTL is the number of IP packets received which were dropped
+ // because their TTL was exhausted.
+ ExhaustedTTL tcpip.MultiCounterStat
+
+ // LinkLocalSource is the number of IP packets which were dropped
+ // because they contained a link-local source address.
+ LinkLocalSource tcpip.MultiCounterStat
+
+ // LinkLocalDestination is the number of IP packets which were dropped
+ // because they contained a link-local destination address.
+ LinkLocalDestination tcpip.MultiCounterStat
+
+ // PacketTooBig is the number of IP packets which were dropped because they
+ // were too big for the outgoing MTU.
+ PacketTooBig tcpip.MultiCounterStat
+
+ // ExtensionHeaderProblem is the number of IP packets which were dropped
+ // because of a problem encountered when processing an IPv6 extension
+ // header.
+ ExtensionHeaderProblem tcpip.MultiCounterStat
+
+ // Errors is the number of IP packets received which could not be
+ // successfully forwarded.
+ Errors tcpip.MultiCounterStat
+}
+
+// Init sets internal counters to track a and b counters.
+func (m *MultiCounterIPForwardingStats) Init(a, b *tcpip.IPForwardingStats) {
+ m.Unrouteable.Init(a.Unrouteable, b.Unrouteable)
+ m.Errors.Init(a.Errors, b.Errors)
+ m.LinkLocalSource.Init(a.LinkLocalSource, b.LinkLocalSource)
+ m.LinkLocalDestination.Init(a.LinkLocalDestination, b.LinkLocalDestination)
+ m.ExtensionHeaderProblem.Init(a.ExtensionHeaderProblem, b.ExtensionHeaderProblem)
+ m.PacketTooBig.Init(a.PacketTooBig, b.PacketTooBig)
+ m.ExhaustedTTL.Init(a.ExhaustedTTL, b.ExhaustedTTL)
+}
+
+// LINT.ThenChange(:MultiCounterIPForwardingStats, ../../../tcpip.go:IPForwardingStats)
+
// LINT.IfChange(MultiCounterIPStats)
// MultiCounterIPStats holds IP statistics, each counter may have several
// versions.
type MultiCounterIPStats struct {
- // PacketsReceived is the number of IP packets received from the link layer.
+ // PacketsReceived is the number of IP packets received from the link
+ // layer.
PacketsReceived tcpip.MultiCounterStat
- // DisabledPacketsReceived is the number of IP packets received from the link
- // layer when the IP layer is disabled.
+ // ValidPacketsReceived is the number of valid IP packets that reached the IP
+ // layer.
+ ValidPacketsReceived tcpip.MultiCounterStat
+
+ // DisabledPacketsReceived is the number of IP packets received from
+ // the link layer when the IP layer is disabled.
DisabledPacketsReceived tcpip.MultiCounterStat
- // InvalidDestinationAddressesReceived is the number of IP packets received
- // with an unknown or invalid destination address.
+ // InvalidDestinationAddressesReceived is the number of IP packets
+ // received with an unknown or invalid destination address.
InvalidDestinationAddressesReceived tcpip.MultiCounterStat
- // InvalidSourceAddressesReceived is the number of IP packets received with a
- // source address that should never have been received on the wire.
+ // InvalidSourceAddressesReceived is the number of IP packets received
+ // with a source address that should never have been received on the
+ // wire.
InvalidSourceAddressesReceived tcpip.MultiCounterStat
- // PacketsDelivered is the number of incoming IP packets that are successfully
+ // PacketsDelivered is the number of incoming IP packets successfully
// delivered to the transport layer.
PacketsDelivered tcpip.MultiCounterStat
// PacketsSent is the number of IP packets sent via WritePacket.
PacketsSent tcpip.MultiCounterStat
- // OutgoingPacketErrors is the number of IP packets which failed to write to a
- // link-layer endpoint.
+ // OutgoingPacketErrors is the number of IP packets which failed to
+ // write to a link-layer endpoint.
OutgoingPacketErrors tcpip.MultiCounterStat
- // MalformedPacketsReceived is the number of IP Packets that were dropped due
- // to the IP packet header failing validation checks.
+ // MalformedPacketsReceived is the number of IP Packets that were
+ // dropped due to the IP packet header failing validation checks.
MalformedPacketsReceived tcpip.MultiCounterStat
- // MalformedFragmentsReceived is the number of IP Fragments that were dropped
- // due to the fragment failing validation checks.
+ // MalformedFragmentsReceived is the number of IP Fragments that were
+ // dropped due to the fragment failing validation checks.
MalformedFragmentsReceived tcpip.MultiCounterStat
// IPTablesPreroutingDropped is the number of IP packets dropped in the
// Prerouting chain.
IPTablesPreroutingDropped tcpip.MultiCounterStat
- // IPTablesInputDropped is the number of IP packets dropped in the Input
- // chain.
+ // IPTablesInputDropped is the number of IP packets dropped in the
+ // Input chain.
IPTablesInputDropped tcpip.MultiCounterStat
- // IPTablesOutputDropped is the number of IP packets dropped in the Output
- // chain.
+ // IPTablesForwardDropped is the number of IP packets dropped in the
+ // Forward chain.
+ IPTablesForwardDropped tcpip.MultiCounterStat
+
+ // IPTablesOutputDropped is the number of IP packets dropped in the
+ // Output chain.
IPTablesOutputDropped tcpip.MultiCounterStat
- // IPTablesPostroutingDropped is the number of IP packets dropped in the
- // Postrouting chain.
+ // IPTablesPostroutingDropped is the number of IP packets dropped in
+ // the Postrouting chain.
IPTablesPostroutingDropped tcpip.MultiCounterStat
- // TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option stats out
- // of IPStats.
+ // TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option
+ // stats out of IPStats.
// OptionTimestampReceived is the number of Timestamp options seen.
OptionTimestampReceived tcpip.MultiCounterStat
- // OptionRecordRouteReceived is the number of Record Route options seen.
+ // OptionRecordRouteReceived is the number of Record Route options
+ // seen.
OptionRecordRouteReceived tcpip.MultiCounterStat
- // OptionRouterAlertReceived is the number of Router Alert options seen.
+ // OptionRouterAlertReceived is the number of Router Alert options
+ // seen.
OptionRouterAlertReceived tcpip.MultiCounterStat
// OptionUnknownReceived is the number of unknown IP options seen.
OptionUnknownReceived tcpip.MultiCounterStat
+
+ // Forwarding collects stats related to IP forwarding.
+ Forwarding MultiCounterIPForwardingStats
}
// Init sets internal counters to track a and b counters.
func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) {
m.PacketsReceived.Init(a.PacketsReceived, b.PacketsReceived)
+ m.ValidPacketsReceived.Init(a.ValidPacketsReceived, b.ValidPacketsReceived)
m.DisabledPacketsReceived.Init(a.DisabledPacketsReceived, b.DisabledPacketsReceived)
m.InvalidDestinationAddressesReceived.Init(a.InvalidDestinationAddressesReceived, b.InvalidDestinationAddressesReceived)
m.InvalidSourceAddressesReceived.Init(a.InvalidSourceAddressesReceived, b.InvalidSourceAddressesReceived)
@@ -100,12 +165,14 @@ func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) {
m.MalformedFragmentsReceived.Init(a.MalformedFragmentsReceived, b.MalformedFragmentsReceived)
m.IPTablesPreroutingDropped.Init(a.IPTablesPreroutingDropped, b.IPTablesPreroutingDropped)
m.IPTablesInputDropped.Init(a.IPTablesInputDropped, b.IPTablesInputDropped)
+ m.IPTablesForwardDropped.Init(a.IPTablesForwardDropped, b.IPTablesForwardDropped)
m.IPTablesOutputDropped.Init(a.IPTablesOutputDropped, b.IPTablesOutputDropped)
m.IPTablesPostroutingDropped.Init(a.IPTablesPostroutingDropped, b.IPTablesPostroutingDropped)
m.OptionTimestampReceived.Init(a.OptionTimestampReceived, b.OptionTimestampReceived)
m.OptionRecordRouteReceived.Init(a.OptionRecordRouteReceived, b.OptionRecordRouteReceived)
m.OptionRouterAlertReceived.Init(a.OptionRouterAlertReceived, b.OptionRouterAlertReceived)
m.OptionUnknownReceived.Init(a.OptionUnknownReceived, b.OptionUnknownReceived)
+ m.Forwarding.Init(&a.Forwarding, &b.Forwarding)
}
// LINT.ThenChange(:MultiCounterIPStats, ../../../tcpip.go:IPStats)
diff --git a/pkg/tcpip/network/internal/testutil/BUILD b/pkg/tcpip/network/internal/testutil/BUILD
index 1c4f583c7..cec3e62c4 100644
--- a/pkg/tcpip/network/internal/testutil/BUILD
+++ b/pkg/tcpip/network/internal/testutil/BUILD
@@ -4,10 +4,7 @@ package(licenses = ["notice"])
go_library(
name = "testutil",
- srcs = [
- "testutil.go",
- "testutil_unsafe.go",
- ],
+ srcs = ["testutil.go"],
visibility = [
"//pkg/tcpip/network/arp:__pkg__",
"//pkg/tcpip/network/internal/fragmentation:__pkg__",
diff --git a/pkg/tcpip/network/internal/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go
index e2cf24b67..605e9ef8d 100644
--- a/pkg/tcpip/network/internal/testutil/testutil.go
+++ b/pkg/tcpip/network/internal/testutil/testutil.go
@@ -19,8 +19,6 @@ package testutil
import (
"fmt"
"math/rand"
- "reflect"
- "strings"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -129,69 +127,3 @@ func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSi
}
return pkt
}
-
-func checkFieldCounts(ref, multi reflect.Value) error {
- refTypeName := ref.Type().Name()
- multiTypeName := multi.Type().Name()
- refNumField := ref.NumField()
- multiNumField := multi.NumField()
-
- if refNumField != multiNumField {
- return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName)
- }
-
- return nil
-}
-
-func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error {
- s, ok := ref.Addr().Interface().(**tcpip.StatCounter)
- if !ok {
- return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name())
- }
-
- // The field names are expected to match (case insensitive).
- if !strings.EqualFold(refName, multiName) {
- return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName)
- }
-
- base := (*s).Value()
- m.Increment()
- if (*s).Value() != base+1 {
- return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName)
- }
-
- return nil
-}
-
-// ValidateMultiCounterStats verifies that every counter stored in multi is
-// correctly tracking its counterpart in the given counters.
-func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error {
- for _, c := range counters {
- if err := checkFieldCounts(c, multi); err != nil {
- return err
- }
- }
-
- for i := 0; i < multi.NumField(); i++ {
- multiName := multi.Type().Field(i).Name
- multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i))
-
- if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok {
- for _, c := range counters {
- if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil {
- return err
- }
- }
- } else {
- var countersNextField []reflect.Value
- for _, c := range counters {
- countersNextField = append(countersNextField, c.Field(i))
- }
- if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil {
- return err
- }
- }
- }
-
- return nil
-}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 74aad126c..bd63e0289 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -1996,8 +1996,8 @@ func TestJoinLeaveAllRoutersGroup(t *testing.T) {
t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr)
}
- if err := s.SetForwarding(test.netProto, true); err != nil {
- t.Fatalf("s.SetForwarding(%d, true): %s", test.netProto, err)
+ if err := s.SetForwardingDefaultAndAllNICs(test.netProto, true); err != nil {
+ t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", test.netProto, err)
}
if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil {
t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err)
@@ -2005,8 +2005,8 @@ func TestJoinLeaveAllRoutersGroup(t *testing.T) {
t.Fatalf("got s.IsInGroup(%d, %s) = false, want = true", nicID, test.allRoutersAddr)
}
- if err := s.SetForwarding(test.netProto, false); err != nil {
- t.Fatalf("s.SetForwarding(%d, false): %s", test.netProto, err)
+ if err := s.SetForwardingDefaultAndAllNICs(test.netProto, false); err != nil {
+ t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, false): %s", test.netProto, err)
}
if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil {
t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err)
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 7ee0495d9..c90974693 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -62,7 +62,7 @@ go_test(
library = ":ipv4",
deps = [
"//pkg/tcpip",
- "//pkg/tcpip/network/internal/testutil",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
],
)
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index f663fdc0b..d1a82b584 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -163,10 +163,12 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet
return
}
- // Skip the ip header, then deliver the error.
- pkt.Data().TrimFront(hlen)
+ // Keep needed information before trimming header.
p := hdr.TransportProtocol()
- e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, errInfo, pkt)
+ dstAddr := hdr.DestinationAddress()
+ // Skip the ip header, then deliver the error.
+ pkt.Data().DeleteFront(hlen)
+ e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, errInfo, pkt)
}
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
@@ -336,14 +338,16 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
case header.ICMPv4DstUnreachable:
received.dstUnreachable.Increment()
- pkt.Data().TrimFront(header.ICMPv4MinimumSize)
- switch h.Code() {
+ mtu := h.MTU()
+ code := h.Code()
+ pkt.Data().DeleteFront(header.ICMPv4MinimumSize)
+ switch code {
case header.ICMPv4HostUnreachable:
e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt)
case header.ICMPv4PortUnreachable:
e.handleControl(&icmpv4DestinationPortUnreachableSockError{}, pkt)
case header.ICMPv4FragmentationNeeded:
- networkMTU, err := calculateNetworkMTU(uint32(h.MTU()), header.IPv4MinimumSize)
+ networkMTU, err := calculateNetworkMTU(uint32(mtu), header.IPv4MinimumSize)
if err != nil {
networkMTU = 0
}
@@ -383,6 +387,8 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
// icmpReason is a marker interface for IPv4 specific ICMP errors.
type icmpReason interface {
isICMPReason()
+ // isForwarding indicates whether or not the error arose while attempting to
+ // forward a packet.
isForwarding() bool
}
@@ -442,6 +448,39 @@ func (r *icmpReasonParamProblem) isForwarding() bool {
return r.forwarding
}
+// icmpReasonNetworkUnreachable is an error in which the network specified in
+// the internet destination field of the datagram is unreachable.
+type icmpReasonNetworkUnreachable struct{}
+
+func (*icmpReasonNetworkUnreachable) isICMPReason() {}
+func (*icmpReasonNetworkUnreachable) isForwarding() bool {
+ // If we hit a Net Unreachable error, then we know we are operating as
+ // a router. As per RFC 792 page 5, Destination Unreachable Message,
+ //
+ // If, according to the information in the gateway's routing tables,
+ // the network specified in the internet destination field of a
+ // datagram is unreachable, e.g., the distance to the network is
+ // infinity, the gateway may send a destination unreachable message to
+ // the internet source host of the datagram.
+ return true
+}
+
+// icmpReasonFragmentationNeeded is an error where a packet requires
+// fragmentation while also having the Don't Fragment flag set, as per RFC 792
+// page 3, Destination Unreachable Message.
+type icmpReasonFragmentationNeeded struct{}
+
+func (*icmpReasonFragmentationNeeded) isICMPReason() {}
+func (*icmpReasonFragmentationNeeded) isForwarding() bool {
+ // If we hit a Don't Fragment error, then we know we are operating as a router.
+ // As per RFC 792 page 4, Destination Unreachable Message,
+ //
+ // Another case is when a datagram must be fragmented to be forwarded by a
+ // gateway yet the Don't Fragment flag is on. In this case the gateway must
+ // discard the datagram and may return a destination unreachable message.
+ return true
+}
+
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv4 and sends it back to the remote device that sent
// the problematic packet. It incorporates as much of that packet as
@@ -610,6 +649,14 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(header.ICMPv4ProtoUnreachable)
counter = sent.dstUnreachable
+ case *icmpReasonNetworkUnreachable:
+ icmpHdr.SetType(header.ICMPv4DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv4NetUnreachable)
+ counter = sent.dstUnreachable
+ case *icmpReasonFragmentationNeeded:
+ icmpHdr.SetType(header.ICMPv4DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv4FragmentationNeeded)
+ counter = sent.dstUnreachable
case *icmpReasonTTLExceeded:
icmpHdr.SetType(header.ICMPv4TimeExceeded)
icmpHdr.SetCode(header.ICMPv4TTLExceeded)
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index a0bc06465..23178277a 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -29,6 +29,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/network/hash"
"gvisor.dev/gvisor/pkg/tcpip/network/internal/fragmentation"
+ "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -62,9 +63,15 @@ const (
fragmentblockSize = 8
)
+const (
+ forwardingDisabled = 0
+ forwardingEnabled = 1
+)
+
var ipv4BroadcastAddr = header.IPv4Broadcast.WithPrefix()
var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil)
+var _ stack.ForwardingNetworkEndpoint = (*endpoint)(nil)
var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
var _ stack.AddressableEndpoint = (*endpoint)(nil)
var _ stack.NetworkEndpoint = (*endpoint)(nil)
@@ -81,6 +88,12 @@ type endpoint struct {
// Must be accessed using atomic operations.
enabled uint32
+ // forwarding is set to forwardingEnabled when the endpoint has forwarding
+ // enabled and forwardingDisabled when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ forwarding uint32
+
mu struct {
sync.RWMutex
@@ -150,14 +163,32 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
delete(p.mu.eps, nicID)
}
-// transitionForwarding transitions the endpoint's forwarding status to
-// forwarding.
+// Forwarding implements stack.ForwardingNetworkEndpoint.
+func (e *endpoint) Forwarding() bool {
+ return atomic.LoadUint32(&e.forwarding) == forwardingEnabled
+}
+
+// setForwarding sets the forwarding status for the endpoint.
//
-// Must only be called when the forwarding status changes.
-func (e *endpoint) transitionForwarding(forwarding bool) {
+// Returns true if the forwarding status was updated.
+func (e *endpoint) setForwarding(v bool) bool {
+ forwarding := uint32(forwardingDisabled)
+ if v {
+ forwarding = forwardingEnabled
+ }
+
+ return atomic.SwapUint32(&e.forwarding, forwarding) != forwarding
+}
+
+// SetForwarding implements stack.ForwardingNetworkEndpoint.
+func (e *endpoint) SetForwarding(forwarding bool) {
e.mu.Lock()
defer e.mu.Unlock()
+ if !e.setForwarding(forwarding) {
+ return
+ }
+
if forwarding {
// There does not seem to be an RFC requirement for a node to join the all
// routers multicast address but
@@ -433,6 +464,12 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIn
}
if packetMustBeFragmented(pkt, networkMTU) {
+ h := header.IPv4(pkt.NetworkHeader().View())
+ if h.Flags()&header.IPv4FlagDontFragment != 0 && pkt.NetworkPacketInfo.IsForwardedPacket {
+ // TODO(gvisor.dev/issue/5919): Handle error condition in which DontFragment
+ // is set but the packet must be fragmented for the non-forwarding case.
+ return &tcpip.ErrMessageTooLong{}
+ }
sent, remain, err := e.handleFragments(r, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error {
// TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
// fragment one by one using WritePacket() (current strategy) or if we
@@ -599,22 +636,25 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
}
// forwardPacket attempts to forward a packet to its final destination.
-func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
+func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
h := header.IPv4(pkt.NetworkHeader().View())
dstAddr := h.DestinationAddress()
- if header.IsV4LinkLocalUnicastAddress(h.SourceAddress()) || header.IsV4LinkLocalUnicastAddress(dstAddr) || header.IsV4LinkLocalMulticastAddress(dstAddr) {
- // As per RFC 3927 section 7,
- //
- // A router MUST NOT forward a packet with an IPv4 Link-Local source or
- // destination address, irrespective of the router's default route
- // configuration or routes obtained from dynamic routing protocols.
- //
- // A router which receives a packet with an IPv4 Link-Local source or
- // destination address MUST NOT forward the packet. This prevents
- // forwarding of packets back onto the network segment from which they
- // originated, or to any other segment.
- return nil
+ // As per RFC 3927 section 7,
+ //
+ // A router MUST NOT forward a packet with an IPv4 Link-Local source or
+ // destination address, irrespective of the router's default route
+ // configuration or routes obtained from dynamic routing protocols.
+ //
+ // A router which receives a packet with an IPv4 Link-Local source or
+ // destination address MUST NOT forward the packet. This prevents
+ // forwarding of packets back onto the network segment from which they
+ // originated, or to any other segment.
+ if header.IsV4LinkLocalUnicastAddress(h.SourceAddress()) {
+ return &ip.ErrLinkLocalSourceAddress{}
+ }
+ if header.IsV4LinkLocalUnicastAddress(dstAddr) || header.IsV4LinkLocalMulticastAddress(dstAddr) {
+ return &ip.ErrLinkLocalDestinationAddress{}
}
ttl := h.TTL()
@@ -624,7 +664,12 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
// If the gateway processing a datagram finds the time to live field
// is zero it must discard the datagram. The gateway may also notify
// the source host via the time exceeded message.
- return e.protocol.returnError(&icmpReasonTTLExceeded{}, pkt)
+ //
+ // We return the original error rather than the result of returning
+ // the ICMP packet because the original error is more relevant to
+ // the caller.
+ _ = e.protocol.returnError(&icmpReasonTTLExceeded{}, pkt)
+ return &ip.ErrTTLExceeded{}
}
if opts := h.Options(); len(opts) != 0 {
@@ -635,10 +680,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
pointer: optProblem.Pointer,
forwarding: true,
}, pkt)
- e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
- e.stats.ip.MalformedPacketsReceived.Increment()
}
- return nil // option problems are not reported locally.
+ return &ip.ErrParameterProblem{}
}
copied := copy(opts, newOpts)
if copied != len(newOpts) {
@@ -655,18 +698,44 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
}
}
+ stk := e.protocol.stack
+
// Check if the destination is owned by the stack.
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
+ inNicName := stk.FindNICNameFromID(e.nic.ID())
+ outNicName := stk.FindNICNameFromID(ep.nic.ID())
+ if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ // iptables is telling us to drop the packet.
+ e.stats.ip.IPTablesForwardDropped.Increment()
+ return nil
+ }
+
ep.handleValidatedPacket(h, pkt)
return nil
}
- r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- return err
+ r, err := stk.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
+ switch err.(type) {
+ case nil:
+ case *tcpip.ErrNoRoute, *tcpip.ErrNetworkUnreachable:
+ // We return the original error rather than the result of returning
+ // the ICMP packet because the original error is more relevant to
+ // the caller.
+ _ = e.protocol.returnError(&icmpReasonNetworkUnreachable{}, pkt)
+ return &ip.ErrNoRoute{}
+ default:
+ return &ip.ErrOther{Err: err}
}
defer r.Release()
+ inNicName := stk.FindNICNameFromID(e.nic.ID())
+ outNicName := stk.FindNICNameFromID(r.NICID())
+ if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ // iptables is telling us to drop the packet.
+ e.stats.ip.IPTablesForwardDropped.Increment()
+ return nil
+ }
+
// We need to do a deep copy of the IP packet because
// WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
// not own it.
@@ -680,10 +749,28 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
// spent, the field must be decremented by 1.
newHdr.SetTTL(ttl - 1)
- return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: buffer.View(newHdr).ToVectorisedView(),
- }))
+ IsForwardedPacket: true,
+ })); err.(type) {
+ case nil:
+ return nil
+ case *tcpip.ErrMessageTooLong:
+ // As per RFC 792, page 4, Destination Unreachable:
+ //
+ // Another case is when a datagram must be fragmented to be forwarded by a
+ // gateway yet the Don't Fragment flag is on. In this case the gateway must
+ // discard the datagram and may return a destination unreachable message.
+ //
+ // WriteHeaderIncludedPacket checks for the presence of the Don't Fragment bit
+ // while sending the packet and returns this error iff fragmentation is
+ // necessary and the bit is also set.
+ _ = e.protocol.returnError(&icmpReasonFragmentationNeeded{}, pkt)
+ return &ip.ErrMessageTooLong{}
+ default:
+ return &ip.ErrOther{Err: err}
+ }
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
@@ -764,6 +851,7 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum
func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) {
pkt.NICID = e.nic.ID()
stats := e.stats
+ stats.ip.ValidPacketsReceived.Increment()
srcAddr := h.SourceAddress()
dstAddr := h.DestinationAddress()
@@ -794,11 +882,30 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer)
addressEndpoint.DecRef()
pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast
} else if !e.IsInGroup(dstAddr) {
- if !e.protocol.Forwarding() {
+ if !e.Forwarding() {
stats.ip.InvalidDestinationAddressesReceived.Increment()
return
}
- _ = e.forwardPacket(pkt)
+ switch err := e.forwardPacket(pkt); err.(type) {
+ case nil:
+ return
+ case *ip.ErrLinkLocalSourceAddress:
+ stats.ip.Forwarding.LinkLocalSource.Increment()
+ case *ip.ErrLinkLocalDestinationAddress:
+ stats.ip.Forwarding.LinkLocalDestination.Increment()
+ case *ip.ErrTTLExceeded:
+ stats.ip.Forwarding.ExhaustedTTL.Increment()
+ case *ip.ErrNoRoute:
+ stats.ip.Forwarding.Unrouteable.Increment()
+ case *ip.ErrParameterProblem:
+ e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
+ stats.ip.MalformedPacketsReceived.Increment()
+ case *ip.ErrMessageTooLong:
+ stats.ip.Forwarding.PacketTooBig.Increment()
+ default:
+ panic(fmt.Sprintf("unexpected error %s while trying to forward packet: %#v", err, pkt))
+ }
+ stats.ip.Forwarding.Errors.Increment()
return
}
@@ -955,8 +1062,8 @@ func (e *endpoint) Close() {
// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.mu.RLock()
+ defer e.mu.RUnlock()
ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
if err == nil {
@@ -967,8 +1074,8 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p
// RemovePermanentAddress implements stack.AddressableEndpoint.
func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.mu.RLock()
+ defer e.mu.RUnlock()
return e.mu.addressableEndpointState.RemovePermanentAddress(addr)
}
@@ -981,8 +1088,8 @@ func (e *endpoint) MainAddress() tcpip.AddressWithPrefix {
// AcquireAssignedAddress implements stack.AddressableEndpoint.
func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint {
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.mu.RLock()
+ defer e.mu.RUnlock()
loopback := e.nic.IsLoopback()
return e.mu.addressableEndpointState.AcquireAssignedAddressOrMatching(localAddr, func(addressEndpoint stack.AddressEndpoint) bool {
@@ -1067,7 +1174,6 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats {
return &e.stats.localStats
}
-var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
var _ fragmentation.TimeoutHandler = (*protocol)(nil)
@@ -1088,12 +1194,6 @@ type protocol struct {
// Must be accessed using atomic operations.
defaultTTL uint32
- // forwarding is set to 1 when the protocol has forwarding enabled and 0
- // when it is disabled.
- //
- // Must be accessed using atomic operations.
- forwarding uint32
-
ids []uint32
hashIV uint32
@@ -1206,35 +1306,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true
}
-// Forwarding implements stack.ForwardingNetworkProtocol.
-func (p *protocol) Forwarding() bool {
- return uint8(atomic.LoadUint32(&p.forwarding)) == 1
-}
-
-// setForwarding sets the forwarding status for the protocol.
-//
-// Returns true if the forwarding status was updated.
-func (p *protocol) setForwarding(v bool) bool {
- if v {
- return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */)
- }
- return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */)
-}
-
-// SetForwarding implements stack.ForwardingNetworkProtocol.
-func (p *protocol) SetForwarding(v bool) {
- p.mu.Lock()
- defer p.mu.Unlock()
-
- if !p.setForwarding(v) {
- return
- }
-
- for _, ep := range p.mu.eps {
- ep.transitionForwarding(v)
- }
-}
-
// calculateNetworkMTU calculates the network-layer payload MTU based on the
// link-layer payload mtu.
func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) {
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 7d413c455..da9cc0ae8 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -112,67 +112,103 @@ func TestExcludeBroadcast(t *testing.T) {
})
}
+type forwardedPacket struct {
+ fragments []fragmentInfo
+}
+
func TestForwarding(t *testing.T) {
const (
- nicID1 = 1
- nicID2 = 2
+ incomingNICID = 1
+ outgoingNICID = 2
randomSequence = 123
randomIdent = 42
randomTimeOffset = 0x10203040
)
- ipv4Addr1 := tcpip.AddressWithPrefix{
+ incomingIPv4Addr := tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()),
PrefixLen: 8,
}
- ipv4Addr2 := tcpip.AddressWithPrefix{
+ outgoingIPv4Addr := tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()),
PrefixLen: 8,
}
- remoteIPv4Addr1 := tcpip.Address(net.ParseIP("10.0.0.2").To4())
- remoteIPv4Addr2 := tcpip.Address(net.ParseIP("11.0.0.2").To4())
+ outgoingLinkAddr := tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
+ remoteIPv4Addr1 := tcptestutil.MustParse4("10.0.0.2")
+ remoteIPv4Addr2 := tcptestutil.MustParse4("11.0.0.2")
+ unreachableIPv4Addr := tcptestutil.MustParse4("12.0.0.2")
+ multicastIPv4Addr := tcptestutil.MustParse4("225.0.0.0")
+ linkLocalIPv4Addr := tcptestutil.MustParse4("169.254.0.0")
tests := []struct {
- name string
- TTL uint8
- expectErrorICMP bool
- options header.IPv4Options
- forwardedOptions header.IPv4Options
- icmpType header.ICMPv4Type
- icmpCode header.ICMPv4Code
+ name string
+ TTL uint8
+ sourceAddr tcpip.Address
+ destAddr tcpip.Address
+ expectErrorICMP bool
+ ipFlags uint8
+ mtu uint32
+ payloadLength int
+ options header.IPv4Options
+ forwardedOptions header.IPv4Options
+ icmpType header.ICMPv4Type
+ icmpCode header.ICMPv4Code
+ expectPacketUnrouteableError bool
+ expectLinkLocalSourceError bool
+ expectLinkLocalDestError bool
+ expectPacketForwarded bool
+ expectedFragmentsForwarded []fragmentInfo
}{
{
name: "TTL of zero",
TTL: 0,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
expectErrorICMP: true,
icmpType: header.ICMPv4TimeExceeded,
icmpCode: header.ICMPv4TTLExceeded,
+ mtu: ipv4.MaxTotalSize,
},
{
- name: "TTL of one",
- TTL: 1,
- expectErrorICMP: false,
+ name: "TTL of one",
+ TTL: 1,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
+ expectPacketForwarded: true,
+ mtu: ipv4.MaxTotalSize,
},
{
- name: "TTL of two",
- TTL: 2,
- expectErrorICMP: false,
+ name: "TTL of two",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
+ expectPacketForwarded: true,
+ mtu: ipv4.MaxTotalSize,
},
{
- name: "Max TTL",
- TTL: math.MaxUint8,
- expectErrorICMP: false,
+ name: "Max TTL",
+ TTL: math.MaxUint8,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
+ expectPacketForwarded: true,
+ mtu: ipv4.MaxTotalSize,
},
{
- name: "four EOL options",
- TTL: 2,
- expectErrorICMP: false,
- options: header.IPv4Options{0, 0, 0, 0},
- forwardedOptions: header.IPv4Options{0, 0, 0, 0},
+ name: "four EOL options",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
+ expectPacketForwarded: true,
+ mtu: ipv4.MaxTotalSize,
+ options: header.IPv4Options{0, 0, 0, 0},
+ forwardedOptions: header.IPv4Options{0, 0, 0, 0},
},
{
- name: "TS type 1 full",
- TTL: 2,
+ name: "TS type 1 full",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
+ mtu: ipv4.MaxTotalSize,
options: header.IPv4Options{
68, 12, 13, 0xF1,
192, 168, 1, 12,
@@ -183,8 +219,11 @@ func TestForwarding(t *testing.T) {
icmpCode: header.ICMPv4UnusedCode,
},
{
- name: "TS type 0",
- TTL: 2,
+ name: "TS type 0",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
+ mtu: ipv4.MaxTotalSize,
options: header.IPv4Options{
68, 24, 21, 0x00,
1, 2, 3, 4,
@@ -201,10 +240,14 @@ func TestForwarding(t *testing.T) {
13, 14, 15, 16,
0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock
},
+ expectPacketForwarded: true,
},
{
- name: "end of options list",
- TTL: 2,
+ name: "end of options list",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
+ mtu: ipv4.MaxTotalSize,
options: header.IPv4Options{
68, 12, 13, 0x11,
192, 168, 1, 12,
@@ -220,11 +263,89 @@ func TestForwarding(t *testing.T) {
0, 0, 0, // 7 bytes unknown option removed.
0, 0, 0, 0,
},
+ expectPacketForwarded: true,
+ },
+ {
+ name: "Network unreachable",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: unreachableIPv4Addr,
+ expectErrorICMP: true,
+ mtu: ipv4.MaxTotalSize,
+ icmpType: header.ICMPv4DstUnreachable,
+ icmpCode: header.ICMPv4NetUnreachable,
+ expectPacketUnrouteableError: true,
+ },
+ {
+ name: "Multicast destination",
+ TTL: 2,
+ destAddr: multicastIPv4Addr,
+ expectPacketUnrouteableError: true,
+ },
+ {
+ name: "Link local destination",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: linkLocalIPv4Addr,
+ expectLinkLocalDestError: true,
+ },
+ {
+ name: "Link local source",
+ TTL: 2,
+ sourceAddr: linkLocalIPv4Addr,
+ destAddr: remoteIPv4Addr2,
+ expectLinkLocalSourceError: true,
+ },
+ {
+ name: "Fragmentation needed and DF set",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
+ ipFlags: header.IPv4FlagDontFragment,
+ // We've picked this MTU because it is:
+ //
+ // 1) Greater than the minimum MTU that IPv4 hosts are required to process
+ // (576 bytes). As per RFC 1812, Section 4.3.2.3:
+ //
+ // The ICMP datagram SHOULD contain as much of the original datagram as
+ // possible without the length of the ICMP datagram exceeding 576 bytes.
+ //
+ // Therefore, setting an MTU greater than 576 bytes ensures that we can fit a
+ // complete ICMP packet on the incoming endpoint (and make assertions about
+ // it).
+ //
+ // 2) Less than `ipv4.MaxTotalSize`, which lets us build an IPv4 packet whose
+ // size exceeds the MTU.
+ mtu: 1000,
+ payloadLength: 1004,
+ expectErrorICMP: true,
+ icmpType: header.ICMPv4DstUnreachable,
+ icmpCode: header.ICMPv4FragmentationNeeded,
+ },
+ {
+ name: "Fragmentation needed and DF not set",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
+ mtu: 1000,
+ payloadLength: 1004,
+ expectPacketForwarded: true,
+ // Combined, these fragments have length of 1012 octets, which is equal to
+ // the length of the payload (1004 octets), plus the length of the ICMP
+ // header (8 octets).
+ expectedFragmentsForwarded: []fragmentInfo{
+ // The first fragment has a length of the greatest multiple of 8 which is
+ // less than or equal to to `mtu - header.IPv4MinimumSize`.
+ {offset: 0, payloadSize: uint16(976), more: true},
+ // The next fragment holds the rest of the packet.
+ {offset: uint16(976), payloadSize: 36, more: false},
+ },
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
clock := faketime.NewManualClock()
+
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
@@ -236,46 +357,52 @@ func TestForwarding(t *testing.T) {
clock.Advance(time.Millisecond * randomTimeOffset)
// We expect at most a single packet in response to our ICMP Echo Request.
- e1 := channel.New(1, ipv4.MaxTotalSize, "")
- if err := s.CreateNIC(nicID1, e1); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ incomingEndpoint := channel.New(1, test.mtu, "")
+ if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err)
}
- ipv4ProtoAddr1 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr1}
- if err := s.AddProtocolAddress(nicID1, ipv4ProtoAddr1); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv4ProtoAddr1, err)
+ incomingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: incomingIPv4Addr}
+ if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv4ProtoAddr, err)
}
- e2 := channel.New(1, ipv4.MaxTotalSize, "")
- if err := s.CreateNIC(nicID2, e2); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ expectedEmittedPacketCount := 1
+ if len(test.expectedFragmentsForwarded) > expectedEmittedPacketCount {
+ expectedEmittedPacketCount = len(test.expectedFragmentsForwarded)
+ }
+ outgoingEndpoint := channel.New(expectedEmittedPacketCount, test.mtu, outgoingLinkAddr)
+ if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err)
}
- ipv4ProtoAddr2 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr2}
- if err := s.AddProtocolAddress(nicID2, ipv4ProtoAddr2); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv4ProtoAddr2, err)
+ outgoingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: outgoingIPv4Addr}
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv4ProtoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
{
- Destination: ipv4Addr1.Subnet(),
- NIC: nicID1,
+ Destination: incomingIPv4Addr.Subnet(),
+ NIC: incomingNICID,
},
{
- Destination: ipv4Addr2.Subnet(),
- NIC: nicID2,
+ Destination: outgoingIPv4Addr.Subnet(),
+ NIC: outgoingNICID,
},
})
- if err := s.SetForwarding(header.IPv4ProtocolNumber, true); err != nil {
- t.Fatalf("SetForwarding(%d, true): %s", header.IPv4ProtocolNumber, err)
+ if err := s.SetForwardingDefaultAndAllNICs(header.IPv4ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", header.IPv4ProtocolNumber, err)
}
ipHeaderLength := header.IPv4MinimumSize + len(test.options)
if ipHeaderLength > header.IPv4MaximumHeaderSize {
t.Fatalf("got ipHeaderLength = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
}
- totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize)
- hdr := buffer.NewPrependable(int(totalLen))
- icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ icmpHeaderLength := header.ICMPv4MinimumSize
+ totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength
+ hdr := buffer.NewPrependable(totalLength)
+ hdr.Prepend(test.payloadLength)
+ icmp := header.ICMPv4(hdr.Prepend(icmpHeaderLength))
icmp.SetIdent(randomIdent)
icmp.SetSequence(randomSequence)
icmp.SetType(header.ICMPv4Echo)
@@ -284,11 +411,12 @@ func TestForwarding(t *testing.T) {
icmp.SetChecksum(^header.Checksum(icmp, 0))
ip := header.IPv4(hdr.Prepend(ipHeaderLength))
ip.Encode(&header.IPv4Fields{
- TotalLength: totalLen,
+ TotalLength: uint16(totalLength),
Protocol: uint8(header.ICMPv4ProtocolNumber),
TTL: test.TTL,
- SrcAddr: remoteIPv4Addr1,
- DstAddr: remoteIPv4Addr2,
+ SrcAddr: test.sourceAddr,
+ DstAddr: test.destAddr,
+ Flags: test.ipFlags,
})
if len(test.options) != 0 {
ip.SetHeaderLength(uint8(ipHeaderLength))
@@ -303,51 +431,122 @@ func TestForwarding(t *testing.T) {
requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
})
- e1.InjectInbound(header.IPv4ProtocolNumber, requestPkt)
+ requestPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
+ incomingEndpoint.InjectInbound(header.IPv4ProtocolNumber, requestPkt)
+
+ reply, ok := incomingEndpoint.Read()
if test.expectErrorICMP {
- reply, ok := e1.Read()
if !ok {
t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType)
}
+ // We expect the ICMP packet to contain as much of the original packet as
+ // possible up to a limit of 576 bytes, split between payload, IP header,
+ // and ICMP header.
+ expectedICMPPayloadLength := func() int {
+ maxICMPPacketLength := header.IPv4MinimumProcessableDatagramSize
+ maxICMPPayloadLength := maxICMPPacketLength - icmpHeaderLength - ipHeaderLength
+ if len(hdr.View()) > maxICMPPayloadLength {
+ return maxICMPPayloadLength
+ }
+ return len(hdr.View())
+ }
+
checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())),
- checker.SrcAddr(ipv4Addr1.Address),
- checker.DstAddr(remoteIPv4Addr1),
+ checker.SrcAddr(incomingIPv4Addr.Address),
+ checker.DstAddr(test.sourceAddr),
checker.TTL(ipv4.DefaultTTL),
checker.ICMPv4(
checker.ICMPv4Checksum(),
checker.ICMPv4Type(test.icmpType),
checker.ICMPv4Code(test.icmpCode),
- checker.ICMPv4Payload([]byte(hdr.View())),
+ checker.ICMPv4Payload([]byte(hdr.View()[0:expectedICMPPayloadLength()])),
),
)
+ } else if ok {
+ t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", reply)
+ }
- if n := e2.Drain(); n != 0 {
- t.Fatalf("got e2.Drain() = %d, want = 0", n)
+ if test.expectPacketForwarded {
+ if len(test.expectedFragmentsForwarded) != 0 {
+ fragmentedPackets := []*stack.PacketBuffer{}
+ for i := 0; i < len(test.expectedFragmentsForwarded); i++ {
+ reply, ok = outgoingEndpoint.Read()
+ if !ok {
+ t.Fatal("expected ICMP Echo fragment through outgoing NIC")
+ }
+ fragmentedPackets = append(fragmentedPackets, reply.Pkt)
+ }
+
+ // The forwarded packet's TTL will have been decremented.
+ ipHeader := header.IPv4(requestPkt.NetworkHeader().View())
+ ipHeader.SetTTL(ipHeader.TTL() - 1)
+
+ // Forwarded packets have available header bytes equalling the sum of the
+ // maximum IP header size and the maximum size allocated for link layer
+ // headers. In this case, no size is allocated for link layer headers.
+ expectedAvailableHeaderBytes := header.IPv4MaximumHeaderSize
+ if err := compareFragments(fragmentedPackets, requestPkt, uint32(test.mtu), test.expectedFragmentsForwarded, header.ICMPv4ProtocolNumber, true /* withIPHeader */, expectedAvailableHeaderBytes); err != nil {
+ t.Error(err)
+ }
+ } else {
+ reply, ok = outgoingEndpoint.Read()
+ if !ok {
+ t.Fatal("expected ICMP Echo packet through outgoing NIC")
+ }
+
+ checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.SrcAddr(test.sourceAddr),
+ checker.DstAddr(test.destAddr),
+ checker.TTL(test.TTL-1),
+ checker.IPv4Options(test.forwardedOptions),
+ checker.ICMPv4(
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Type(header.ICMPv4Echo),
+ checker.ICMPv4Code(header.ICMPv4UnusedCode),
+ checker.ICMPv4Payload(nil),
+ ),
+ )
}
} else {
- reply, ok := e2.Read()
- if !ok {
- t.Fatal("expected ICMP Echo packet through outgoing NIC")
+ if reply, ok = outgoingEndpoint.Read(); ok {
+ t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply)
+ }
+ }
+ boolToInt := func(val bool) uint64 {
+ if val {
+ return 1
}
+ return 0
+ }
- checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())),
- checker.SrcAddr(remoteIPv4Addr1),
- checker.DstAddr(remoteIPv4Addr2),
- checker.TTL(test.TTL-1),
- checker.IPv4Options(test.forwardedOptions),
- checker.ICMPv4(
- checker.ICMPv4Checksum(),
- checker.ICMPv4Type(header.ICMPv4Echo),
- checker.ICMPv4Code(header.ICMPv4UnusedCode),
- checker.ICMPv4Payload(nil),
- ),
- )
+ if got, want := s.Stats().IP.Forwarding.LinkLocalSource.Value(), boolToInt(test.expectLinkLocalSourceError); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.LinkLocalSource.Value() = %d, want = %d", got, want)
+ }
- if n := e1.Drain(); n != 0 {
- t.Fatalf("got e1.Drain() = %d, want = 0", n)
- }
+ if got, want := s.Stats().IP.Forwarding.LinkLocalDestination.Value(), boolToInt(test.expectLinkLocalDestError); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.LinkLocalDestination.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), boolToInt(test.icmpType == header.ICMPv4ParamProblem); got != want {
+ t.Errorf("got s.Stats().IP.MalformedPacketsReceived.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.ExhaustedTTL.Value(), boolToInt(test.TTL <= 0); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.ExhaustedTTL.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.Unrouteable.Value(), boolToInt(test.expectPacketUnrouteableError); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.Unrouteable.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.Errors.Value(), boolToInt(!test.expectPacketForwarded); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.Errors.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.PacketTooBig.Value(), boolToInt(test.icmpCode == header.ICMPv4FragmentationNeeded); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.PacketTooBig.Value() = %d, want = %d", got, want)
}
})
}
@@ -1170,13 +1369,25 @@ func TestIPv4Sanity(t *testing.T) {
}
}
-// comparePayloads compared the contents of all the packets against the contents
-// of the source packet.
-func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error {
+// compareFragments compares the contents of a set of fragmented packets against
+// the contents of a source packet.
+//
+// If withIPHeader is set to true, we will validate the fragmented packets' IP
+// headers against the source packet's IP header. If set to false, we validate
+// the fragmented packets' IP headers against each other.
+func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber, withIPHeader bool, expectedAvailableHeaderBytes int) error {
// Make a complete array of the sourcePacket packet.
- source := header.IPv4(packets[0].NetworkHeader().View())
+ var source header.IPv4
vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views())
- source = append(source, vv.ToView()...)
+
+ // If the packet to be fragmented contains an IPv4 header, use that header for
+ // validating fragment headers. Else, use the header of the first fragment.
+ if withIPHeader {
+ source = header.IPv4(vv.ToView())
+ } else {
+ source = header.IPv4(packets[0].NetworkHeader().View())
+ source = append(source, vv.ToView()...)
+ }
// Make a copy of the IP header, which will be modified in some fields to make
// an expected header.
@@ -1199,12 +1410,12 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB
if got := fragmentIPHeader.TransportProtocol(); got != proto {
return fmt.Errorf("fragment #%d: got fragmentIPHeader.TransportProtocol() = %d, want = %d", i, got, uint8(proto))
}
- if got := packet.AvailableHeaderBytes(); got != extraHeaderReserve {
- return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve)
- }
if got, want := packet.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber; got != want {
return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, got, want)
}
+ if got := packet.AvailableHeaderBytes(); got != expectedAvailableHeaderBytes {
+ return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, expectedAvailableHeaderBytes)
+ }
if got, want := fragmentIPHeader.CalculateChecksum(), uint16(0xffff); got != want {
return fmt.Errorf("fragment #%d: got ip.CalculateChecksum() = %#x, want = %#x", i, got, want)
}
@@ -1220,6 +1431,14 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB
sourceCopy.SetTotalLength(wantFragments[i].payloadSize + header.IPv4MinimumSize)
sourceCopy.SetChecksum(0)
sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum())
+
+ // If we are validating against the original IP header, we should exclude the
+ // ID field, which will only be set fo fragmented packets.
+ if withIPHeader {
+ fragmentIPHeader.SetID(0)
+ fragmentIPHeader.SetChecksum(0)
+ fragmentIPHeader.SetChecksum(^fragmentIPHeader.CalculateChecksum())
+ }
if diff := cmp.Diff(fragmentIPHeader[:fragmentIPHeader.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]); diff != "" {
return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff)
}
@@ -1348,7 +1567,7 @@ func TestFragmentationWritePacket(t *testing.T) {
if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
}
- if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
+ if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber, false /* withIPHeader */, extraHeaderReserve); err != nil {
t.Error(err)
}
})
@@ -1429,7 +1648,7 @@ func TestFragmentationWritePackets(t *testing.T) {
}
fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore]
- if err := compareFragments(fragments, pkt, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
+ if err := compareFragments(fragments, pkt, ft.mtu, ft.wantFragments, tcp.ProtocolNumber, false /* withIPHeader */, extraHeaderReserve); err != nil {
t.Error(err)
}
})
diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go
index a637f9d50..d1f9e3cf5 100644
--- a/pkg/tcpip/network/ipv4/stats_test.go
+++ b/pkg/tcpip/network/ipv4/stats_test.go
@@ -19,8 +19,8 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
var _ stack.NetworkInterface = (*testInterface)(nil)
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index db998e83e..f99cbf8f3 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -45,6 +45,7 @@ go_test(
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/internal/testutil",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 1319db32b..307e1972d 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -181,10 +181,13 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe
return
}
+ // Keep needed information before trimming header.
+ p := hdr.TransportProtocol()
+ dstAddr := hdr.DestinationAddress()
+
// Skip the IP header, then handle the fragmentation header if there
// is one.
- pkt.Data().TrimFront(header.IPv6MinimumSize)
- p := hdr.TransportProtocol()
+ pkt.Data().DeleteFront(header.IPv6MinimumSize)
if p == header.IPv6FragmentHeader {
f, ok := pkt.Data().PullUp(header.IPv6FragmentHeaderSize)
if !ok {
@@ -196,14 +199,14 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe
// because they don't have the transport headers.
return
}
+ p = fragHdr.TransportProtocol()
// Skip fragmentation header and find out the actual protocol
// number.
- pkt.Data().TrimFront(header.IPv6FragmentHeaderSize)
- p = fragHdr.TransportProtocol()
+ pkt.Data().DeleteFront(header.IPv6FragmentHeaderSize)
}
- e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, transErr, pkt)
+ e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, transErr, pkt)
}
// getLinkAddrOption searches NDP options for a given link address option using
@@ -327,11 +330,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
received.invalid.Increment()
return
}
- pkt.Data().TrimFront(header.ICMPv6PacketTooBigMinimumSize)
networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize)
if err != nil {
networkMTU = 0
}
+ pkt.Data().DeleteFront(header.ICMPv6PacketTooBigMinimumSize)
e.handleControl(&icmpv6PacketTooBigSockError{mtu: networkMTU}, pkt)
case header.ICMPv6DstUnreachable:
@@ -341,8 +344,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
received.invalid.Increment()
return
}
- pkt.Data().TrimFront(header.ICMPv6DstUnreachableMinimumSize)
- switch header.ICMPv6(hdr).Code() {
+ code := header.ICMPv6(hdr).Code()
+ pkt.Data().DeleteFront(header.ICMPv6DstUnreachableMinimumSize)
+ switch code {
case header.ICMPv6NetworkUnreachable:
e.handleControl(&icmpv6DestinationNetworkUnreachableSockError{}, pkt)
case header.ICMPv6PortUnreachable:
@@ -741,11 +745,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
return
}
- stack := e.protocol.stack
-
- // Is the networking stack operating as a router?
- if !stack.Forwarding(ProtocolNumber) {
- // ... No, silently drop the packet.
+ if !e.Forwarding() {
received.routerOnlyPacketsDroppedByHost.Increment()
return
}
@@ -951,6 +951,19 @@ func (*endpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo
// icmpReason is a marker interface for IPv6 specific ICMP errors.
type icmpReason interface {
isICMPReason()
+ // isForwarding indicates whether or not the error arose while attempting to
+ // forward a packet.
+ isForwarding() bool
+ // respondToMulticast indicates whether this error falls under the exception
+ // outlined by RFC 4443 section 2.4 point e.3 exception 2:
+ //
+ // (e.3) A packet destined to an IPv6 multicast address. (There are two
+ // exceptions to this rule: (1) the Packet Too Big Message (Section 3.2) to
+ // allow Path MTU discovery to work for IPv6 multicast, and (2) the Parameter
+ // Problem Message, Code 2 (Section 3.4) reporting an unrecognized IPv6
+ // option (see Section 4.2 of [IPv6]) that has the Option Type highest-
+ // order two bits set to 10).
+ respondsToMulticast() bool
}
// icmpReasonParameterProblem is an error during processing of extension headers
@@ -958,18 +971,6 @@ type icmpReason interface {
type icmpReasonParameterProblem struct {
code header.ICMPv6Code
- // respondToMulticast indicates that we are sending a packet that falls under
- // the exception outlined by RFC 4443 section 2.4 point e.3 exception 2:
- //
- // (e.3) A packet destined to an IPv6 multicast address. (There are
- // two exceptions to this rule: (1) the Packet Too Big Message
- // (Section 3.2) to allow Path MTU discovery to work for IPv6
- // multicast, and (2) the Parameter Problem Message, Code 2
- // (Section 3.4) reporting an unrecognized IPv6 option (see
- // Section 4.2 of [IPv6]) that has the Option Type highest-
- // order two bits set to 10).
- respondToMulticast bool
-
// pointer is defined in the RFC 4443 setion 3.4 which reads:
//
// Pointer Identifies the octet offset within the invoking packet
@@ -979,9 +980,20 @@ type icmpReasonParameterProblem struct {
// packet if the field in error is beyond what can fit
// in the maximum size of an ICMPv6 error message.
pointer uint32
+
+ forwarding bool
+
+ respondToMulticast bool
}
func (*icmpReasonParameterProblem) isICMPReason() {}
+func (p *icmpReasonParameterProblem) isForwarding() bool {
+ return p.forwarding
+}
+
+func (p *icmpReasonParameterProblem) respondsToMulticast() bool {
+ return p.respondToMulticast
+}
// icmpReasonPortUnreachable is an error where the transport protocol has no
// listener and no alternative means to inform the sender.
@@ -989,12 +1001,76 @@ type icmpReasonPortUnreachable struct{}
func (*icmpReasonPortUnreachable) isICMPReason() {}
+func (*icmpReasonPortUnreachable) isForwarding() bool {
+ return false
+}
+
+func (*icmpReasonPortUnreachable) respondsToMulticast() bool {
+ return false
+}
+
+// icmpReasonNetUnreachable is an error where no route can be found to the
+// network of the final destination.
+type icmpReasonNetUnreachable struct{}
+
+func (*icmpReasonNetUnreachable) isICMPReason() {}
+
+func (*icmpReasonNetUnreachable) isForwarding() bool {
+ // If we hit a Network Unreachable error, then we also know we are
+ // operating as a router. As per RFC 4443 section 3.1:
+ //
+ // If the reason for the failure to deliver is lack of a matching
+ // entry in the forwarding node's routing table, the Code field is
+ // set to 0 (Network Unreachable).
+ return true
+}
+
+func (*icmpReasonNetUnreachable) respondsToMulticast() bool {
+ return false
+}
+
+// icmpReasonFragmentationNeeded is an error where a packet is to big to be sent
+// out through the outgoing MTU, as per RFC 4443 page 9, Packet Too Big Message.
+type icmpReasonPacketTooBig struct{}
+
+func (*icmpReasonPacketTooBig) isICMPReason() {}
+
+func (*icmpReasonPacketTooBig) isForwarding() bool {
+ // If we hit a Packet Too Big error, then we know we are operating as a router.
+ // As per RFC 4443 section 3.2:
+ //
+ // A Packet Too Big MUST be sent by a router in response to a packet that it
+ // cannot forward because the packet is larger than the MTU of the outgoing
+ // link.
+ return true
+}
+
+func (*icmpReasonPacketTooBig) respondsToMulticast() bool {
+ return true
+}
+
// icmpReasonHopLimitExceeded is an error where a packet's hop limit exceeded in
// transit to its final destination, as per RFC 4443 section 3.3.
type icmpReasonHopLimitExceeded struct{}
func (*icmpReasonHopLimitExceeded) isICMPReason() {}
+func (*icmpReasonHopLimitExceeded) isForwarding() bool {
+ // If we hit a Hop Limit Exceeded error, then we know we are operating
+ // as a router. As per RFC 4443 section 3.3:
+ //
+ // If a router receives a packet with a Hop Limit of zero, or if a
+ // router decrements a packet's Hop Limit to zero, it MUST discard
+ // the packet and originate an ICMPv6 Time Exceeded message with Code
+ // 0 to the source of the packet. This indicates either a routing
+ // loop or too small an initial Hop Limit value.
+ return true
+}
+
+func (*icmpReasonHopLimitExceeded) respondsToMulticast() bool {
+ return false
+}
+
// icmpReasonReassemblyTimeout is an error where insufficient fragments are
// received to complete reassembly of a packet within a configured time after
// the reception of the first-arriving fragment of that packet.
@@ -1002,6 +1078,14 @@ type icmpReasonReassemblyTimeout struct{}
func (*icmpReasonReassemblyTimeout) isICMPReason() {}
+func (*icmpReasonReassemblyTimeout) isForwarding() bool {
+ return false
+}
+
+func (*icmpReasonReassemblyTimeout) respondsToMulticast() bool {
+ return false
+}
+
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv6 and sends it.
func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip.Error {
@@ -1030,25 +1114,12 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
// Section 4.2 of [IPv6]) that has the Option Type highest-
// order two bits set to 10).
//
- var allowResponseToMulticast bool
- if reason, ok := reason.(*icmpReasonParameterProblem); ok {
- allowResponseToMulticast = reason.respondToMulticast
- }
-
+ allowResponseToMulticast := reason.respondsToMulticast()
isOrigDstMulticast := header.IsV6MulticastAddress(origIPHdrDst)
if (!allowResponseToMulticast && isOrigDstMulticast) || origIPHdrSrc == header.IPv6Any {
return nil
}
- // If we hit a Hop Limit Exceeded error, then we know we are operating as a
- // router. As per RFC 4443 section 3.3:
- //
- // If a router receives a packet with a Hop Limit of zero, or if a
- // router decrements a packet's Hop Limit to zero, it MUST discard the
- // packet and originate an ICMPv6 Time Exceeded message with Code 0 to
- // the source of the packet. This indicates either a routing loop or
- // too small an initial Hop Limit value.
- //
// If we are operating as a router, do not use the packet's destination
// address as the response's source address as we should not own the
// destination address of a packet we are forwarding.
@@ -1058,7 +1129,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
// packet as "multicast addresses must not be used as source addresses in IPv6
// packets", as per RFC 4291 section 2.7.
localAddr := origIPHdrDst
- if _, ok := reason.(*icmpReasonHopLimitExceeded); ok || isOrigDstMulticast {
+ if reason.isForwarding() || isOrigDstMulticast {
localAddr = ""
}
// Even if we were able to receive a packet from some remote, we may not have
@@ -1147,6 +1218,14 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
icmpHdr.SetType(header.ICMPv6DstUnreachable)
icmpHdr.SetCode(header.ICMPv6PortUnreachable)
counter = sent.dstUnreachable
+ case *icmpReasonNetUnreachable:
+ icmpHdr.SetType(header.ICMPv6DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv6NetworkUnreachable)
+ counter = sent.dstUnreachable
+ case *icmpReasonPacketTooBig:
+ icmpHdr.SetType(header.ICMPv6PacketTooBig)
+ icmpHdr.SetCode(header.ICMPv6UnusedCode)
+ counter = sent.packetTooBig
case *icmpReasonHopLimitExceeded:
icmpHdr.SetType(header.ICMPv6TimeExceeded)
icmpHdr.SetCode(header.ICMPv6HopLimitExceeded)
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index e457be3cf..040cd4bc8 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -673,8 +673,9 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
if isRouter {
- // Enabling forwarding makes the stack act as a router.
- s.SetForwarding(ProtocolNumber, true)
+ if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err)
+ }
}
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(_, _) = %s", err)
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index f7510c243..95e11ac51 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -63,6 +63,11 @@ const (
buckets = 2048
)
+const (
+ forwardingDisabled = 0
+ forwardingEnabled = 1
+)
+
// policyTable is the default policy table defined in RFC 6724 section 2.1.
//
// A more human-readable version:
@@ -168,6 +173,7 @@ func getLabel(addr tcpip.Address) uint8 {
var _ stack.DuplicateAddressDetector = (*endpoint)(nil)
var _ stack.LinkAddressResolver = (*endpoint)(nil)
var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil)
+var _ stack.ForwardingNetworkEndpoint = (*endpoint)(nil)
var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
var _ stack.AddressableEndpoint = (*endpoint)(nil)
var _ stack.NetworkEndpoint = (*endpoint)(nil)
@@ -187,6 +193,12 @@ type endpoint struct {
// Must be accessed using atomic operations.
enabled uint32
+ // forwarding is set to forwardingEnabled when the endpoint has forwarding
+ // enabled and forwardingDisabled when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ forwarding uint32
+
mu struct {
sync.RWMutex
@@ -405,27 +417,39 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr t
}
}
-// transitionForwarding transitions the endpoint's forwarding status to
-// forwarding.
+// Forwarding implements stack.ForwardingNetworkEndpoint.
+func (e *endpoint) Forwarding() bool {
+ return atomic.LoadUint32(&e.forwarding) == forwardingEnabled
+}
+
+// setForwarding sets the forwarding status for the endpoint.
//
-// Must only be called when the forwarding status changes.
-func (e *endpoint) transitionForwarding(forwarding bool) {
+// Returns true if the forwarding status was updated.
+func (e *endpoint) setForwarding(v bool) bool {
+ forwarding := uint32(forwardingDisabled)
+ if v {
+ forwarding = forwardingEnabled
+ }
+
+ return atomic.SwapUint32(&e.forwarding, forwarding) != forwarding
+}
+
+// SetForwarding implements stack.ForwardingNetworkEndpoint.
+func (e *endpoint) SetForwarding(forwarding bool) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if !e.setForwarding(forwarding) {
+ return
+ }
+
allRoutersGroups := [...]tcpip.Address{
header.IPv6AllRoutersInterfaceLocalMulticastAddress,
header.IPv6AllRoutersLinkLocalMulticastAddress,
header.IPv6AllRoutersSiteLocalMulticastAddress,
}
- e.mu.Lock()
- defer e.mu.Unlock()
-
if forwarding {
- // When transitioning into an IPv6 router, host-only state (NDP discovered
- // routers, discovered on-link prefixes, and auto-generated addresses) is
- // cleaned up/invalidated and NDP router solicitations are stopped.
- e.mu.ndp.stopSolicitingRouters()
- e.mu.ndp.cleanupState(true /* hostOnly */)
-
// As per RFC 4291 section 2.8:
//
// A router is required to recognize all addresses that a host is
@@ -449,28 +473,19 @@ func (e *endpoint) transitionForwarding(forwarding bool) {
panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", g, err))
}
}
-
- return
- }
-
- for _, g := range allRoutersGroups {
- switch err := e.leaveGroupLocked(g).(type) {
- case nil:
- case *tcpip.ErrBadLocalAddress:
- // The endpoint may have already left the multicast group.
- default:
- panic(fmt.Sprintf("e.leaveGroupLocked(%s): %s", g, err))
+ } else {
+ for _, g := range allRoutersGroups {
+ switch err := e.leaveGroupLocked(g).(type) {
+ case nil:
+ case *tcpip.ErrBadLocalAddress:
+ // The endpoint may have already left the multicast group.
+ default:
+ panic(fmt.Sprintf("e.leaveGroupLocked(%s): %s", g, err))
+ }
}
}
- // When transitioning into an IPv6 host, NDP router solicitations are
- // started if the endpoint is enabled.
- //
- // If the endpoint is not currently enabled, routers will be solicited when
- // the endpoint becomes enabled (if it is still a host).
- if e.Enabled() {
- e.mu.ndp.startSolicitingRouters()
- }
+ e.mu.ndp.forwardingChanged(forwarding)
}
// Enable implements stack.NetworkEndpoint.
@@ -552,17 +567,7 @@ func (e *endpoint) Enable() tcpip.Error {
e.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime)
}
- // If we are operating as a router, then do not solicit routers since we
- // won't process the RAs anyway.
- //
- // Routers do not process Router Advertisements (RA) the same way a host
- // does. That is, routers do not learn from RAs (e.g. on-link prefixes
- // and default routers). Therefore, soliciting RAs from other routers on
- // a link is unnecessary for routers.
- if !e.protocol.Forwarding() {
- e.mu.ndp.startSolicitingRouters()
- }
-
+ e.mu.ndp.startSolicitingRouters()
return nil
}
@@ -613,7 +618,7 @@ func (e *endpoint) disableLocked() {
return true
})
- e.mu.ndp.cleanupState(false /* hostOnly */)
+ e.mu.ndp.cleanupState()
// The endpoint may have already left the multicast group.
switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress).(type) {
@@ -786,6 +791,12 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol
}
if packetMustBeFragmented(pkt, networkMTU) {
+ if pkt.NetworkPacketInfo.IsForwardedPacket {
+ // As per RFC 2460, section 4.5:
+ // Unlike IPv4, fragmentation in IPv6 is performed only by source nodes,
+ // not by routers along a packet's delivery path.
+ return &tcpip.ErrMessageTooLong{}
+ }
sent, remain, err := e.handleFragments(r, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error {
// TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
// fragment one by one using WritePacket() (current strategy) or if we
@@ -928,16 +939,19 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
}
// forwardPacket attempts to forward a packet to its final destination.
-func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
+func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
h := header.IPv6(pkt.NetworkHeader().View())
dstAddr := h.DestinationAddress()
- if header.IsV6LinkLocalUnicastAddress(h.SourceAddress()) || header.IsV6LinkLocalUnicastAddress(dstAddr) || header.IsV6LinkLocalMulticastAddress(dstAddr) {
- // As per RFC 4291 section 2.5.6,
- //
- // Routers must not forward any packets with Link-Local source or
- // destination addresses to other links.
- return nil
+ // As per RFC 4291 section 2.5.6,
+ //
+ // Routers must not forward any packets with Link-Local source or
+ // destination addresses to other links.
+ if header.IsV6LinkLocalUnicastAddress(h.SourceAddress()) {
+ return &ip.ErrLinkLocalSourceAddress{}
+ }
+ if header.IsV6LinkLocalUnicastAddress(dstAddr) || header.IsV6LinkLocalMulticastAddress(dstAddr) {
+ return &ip.ErrLinkLocalDestinationAddress{}
}
hopLimit := h.HopLimit()
@@ -949,21 +963,56 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
// packet and originate an ICMPv6 Time Exceeded message with Code 0 to
// the source of the packet. This indicates either a routing loop or
// too small an initial Hop Limit value.
- return e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt)
+ //
+ // We return the original error rather than the result of returning
+ // the ICMP packet because the original error is more relevant to
+ // the caller.
+ _ = e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt)
+ return &ip.ErrTTLExceeded{}
}
+ stk := e.protocol.stack
+
// Check if the destination is owned by the stack.
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
+ inNicName := stk.FindNICNameFromID(e.nic.ID())
+ outNicName := stk.FindNICNameFromID(ep.nic.ID())
+ if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ // iptables is telling us to drop the packet.
+ e.stats.ip.IPTablesForwardDropped.Increment()
+ return nil
+ }
+
ep.handleValidatedPacket(h, pkt)
return nil
}
- r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- return err
+ // Check extension headers for any errors requiring action during forwarding.
+ if err := e.processExtensionHeaders(h, pkt, true /* forwarding */); err != nil {
+ return &ip.ErrParameterProblem{}
+ }
+
+ r, err := stk.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
+ switch err.(type) {
+ case nil:
+ case *tcpip.ErrNoRoute, *tcpip.ErrNetworkUnreachable:
+ // We return the original error rather than the result of returning the
+ // ICMP packet because the original error is more relevant to the caller.
+ _ = e.protocol.returnError(&icmpReasonNetUnreachable{}, pkt)
+ return &ip.ErrNoRoute{}
+ default:
+ return &ip.ErrOther{Err: err}
}
defer r.Release()
+ inNicName := stk.FindNICNameFromID(e.nic.ID())
+ outNicName := stk.FindNICNameFromID(r.NICID())
+ if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ // iptables is telling us to drop the packet.
+ e.stats.ip.IPTablesForwardDropped.Increment()
+ return nil
+ }
+
// We need to do a deep copy of the IP packet because
// WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
// not own it.
@@ -975,10 +1024,23 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
// each node that forwards the packet.
newHdr.SetHopLimit(hopLimit - 1)
- return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: buffer.View(newHdr).ToVectorisedView(),
- }))
+ IsForwardedPacket: true,
+ })); err.(type) {
+ case nil:
+ return nil
+ case *tcpip.ErrMessageTooLong:
+ // As per RFC 4443, section 3.2:
+ // A Packet Too Big MUST be sent by a router in response to a packet that
+ // it cannot forward because the packet is larger than the MTU of the
+ // outgoing link.
+ _ = e.protocol.returnError(&icmpReasonPacketTooBig{}, pkt)
+ return &ip.ErrMessageTooLong{}
+ default:
+ return &ip.ErrOther{Err: err}
+ }
}
// HandlePacket is called by the link layer when new ipv6 packets arrive for
@@ -1059,6 +1121,8 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum
func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) {
pkt.NICID = e.nic.ID()
stats := e.stats.ip
+ stats.ValidPacketsReceived.Increment()
+
srcAddr := h.SourceAddress()
dstAddr := h.DestinationAddress()
@@ -1075,15 +1139,54 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil {
addressEndpoint.DecRef()
} else if !e.IsInGroup(dstAddr) {
- if !e.protocol.Forwarding() {
+ if !e.Forwarding() {
stats.InvalidDestinationAddressesReceived.Increment()
return
}
+ switch err := e.forwardPacket(pkt); err.(type) {
+ case nil:
+ return
+ case *ip.ErrLinkLocalSourceAddress:
+ e.stats.ip.Forwarding.LinkLocalSource.Increment()
+ case *ip.ErrLinkLocalDestinationAddress:
+ e.stats.ip.Forwarding.LinkLocalDestination.Increment()
+ case *ip.ErrTTLExceeded:
+ e.stats.ip.Forwarding.ExhaustedTTL.Increment()
+ case *ip.ErrNoRoute:
+ e.stats.ip.Forwarding.Unrouteable.Increment()
+ case *ip.ErrParameterProblem:
+ e.stats.ip.Forwarding.ExtensionHeaderProblem.Increment()
+ case *ip.ErrMessageTooLong:
+ e.stats.ip.Forwarding.PacketTooBig.Increment()
+ default:
+ panic(fmt.Sprintf("unexpected error %s while trying to forward packet: %#v", err, pkt))
+ }
+ e.stats.ip.Forwarding.Errors.Increment()
+ return
+ }
- _ = e.forwardPacket(pkt)
+ // iptables filtering. All packets that reach here are intended for
+ // this machine and need not be forwarded.
+ inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok {
+ // iptables is telling us to drop the packet.
+ stats.IPTablesInputDropped.Increment()
return
}
+ // Any returned error is only useful for terminating execution early, but
+ // we have nothing left to do, so we can drop it.
+ _ = e.processExtensionHeaders(h, pkt, false /* forwarding */)
+}
+
+// processExtensionHeaders processes the extension headers in the given packet.
+// Returns an error if the processing of a header failed or if the packet should
+// be discarded.
+func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffer, forwarding bool) error {
+ stats := e.stats.ip
+ srcAddr := h.SourceAddress()
+ dstAddr := h.DestinationAddress()
+
// Create a VV to parse the packet. We don't plan to modify anything here.
// vv consists of:
// - Any IPv6 header bytes after the first 40 (i.e. extensions).
@@ -1094,15 +1197,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
vv.AppendViews(pkt.Data().Views())
it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv)
- // iptables filtering. All packets that reach here are intended for
- // this machine and need not be forwarded.
- inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok {
- // iptables is telling us to drop the packet.
- stats.IPTablesInputDropped.Increment()
- return
- }
-
var (
hasFragmentHeader bool
routerAlert *header.IPv6RouterAlertOption
@@ -1115,22 +1209,41 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
extHdr, done, err := it.Next()
if err != nil {
stats.MalformedPacketsReceived.Increment()
- return
+ return err
}
if done {
break
}
+ // As per RFC 8200, section 4:
+ //
+ // Extension headers (except for the Hop-by-Hop Options header) are
+ // not processed, inserted, or deleted by any node along a packet's
+ // delivery path until the packet reaches the node identified in the
+ // Destination Address field of the IPv6 header.
+ //
+ // Furthermore, as per RFC 8200 section 4.1, the Hop By Hop extension
+ // header is restricted to appear first in the list of extension headers.
+ //
+ // Therefore, we can immediately return once we hit any header other
+ // than the Hop-by-Hop header while forwarding a packet.
+ if forwarding {
+ if _, ok := extHdr.(header.IPv6HopByHopOptionsExtHdr); !ok {
+ return nil
+ }
+ }
+
switch extHdr := extHdr.(type) {
case header.IPv6HopByHopOptionsExtHdr:
// As per RFC 8200 section 4.1, the Hop By Hop extension header is
// restricted to appear immediately after an IPv6 fixed header.
if previousHeaderStart != 0 {
_ = e.protocol.returnError(&icmpReasonParameterProblem{
- code: header.ICMPv6UnknownHeader,
- pointer: previousHeaderStart,
+ code: header.ICMPv6UnknownHeader,
+ pointer: previousHeaderStart,
+ forwarding: forwarding,
}, pkt)
- return
+ return fmt.Errorf("found Hop-by-Hop header = %#v with non-zero previous header offset = %d", extHdr, previousHeaderStart)
}
optsIt := extHdr.Iter()
@@ -1139,7 +1252,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
opt, done, err := optsIt.Next()
if err != nil {
stats.MalformedPacketsReceived.Increment()
- return
+ return err
}
if done {
break
@@ -1154,7 +1267,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
// There MUST only be one option of this type, regardless of
// value, per Hop-by-Hop header.
stats.MalformedPacketsReceived.Increment()
- return
+ return fmt.Errorf("found multiple Router Alert options (%#v, %#v)", opt, routerAlert)
}
routerAlert = opt
stats.OptionRouterAlertReceived.Increment()
@@ -1162,10 +1275,10 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
switch opt.UnknownAction() {
case header.IPv6OptionUnknownActionSkip:
case header.IPv6OptionUnknownActionDiscard:
- return
+ return fmt.Errorf("found unknown Hop-by-Hop header option = %#v with discard action", opt)
case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
if header.IsV6MulticastAddress(dstAddr) {
- return
+ return fmt.Errorf("found unknown hop-by-hop header option = %#v with discard action", opt)
}
fallthrough
case header.IPv6OptionUnknownActionDiscardSendICMP:
@@ -1180,10 +1293,11 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
code: header.ICMPv6UnknownOption,
pointer: it.ParseOffset() + optsIt.OptionOffset(),
respondToMulticast: true,
+ forwarding: forwarding,
}, pkt)
- return
+ return fmt.Errorf("found unknown hop-by-hop header option = %#v with discard action", opt)
default:
- panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt))
+ panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %#v", opt))
}
}
}
@@ -1205,8 +1319,13 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
_ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6ErroneousHeader,
pointer: it.ParseOffset(),
+ // For the sake of consistency, we're using the value of `forwarding`
+ // here, even though it should always be false if we've reached this
+ // point. If `forwarding` is true here, we're executing undefined
+ // behavior no matter what.
+ forwarding: forwarding,
}, pkt)
- return
+ return fmt.Errorf("found unrecognized routing type with non-zero segments left in header = %#v", extHdr)
}
case header.IPv6FragmentExtHdr:
@@ -1241,7 +1360,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
if err != nil {
stats.MalformedPacketsReceived.Increment()
stats.MalformedFragmentsReceived.Increment()
- return
+ return err
}
if done {
break
@@ -1269,7 +1388,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
default:
stats.MalformedPacketsReceived.Increment()
stats.MalformedFragmentsReceived.Increment()
- return
+ return fmt.Errorf("known extension header = %#v present after fragment header in a non-initial fragment", lastHdr)
}
}
@@ -1278,7 +1397,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
// Drop the packet as it's marked as a fragment but has no payload.
stats.MalformedPacketsReceived.Increment()
stats.MalformedFragmentsReceived.Increment()
- return
+ return fmt.Errorf("fragment has no payload")
}
// As per RFC 2460 Section 4.5:
@@ -1296,7 +1415,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
code: header.ICMPv6ErroneousHeader,
pointer: header.IPv6PayloadLenOffset,
}, pkt)
- return
+ return fmt.Errorf("found fragment length = %d that is not a multiple of 8 octets", fragmentPayloadLen)
}
// The packet is a fragment, let's try to reassemble it.
@@ -1310,14 +1429,15 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
// Parameter Problem, Code 0, message should be sent to the source of
// the fragment, pointing to the Fragment Offset field of the fragment
// packet.
- if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize {
+ lengthAfterReassembly := int(start) + fragmentPayloadLen
+ if lengthAfterReassembly > header.IPv6MaximumPayloadSize {
stats.MalformedPacketsReceived.Increment()
stats.MalformedFragmentsReceived.Increment()
_ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6ErroneousHeader,
pointer: fragmentFieldOffset,
}, pkt)
- return
+ return fmt.Errorf("determined that reassembled packet length = %d would exceed allowed length = %d", lengthAfterReassembly, header.IPv6MaximumPayloadSize)
}
// Note that pkt doesn't have its transport header set after reassembly,
@@ -1339,7 +1459,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
if err != nil {
stats.MalformedPacketsReceived.Increment()
stats.MalformedFragmentsReceived.Increment()
- return
+ return err
}
if ready {
@@ -1361,7 +1481,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
opt, done, err := optsIt.Next()
if err != nil {
stats.MalformedPacketsReceived.Increment()
- return
+ return err
}
if done {
break
@@ -1372,10 +1492,10 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
switch opt.UnknownAction() {
case header.IPv6OptionUnknownActionSkip:
case header.IPv6OptionUnknownActionDiscard:
- return
+ return fmt.Errorf("found unknown destination header option = %#v with discard action", opt)
case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
if header.IsV6MulticastAddress(dstAddr) {
- return
+ return fmt.Errorf("found unknown destination header option %#v with discard action", opt)
}
fallthrough
case header.IPv6OptionUnknownActionDiscardSendICMP:
@@ -1392,9 +1512,9 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
pointer: it.ParseOffset() + optsIt.OptionOffset(),
respondToMulticast: true,
}, pkt)
- return
+ return fmt.Errorf("found unknown destination header option %#v with discard action", opt)
default:
- panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt))
+ panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %#v", opt))
}
}
@@ -1402,13 +1522,19 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
// If the last header in the payload isn't a known IPv6 extension header,
// handle it as if it is transport layer data.
+ // Calculate the number of octets parsed from data. We want to remove all
+ // the data except the unparsed portion located at the end, which its size
+ // is extHdr.Buf.Size().
+ trim := pkt.Data().Size() - extHdr.Buf.Size()
+
// For unfragmented packets, extHdr still contains the transport header.
// Get rid of it.
//
// For reassembled fragments, pkt.TransportHeader is unset, so this is a
// no-op and pkt.Data begins with the transport header.
- extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size())
- pkt.Data().Replace(extHdr.Buf)
+ trim += pkt.TransportHeader().View().Size()
+
+ pkt.Data().DeleteFront(trim)
stats.PacketsDelivered.Increment()
if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
@@ -1425,6 +1551,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
// transport protocol (e.g., UDP) has no listener, if that transport
// protocol has no alternative means to inform the sender.
_ = e.protocol.returnError(&icmpReasonPortUnreachable{}, pkt)
+ return fmt.Errorf("destination port unreachable")
case stack.TransportPacketProtocolUnreachable:
// As per RFC 8200 section 4. (page 7):
// Extension headers are numbered from IANA IP Protocol Numbers
@@ -1456,6 +1583,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
code: header.ICMPv6UnknownHeader,
pointer: prevHdrIDOffset,
}, pkt)
+ return fmt.Errorf("transport protocol unreachable")
default:
panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res))
}
@@ -1469,6 +1597,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
}
}
+ return nil
}
// Close cleans up resources associated with the endpoint.
@@ -1490,8 +1619,8 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
// TODO(b/169350103): add checks here after making sure we no longer receive
// an empty address.
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.mu.RLock()
+ defer e.mu.RUnlock()
return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated)
}
@@ -1532,8 +1661,8 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre
// RemovePermanentAddress implements stack.AddressableEndpoint.
func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.mu.RLock()
+ defer e.mu.RUnlock()
addressEndpoint := e.getAddressRLocked(addr)
if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() {
@@ -1610,8 +1739,8 @@ func (e *endpoint) MainAddress() tcpip.AddressWithPrefix {
// AcquireAssignedAddress implements stack.AddressableEndpoint.
func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint {
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.mu.RLock()
+ defer e.mu.RUnlock()
return e.acquireAddressOrCreateTempLocked(localAddr, allowTemp, tempPEB)
}
@@ -1833,7 +1962,6 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats {
return &e.stats.localStats
}
-var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
var _ fragmentation.TimeoutHandler = (*protocol)(nil)
@@ -1858,12 +1986,6 @@ type protocol struct {
// Must be accessed using atomic operations.
defaultTTL uint32
- // forwarding is set to 1 when the protocol has forwarding enabled and 0
- // when it is disabled.
- //
- // Must be accessed using atomic operations.
- forwarding uint32
-
fragmentation *fragmentation.Fragmentation
}
@@ -2038,35 +2160,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return proto, !fragMore && fragOffset == 0, true
}
-// Forwarding implements stack.ForwardingNetworkProtocol.
-func (p *protocol) Forwarding() bool {
- return uint8(atomic.LoadUint32(&p.forwarding)) == 1
-}
-
-// setForwarding sets the forwarding status for the protocol.
-//
-// Returns true if the forwarding status was updated.
-func (p *protocol) setForwarding(v bool) bool {
- if v {
- return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */)
- }
- return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */)
-}
-
-// SetForwarding implements stack.ForwardingNetworkProtocol.
-func (p *protocol) SetForwarding(v bool) {
- p.mu.Lock()
- defer p.mu.Unlock()
-
- if !p.setForwarding(v) {
- return
- }
-
- for _, ep := range p.mu.eps {
- ep.transitionForwarding(v)
- }
-}
-
// calculateNetworkMTU calculates the network-layer payload MTU based on the
// link-layer payload MTU and the length of every IPv6 header.
// Note that this is different than the Payload Length field of the IPv6 header,
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 40a793d6b..afc6c3547 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -31,8 +31,9 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
+ iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -2603,7 +2604,7 @@ func TestWriteStats(t *testing.T) {
t.Run(writer.name, func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets)
+ ep := iptestutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets)
rt := buildRoute(t, ep)
var pkts stack.PacketBufferList
for i := 0; i < nPackets; i++ {
@@ -2802,9 +2803,9 @@ func TestFragmentationWritePacket(t *testing.T) {
for _, ft := range fragmentationTests {
t.Run(ft.description, func(t *testing.T) {
- pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
source := pkt.Clone()
- ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
r := buildRoute(t, ep)
err := r.WritePacket(stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
@@ -2858,7 +2859,7 @@ func TestFragmentationWritePackets(t *testing.T) {
insertAfter: 1,
},
}
- tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber)
+ tinyPacket := iptestutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber)
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
@@ -2868,14 +2869,14 @@ func TestFragmentationWritePackets(t *testing.T) {
for i := 0; i < test.insertBefore; i++ {
pkts.PushBack(tinyPacket.Clone())
}
- pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
source := pkt
pkts.PushBack(pkt.Clone())
for i := 0; i < test.insertAfter; i++ {
pkts.PushBack(tinyPacket.Clone())
}
- ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
r := buildRoute(t, ep)
wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter
@@ -2980,8 +2981,8 @@ func TestFragmentationErrors(t *testing.T) {
for _, ft := range tests {
t.Run(ft.description, func(t *testing.T) {
- pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
- ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
+ pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ ep := iptestutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
r := buildRoute(t, ep)
err := r.WritePacket(stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
@@ -3003,52 +3004,289 @@ func TestFragmentationErrors(t *testing.T) {
func TestForwarding(t *testing.T) {
const (
- nicID1 = 1
- nicID2 = 2
+ incomingNICID = 1
+ outgoingNICID = 2
randomSequence = 123
randomIdent = 42
)
- ipv6Addr1 := tcpip.AddressWithPrefix{
+ incomingIPv6Addr := tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("10::1").To16()),
PrefixLen: 64,
}
- ipv6Addr2 := tcpip.AddressWithPrefix{
+ outgoingIPv6Addr := tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("11::1").To16()),
PrefixLen: 64,
}
+ multicastIPv6Addr := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("ff00::").To16()),
+ PrefixLen: 64,
+ }
+
remoteIPv6Addr1 := tcpip.Address(net.ParseIP("10::2").To16())
remoteIPv6Addr2 := tcpip.Address(net.ParseIP("11::2").To16())
+ unreachableIPv6Addr := tcpip.Address(net.ParseIP("12::2").To16())
+ linkLocalIPv6Addr := tcpip.Address(net.ParseIP("fe80::").To16())
tests := []struct {
- name string
- TTL uint8
- expectErrorICMP bool
+ name string
+ extHdr func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker)
+ TTL uint8
+ expectErrorICMP bool
+ expectPacketForwarded bool
+ payloadLength int
+ countUnrouteablePackets uint64
+ sourceAddr tcpip.Address
+ destAddr tcpip.Address
+ icmpType header.ICMPv6Type
+ icmpCode header.ICMPv6Code
+ expectPacketUnrouteableError bool
+ expectLinkLocalSourceError bool
+ expectLinkLocalDestError bool
+ expectExtensionHeaderError bool
}{
{
name: "TTL of zero",
TTL: 0,
expectErrorICMP: true,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ icmpType: header.ICMPv6TimeExceeded,
+ icmpCode: header.ICMPv6HopLimitExceeded,
},
{
name: "TTL of one",
TTL: 1,
expectErrorICMP: true,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ icmpType: header.ICMPv6TimeExceeded,
+ icmpCode: header.ICMPv6HopLimitExceeded,
},
{
- name: "TTL of two",
- TTL: 2,
- expectErrorICMP: false,
+ name: "TTL of two",
+ TTL: 2,
+ expectPacketForwarded: true,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ },
+ {
+ name: "TTL of three",
+ TTL: 3,
+ expectPacketForwarded: true,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ },
+ {
+ name: "Max TTL",
+ TTL: math.MaxUint8,
+ expectPacketForwarded: true,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ },
+ {
+ name: "Network unreachable",
+ TTL: 2,
+ expectErrorICMP: true,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: unreachableIPv6Addr,
+ icmpType: header.ICMPv6DstUnreachable,
+ icmpCode: header.ICMPv6NetworkUnreachable,
+ expectPacketUnrouteableError: true,
+ },
+ {
+ name: "Multicast destination",
+ TTL: 2,
+ countUnrouteablePackets: 1,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: multicastIPv6Addr.Address,
+ expectPacketForwarded: true,
+ },
+ {
+ name: "Link local destination",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: linkLocalIPv6Addr,
+ expectLinkLocalDestError: true,
+ },
+ {
+ name: "Link local source",
+ TTL: 2,
+ sourceAddr: linkLocalIPv6Addr,
+ destAddr: remoteIPv6Addr2,
+ expectLinkLocalSourceError: true,
+ },
+ {
+ name: "Hopbyhop with unknown option skippable action",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Skippable unknown.
+ 62, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID, checker.IPv6ExtHdr(checker.IPv6HopByHopExtensionHeader(checker.IPv6UnknownOption(), checker.IPv6UnknownOption()))
+ },
+ expectPacketForwarded: true,
+ },
+ {
+ name: "Hopbyhop with unknown option discard action",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard unknown.
+ 127, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID, nil
+ },
+ expectExtensionHeaderError: true,
+ },
+ {
+ name: "Hopbyhop with unknown option discard and send icmp action (unicast)",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID, nil
+ },
+ expectErrorICMP: true,
+ icmpType: header.ICMPv6ParamProblem,
+ icmpCode: header.ICMPv6UnknownOption,
+ expectExtensionHeaderError: true,
+ },
+ {
+ name: "Hopbyhop with unknown option discard and send icmp action (multicast)",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: multicastIPv6Addr.Address,
+ extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID, nil
+ },
+ expectErrorICMP: true,
+ icmpType: header.ICMPv6ParamProblem,
+ icmpCode: header.ICMPv6UnknownOption,
+ expectExtensionHeaderError: true,
+ },
+ {
+ name: "Hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID, nil
+ },
+ expectErrorICMP: true,
+ icmpType: header.ICMPv6ParamProblem,
+ icmpCode: header.ICMPv6UnknownOption,
+ expectExtensionHeaderError: true,
+ },
+ {
+ name: "Hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: multicastIPv6Addr.Address,
+ extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID, nil
+ },
+ expectExtensionHeaderError: true,
},
{
- name: "TTL of three",
- TTL: 3,
- expectErrorICMP: false,
+ name: "Hopbyhop with router alert option",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
+ return []byte{
+ nextHdr, 0,
+
+ // Router Alert option.
+ 5, 2, 0, 0, 0, 0,
+ }, hopByHopExtHdrID, checker.IPv6ExtHdr(checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)))
+ },
+ expectPacketForwarded: true,
+ },
+ {
+ name: "Hopbyhop with two router alert options",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
+ return []byte{
+ nextHdr, 1,
+
+ // Router Alert option.
+ 5, 2, 0, 0, 0, 0,
+
+ // Router Alert option.
+ 5, 2, 0, 0, 0, 0,
+ }, hopByHopExtHdrID, nil
+ },
+ expectExtensionHeaderError: true,
+ },
+ {
+ name: "Can't fragment",
+ TTL: 2,
+ payloadLength: header.IPv6MinimumMTU + 1,
+ expectErrorICMP: true,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ icmpType: header.ICMPv6PacketTooBig,
+ icmpCode: header.ICMPv6UnusedCode,
},
{
- name: "Max TTL",
- TTL: math.MaxUint8,
- expectErrorICMP: false,
+ name: "Can't fragment multicast",
+ TTL: 2,
+ payloadLength: header.IPv6MinimumMTU + 1,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: multicastIPv6Addr.Address,
+ expectErrorICMP: true,
+ icmpType: header.ICMPv6PacketTooBig,
+ icmpCode: header.ICMPv6UnusedCode,
},
}
@@ -3059,41 +3297,60 @@ func TestForwarding(t *testing.T) {
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
})
// We expect at most a single packet in response to our ICMP Echo Request.
- e1 := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID1, e1); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ incomingEndpoint := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err)
}
- ipv6ProtoAddr1 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr1}
- if err := s.AddProtocolAddress(nicID1, ipv6ProtoAddr1); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv6ProtoAddr1, err)
+ incomingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: incomingIPv6Addr}
+ if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv6ProtoAddr, err)
}
- e2 := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID2, e2); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ outgoingEndpoint := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err)
}
- ipv6ProtoAddr2 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr2}
- if err := s.AddProtocolAddress(nicID2, ipv6ProtoAddr2); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv6ProtoAddr2, err)
+ outgoingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: outgoingIPv6Addr}
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv6ProtoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
{
- Destination: ipv6Addr1.Subnet(),
- NIC: nicID1,
+ Destination: incomingIPv6Addr.Subnet(),
+ NIC: incomingNICID,
+ },
+ {
+ Destination: outgoingIPv6Addr.Subnet(),
+ NIC: outgoingNICID,
},
{
- Destination: ipv6Addr2.Subnet(),
- NIC: nicID2,
+ Destination: multicastIPv6Addr.Subnet(),
+ NIC: outgoingNICID,
},
})
- if err := s.SetForwarding(ProtocolNumber, true); err != nil {
- t.Fatalf("SetForwarding(%d, true): %s", ProtocolNumber, err)
+ if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err)
}
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6MinimumSize)
- icmp := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ transportProtocol := header.ICMPv6ProtocolNumber
+ extHdrBytes := []byte{}
+ extHdrChecker := checker.IPv6ExtHdr()
+ if test.extHdr != nil {
+ nextHdrID := hopByHopExtHdrID
+ extHdrBytes, nextHdrID, extHdrChecker = test.extHdr(uint8(header.ICMPv6ProtocolNumber))
+ transportProtocol = tcpip.TransportProtocolNumber(nextHdrID)
+ }
+ extHdrLen := len(extHdrBytes)
+
+ ipHeaderLength := header.IPv6MinimumSize
+ icmpHeaderLength := header.ICMPv6MinimumSize
+ totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength + extHdrLen
+ hdr := buffer.NewPrependable(totalLength)
+ hdr.Prepend(test.payloadLength)
+ icmp := header.ICMPv6(hdr.Prepend(icmpHeaderLength))
+
icmp.SetIdent(randomIdent)
icmp.SetSequence(randomSequence)
icmp.SetType(header.ICMPv6EchoRequest)
@@ -3101,52 +3358,72 @@ func TestForwarding(t *testing.T) {
icmp.SetChecksum(0)
icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmp,
- Src: remoteIPv6Addr1,
- Dst: remoteIPv6Addr2,
+ Src: test.sourceAddr,
+ Dst: test.destAddr,
}))
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ copy(hdr.Prepend(extHdrLen), extHdrBytes)
+ ip := header.IPv6(hdr.Prepend(ipHeaderLength))
ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- TransportProtocol: header.ICMPv6ProtocolNumber,
+ PayloadLength: uint16(header.ICMPv6MinimumSize + test.payloadLength),
+ TransportProtocol: transportProtocol,
HopLimit: test.TTL,
- SrcAddr: remoteIPv6Addr1,
- DstAddr: remoteIPv6Addr2,
+ SrcAddr: test.sourceAddr,
+ DstAddr: test.destAddr,
})
requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
})
- e1.InjectInbound(ProtocolNumber, requestPkt)
+ incomingEndpoint.InjectInbound(ProtocolNumber, requestPkt)
+
+ reply, ok := incomingEndpoint.Read()
if test.expectErrorICMP {
- reply, ok := e1.Read()
if !ok {
- t.Fatal("expected ICMP Hop Limit Exceeded packet through incoming NIC")
+ t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType)
+ }
+
+ // As per RFC 4443, page 9:
+ //
+ // The returned ICMP packet will contain as much of invoking packet
+ // as possible without the ICMPv6 packet exceeding the minimum IPv6
+ // MTU.
+ expectedICMPPayloadLength := func() int {
+ maxICMPPayloadLength := header.IPv6MinimumMTU - ipHeaderLength - icmpHeaderLength
+ if len(hdr.View()) > maxICMPPayloadLength {
+ return maxICMPPayloadLength
+ }
+ return len(hdr.View())
}
checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())),
- checker.SrcAddr(ipv6Addr1.Address),
- checker.DstAddr(remoteIPv6Addr1),
+ checker.SrcAddr(incomingIPv6Addr.Address),
+ checker.DstAddr(test.sourceAddr),
checker.TTL(DefaultTTL),
checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6TimeExceeded),
- checker.ICMPv6Code(header.ICMPv6HopLimitExceeded),
- checker.ICMPv6Payload([]byte(hdr.View())),
+ checker.ICMPv6Type(test.icmpType),
+ checker.ICMPv6Code(test.icmpCode),
+ checker.ICMPv6Payload([]byte(hdr.View()[0:expectedICMPPayloadLength()])),
),
)
- if n := e2.Drain(); n != 0 {
+ if n := outgoingEndpoint.Drain(); n != 0 {
t.Fatalf("got e2.Drain() = %d, want = 0", n)
}
- } else {
- reply, ok := e2.Read()
+ } else if ok {
+ t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", reply)
+ }
+
+ reply, ok = outgoingEndpoint.Read()
+ if test.expectPacketForwarded {
if !ok {
t.Fatal("expected ICMP Echo Request packet through outgoing NIC")
}
- checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())),
- checker.SrcAddr(remoteIPv6Addr1),
- checker.DstAddr(remoteIPv6Addr2),
+ checker.IPv6WithExtHdr(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.SrcAddr(test.sourceAddr),
+ checker.DstAddr(test.destAddr),
checker.TTL(test.TTL-1),
+ extHdrChecker,
checker.ICMPv6(
checker.ICMPv6Type(header.ICMPv6EchoRequest),
checker.ICMPv6Code(header.ICMPv6UnusedCode),
@@ -3154,9 +3431,46 @@ func TestForwarding(t *testing.T) {
),
)
- if n := e1.Drain(); n != 0 {
+ if n := incomingEndpoint.Drain(); n != 0 {
t.Fatalf("got e1.Drain() = %d, want = 0", n)
}
+ } else if ok {
+ t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply)
+ }
+
+ boolToInt := func(val bool) uint64 {
+ if val {
+ return 1
+ }
+ return 0
+ }
+
+ if got, want := s.Stats().IP.Forwarding.LinkLocalSource.Value(), boolToInt(test.expectLinkLocalSourceError); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.LinkLocalSource.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.LinkLocalDestination.Value(), boolToInt(test.expectLinkLocalDestError); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.LinkLocalDestination.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.ExhaustedTTL.Value(), boolToInt(test.TTL <= 1); got != want {
+ t.Errorf("got rt.Stats().IP.Forwarding.ExhaustedTTL.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.Unrouteable.Value(), boolToInt(test.expectPacketUnrouteableError); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.Unrouteable.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.Errors.Value(), boolToInt(!test.expectPacketForwarded); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.Errors.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.ExtensionHeaderProblem.Value(), boolToInt(test.expectExtensionHeaderError); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.ExtensionHeaderProblem.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.PacketTooBig.Value(), boolToInt(test.icmpType == header.ICMPv6PacketTooBig); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.PacketTooBig.Value() = %d, want = %d", got, want)
}
})
}
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index d6e0a81a6..f0ff111c5 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -48,7 +48,7 @@ const (
// defaultHandleRAs is the default configuration for whether or not to
// handle incoming Router Advertisements as a host.
- defaultHandleRAs = true
+ defaultHandleRAs = HandlingRAsEnabledWhenForwardingDisabled
// defaultDiscoverDefaultRouters is the default configuration for
// whether or not to discover default routers from incoming Router
@@ -301,10 +301,60 @@ type NDPDispatcher interface {
OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA)
}
+var _ fmt.Stringer = HandleRAsConfiguration(0)
+
+// HandleRAsConfiguration enumerates when RAs may be handled.
+type HandleRAsConfiguration int
+
+const (
+ // HandlingRAsDisabled indicates that Router Advertisements will not be
+ // handled.
+ HandlingRAsDisabled HandleRAsConfiguration = iota
+
+ // HandlingRAsEnabledWhenForwardingDisabled indicates that router
+ // advertisements will only be handled when forwarding is disabled.
+ HandlingRAsEnabledWhenForwardingDisabled
+
+ // HandlingRAsAlwaysEnabled indicates that Router Advertisements will always
+ // be handled, even when forwarding is enabled.
+ HandlingRAsAlwaysEnabled
+)
+
+// String implements fmt.Stringer.
+func (c HandleRAsConfiguration) String() string {
+ switch c {
+ case HandlingRAsDisabled:
+ return "HandlingRAsDisabled"
+ case HandlingRAsEnabledWhenForwardingDisabled:
+ return "HandlingRAsEnabledWhenForwardingDisabled"
+ case HandlingRAsAlwaysEnabled:
+ return "HandlingRAsAlwaysEnabled"
+ default:
+ return fmt.Sprintf("HandleRAsConfiguration(%d)", c)
+ }
+}
+
+// enabled returns true iff Router Advertisements may be handled given the
+// specified forwarding status.
+func (c HandleRAsConfiguration) enabled(forwarding bool) bool {
+ switch c {
+ case HandlingRAsDisabled:
+ return false
+ case HandlingRAsEnabledWhenForwardingDisabled:
+ return !forwarding
+ case HandlingRAsAlwaysEnabled:
+ return true
+ default:
+ panic(fmt.Sprintf("unhandled HandleRAsConfiguration = %d", c))
+ }
+}
+
// NDPConfigurations is the NDP configurations for the netstack.
type NDPConfigurations struct {
// The number of Router Solicitation messages to send when the IPv6 endpoint
// becomes enabled.
+ //
+ // Ignored unless configured to handle Router Advertisements.
MaxRtrSolicitations uint8
// The amount of time between transmitting Router Solicitation messages.
@@ -318,8 +368,9 @@ type NDPConfigurations struct {
// Must be greater than or equal to 0s.
MaxRtrSolicitationDelay time.Duration
- // HandleRAs determines whether or not Router Advertisements are processed.
- HandleRAs bool
+ // HandleRAs is the configuration for when Router Advertisements should be
+ // handled.
+ HandleRAs HandleRAsConfiguration
// DiscoverDefaultRouters determines whether or not default routers are
// discovered from Router Advertisements, as per RFC 4861 section 6. This
@@ -654,7 +705,8 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
// per-interface basis; it is a protocol-wide configuration, so we check the
// protocol's forwarding flag to determine if the IPv6 endpoint is forwarding
// packets.
- if !ndp.configs.HandleRAs || ndp.ep.protocol.Forwarding() {
+ if !ndp.configs.HandleRAs.enabled(ndp.ep.Forwarding()) {
+ ndp.ep.stats.localStats.UnhandledRouterAdvertisements.Increment()
return
}
@@ -1609,44 +1661,16 @@ func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotifyInner(tempAddrs map[t
delete(tempAddrs, tempAddr)
}
-// removeSLAACAddresses removes all SLAAC addresses.
-//
-// If keepLinkLocal is false, the SLAAC generated link-local address is removed.
-//
-// The IPv6 endpoint that ndp belongs to MUST be locked.
-func (ndp *ndpState) removeSLAACAddresses(keepLinkLocal bool) {
- linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet()
- var linkLocalPrefixes int
- for prefix, state := range ndp.slaacPrefixes {
- // RFC 4862 section 5 states that routers are also expected to generate a
- // link-local address so we do not invalidate them if we are cleaning up
- // host-only state.
- if keepLinkLocal && prefix == linkLocalSubnet {
- linkLocalPrefixes++
- continue
- }
-
- ndp.invalidateSLAACPrefix(prefix, state)
- }
-
- if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes {
- panic(fmt.Sprintf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes))
- }
-}
-
// cleanupState cleans up ndp's state.
//
-// If hostOnly is true, then only host-specific state is cleaned up.
-//
// This function invalidates all discovered on-link prefixes, discovered
// routers, and auto-generated addresses.
//
-// If hostOnly is true, then the link-local auto-generated address aren't
-// invalidated as routers are also expected to generate a link-local address.
-//
// The IPv6 endpoint that ndp belongs to MUST be locked.
-func (ndp *ndpState) cleanupState(hostOnly bool) {
- ndp.removeSLAACAddresses(hostOnly /* keepLinkLocal */)
+func (ndp *ndpState) cleanupState() {
+ for prefix, state := range ndp.slaacPrefixes {
+ ndp.invalidateSLAACPrefix(prefix, state)
+ }
for prefix := range ndp.onLinkPrefixes {
ndp.invalidateOnLinkPrefix(prefix)
@@ -1670,6 +1694,10 @@ func (ndp *ndpState) cleanupState(hostOnly bool) {
// startSolicitingRouters starts soliciting routers, as per RFC 4861 section
// 6.3.7. If routers are already being solicited, this function does nothing.
//
+// If ndp is not configured to handle Router Advertisements, routers will not
+// be solicited as there is no point soliciting routers if we don't handle their
+// advertisements.
+//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) startSolicitingRouters() {
if ndp.rtrSolicitTimer.timer != nil {
@@ -1682,6 +1710,10 @@ func (ndp *ndpState) startSolicitingRouters() {
return
}
+ if !ndp.configs.HandleRAs.enabled(ndp.ep.Forwarding()) {
+ return
+ }
+
// Calculate the random delay before sending our first RS, as per RFC
// 4861 section 6.3.7.
var delay time.Duration
@@ -1774,6 +1806,32 @@ func (ndp *ndpState) startSolicitingRouters() {
}
}
+// forwardingChanged handles a change in forwarding configuration.
+//
+// If transitioning to a host, router solicitation will be started. Otherwise,
+// router solicitation will be stopped if NDP is not configured to handle RAs
+// as a router.
+//
+// Precondition: ndp.ep.mu must be locked.
+func (ndp *ndpState) forwardingChanged(forwarding bool) {
+ if forwarding {
+ if ndp.configs.HandleRAs.enabled(forwarding) {
+ return
+ }
+
+ ndp.stopSolicitingRouters()
+ return
+ }
+
+ // Solicit routers when transitioning to a host.
+ //
+ // If the endpoint is not currently enabled, routers will be solicited when
+ // the endpoint becomes enabled (if it is still a host).
+ if ndp.ep.Enabled() {
+ ndp.startSolicitingRouters()
+ }
+}
+
// stopSolicitingRouters stops soliciting routers. If routers are not currently
// being solicited, this function does nothing.
//
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 52b9a200c..234e34952 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -732,15 +732,7 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) {
}
func TestNDPValidation(t *testing.T) {
- setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) {
- t.Helper()
-
- // Create a stack with the assigned link-local address lladdr0
- // and an endpoint to lladdr1.
- s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1)
-
- return s, ep
- }
+ const nicID = 1
handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) {
var extHdrs header.IPv6ExtHdrSerializer
@@ -865,6 +857,11 @@ func TestNDPValidation(t *testing.T) {
},
}
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0))))
+ if err != nil {
+ t.Fatal(err)
+ }
+
for _, typ := range types {
for _, isRouter := range []bool{false, true} {
name := typ.name
@@ -875,13 +872,35 @@ func TestNDPValidation(t *testing.T) {
t.Run(name, func(t *testing.T) {
for _, test := range subTests {
t.Run(test.name, func(t *testing.T) {
- s, ep := setup(t)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
+ })
if isRouter {
- // Enabling forwarding makes the stack act as a router.
- s.SetForwarding(ProtocolNumber, true)
+ if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err)
+ }
}
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber)
+ if err != nil {
+ t.Fatal("cannot find network endpoint instance for IPv6")
+ }
+
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: subnet,
+ NIC: nicID,
+ }})
+
stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
routerOnly := stats.RouterOnlyPacketsDroppedByHost
@@ -906,12 +925,12 @@ func TestNDPValidation(t *testing.T) {
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
- t.Errorf("got invalid = %d, want = 0", got)
+ t.Errorf("got invalid.Value() = %d, want = 0", got)
}
- // RouterOnlyPacketsReceivedByHost count should initially be 0.
+ // Should initially not have dropped any packets.
if got := routerOnly.Value(); got != 0 {
- t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
+ t.Errorf("got routerOnly.Value() = %d, want = 0", got)
}
if t.Failed() {
@@ -931,18 +950,18 @@ func TestNDPValidation(t *testing.T) {
want = 1
}
if got := invalid.Value(); got != want {
- t.Errorf("got invalid = %d, want = %d", got, want)
+ t.Errorf("got invalid.Value() = %d, want = %d", got, want)
}
want = 0
if test.valid && !isRouter && typ.routerOnly {
- // RouterOnlyPacketsReceivedByHost count should have increased.
+ // Router only packets are expected to be dropped when operating
+ // as a host.
want = 1
}
if got := routerOnly.Value(); got != want {
- t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want)
+ t.Errorf("got routerOnly.Value() = %d, want = %d", got, want)
}
-
})
}
})
diff --git a/pkg/tcpip/network/ipv6/stats.go b/pkg/tcpip/network/ipv6/stats.go
index c2758352f..2f18f60e8 100644
--- a/pkg/tcpip/network/ipv6/stats.go
+++ b/pkg/tcpip/network/ipv6/stats.go
@@ -29,6 +29,10 @@ type Stats struct {
// ICMP holds ICMPv6 statistics.
ICMP tcpip.ICMPv6Stats
+
+ // UnhandledRouterAdvertisements is the number of Router Advertisements that
+ // were observed but not handled.
+ UnhandledRouterAdvertisements *tcpip.StatCounter
}
// IsNetworkEndpointStats implements stack.NetworkEndpointStats.
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index a6c877158..b26936b7f 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -18,6 +18,7 @@ import (
"math"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -213,7 +214,7 @@ type SocketOptions struct {
getSendBufferLimits GetSendBufferLimits `state:"manual"`
// sendBufferSize determines the send buffer size for this socket.
- sendBufferSize int64
+ sendBufferSize atomicbitops.AlignedAtomicInt64
// getReceiveBufferLimits provides the handler to get the min, default and
// max size for receive buffer. It is initialized at the creation time and
@@ -612,7 +613,7 @@ func (so *SocketOptions) SetBindToDevice(bindToDevice int32) Error {
// GetSendBufferSize gets value for SO_SNDBUF option.
func (so *SocketOptions) GetSendBufferSize() int64 {
- return atomic.LoadInt64(&so.sendBufferSize)
+ return so.sendBufferSize.Load()
}
// SetSendBufferSize sets value for SO_SNDBUF option. notify indicates if the
@@ -621,7 +622,7 @@ func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) {
v := sendBufferSize
if !notify {
- atomic.StoreInt64(&so.sendBufferSize, v)
+ so.sendBufferSize.Store(v)
return
}
@@ -647,7 +648,7 @@ func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) {
// Notify endpoint about change in buffer size.
newSz := so.handler.OnSetSendBufferSize(v)
- atomic.StoreInt64(&so.sendBufferSize, newSz)
+ so.sendBufferSize.Store(newSz)
}
// GetReceiveBufferSize gets value for SO_RCVBUF option.
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 2bd6a67f5..84aa6a9e4 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -73,6 +73,8 @@ go_library(
],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/atomicbitops",
+ "//pkg/buffer",
"//pkg/ilist",
"//pkg/log",
"//pkg/rand",
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index e5590ecc0..ce9cebdaa 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -440,33 +440,54 @@ func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*ad
// Regardless how the address was obtained, it will be acquired before it is
// returned.
func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tcpip.Address, f func(AddressEndpoint) bool, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
- a.mu.Lock()
- defer a.mu.Unlock()
+ lookup := func() *addressState {
+ if addrState, ok := a.mu.endpoints[localAddr]; ok {
+ if !addrState.IsAssigned(allowTemp) {
+ return nil
+ }
- if addrState, ok := a.mu.endpoints[localAddr]; ok {
- if !addrState.IsAssigned(allowTemp) {
- return nil
- }
+ if !addrState.IncRef() {
+ panic(fmt.Sprintf("failed to increase the reference count for address = %s", addrState.addr))
+ }
- if !addrState.IncRef() {
- panic(fmt.Sprintf("failed to increase the reference count for address = %s", addrState.addr))
+ return addrState
}
- return addrState
- }
-
- if f != nil {
- for _, addrState := range a.mu.endpoints {
- if addrState.IsAssigned(allowTemp) && f(addrState) && addrState.IncRef() {
- return addrState
+ if f != nil {
+ for _, addrState := range a.mu.endpoints {
+ if addrState.IsAssigned(allowTemp) && f(addrState) && addrState.IncRef() {
+ return addrState
+ }
}
}
+ return nil
+ }
+ // Avoid exclusive lock on mu unless we need to add a new address.
+ a.mu.RLock()
+ ep := lookup()
+ a.mu.RUnlock()
+
+ if ep != nil {
+ return ep
}
if !allowTemp {
return nil
}
+ // Acquire state lock in exclusive mode as we need to add a new temporary
+ // endpoint.
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ // Do the lookup again in case another goroutine added the address in the time
+ // we released and acquired the lock.
+ ep = lookup()
+ if ep != nil {
+ return ep
+ }
+
+ // Proceed to add a new temporary endpoint.
addr := localAddr.WithPrefix()
ep, err := a.addAndAcquireAddressLocked(addr, tempPEB, AddressConfigStatic, false /* deprecated */, false /* permanent */)
if err != nil {
@@ -475,6 +496,7 @@ func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tc
// expect no error.
panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, %d, %d, false, false): %s", addr, tempPEB, AddressConfigStatic, err))
}
+
// From https://golang.org/doc/faq#nil_error:
//
// Under the covers, interfaces are implemented as two elements, a type T and
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index 2d74e0abc..7107d598d 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -54,6 +54,11 @@ type fwdTestNetworkEndpoint struct {
nic NetworkInterface
proto *fwdTestNetworkProtocol
dispatcher TransportDispatcher
+
+ mu struct {
+ sync.RWMutex
+ forwarding bool
+ }
}
func (*fwdTestNetworkEndpoint) Enable() tcpip.Error {
@@ -101,7 +106,7 @@ func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) {
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: vv.ToView().ToVectorisedView(),
})
- // TODO(b/143425874) Decrease the TTL field in forwarded packets.
+ // TODO(gvisor.dev/issue/1085) Decrease the TTL field in forwarded packets.
_ = r.WriteHeaderIncludedPacket(pkt)
}
@@ -169,11 +174,6 @@ type fwdTestNetworkProtocol struct {
addrResolveDelay time.Duration
onLinkAddressResolved func(*neighborCache, tcpip.Address, tcpip.LinkAddress)
onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool)
-
- mu struct {
- sync.RWMutex
- forwarding bool
- }
}
func (*fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
@@ -242,16 +242,16 @@ func (*fwdTestNetworkEndpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber
return fwdTestNetNumber
}
-// Forwarding implements stack.ForwardingNetworkProtocol.
-func (f *fwdTestNetworkProtocol) Forwarding() bool {
+// Forwarding implements stack.ForwardingNetworkEndpoint.
+func (f *fwdTestNetworkEndpoint) Forwarding() bool {
f.mu.RLock()
defer f.mu.RUnlock()
return f.mu.forwarding
}
-// SetForwarding implements stack.ForwardingNetworkProtocol.
-func (f *fwdTestNetworkProtocol) SetForwarding(v bool) {
+// SetForwarding implements stack.ForwardingNetworkEndpoint.
+func (f *fwdTestNetworkEndpoint) SetForwarding(v bool) {
f.mu.Lock()
defer f.mu.Unlock()
f.mu.forwarding = v
@@ -264,6 +264,8 @@ type fwdTestPacketInfo struct {
Pkt *PacketBuffer
}
+var _ LinkEndpoint = (*fwdTestLinkEndpoint)(nil)
+
type fwdTestLinkEndpoint struct {
dispatcher NetworkDispatcher
mtu uint32
@@ -306,11 +308,6 @@ func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities {
return caps | CapabilityResolutionRequired
}
-// GSOMaxSize returns the maximum GSO packet size.
-func (*fwdTestLinkEndpoint) GSOMaxSize() uint32 {
- return 1 << 15
-}
-
// MaxHeaderLength returns the maximum size of the link layer header. Given it
// doesn't have a header, it just returns 0.
func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 {
@@ -370,8 +367,10 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f
}},
})
- // Enable forwarding.
- s.SetForwarding(proto.Number(), true)
+ protoNum := proto.Number()
+ if err := s.SetForwardingDefaultAndAllNICs(protoNum, true); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", protoNum, err)
+ }
// NIC 1 has the link address "a", and added the network address 1.
ep1 = &fwdTestLinkEndpoint{
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index e2894c548..3670d5995 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -177,6 +177,7 @@ func DefaultTables() *IPTables {
priorities: [NumHooks][]TableID{
Prerouting: {MangleID, NATID},
Input: {NATID, FilterID},
+ Forward: {FilterID},
Output: {MangleID, NATID, FilterID},
Postrouting: {MangleID, NATID},
},
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 4631ab93f..93592e7f5 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -280,9 +280,18 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicNa
return matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert)
case Output:
return matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert)
- case Forward, Postrouting:
- // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING
- // hooks after supported.
+ case Forward:
+ if !matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert) {
+ return false
+ }
+
+ if !matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert) {
+ return false
+ }
+
+ return true
+ case Postrouting:
+ // TODO(gvisor.dev/issue/170): Add the check for POSTROUTING.
return true
default:
panic(fmt.Sprintf("unknown hook: %d", hook))
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index b6cf24739..ac2fa777e 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -481,13 +481,9 @@ func TestDADResolve(t *testing.T) {
}
for _, test := range tests {
- test := test
-
t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent),
+ dadC: make(chan ndpDADEvent, 1),
}
e := channelLinkWithHeaderLength{
@@ -499,7 +495,9 @@ func TestDADResolve(t *testing.T) {
var secureRNG bytes.Reader
secureRNG.Reset(secureRNGBytes)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
+ Clock: clock,
SecureRNG: &secureRNG,
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPDisp: &ndpDisp,
@@ -529,14 +527,10 @@ func TestDADResolve(t *testing.T) {
t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
}
- // Address should not be considered bound to the NIC yet (DAD ongoing).
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
- t.Fatal(err)
- }
-
// Make sure the address does not resolve before the resolution time has
// passed.
- time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncNegativeEventTimeout)
+ const delta = time.Nanosecond
+ clock.Advance(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta)
if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
t.Error(err)
}
@@ -566,13 +560,14 @@ func TestDADResolve(t *testing.T) {
}
// Wait for DAD to resolve.
+ clock.Advance(delta)
select {
- case <-time.After(defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr1, &stack.DADSucceeded{}); diff != "" {
t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
+ default:
+ t.Fatalf("expected DAD event for %s on NIC(%d)", addr1, nicID)
}
if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
t.Error(err)
@@ -1146,57 +1141,198 @@ func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, on
})
}
-// TestNoRouterDiscovery tests that router discovery will not be performed if
-// configured not to.
-func TestNoRouterDiscovery(t *testing.T) {
- // Being configured to discover routers means handle and
- // discover are set to true and forwarding is set to false.
- // This tests all possible combinations of the configurations,
- // except for the configuration where handle = true, discover =
- // true and forwarding = false (the required configuration to do
- // router discovery) - that will done in other tests.
- for i := 0; i < 7; i++ {
- handle := i&1 != 0
- discover := i&2 != 0
- forwarding := i&4 == 0
-
- t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: handle,
- DiscoverDefaultRouters: discover,
- },
- NDPDisp: &ndpDisp,
- })},
- })
- s.SetForwarding(ipv6.ProtocolNumber, forwarding)
+func TestDynamicConfigurationsDisabled(t *testing.T) {
+ const (
+ nicID = 1
+ maxRtrSolicitDelay = time.Second
+ )
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
+ prefix := tcpip.AddressWithPrefix{
+ Address: testutil.MustParse6("102:304:506:708::"),
+ PrefixLen: 64,
+ }
- // Rx an RA with non-zero lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
- select {
- case <-ndpDisp.routerC:
- t.Fatal("unexpectedly discovered a router when configured not to")
- default:
+ tests := []struct {
+ name string
+ config func(bool) ipv6.NDPConfigurations
+ ra *stack.PacketBuffer
+ }{
+ {
+ name: "No Router Discovery",
+ config: func(enable bool) ipv6.NDPConfigurations {
+ return ipv6.NDPConfigurations{DiscoverDefaultRouters: enable}
+ },
+ ra: raBuf(llAddr2, 1000),
+ },
+ {
+ name: "No Prefix Discovery",
+ config: func(enable bool) ipv6.NDPConfigurations {
+ return ipv6.NDPConfigurations{DiscoverOnLinkPrefixes: enable}
+ },
+ ra: raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0),
+ },
+ {
+ name: "No Autogenerate Addresses",
+ config: func(enable bool) ipv6.NDPConfigurations {
+ return ipv6.NDPConfigurations{AutoGenGlobalAddresses: enable}
+ },
+ ra: raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0),
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ // Being configured to discover routers/prefixes or auto-generate
+ // addresses means RAs must be handled, and router/prefix discovery or
+ // SLAAC must be enabled.
+ //
+ // This tests all possible combinations of the configurations where
+ // router/prefix discovery or SLAAC are disabled.
+ for i := 0; i < 7; i++ {
+ handle := ipv6.HandlingRAsDisabled
+ if i&1 != 0 {
+ handle = ipv6.HandlingRAsEnabledWhenForwardingDisabled
+ }
+ enable := i&2 != 0
+ forwarding := i&4 == 0
+
+ t.Run(fmt.Sprintf("HandleRAs(%s), Forwarding(%t), Enabled(%t)", handle, forwarding, enable), func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, 1),
+ prefixC: make(chan ndpPrefixEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ ndpConfigs := test.config(enable)
+ ndpConfigs.HandleRAs = handle
+ ndpConfigs.MaxRtrSolicitations = 1
+ ndpConfigs.RtrSolicitationInterval = maxRtrSolicitDelay
+ ndpConfigs.MaxRtrSolicitationDelay = maxRtrSolicitDelay
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ Clock: clock,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ })},
+ })
+ if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
+ }
+
+ e := channel.New(1, 1280, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ handleRAsDisabled := handle == ipv6.HandlingRAsDisabled || forwarding
+ ep, err := s.GetNetworkEndpoint(nicID, ipv6.ProtocolNumber)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ipv6.ProtocolNumber, err)
+ }
+ stats := ep.Stats()
+ v6Stats, ok := stats.(*ipv6.Stats)
+ if !ok {
+ t.Fatalf("got v6Stats = %T, expected = %T", stats, v6Stats)
+ }
+
+ // Make sure that when handling RAs are enabled, we solicit routers.
+ clock.Advance(maxRtrSolicitDelay)
+ if got, want := v6Stats.ICMP.PacketsSent.RouterSolicit.Value(), boolToUint64(!handleRAsDisabled); got != want {
+ t.Errorf("got v6Stats.ICMP.PacketsSent.RouterSolicit.Value() = %d, want = %d", got, want)
+ }
+ if handleRAsDisabled {
+ if p, ok := e.Read(); ok {
+ t.Errorf("unexpectedly got a packet = %#v", p)
+ }
+ } else if p, ok := e.Read(); !ok {
+ t.Error("expected router solicitation packet")
+ } else if p.Proto != header.IPv6ProtocolNumber {
+ t.Errorf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ } else {
+ if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS(checker.NDPRSOptions(nil)),
+ )
+ }
+
+ // Make sure we do not discover any routers or prefixes, or perform
+ // SLAAC on reception of an RA.
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra.Clone())
+ // Make sure that the unhandled RA stat is only incremented when
+ // handling RAs is disabled.
+ if got, want := v6Stats.UnhandledRouterAdvertisements.Value(), boolToUint64(handleRAsDisabled); got != want {
+ t.Errorf("got v6Stats.UnhandledRouterAdvertisements.Value() = %d, want = %d", got, want)
+ }
+ select {
+ case e := <-ndpDisp.routerC:
+ t.Errorf("unexpectedly discovered a router when configured not to: %#v", e)
+ default:
+ }
+ select {
+ case e := <-ndpDisp.prefixC:
+ t.Errorf("unexpectedly discovered a prefix when configured not to: %#v", e)
+ default:
+ }
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Errorf("unexpectedly auto-generated an address when configured not to: %#v", e)
+ default:
+ }
+ })
}
})
}
}
+func boolToUint64(v bool) uint64 {
+ if v {
+ return 1
+ }
+ return 0
+}
+
// Check e to make sure that the event is for addr on nic with ID 1, and the
// discovered flag set to discovered.
func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string {
return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e))
}
+func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, bool)) {
+ tests := [...]struct {
+ name string
+ handleRAs ipv6.HandleRAsConfiguration
+ forwarding bool
+ }{
+ {
+ name: "Handle RAs when forwarding disabled",
+ handleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
+ forwarding: false,
+ },
+ {
+ name: "Always Handle RAs with forwarding disabled",
+ handleRAs: ipv6.HandlingRAsAlwaysEnabled,
+ forwarding: false,
+ },
+ {
+ name: "Always Handle RAs with forwarding enabled",
+ handleRAs: ipv6.HandlingRAsAlwaysEnabled,
+ forwarding: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ f(t, test.handleRAs, test.forwarding)
+ })
+ }
+}
+
// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not
// remember a discovered router when the dispatcher asks it not to.
func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
@@ -1207,7 +1343,7 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
DiscoverDefaultRouters: true,
},
NDPDisp: &ndpDisp,
@@ -1241,103 +1377,109 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
}
func TestRouterDiscovery(t *testing.T) {
- ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- rememberRouter: true,
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: true,
- },
- NDPDisp: &ndpDisp,
- })},
- })
+ testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) {
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, 1),
+ rememberRouter: true,
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: handleRAs,
+ DiscoverDefaultRouters: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
+ })
- expectRouterEvent := func(addr tcpip.Address, discovered bool) {
- t.Helper()
+ expectRouterEvent := func(addr tcpip.Address, discovered bool) {
+ t.Helper()
- select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, addr, discovered); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ select {
+ case e := <-ndpDisp.routerC:
+ if diff := checkRouterEvent(e, addr, discovered); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected router discovery event")
}
- default:
- t.Fatal("expected router discovery event")
}
- }
- expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) {
- t.Helper()
+ expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) {
+ t.Helper()
- select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, addr, false); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ select {
+ case e := <-ndpDisp.routerC:
+ if diff := checkRouterEvent(e, addr, false); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(timeout):
+ t.Fatal("timed out waiting for router discovery event")
}
- case <-time.After(timeout):
- t.Fatal("timed out waiting for router discovery event")
}
- }
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Rx an RA from lladdr2 with zero lifetime. It should not be
- // remembered.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
- select {
- case <-ndpDisp.routerC:
- t.Fatal("unexpectedly discovered a router with 0 lifetime")
- default:
- }
-
- // Rx an RA from lladdr2 with a huge lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
- expectRouterEvent(llAddr2, true)
+ if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
+ }
- // Rx an RA from another router (lladdr3) with non-zero lifetime.
- const l3LifetimeSeconds = 6
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds))
- expectRouterEvent(llAddr3, true)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
- // Rx an RA from lladdr2 with lesser lifetime.
- const l2LifetimeSeconds = 2
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds))
- select {
- case <-ndpDisp.routerC:
- t.Fatal("Should not receive a router event when updating lifetimes for known routers")
- default:
- }
+ // Rx an RA from lladdr2 with zero lifetime. It should not be
+ // remembered.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
+ select {
+ case <-ndpDisp.routerC:
+ t.Fatal("unexpectedly discovered a router with 0 lifetime")
+ default:
+ }
- // Wait for lladdr2's router invalidation job to execute. The lifetime
- // of the router should have been updated to the most recent (smaller)
- // lifetime.
- //
- // Wait for the normal lifetime plus an extra bit for the
- // router to get invalidated. If we don't get an invalidation
- // event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
+ // Rx an RA from lladdr2 with a huge lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
+ expectRouterEvent(llAddr2, true)
- // Rx an RA from lladdr2 with huge lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
- expectRouterEvent(llAddr2, true)
+ // Rx an RA from another router (lladdr3) with non-zero lifetime.
+ const l3LifetimeSeconds = 6
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds))
+ expectRouterEvent(llAddr3, true)
- // Rx an RA from lladdr2 with zero lifetime. It should be invalidated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
- expectRouterEvent(llAddr2, false)
+ // Rx an RA from lladdr2 with lesser lifetime.
+ const l2LifetimeSeconds = 2
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds))
+ select {
+ case <-ndpDisp.routerC:
+ t.Fatal("Should not receive a router event when updating lifetimes for known routers")
+ default:
+ }
- // Wait for lladdr3's router invalidation job to execute. The lifetime
- // of the router should have been updated to the most recent (smaller)
- // lifetime.
- //
- // Wait for the normal lifetime plus an extra bit for the
- // router to get invalidated. If we don't get an invalidation
- // event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
+ // Wait for lladdr2's router invalidation job to execute. The lifetime
+ // of the router should have been updated to the most recent (smaller)
+ // lifetime.
+ //
+ // Wait for the normal lifetime plus an extra bit for the
+ // router to get invalidated. If we don't get an invalidation
+ // event after this time, then something is wrong.
+ expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
+
+ // Rx an RA from lladdr2 with huge lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
+ expectRouterEvent(llAddr2, true)
+
+ // Rx an RA from lladdr2 with zero lifetime. It should be invalidated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
+ expectRouterEvent(llAddr2, false)
+
+ // Wait for lladdr3's router invalidation job to execute. The lifetime
+ // of the router should have been updated to the most recent (smaller)
+ // lifetime.
+ //
+ // Wait for the normal lifetime plus an extra bit for the
+ // router to get invalidated. If we don't get an invalidation
+ // event after this time, then something is wrong.
+ expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
+ })
}
// TestRouterDiscoveryMaxRouters tests that only
@@ -1351,7 +1493,7 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
DiscoverDefaultRouters: true,
},
NDPDisp: &ndpDisp,
@@ -1390,57 +1532,6 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
}
}
-// TestNoPrefixDiscovery tests that prefix discovery will not be performed if
-// configured not to.
-func TestNoPrefixDiscovery(t *testing.T) {
- prefix := tcpip.AddressWithPrefix{
- Address: testutil.MustParse6("102:304:506:708::"),
- PrefixLen: 64,
- }
-
- // Being configured to discover prefixes means handle and
- // discover are set to true and forwarding is set to false.
- // This tests all possible combinations of the configurations,
- // except for the configuration where handle = true, discover =
- // true and forwarding = false (the required configuration to do
- // prefix discovery) - that will done in other tests.
- for i := 0; i < 7; i++ {
- handle := i&1 != 0
- discover := i&2 != 0
- forwarding := i&4 == 0
-
- t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: handle,
- DiscoverOnLinkPrefixes: discover,
- },
- NDPDisp: &ndpDisp,
- })},
- })
- s.SetForwarding(ipv6.ProtocolNumber, forwarding)
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Rx an RA with prefix with non-zero lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0))
-
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("unexpectedly discovered a prefix when configured not to")
- default:
- }
- })
- }
-}
-
// Check e to make sure that the event is for prefix on nic with ID 1, and the
// discovered flag set to discovered.
func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) string {
@@ -1459,8 +1550,7 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: false,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
DiscoverOnLinkPrefixes: true,
},
NDPDisp: &ndpDisp,
@@ -1498,87 +1588,93 @@ func TestPrefixDiscovery(t *testing.T) {
prefix2, subnet2, _ := prefixSubnetAddr(1, "")
prefix3, subnet3, _ := prefixSubnetAddr(2, "")
- ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 1),
- rememberPrefix: true,
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
- DiscoverOnLinkPrefixes: true,
- },
- NDPDisp: &ndpDisp,
- })},
- })
+ testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) {
+ ndpDisp := ndpDispatcher{
+ prefixC: make(chan ndpPrefixEvent, 1),
+ rememberPrefix: true,
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: handleRAs,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
+ })
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
- expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) {
- t.Helper()
+ expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) {
+ t.Helper()
- select {
- case e := <-ndpDisp.prefixC:
- if diff := checkPrefixEvent(e, prefix, discovered); diff != "" {
- t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, prefix, discovered); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected prefix discovery event")
}
- default:
- t.Fatal("expected prefix discovery event")
}
- }
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with zero valid lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0))
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("unexpectedly discovered a prefix with 0 lifetime")
- default:
- }
+ if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
+ }
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with non-zero lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0))
- expectPrefixEvent(subnet1, true)
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with zero valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0))
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("unexpectedly discovered a prefix with 0 lifetime")
+ default:
+ }
- // Receive an RA with prefix2 in a PI.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0))
- expectPrefixEvent(subnet2, true)
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0))
+ expectPrefixEvent(subnet1, true)
- // Receive an RA with prefix3 in a PI.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0))
- expectPrefixEvent(subnet3, true)
+ // Receive an RA with prefix2 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0))
+ expectPrefixEvent(subnet2, true)
- // Receive an RA with prefix1 in a PI with lifetime = 0.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0))
- expectPrefixEvent(subnet1, false)
+ // Receive an RA with prefix3 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0))
+ expectPrefixEvent(subnet3, true)
- // Receive an RA with prefix2 in a PI with lesser lifetime.
- lifetime := uint32(2)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0))
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("unexpectedly received prefix event when updating lifetime")
- default:
- }
+ // Receive an RA with prefix1 in a PI with lifetime = 0.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0))
+ expectPrefixEvent(subnet1, false)
- // Wait for prefix2's most recent invalidation job plus some buffer to
- // expire.
- select {
- case e := <-ndpDisp.prefixC:
- if diff := checkPrefixEvent(e, subnet2, false); diff != "" {
- t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ // Receive an RA with prefix2 in a PI with lesser lifetime.
+ lifetime := uint32(2)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0))
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("unexpectedly received prefix event when updating lifetime")
+ default:
+ }
+
+ // Wait for prefix2's most recent invalidation job plus some buffer to
+ // expire.
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, subnet2, false); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for prefix discovery event")
}
- case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for prefix discovery event")
- }
- // Receive RA to invalidate prefix3.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0))
- expectPrefixEvent(subnet3, false)
+ // Receive RA to invalidate prefix3.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0))
+ expectPrefixEvent(subnet3, false)
+ })
}
func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
@@ -1607,7 +1703,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
DiscoverOnLinkPrefixes: true,
},
NDPDisp: &ndpDisp,
@@ -1692,7 +1788,7 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
DiscoverDefaultRouters: false,
DiscoverOnLinkPrefixes: true,
},
@@ -1757,53 +1853,6 @@ func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix)
return containsAddr(list, protocolAddress)
}
-// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to.
-func TestNoAutoGenAddr(t *testing.T) {
- prefix, _, _ := prefixSubnetAddr(0, "")
-
- // Being configured to auto-generate addresses means handle and
- // autogen are set to true and forwarding is set to false.
- // This tests all possible combinations of the configurations,
- // except for the configuration where handle = true, autogen =
- // true and forwarding = false (the required configuration to do
- // SLAAC) - that will done in other tests.
- for i := 0; i < 7; i++ {
- handle := i&1 != 0
- autogen := i&2 != 0
- forwarding := i&4 == 0
-
- t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: handle,
- AutoGenGlobalAddresses: autogen,
- },
- NDPDisp: &ndpDisp,
- })},
- })
- s.SetForwarding(ipv6.ProtocolNumber, forwarding)
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Rx an RA with prefix with non-zero lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0))
-
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly auto-generated an address when configured not to")
- default:
- }
- })
- }
-}
-
// Check e to make sure that the event is for addr on nic with ID 1, and the
// event type is set to eventType.
func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) string {
@@ -1812,7 +1861,7 @@ func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix,
// TestAutoGenAddr tests that an address is properly generated and invalidated
// when configured to do so.
-func TestAutoGenAddr2(t *testing.T) {
+func TestAutoGenAddr(t *testing.T) {
const newMinVL = 2
newMinVLDuration := newMinVL * time.Second
saved := ipv6.MinPrefixInformationValidLifetimeForUpdate
@@ -1824,96 +1873,102 @@ func TestAutoGenAddr2(t *testing.T) {
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- })},
- })
+ testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) {
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: handleRAs,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
+ })
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
+ if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
+ }
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
}
- default:
- t.Fatal("expected addr auto gen event")
}
- }
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with zero valid lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly auto-generated an address with 0 lifetime")
- default:
- }
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with zero valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address with 0 lifetime")
+ default:
+ }
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with non-zero lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
- expectAutoGenAddrEvent(addr1, newAddr)
- if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
- t.Fatalf("Should have %s in the list of addresses", addr1)
- }
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ t.Fatalf("Should have %s in the list of addresses", addr1)
+ }
- // Receive an RA with prefix2 in an NDP Prefix Information option (PI)
- // with preferred lifetime > valid lifetime
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime")
- default:
- }
+ // Receive an RA with prefix2 in an NDP Prefix Information option (PI)
+ // with preferred lifetime > valid lifetime
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime")
+ default:
+ }
- // Receive an RA with prefix2 in a PI.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
- expectAutoGenAddrEvent(addr2, newAddr)
- if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
- t.Fatalf("Should have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) {
- t.Fatalf("Should have %s in the list of addresses", addr2)
- }
+ // Receive an RA with prefix2 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ t.Fatalf("Should have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) {
+ t.Fatalf("Should have %s in the list of addresses", addr2)
+ }
- // Refresh valid lifetime for addr of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix")
- default:
- }
+ // Refresh valid lifetime for addr of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix")
+ default:
+ }
- // Wait for addr of prefix1 to be invalidated.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ // Wait for addr of prefix1 to be invalidated.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
}
- case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
- if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
- t.Fatalf("Should not have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) {
- t.Fatalf("Should have %s in the list of addresses", addr2)
- }
+ if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ t.Fatalf("Should not have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) {
+ t.Fatalf("Should have %s in the list of addresses", addr2)
+ }
+ })
}
func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []tcpip.AddressWithPrefix) string {
@@ -2001,7 +2056,7 @@ func TestAutoGenTempAddr(t *testing.T) {
RetransmitTimer: test.retransmitTimer,
},
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
},
@@ -2302,7 +2357,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
RetransmitTimer: retransmitTimer,
},
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
},
@@ -2389,7 +2444,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
ndpConfigs := ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
RegenAdvanceDuration: newMinVLDuration - regenAfter,
@@ -2538,7 +2593,7 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
ndpConfigs := ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
RegenAdvanceDuration: newMinVLDuration - regenAfter,
@@ -2739,7 +2794,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
Clock: clock,
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: test.tempAddrs,
AutoGenAddressConflictRetries: 1,
@@ -2884,7 +2939,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
},
NDPDisp: ndpDisp,
@@ -3351,7 +3406,7 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
},
NDPDisp: &ndpDisp,
@@ -3494,7 +3549,7 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
},
NDPDisp: &ndpDisp,
@@ -3561,7 +3616,7 @@ func TestAutoGenAddrRemoval(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
},
NDPDisp: &ndpDisp,
@@ -3727,7 +3782,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
},
NDPDisp: &ndpDisp,
@@ -3809,7 +3864,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
},
NDPDisp: &ndpDisp,
@@ -3973,7 +4028,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
{
name: "Global address",
ndpConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
},
prepareFn: func(_ *testing.T, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix {
@@ -4000,7 +4055,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
{
name: "Temporary address",
ndpConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
},
@@ -4150,7 +4205,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
{
name: "Global address",
ndpConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenAddressConflictRetries: maxRetries,
},
@@ -4278,7 +4333,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
RetransmitTimer: retransmitTimer,
},
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenAddressConflictRetries: maxRetries,
},
@@ -4484,7 +4539,7 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
},
NDPDisp: &ndpDisp,
})},
@@ -4535,7 +4590,7 @@ func TestNDPDNSSearchListDispatch(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
},
NDPDisp: &ndpDisp,
})},
@@ -4629,8 +4684,110 @@ func TestNDPDNSSearchListDispatch(t *testing.T) {
}
}
-// TestCleanupNDPState tests that all discovered routers and prefixes, and
-// auto-generated addresses are invalidated when a NIC becomes a router.
+func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) {
+ const (
+ lifetimeSeconds = 999
+ nicID = 1
+ )
+
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, 1),
+ rememberRouter: true,
+ prefixC: make(chan ndpPrefixEvent, 1),
+ rememberPrefix: true,
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ AutoGenLinkLocal: true,
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
+ DiscoverDefaultRouters: true,
+ DiscoverOnLinkPrefixes: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
+ })
+
+ e1 := channel.New(0, header.IPv6MinimumMTU, linkAddr1)
+ if err := s.CreateNIC(nicID, e1); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ llAddr := tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, llAddr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", llAddr, nicID)
+ }
+
+ prefix, subnet, addr := prefixSubnetAddr(0, linkAddr1)
+ e1.InjectInbound(
+ header.IPv6ProtocolNumber,
+ raBufWithPI(
+ llAddr3,
+ lifetimeSeconds,
+ prefix,
+ true, /* onLink */
+ true, /* auto */
+ lifetimeSeconds,
+ lifetimeSeconds,
+ ),
+ )
+ select {
+ case e := <-ndpDisp.routerC:
+ if diff := checkRouterEvent(e, llAddr3, true /* discovered */); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID)
+ }
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, subnet, true /* discovered */); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix, nicID)
+ }
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", addr, nicID)
+ }
+
+ // Enabling or disabling forwarding should not invalidate discovered prefixes
+ // or routers, or auto-generated address.
+ for _, forwarding := range [...]bool{true, false} {
+ t.Run(fmt.Sprintf("Transition forwarding to %t", forwarding), func(t *testing.T) {
+ if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
+ }
+ select {
+ case e := <-ndpDisp.routerC:
+ t.Errorf("unexpected router event = %#v", e)
+ default:
+ }
+ select {
+ case e := <-ndpDisp.prefixC:
+ t.Errorf("unexpected prefix event = %#v", e)
+ default:
+ }
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Errorf("unexpected auto-gen addr event = %#v", e)
+ default:
+ }
+ })
+ }
+}
+
func TestCleanupNDPState(t *testing.T) {
const (
lifetimeSeconds = 5
@@ -4659,18 +4816,6 @@ func TestCleanupNDPState(t *testing.T) {
maxAutoGenAddrEvents int
skipFinalAddrCheck bool
}{
- // A NIC should still keep its auto-generated link-local address when
- // becoming a router.
- {
- name: "Enable forwarding",
- cleanupFn: func(t *testing.T, s *stack.Stack) {
- t.Helper()
- s.SetForwarding(ipv6.ProtocolNumber, true)
- },
- keepAutoGenLinkLocal: true,
- maxAutoGenAddrEvents: 4,
- },
-
// A NIC should cleanup all NDP state when it is disabled.
{
name: "Disable NIC",
@@ -4722,7 +4867,7 @@ func TestCleanupNDPState(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
AutoGenLinkLocal: true,
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
DiscoverDefaultRouters: true,
DiscoverOnLinkPrefixes: true,
AutoGenGlobalAddresses: true,
@@ -4995,7 +5140,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
},
NDPDisp: &ndpDisp,
})},
@@ -5186,96 +5331,127 @@ func TestRouterSolicitation(t *testing.T) {
},
}
+ subTests := []struct {
+ name string
+ handleRAs ipv6.HandleRAsConfiguration
+ afterFirstRS func(*testing.T, *stack.Stack)
+ }{
+ {
+ name: "Handle RAs when forwarding disabled",
+ handleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
+ afterFirstRS: func(*testing.T, *stack.Stack) {},
+ },
+
+ // Enabling forwarding when RAs are always configured to be handled
+ // should not stop router solicitations.
+ {
+ name: "Handle RAs always",
+ handleRAs: ipv6.HandlingRAsAlwaysEnabled,
+ afterFirstRS: func(t *testing.T, s *stack.Stack) {
+ if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
+ }
+ },
+ },
+ }
+
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- e := channelLinkWithHeaderLength{
- Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
- headerLength: test.linkHeaderLen,
- }
- e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- waitForPkt := func(timeout time.Duration) {
- t.Helper()
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ e := channelLinkWithHeaderLength{
+ Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
+ headerLength: test.linkHeaderLen,
+ }
+ e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ waitForPkt := func(timeout time.Duration) {
+ t.Helper()
+
+ clock.Advance(timeout)
+ p, ok := e.Read()
+ if !ok {
+ t.Fatal("expected router solicitation packet")
+ }
- clock.Advance(timeout)
- p, ok := e.Read()
- if !ok {
- t.Fatal("expected router solicitation packet")
- }
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
+ // Make sure the right remote link address is used.
+ if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ }
- // Make sure the right remote link address is used.
- if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want {
- t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
- }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(test.expectedSrcAddr),
+ checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
+ )
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(test.expectedSrcAddr),
- checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
- )
+ if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
+ }
+ }
+ waitForNothing := func(timeout time.Duration) {
+ t.Helper()
- if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
- t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
- }
- }
- waitForNothing := func(timeout time.Duration) {
- t.Helper()
+ clock.Advance(timeout)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("unexpectedly got a packet = %#v", p)
+ }
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: subTest.handleRAs,
+ MaxRtrSolicitations: test.maxRtrSolicit,
+ RtrSolicitationInterval: test.rtrSolicitInt,
+ MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
+ },
+ })},
+ Clock: clock,
+ })
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
- clock.Advance(timeout)
- if p, ok := e.Read(); ok {
- t.Fatalf("unexpectedly got a packet = %#v", p)
- }
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- MaxRtrSolicitations: test.maxRtrSolicit,
- RtrSolicitationInterval: test.rtrSolicitInt,
- MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
- },
- })},
- Clock: clock,
- })
- if err := s.CreateNIC(nicID, &e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
+ if addr := test.nicAddr; addr != "" {
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
+ }
+ }
- if addr := test.nicAddr; addr != "" {
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
- }
- }
+ // Make sure each RS is sent at the right time.
+ remaining := test.maxRtrSolicit
+ if remaining > 0 {
+ waitForPkt(test.effectiveMaxRtrSolicitDelay)
+ remaining--
+ }
- // Make sure each RS is sent at the right time.
- remaining := test.maxRtrSolicit
- if remaining > 0 {
- waitForPkt(test.effectiveMaxRtrSolicitDelay)
- remaining--
- }
+ subTest.afterFirstRS(t, s)
- for ; remaining > 0; remaining-- {
- if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
- waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond)
- waitForPkt(time.Nanosecond)
- } else {
- waitForPkt(test.effectiveRtrSolicitInt)
- }
- }
+ for ; remaining > 0; remaining-- {
+ if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
+ waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond)
+ waitForPkt(time.Nanosecond)
+ } else {
+ waitForPkt(test.effectiveRtrSolicitInt)
+ }
+ }
- // Make sure no more RS.
- if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
- waitForNothing(test.effectiveRtrSolicitInt)
- } else {
- waitForNothing(test.effectiveMaxRtrSolicitDelay)
- }
+ // Make sure no more RS.
+ if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
+ waitForNothing(test.effectiveRtrSolicitInt)
+ } else {
+ waitForNothing(test.effectiveMaxRtrSolicitDelay)
+ }
- if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
- t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
+ if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
+ t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
+ }
+ })
}
})
}
@@ -5300,11 +5476,17 @@ func TestStopStartSolicitingRouters(t *testing.T) {
name: "Enable and disable forwarding",
startFn: func(t *testing.T, s *stack.Stack) {
t.Helper()
- s.SetForwarding(ipv6.ProtocolNumber, false)
+
+ if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, false); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, false): %s", ipv6.ProtocolNumber, err)
+ }
},
stopFn: func(t *testing.T, s *stack.Stack, _ bool) {
t.Helper()
- s.SetForwarding(ipv6.ProtocolNumber, true)
+
+ if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
+ }
},
},
@@ -5373,6 +5555,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
MaxRtrSolicitations: maxRtrSolicitations,
RtrSolicitationInterval: interval,
MaxRtrSolicitationDelay: delay,
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index 48bb75e2f..9821a18d3 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -1556,7 +1556,7 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
func BenchmarkCacheClear(b *testing.B) {
b.StopTimer()
config := DefaultNUDConfigurations()
- clock := &tcpip.StdClock{}
+ clock := tcpip.NewStdClock()
linkRes := newTestNeighborResolver(nil, config, clock)
linkRes.delay = 0
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 8d615500f..dbba2c79f 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -1000,3 +1000,32 @@ func (n *nic) checkDuplicateAddress(protocol tcpip.NetworkProtocolNumber, addr t
return d.CheckDuplicateAddress(addr, h), nil
}
+
+func (n *nic) setForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error {
+ ep := n.getNetworkEndpoint(protocol)
+ if ep == nil {
+ return &tcpip.ErrUnknownProtocol{}
+ }
+
+ forwardingEP, ok := ep.(ForwardingNetworkEndpoint)
+ if !ok {
+ return &tcpip.ErrNotSupported{}
+ }
+
+ forwardingEP.SetForwarding(enable)
+ return nil
+}
+
+func (n *nic) forwarding(protocol tcpip.NetworkProtocolNumber) (bool, tcpip.Error) {
+ ep := n.getNetworkEndpoint(protocol)
+ if ep == nil {
+ return false, &tcpip.ErrUnknownProtocol{}
+ }
+
+ forwardingEP, ok := ep.(ForwardingNetworkEndpoint)
+ if !ok {
+ return false, &tcpip.ErrNotSupported{}
+ }
+
+ return forwardingEP.Forwarding(), nil
+}
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 646979d1e..4ca702121 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -16,9 +16,10 @@ package stack
import (
"fmt"
+ "gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ tcpipbuffer "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -39,7 +40,11 @@ type PacketBufferOptions struct {
// Data is the initial unparsed data for the new packet. If set, it will be
// owned by the new packet.
- Data buffer.VectorisedView
+ Data tcpipbuffer.VectorisedView
+
+ // IsForwardedPacket identifies that the PacketBuffer being created is for a
+ // forwarded packet.
+ IsForwardedPacket bool
}
// A PacketBuffer contains all the data of a network packet.
@@ -52,6 +57,34 @@ type PacketBufferOptions struct {
// empty. Use of PacketBuffer in any other order is unsupported.
//
// PacketBuffer must be created with NewPacketBuffer.
+//
+// Internal structure: A PacketBuffer holds a pointer to buffer.Buffer, which
+// exposes a logically-contiguous byte storage. The underlying storage structure
+// is abstracted out, and should not be a concern here for most of the time.
+//
+// |- reserved ->|
+// |--->| consumed (incoming)
+// 0 V V
+// +--------+----+----+--------------------+
+// | | | | current data ... | (buf)
+// +--------+----+----+--------------------+
+// ^ |
+// |<---| pushed (outgoing)
+//
+// When a PacketBuffer is created, a `reserved` header region can be specified,
+// which stack pushes headers in this region for an outgoing packet. There could
+// be no such region for an incoming packet, and `reserved` is 0. The value of
+// `reserved` never changes in the entire lifetime of the packet.
+//
+// Outgoing Packet: When a header is pushed, `pushed` gets incremented by the
+// pushed length, and the current value is stored for each header. PacketBuffer
+// substracts this value from `reserved` to compute the starting offset of each
+// header in `buf`.
+//
+// Incoming Packet: When a header is consumed (a.k.a. parsed), the current
+// `consumed` value is stored for each header, and it gets incremented by the
+// consumed length. PacketBuffer adds this value to `reserved` to compute the
+// starting offset of each header in `buf`.
type PacketBuffer struct {
_ sync.NoCopy
@@ -59,28 +92,16 @@ type PacketBuffer struct {
// PacketBuffers.
PacketBufferEntry
- // data holds the payload of the packet.
- //
- // For inbound packets, Data is initially the whole packet. Then gets moved to
- // headers via PacketHeader.Consume, when the packet is being parsed.
- //
- // For outbound packets, Data is the innermost layer, defined by the protocol.
- // Headers are pushed in front of it via PacketHeader.Push.
- //
- // The bytes backing Data are immutable, a.k.a. users shouldn't write to its
- // backing storage.
- data buffer.VectorisedView
+ // buf is the underlying buffer for the packet. See struct level docs for
+ // details.
+ buf *buffer.Buffer
+ reserved int
+ pushed int
+ consumed int
// headers stores metadata about each header.
headers [numHeaderType]headerInfo
- // header is the internal storage for outbound packets. Headers will be pushed
- // (prepended) on this storage as the packet is being constructed.
- //
- // TODO(gvisor.dev/issue/2404): Switch to an implementation that header and
- // data are held in the same underlying buffer storage.
- header buffer.Prependable
-
// NetworkProtocolNumber is only valid when NetworkHeader().View().IsEmpty()
// returns false.
// TODO(gvisor.dev/issue/3574): Remove the separately passed protocol
@@ -127,10 +148,17 @@ type PacketBuffer struct {
// NewPacketBuffer creates a new PacketBuffer with opts.
func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer {
pk := &PacketBuffer{
- data: opts.Data,
+ buf: &buffer.Buffer{},
}
if opts.ReserveHeaderBytes != 0 {
- pk.header = buffer.NewPrependable(opts.ReserveHeaderBytes)
+ pk.buf.AppendOwned(make([]byte, opts.ReserveHeaderBytes))
+ pk.reserved = opts.ReserveHeaderBytes
+ }
+ for _, v := range opts.Data.Views() {
+ pk.buf.AppendOwned(v)
+ }
+ if opts.IsForwardedPacket {
+ pk.NetworkPacketInfo.IsForwardedPacket = opts.IsForwardedPacket
}
return pk
}
@@ -138,13 +166,13 @@ func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer {
// ReservedHeaderBytes returns the number of bytes initially reserved for
// headers.
func (pk *PacketBuffer) ReservedHeaderBytes() int {
- return pk.header.UsedLength() + pk.header.AvailableLength()
+ return pk.reserved
}
// AvailableHeaderBytes returns the number of bytes currently available for
// headers. This is relevant to PacketHeader.Push method only.
func (pk *PacketBuffer) AvailableHeaderBytes() int {
- return pk.header.AvailableLength()
+ return pk.reserved - pk.pushed
}
// LinkHeader returns the handle to link-layer header.
@@ -173,24 +201,18 @@ func (pk *PacketBuffer) TransportHeader() PacketHeader {
// HeaderSize returns the total size of all headers in bytes.
func (pk *PacketBuffer) HeaderSize() int {
- // Note for inbound packets (Consume called), headers are not stored in
- // pk.header. Thus, calculation of size of each header is needed.
- var size int
- for i := range pk.headers {
- size += len(pk.headers[i].buf)
- }
- return size
+ return pk.pushed + pk.consumed
}
// Size returns the size of packet in bytes.
func (pk *PacketBuffer) Size() int {
- return pk.HeaderSize() + pk.data.Size()
+ return int(pk.buf.Size()) - pk.headerOffset()
}
// MemSize returns the estimation size of the pk in memory, including backing
// buffer data.
func (pk *PacketBuffer) MemSize() int {
- return pk.HeaderSize() + pk.data.MemSize() + packetBufferStructSize
+ return int(pk.buf.Size()) + packetBufferStructSize
}
// Data returns the handle to data portion of pk.
@@ -199,61 +221,65 @@ func (pk *PacketBuffer) Data() PacketData {
}
// Views returns the underlying storage of the whole packet.
-func (pk *PacketBuffer) Views() []buffer.View {
- // Optimization for outbound packets that headers are in pk.header.
- useHeader := true
- for i := range pk.headers {
- if !canUseHeader(&pk.headers[i]) {
- useHeader = false
- break
- }
- }
+func (pk *PacketBuffer) Views() []tcpipbuffer.View {
+ var views []tcpipbuffer.View
+ offset := pk.headerOffset()
+ pk.buf.SubApply(offset, int(pk.buf.Size())-offset, func(v []byte) {
+ views = append(views, v)
+ })
+ return views
+}
- dataViews := pk.data.Views()
-
- var vs []buffer.View
- if useHeader {
- vs = make([]buffer.View, 0, 1+len(dataViews))
- vs = append(vs, pk.header.View())
- } else {
- vs = make([]buffer.View, 0, len(pk.headers)+len(dataViews))
- for i := range pk.headers {
- if v := pk.headers[i].buf; len(v) > 0 {
- vs = append(vs, v)
- }
- }
- }
- return append(vs, dataViews...)
+func (pk *PacketBuffer) headerOffset() int {
+ return pk.reserved - pk.pushed
+}
+
+func (pk *PacketBuffer) headerOffsetOf(typ headerType) int {
+ return pk.reserved + pk.headers[typ].offset
}
-func canUseHeader(h *headerInfo) bool {
- // h.offset will be negative if the header was pushed in to prependable
- // portion, or doesn't matter when it's empty.
- return len(h.buf) == 0 || h.offset < 0
+func (pk *PacketBuffer) dataOffset() int {
+ return pk.reserved + pk.consumed
}
-func (pk *PacketBuffer) push(typ headerType, size int) buffer.View {
+func (pk *PacketBuffer) push(typ headerType, size int) tcpipbuffer.View {
h := &pk.headers[typ]
- if h.buf != nil {
- panic(fmt.Sprintf("push must not be called twice: type %s", typ))
+ if h.length > 0 {
+ panic(fmt.Sprintf("push(%s, %d) called after previous push", typ, size))
+ }
+ if pk.pushed+size > pk.reserved {
+ panic(fmt.Sprintf("push(%s, %d) overflows; pushed=%d reserved=%d", typ, size, pk.pushed, pk.reserved))
}
- h.buf = buffer.View(pk.header.Prepend(size))
- h.offset = -pk.header.UsedLength()
- return h.buf
+ pk.pushed += size
+ h.offset = -pk.pushed
+ h.length = size
+ return pk.headerView(typ)
}
-func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consumed bool) {
+func (pk *PacketBuffer) consume(typ headerType, size int) (v tcpipbuffer.View, consumed bool) {
h := &pk.headers[typ]
- if h.buf != nil {
+ if h.length > 0 {
panic(fmt.Sprintf("consume must not be called twice: type %s", typ))
}
- v, ok := pk.data.PullUp(size)
+ if pk.reserved+pk.consumed+size > int(pk.buf.Size()) {
+ return nil, false
+ }
+ h.offset = pk.consumed
+ h.length = size
+ pk.consumed += size
+ return pk.headerView(typ), true
+}
+
+func (pk *PacketBuffer) headerView(typ headerType) tcpipbuffer.View {
+ h := &pk.headers[typ]
+ if h.length == 0 {
+ return nil
+ }
+ v, ok := pk.buf.PullUp(pk.headerOffsetOf(typ), h.length)
if !ok {
- return
+ panic("PullUp failed")
}
- pk.data.TrimFront(size)
- h.buf = v
- return h.buf, true
+ return v
}
// Clone makes a shallow copy of pk.
@@ -263,9 +289,11 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum
func (pk *PacketBuffer) Clone() *PacketBuffer {
return &PacketBuffer{
PacketBufferEntry: pk.PacketBufferEntry,
- data: pk.data.Clone(nil),
+ buf: pk.buf,
+ reserved: pk.reserved,
+ pushed: pk.pushed,
+ consumed: pk.consumed,
headers: pk.headers,
- header: pk.header,
Hash: pk.Hash,
Owner: pk.Owner,
GSOOptions: pk.GSOOptions,
@@ -299,9 +327,11 @@ func (pk *PacketBuffer) Network() header.Network {
// See PacketBuffer.Data for details about how a packet buffer holds an inbound
// packet.
func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
- newPk := NewPacketBuffer(PacketBufferOptions{
- Data: buffer.NewVectorisedView(pk.Size(), pk.Views()),
- })
+ newPk := &PacketBuffer{
+ buf: pk.buf,
+ // Treat unfilled header portion as reserved.
+ reserved: pk.AvailableHeaderBytes(),
+ }
// TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to
// maintain this flag in the packet. Currently conntrack needs this flag to
// tell if a noop connection should be inserted at Input hook. Once conntrack
@@ -315,15 +345,12 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
// headerInfo stores metadata about a header in a packet.
type headerInfo struct {
- // buf is the memorized slice for both prepended and consumed header.
- // When header is prepended, buf serves as memorized value, which is a slice
- // of pk.header. When header is consumed, buf is the slice pulled out from
- // pk.Data, which is the only place to hold this header.
- buf buffer.View
-
- // offset will be a negative number denoting the offset where this header is
- // from the end of pk.header, if it is prepended. Otherwise, zero.
+ // offset is the offset of the header in pk.buf relative to
+ // pk.buf[pk.reserved]. See the PacketBuffer struct for details.
offset int
+
+ // length is the length of this header.
+ length int
}
// PacketHeader is a handle object to a header in the underlying packet.
@@ -333,14 +360,14 @@ type PacketHeader struct {
}
// View returns the underlying storage of h.
-func (h PacketHeader) View() buffer.View {
- return h.pk.headers[h.typ].buf
+func (h PacketHeader) View() tcpipbuffer.View {
+ return h.pk.headerView(h.typ)
}
// Push pushes size bytes in the front of its residing packet, and returns the
// backing storage. Callers may only call one of Push or Consume once on each
// header in the lifetime of the underlying packet.
-func (h PacketHeader) Push(size int) buffer.View {
+func (h PacketHeader) Push(size int) tcpipbuffer.View {
return h.pk.push(h.typ, size)
}
@@ -349,7 +376,7 @@ func (h PacketHeader) Push(size int) buffer.View {
// size, consumed will be false, and the state of h will not be affected.
// Callers may only call one of Push or Consume once on each header in the
// lifetime of the underlying packet.
-func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) {
+func (h PacketHeader) Consume(size int) (v tcpipbuffer.View, consumed bool) {
return h.pk.consume(h.typ, size)
}
@@ -360,54 +387,84 @@ type PacketData struct {
// PullUp returns a contiguous view of size bytes from the beginning of d.
// Callers should not write to or keep the view for later use.
-func (d PacketData) PullUp(size int) (buffer.View, bool) {
- return d.pk.data.PullUp(size)
+func (d PacketData) PullUp(size int) (tcpipbuffer.View, bool) {
+ return d.pk.buf.PullUp(d.pk.dataOffset(), size)
}
-// TrimFront removes count from the beginning of d. It panics if count >
-// d.Size().
-func (d PacketData) TrimFront(count int) {
- d.pk.data.TrimFront(count)
+// DeleteFront removes count from the beginning of d. It panics if count >
+// d.Size(). All backing storage references after the front of the d are
+// invalidated.
+func (d PacketData) DeleteFront(count int) {
+ if !d.pk.buf.Remove(d.pk.dataOffset(), count) {
+ panic("count > d.Size()")
+ }
}
// CapLength reduces d to at most length bytes.
func (d PacketData) CapLength(length int) {
- d.pk.data.CapLength(length)
+ if length < 0 {
+ panic("length < 0")
+ }
+ if currLength := d.Size(); currLength > length {
+ trim := currLength - length
+ d.pk.buf.Remove(int(d.pk.buf.Size())-trim, trim)
+ }
}
// Views returns the underlying storage of d in a slice of Views. Caller should
// not modify the returned slice.
-func (d PacketData) Views() []buffer.View {
- return d.pk.data.Views()
+func (d PacketData) Views() []tcpipbuffer.View {
+ var views []tcpipbuffer.View
+ offset := d.pk.dataOffset()
+ d.pk.buf.SubApply(offset, int(d.pk.buf.Size())-offset, func(v []byte) {
+ views = append(views, v)
+ })
+ return views
}
// AppendView appends v into d, taking the ownership of v.
-func (d PacketData) AppendView(v buffer.View) {
- d.pk.data.AppendView(v)
+func (d PacketData) AppendView(v tcpipbuffer.View) {
+ d.pk.buf.AppendOwned(v)
}
-// ReadFromData moves at most count bytes from the beginning of srcData to the
-// end of d and returns the number of bytes moved.
-func (d PacketData) ReadFromData(srcData PacketData, count int) int {
- return srcData.pk.data.ReadToVV(&d.pk.data, count)
+// MergeFragment appends the data portion of frag to dst. It takes ownership of
+// frag and frag should not be used again.
+func MergeFragment(dst, frag *PacketBuffer) {
+ frag.buf.TrimFront(int64(frag.dataOffset()))
+ dst.buf.Merge(frag.buf)
}
// ReadFromVV moves at most count bytes from the beginning of srcVV to the end
// of d and returns the number of bytes moved.
-func (d PacketData) ReadFromVV(srcVV *buffer.VectorisedView, count int) int {
- return srcVV.ReadToVV(&d.pk.data, count)
+func (d PacketData) ReadFromVV(srcVV *tcpipbuffer.VectorisedView, count int) int {
+ done := 0
+ for _, v := range srcVV.Views() {
+ if len(v) < count {
+ count -= len(v)
+ done += len(v)
+ d.pk.buf.AppendOwned(v)
+ } else {
+ v = v[:count]
+ count -= len(v)
+ done += len(v)
+ d.pk.buf.Append(v)
+ break
+ }
+ }
+ srcVV.TrimFront(done)
+ return done
}
// Size returns the number of bytes in the data payload of the packet.
func (d PacketData) Size() int {
- return d.pk.data.Size()
+ return int(d.pk.buf.Size()) - d.pk.dataOffset()
}
// AsRange returns a Range representing the current data payload of the packet.
func (d PacketData) AsRange() Range {
return Range{
pk: d.pk,
- offset: d.pk.HeaderSize(),
+ offset: d.pk.dataOffset(),
length: d.Size(),
}
}
@@ -417,17 +474,12 @@ func (d PacketData) AsRange() Range {
//
// This method exists for compatibility between PacketBuffer and VectorisedView.
// It may be removed later and should be used with care.
-func (d PacketData) ExtractVV() buffer.VectorisedView {
- return d.pk.data
-}
-
-// Replace replaces the data portion of the packet with vv, taking the ownership
-// of vv.
-//
-// This method exists for compatibility between PacketBuffer and VectorisedView.
-// It may be removed later and should be used with care.
-func (d PacketData) Replace(vv buffer.VectorisedView) {
- d.pk.data = vv
+func (d PacketData) ExtractVV() tcpipbuffer.VectorisedView {
+ var vv tcpipbuffer.VectorisedView
+ d.pk.buf.SubApply(d.pk.dataOffset(), d.pk.Size(), func(v []byte) {
+ vv.AppendView(v)
+ })
+ return vv
}
// Range represents a contiguous subportion of a PacketBuffer.
@@ -471,9 +523,9 @@ func (r Range) Capped(max int) Range {
// AsView returns the backing storage of r if possible. It will allocate a new
// View if r spans multiple pieces internally. Caller should not write to the
// returned View in any way.
-func (r Range) AsView() buffer.View {
+func (r Range) AsView() tcpipbuffer.View {
var allocated bool
- var v buffer.View
+ var v tcpipbuffer.View
r.iterate(func(b []byte) {
if v == nil {
// v has not been assigned, allowing first view to be returned.
@@ -494,7 +546,7 @@ func (r Range) AsView() buffer.View {
}
// ToOwnedView returns a owned copy of data in r.
-func (r Range) ToOwnedView() buffer.View {
+func (r Range) ToOwnedView() tcpipbuffer.View {
if r.length == 0 {
return nil
}
@@ -515,63 +567,7 @@ func (r Range) Checksum() uint16 {
// iterate calls fn for each piece in r. fn is always called with a non-empty
// slice.
func (r Range) iterate(fn func([]byte)) {
- w := window{
- offset: r.offset,
- length: r.length,
- }
- // Header portion.
- for i := range r.pk.headers {
- if b := w.process(r.pk.headers[i].buf); len(b) > 0 {
- fn(b)
- }
- if w.isDone() {
- break
- }
- }
- // Data portion.
- if !w.isDone() {
- for _, v := range r.pk.data.Views() {
- if b := w.process(v); len(b) > 0 {
- fn(b)
- }
- if w.isDone() {
- break
- }
- }
- }
-}
-
-// window represents contiguous region of byte stream. User would call process()
-// to input bytes, and obtain a subslice that is inside the window.
-type window struct {
- offset int
- length int
-}
-
-// isDone returns true if the window has passed and further process() calls will
-// always return an empty slice. This can be used to end processing early.
-func (w *window) isDone() bool {
- return w.length == 0
-}
-
-// process feeds b in and returns a subslice that is inside the window. The
-// returned slice will be a subslice of b, and it does not keep b after method
-// returns. This method may return an empty slice if nothing in b is inside the
-// window.
-func (w *window) process(b []byte) (inWindow []byte) {
- if w.offset >= len(b) {
- w.offset -= len(b)
- return nil
- }
- if w.offset > 0 {
- b = b[w.offset:]
- w.offset = 0
- }
- if w.length < len(b) {
- b = b[:w.length]
- }
- w.length -= len(b)
- return b
+ r.pk.buf.SubApply(r.offset, r.length, fn)
}
// PayloadSince returns packet payload starting from and including a particular
@@ -579,21 +575,14 @@ func (w *window) process(b []byte) (inWindow []byte) {
//
// The returned View is owned by the caller - its backing buffer is separate
// from the packet header's underlying packet buffer.
-func PayloadSince(h PacketHeader) buffer.View {
- size := h.pk.data.Size()
- for _, hinfo := range h.pk.headers[h.typ:] {
- size += len(hinfo.buf)
+func PayloadSince(h PacketHeader) tcpipbuffer.View {
+ offset := h.pk.headerOffset()
+ for i := headerType(0); i < h.typ; i++ {
+ offset += h.pk.headers[i].length
}
-
- v := make(buffer.View, 0, size)
-
- for _, hinfo := range h.pk.headers[h.typ:] {
- v = append(v, hinfo.buf...)
- }
-
- for _, view := range h.pk.data.Views() {
- v = append(v, view...)
- }
-
- return v
+ return Range{
+ pk: h.pk,
+ offset: offset,
+ length: int(h.pk.buf.Size()) - offset,
+ }.ToOwnedView()
}
diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go
index 6728370c3..a8da34992 100644
--- a/pkg/tcpip/stack/packet_buffer_test.go
+++ b/pkg/tcpip/stack/packet_buffer_test.go
@@ -112,23 +112,13 @@ func TestPacketHeaderPush(t *testing.T) {
if got, want := pk.Size(), allHdrSize+len(test.data); got != want {
t.Errorf("After pk.Size() = %d, want %d", got, want)
}
- checkData(t, pk, test.data)
- checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...),
- concatViews(test.link, test.network, test.transport, test.data))
- // Check the after values for each header.
- checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), test.link)
- checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), test.network)
- checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), test.transport)
- // Check the after values for PayloadSince.
- checkViewEqual(t, "After PayloadSince(LinkHeader)",
- PayloadSince(pk.LinkHeader()),
- concatViews(test.link, test.network, test.transport, test.data))
- checkViewEqual(t, "After PayloadSince(NetworkHeader)",
- PayloadSince(pk.NetworkHeader()),
- concatViews(test.network, test.transport, test.data))
- checkViewEqual(t, "After PayloadSince(TransportHeader)",
- PayloadSince(pk.TransportHeader()),
- concatViews(test.transport, test.data))
+ // Check the after state.
+ checkPacketContents(t, "After ", pk, packetContents{
+ link: test.link,
+ network: test.network,
+ transport: test.transport,
+ data: test.data,
+ })
})
}
}
@@ -199,29 +189,13 @@ func TestPacketHeaderConsume(t *testing.T) {
if got, want := pk.Size(), len(test.data); got != want {
t.Errorf("After pk.Size() = %d, want %d", got, want)
}
- // After state of pk.
- var (
- link = test.data[:test.link]
- network = test.data[test.link:][:test.network]
- transport = test.data[test.link+test.network:][:test.transport]
- payload = test.data[allHdrSize:]
- )
- checkData(t, pk, payload)
- checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data)
- // Check the after values for each header.
- checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link)
- checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), network)
- checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), transport)
- // Check the after values for PayloadSince.
- checkViewEqual(t, "After PayloadSince(LinkHeader)",
- PayloadSince(pk.LinkHeader()),
- concatViews(link, network, transport, payload))
- checkViewEqual(t, "After PayloadSince(NetworkHeader)",
- PayloadSince(pk.NetworkHeader()),
- concatViews(network, transport, payload))
- checkViewEqual(t, "After PayloadSince(TransportHeader)",
- PayloadSince(pk.TransportHeader()),
- concatViews(transport, payload))
+ // Check the after state of pk.
+ checkPacketContents(t, "After ", pk, packetContents{
+ link: test.data[:test.link],
+ network: test.data[test.link:][:test.network],
+ transport: test.data[test.link+test.network:][:test.transport],
+ data: test.data[allHdrSize:],
+ })
})
}
}
@@ -252,6 +226,70 @@ func TestPacketHeaderConsumeDataTooShort(t *testing.T) {
})
}
+// This is a very obscure use-case seen in the code that verifies packets
+// before sending them out. It tries to parse the headers to verify.
+// PacketHeader was initially not designed to mix Push() and Consume(), but it
+// works and it's been relied upon. Include a test here.
+func TestPacketHeaderPushConsumeMixed(t *testing.T) {
+ link := makeView(10)
+ network := makeView(20)
+ data := makeView(30)
+
+ initData := append([]byte(nil), network...)
+ initData = append(initData, data...)
+ pk := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: len(link),
+ Data: buffer.NewViewFromBytes(initData).ToVectorisedView(),
+ })
+
+ // 1. Consume network header
+ gotNetwork, ok := pk.NetworkHeader().Consume(len(network))
+ if !ok {
+ t.Fatalf("pk.NetworkHeader().Consume(%d) = _, false; want _, true", len(network))
+ }
+ checkViewEqual(t, "gotNetwork", gotNetwork, network)
+
+ // 2. Push link header
+ copy(pk.LinkHeader().Push(len(link)), link)
+
+ checkPacketContents(t, "" /* prefix */, pk, packetContents{
+ link: link,
+ network: network,
+ data: data,
+ })
+}
+
+func TestPacketHeaderPushConsumeMixedTooLong(t *testing.T) {
+ link := makeView(10)
+ network := makeView(20)
+ data := makeView(30)
+
+ initData := concatViews(network, data)
+ pk := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: len(link),
+ Data: buffer.NewViewFromBytes(initData).ToVectorisedView(),
+ })
+
+ // 1. Push link header
+ copy(pk.LinkHeader().Push(len(link)), link)
+
+ checkPacketContents(t, "" /* prefix */, pk, packetContents{
+ link: link,
+ data: initData,
+ })
+
+ // 2. Consume network header, with a number of bytes too large.
+ gotNetwork, ok := pk.NetworkHeader().Consume(len(initData) + 1)
+ if ok {
+ t.Fatalf("pk.NetworkHeader().Consume(%d) = %q, true; want _, false", len(initData)+1, gotNetwork)
+ }
+
+ checkPacketContents(t, "" /* prefix */, pk, packetContents{
+ link: link,
+ data: initData,
+ })
+}
+
func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) {
const headerSize = 10
@@ -397,11 +435,11 @@ func TestPacketBufferData(t *testing.T) {
}
})
- // TrimFront
+ // DeleteFront
for _, n := range []int{1, len(tc.data)} {
- t.Run(fmt.Sprintf("TrimFront%d", n), func(t *testing.T) {
+ t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) {
pkt := tc.makePkt(t)
- pkt.Data().TrimFront(n)
+ pkt.Data().DeleteFront(n)
checkData(t, pkt, []byte(tc.data)[n:])
})
@@ -437,23 +475,8 @@ func TestPacketBufferData(t *testing.T) {
checkData(t, pkt, []byte(tc.data+s))
})
- // ReadFromData/VV
+ // ReadFromVV
for _, n := range []int{0, 1, 2, 7, 10, 14, 20} {
- t.Run(fmt.Sprintf("ReadFromData%d", n), func(t *testing.T) {
- s := "TO READ"
- otherPkt := NewPacketBuffer(PacketBufferOptions{
- Data: vv(s, s),
- })
- s += s
-
- pkt := tc.makePkt(t)
- pkt.Data().ReadFromData(otherPkt.Data(), n)
-
- if n < len(s) {
- s = s[:n]
- }
- checkData(t, pkt, []byte(tc.data+s))
- })
t.Run(fmt.Sprintf("ReadFromVV%d", n), func(t *testing.T) {
s := "TO READ"
srcVV := vv(s, s)
@@ -480,20 +503,41 @@ func TestPacketBufferData(t *testing.T) {
t.Errorf("pkt.Data().ExtractVV().ToOwnedView() = %q, want %q", got, want)
}
})
-
- // Replace
- t.Run("Replace", func(t *testing.T) {
- s := "REPLACED"
-
- pkt := tc.makePkt(t)
- pkt.Data().Replace(vv(s))
-
- checkData(t, pkt, []byte(s))
- })
})
}
}
+type packetContents struct {
+ link buffer.View
+ network buffer.View
+ transport buffer.View
+ data buffer.View
+}
+
+func checkPacketContents(t *testing.T, prefix string, pk *PacketBuffer, want packetContents) {
+ t.Helper()
+ // Headers.
+ checkPacketHeader(t, prefix+"pk.LinkHeader", pk.LinkHeader(), want.link)
+ checkPacketHeader(t, prefix+"pk.NetworkHeader", pk.NetworkHeader(), want.network)
+ checkPacketHeader(t, prefix+"pk.TransportHeader", pk.TransportHeader(), want.transport)
+ // Data.
+ checkData(t, pk, want.data)
+ // Whole packet.
+ checkViewEqual(t, prefix+"pk.Views()",
+ concatViews(pk.Views()...),
+ concatViews(want.link, want.network, want.transport, want.data))
+ // PayloadSince.
+ checkViewEqual(t, prefix+"PayloadSince(LinkHeader)",
+ PayloadSince(pk.LinkHeader()),
+ concatViews(want.link, want.network, want.transport, want.data))
+ checkViewEqual(t, prefix+"PayloadSince(NetworkHeader)",
+ PayloadSince(pk.NetworkHeader()),
+ concatViews(want.network, want.transport, want.data))
+ checkViewEqual(t, prefix+"PayloadSince(TransportHeader)",
+ PayloadSince(pk.TransportHeader()),
+ concatViews(want.transport, want.data))
+}
+
func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) {
t.Helper()
reserved := opts.ReserveHeaderBytes
@@ -510,19 +554,9 @@ func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferO
if got, want := pk.Size(), len(data); got != want {
t.Errorf("Initial pk.Size() = %d, want %d", got, want)
}
- checkData(t, pk, data)
- checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data)
- // Check the initial values for each header.
- checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil)
- checkPacketHeader(t, "Initial pk.NetworkHeader", pk.NetworkHeader(), nil)
- checkPacketHeader(t, "Initial pk.TransportHeader", pk.TransportHeader(), nil)
- // Check the initial valies for PayloadSince.
- checkViewEqual(t, "Initial PayloadSince(LinkHeader)",
- PayloadSince(pk.LinkHeader()), data)
- checkViewEqual(t, "Initial PayloadSince(NetworkHeader)",
- PayloadSince(pk.NetworkHeader()), data)
- checkViewEqual(t, "Initial PayloadSince(TransportHeader)",
- PayloadSince(pk.TransportHeader()), data)
+ checkPacketContents(t, "Initial ", pk, packetContents{
+ data: data,
+ })
}
func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) {
@@ -540,7 +574,7 @@ func checkViewEqual(t *testing.T, what string, got, want buffer.View) {
func checkData(t *testing.T, pkt *PacketBuffer, want []byte) {
t.Helper()
if got := concatViews(pkt.Data().Views()...); !bytes.Equal(got, want) {
- t.Errorf("pkt.Data().Views() = %x, want %x", got, want)
+ t.Errorf("pkt.Data().Views() = 0x%x, want 0x%x", got, want)
}
if got := pkt.Data().Size(); got != len(want) {
t.Errorf("pkt.Data().Size() = %d, want %d", got, len(want))
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 7ad206f6d..85bb87b4b 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -55,6 +55,9 @@ type NetworkPacketInfo struct {
// LocalAddressBroadcast is true if the packet's local address is a broadcast
// address.
LocalAddressBroadcast bool
+
+ // IsForwardedPacket is true if the packet is being forwarded.
+ IsForwardedPacket bool
}
// TransportErrorKind enumerates error types that are handled by the transport
@@ -655,9 +658,9 @@ type IPNetworkEndpointStats interface {
IPStats() *tcpip.IPStats
}
-// ForwardingNetworkProtocol is a NetworkProtocol that may forward packets.
-type ForwardingNetworkProtocol interface {
- NetworkProtocol
+// ForwardingNetworkEndpoint is a network endpoint that may forward packets.
+type ForwardingNetworkEndpoint interface {
+ NetworkEndpoint
// Forwarding returns the forwarding configuration.
Forwarding() bool
@@ -756,11 +759,6 @@ const (
CapabilitySaveRestore
CapabilityDisconnectOk
CapabilityLoopback
- CapabilityHardwareGSO
-
- // CapabilitySoftwareGSO indicates the link endpoint supports of sending
- // multiple packets using a single call (LinkEndpoint.WritePackets).
- CapabilitySoftwareGSO
)
// NetworkLinkEndpoint is a data-link layer that supports sending network
@@ -1047,10 +1045,29 @@ type GSO struct {
MaxSize uint32
}
+// SupportedGSO returns the type of segmentation offloading supported.
+type SupportedGSO int
+
+const (
+ // GSONotSupported indicates that segmentation offloading is not supported.
+ GSONotSupported SupportedGSO = iota
+
+ // HWGSOSupported indicates that segmentation offloading may be performed by
+ // the hardware.
+ HWGSOSupported
+
+ // SWGSOSupported indicates that segmentation offloading may be performed in
+ // software.
+ SWGSOSupported
+)
+
// GSOEndpoint provides access to GSO properties.
type GSOEndpoint interface {
// GSOMaxSize returns the maximum GSO packet size.
GSOMaxSize() uint32
+
+ // SupportedGSO returns the supported segmentation offloading.
+ SupportedGSO() SupportedGSO
}
// SoftwareGSOMaxSize is a maximum allowed size of a software GSO segment.
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 4ecde5995..f17c04277 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -300,12 +300,18 @@ func (r *Route) RequiresTXTransportChecksum() bool {
// HasSoftwareGSOCapability returns true if the route supports software GSO.
func (r *Route) HasSoftwareGSOCapability() bool {
- return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySoftwareGSO != 0
+ if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok {
+ return gso.SupportedGSO() == SWGSOSupported
+ }
+ return false
}
// HasHardwareGSOCapability returns true if the route supports hardware GSO.
func (r *Route) HasHardwareGSOCapability() bool {
- return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityHardwareGSO != 0
+ if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok {
+ return gso.SupportedGSO() == HWGSOSupported
+ }
+ return false
}
// HasSaveRestoreCapability returns true if the route supports save/restore.
@@ -440,7 +446,7 @@ func (r *Route) isValidForOutgoingRLocked() bool {
// If the source NIC and outgoing NIC are different, make sure the stack has
// forwarding enabled, or the packet will be handled locally.
- if r.outgoingNIC != r.localAddressNIC && !r.outgoingNIC.stack.Forwarding(r.NetProto()) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto(), r.RemoteAddress())) {
+ if r.outgoingNIC != r.localAddressNIC && !isNICForwarding(r.localAddressNIC, r.NetProto()) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto(), r.RemoteAddress())) {
return false
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 843118b13..8814f45a6 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -29,6 +29,7 @@ import (
"time"
"golang.org/x/time/rate"
+ "gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -65,10 +66,10 @@ type ResumableEndpoint interface {
}
// uniqueIDGenerator is a default unique ID generator.
-type uniqueIDGenerator uint64
+type uniqueIDGenerator atomicbitops.AlignedAtomicUint64
func (u *uniqueIDGenerator) UniqueID() uint64 {
- return atomic.AddUint64((*uint64)(u), 1)
+ return ((*atomicbitops.AlignedAtomicUint64)(u)).Add(1)
}
// Stack is a networking stack, with all supported protocols, NICs, and route
@@ -94,8 +95,9 @@ type Stack struct {
}
}
- mu sync.RWMutex
- nics map[tcpip.NICID]*nic
+ mu sync.RWMutex
+ nics map[tcpip.NICID]*nic
+ defaultForwardingEnabled map[tcpip.NetworkProtocolNumber]struct{}
// cleanupEndpointsMu protects cleanupEndpoints.
cleanupEndpointsMu sync.Mutex
@@ -322,7 +324,7 @@ func (*TransportEndpointInfo) IsEndpointInfo() {}
func New(opts Options) *Stack {
clock := opts.Clock
if clock == nil {
- clock = &tcpip.StdClock{}
+ clock = tcpip.NewStdClock()
}
if opts.UniqueID == nil {
@@ -347,22 +349,23 @@ func New(opts Options) *Stack {
}
s := &Stack{
- transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
- networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
- nics: make(map[tcpip.NICID]*nic),
- cleanupEndpoints: make(map[TransportEndpoint]struct{}),
- PortManager: ports.NewPortManager(),
- clock: clock,
- stats: opts.Stats.FillIn(),
- handleLocal: opts.HandleLocal,
- tables: opts.IPTables,
- icmpRateLimiter: NewICMPRateLimiter(),
- seed: generateRandUint32(),
- nudConfigs: opts.NUDConfigs,
- uniqueIDGenerator: opts.UniqueID,
- nudDisp: opts.NUDDisp,
- randomGenerator: mathrand.New(randSrc),
- secureRNG: opts.SecureRNG,
+ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
+ networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
+ nics: make(map[tcpip.NICID]*nic),
+ defaultForwardingEnabled: make(map[tcpip.NetworkProtocolNumber]struct{}),
+ cleanupEndpoints: make(map[TransportEndpoint]struct{}),
+ PortManager: ports.NewPortManager(),
+ clock: clock,
+ stats: opts.Stats.FillIn(),
+ handleLocal: opts.HandleLocal,
+ tables: opts.IPTables,
+ icmpRateLimiter: NewICMPRateLimiter(),
+ seed: generateRandUint32(),
+ nudConfigs: opts.NUDConfigs,
+ uniqueIDGenerator: opts.UniqueID,
+ nudDisp: opts.NUDDisp,
+ randomGenerator: mathrand.New(randSrc),
+ secureRNG: opts.SecureRNG,
sendBufferSize: tcpip.SendBufferSizeOption{
Min: MinBufferSize,
Default: DefaultBufferSize,
@@ -491,37 +494,61 @@ func (s *Stack) Stats() tcpip.Stats {
return s.stats
}
-// SetForwarding enables or disables packet forwarding between NICs for the
-// passed protocol.
-func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) tcpip.Error {
- protocol, ok := s.networkProtocols[protocolNum]
+// SetNICForwarding enables or disables packet forwarding on the specified NIC
+// for the passed protocol.
+func (s *Stack) SetNICForwarding(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[id]
if !ok {
- return &tcpip.ErrUnknownProtocol{}
+ return &tcpip.ErrUnknownNICID{}
}
- forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol)
+ return nic.setForwarding(protocol, enable)
+}
+
+// NICForwarding returns the forwarding configuration for the specified NIC.
+func (s *Stack) NICForwarding(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (bool, tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[id]
if !ok {
- return &tcpip.ErrNotSupported{}
+ return false, &tcpip.ErrUnknownNICID{}
}
- forwardingProtocol.SetForwarding(enable)
- return nil
+ return nic.forwarding(protocol)
}
-// Forwarding returns true if packet forwarding between NICs is enabled for the
-// passed protocol.
-func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool {
- protocol, ok := s.networkProtocols[protocolNum]
- if !ok {
- return false
+// SetForwardingDefaultAndAllNICs sets packet forwarding for all NICs for the
+// passed protocol and sets the default setting for newly created NICs.
+func (s *Stack) SetForwardingDefaultAndAllNICs(protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ doneOnce := false
+ for id, nic := range s.nics {
+ if err := nic.setForwarding(protocol, enable); err != nil {
+ // Expect forwarding to be settable on all interfaces if it was set on
+ // one.
+ if doneOnce {
+ panic(fmt.Sprintf("nic(id=%d).setForwarding(%d, %t): %s", id, protocol, enable, err))
+ }
+
+ return err
+ }
+
+ doneOnce = true
}
- forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol)
- if !ok {
- return false
+ if enable {
+ s.defaultForwardingEnabled[protocol] = struct{}{}
+ } else {
+ delete(s.defaultForwardingEnabled, protocol)
}
- return forwardingProtocol.Forwarding()
+ return nil
}
// PortRange returns the UDP and TCP inclusive range of ephemeral ports used in
@@ -658,6 +685,11 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp
}
n := newNIC(s, id, opts.Name, ep, opts.Context)
+ for proto := range s.defaultForwardingEnabled {
+ if err := n.setForwarding(proto, true); err != nil {
+ panic(fmt.Sprintf("newNIC(%d, ...).setForwarding(%d, true): %s", id, proto, err))
+ }
+ }
s.nics[id] = n
if !opts.Disabled {
return n.enable()
@@ -785,6 +817,10 @@ type NICInfo struct {
// value sent in haType field of an ARP Request sent by this NIC and the
// value expected in the haType field of an ARP response.
ARPHardwareType header.ARPHardwareType
+
+ // Forwarding holds the forwarding status for each network endpoint that
+ // supports forwarding.
+ Forwarding map[tcpip.NetworkProtocolNumber]bool
}
// HasNIC returns true if the NICID is defined in the stack.
@@ -814,7 +850,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
netStats[proto] = netEP.Stats()
}
- nics[id] = NICInfo{
+ info := NICInfo{
Name: nic.name,
LinkAddress: nic.LinkEndpoint.LinkAddress(),
ProtocolAddresses: nic.primaryAddresses(),
@@ -824,7 +860,23 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
NetworkStats: netStats,
Context: nic.context,
ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(),
+ Forwarding: make(map[tcpip.NetworkProtocolNumber]bool),
}
+
+ for proto := range s.networkProtocols {
+ switch forwarding, err := nic.forwarding(proto); err.(type) {
+ case nil:
+ info.Forwarding[proto] = forwarding
+ case *tcpip.ErrUnknownProtocol:
+ panic(fmt.Sprintf("expected network protocol %d to be available on NIC %d", proto, nic.ID()))
+ case *tcpip.ErrNotSupported:
+ // Not all network protocols support forwarding.
+ default:
+ panic(fmt.Sprintf("nic(id=%d).forwarding(%d): %s", nic.ID(), proto, err))
+ }
+ }
+
+ nics[id] = info
}
return nics
}
@@ -1028,6 +1080,20 @@ func (s *Stack) HandleLocal() bool {
return s.handleLocal
}
+func isNICForwarding(nic *nic, proto tcpip.NetworkProtocolNumber) bool {
+ switch forwarding, err := nic.forwarding(proto); err.(type) {
+ case nil:
+ return forwarding
+ case *tcpip.ErrUnknownProtocol:
+ panic(fmt.Sprintf("expected network protocol %d to be available on NIC %d", proto, nic.ID()))
+ case *tcpip.ErrNotSupported:
+ // Not all network protocols support forwarding.
+ return false
+ default:
+ panic(fmt.Sprintf("nic(id=%d).forwarding(%d): %s", nic.ID(), proto, err))
+ }
+}
+
// FindRoute creates a route to the given destination address, leaving through
// the given NIC and local address (if provided).
//
@@ -1080,7 +1146,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
return nil, &tcpip.ErrNetworkUnreachable{}
}
- canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal
+ onlyGlobalAddresses := !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal
// Find a route to the remote with the route table.
var chosenRoute tcpip.Route
@@ -1119,7 +1185,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
// requirement to do this from any RFC but simply a choice made to better
// follow a strong host model which the netstack follows at the time of
// writing.
- if canForward && chosenRoute == (tcpip.Route{}) {
+ if onlyGlobalAddresses && chosenRoute == (tcpip.Route{}) && isNICForwarding(nic, netProto) {
chosenRoute = route
}
}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 8ead3b8df..02d54d29b 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -84,7 +84,8 @@ type fakeNetworkEndpoint struct {
mu struct {
sync.RWMutex
- enabled bool
+ enabled bool
+ forwarding bool
}
nic stack.NetworkInterface
@@ -138,11 +139,13 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Handle control packets.
if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) {
- nb, ok := pkt.Data().PullUp(fakeNetHeaderLen)
+ hdr, ok := pkt.Data().PullUp(fakeNetHeaderLen)
if !ok {
return
}
- pkt.Data().TrimFront(fakeNetHeaderLen)
+ // DeleteFront invalidates slices. Make a copy before trimming.
+ nb := append([]byte(nil), hdr...)
+ pkt.Data().DeleteFront(fakeNetHeaderLen)
f.dispatcher.DeliverTransportError(
tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]),
tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]),
@@ -225,11 +228,6 @@ type fakeNetworkProtocol struct {
packetCount [10]int
sendPacketCount [10]int
defaultTTL uint8
-
- mu struct {
- sync.RWMutex
- forwarding bool
- }
}
func (*fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
@@ -298,15 +296,15 @@ func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProto
return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true
}
-// Forwarding implements stack.ForwardingNetworkProtocol.
-func (f *fakeNetworkProtocol) Forwarding() bool {
+// Forwarding implements stack.ForwardingNetworkEndpoint.
+func (f *fakeNetworkEndpoint) Forwarding() bool {
f.mu.RLock()
defer f.mu.RUnlock()
return f.mu.forwarding
}
-// SetForwarding implements stack.ForwardingNetworkProtocol.
-func (f *fakeNetworkProtocol) SetForwarding(v bool) {
+// SetForwarding implements stack.ForwardingNetworkEndpoint.
+func (f *fakeNetworkEndpoint) SetForwarding(v bool) {
f.mu.Lock()
defer f.mu.Unlock()
f.mu.forwarding = v
@@ -3020,7 +3018,7 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
},
@@ -4218,8 +4216,8 @@ func TestFindRouteWithForwarding(t *testing.T) {
t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err)
}
- if err := s.SetForwarding(test.netCfg.proto, test.forwardingEnabled); err != nil {
- t.Fatalf("SetForwarding(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err)
+ if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, test.forwardingEnabled); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err)
}
s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}})
@@ -4273,8 +4271,8 @@ func TestFindRouteWithForwarding(t *testing.T) {
// Disabling forwarding when the route is dependent on forwarding being
// enabled should make the route invalid.
- if err := s.SetForwarding(test.netCfg.proto, false); err != nil {
- t.Fatalf("SetForwarding(%d, false): %s", test.netCfg.proto, err)
+ if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, false); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, false): %s", test.netCfg.proto, err)
}
{
err := send(r, data)
diff --git a/pkg/tcpip/stdclock.go b/pkg/tcpip/stdclock.go
new file mode 100644
index 000000000..7ce43a68e
--- /dev/null
+++ b/pkg/tcpip/stdclock.go
@@ -0,0 +1,130 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpip
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// stdClock implements Clock with the time package.
+//
+// +stateify savable
+type stdClock struct {
+ // baseTime holds the time when the clock was constructed.
+ //
+ // This value is used to calculate the monotonic time from the time package.
+ // As per https://golang.org/pkg/time/#hdr-Monotonic_Clocks,
+ //
+ // Operating systems provide both a “wall clock,” which is subject to
+ // changes for clock synchronization, and a “monotonic clock,” which is not.
+ // The general rule is that the wall clock is for telling time and the
+ // monotonic clock is for measuring time. Rather than split the API, in this
+ // package the Time returned by time.Now contains both a wall clock reading
+ // and a monotonic clock reading; later time-telling operations use the wall
+ // clock reading, but later time-measuring operations, specifically
+ // comparisons and subtractions, use the monotonic clock reading.
+ //
+ // ...
+ //
+ // If Times t and u both contain monotonic clock readings, the operations
+ // t.After(u), t.Before(u), t.Equal(u), and t.Sub(u) are carried out using
+ // the monotonic clock readings alone, ignoring the wall clock readings. If
+ // either t or u contains no monotonic clock reading, these operations fall
+ // back to using the wall clock readings.
+ //
+ // Given the above, we can safely conclude that time.Since(baseTime) will
+ // return monotonically increasing values if we use time.Now() to set baseTime
+ // at the time of clock construction.
+ //
+ // Note that time.Since(t) is shorthand for time.Now().Sub(t), as per
+ // https://golang.org/pkg/time/#Since.
+ baseTime time.Time `state:"nosave"`
+
+ // monotonicOffset is the offset applied to the calculated monotonic time.
+ //
+ // monotonicOffset is assigned maxMonotonic after restore so that the
+ // monotonic time will continue from where it "left off" before saving as part
+ // of S/R.
+ monotonicOffset int64 `state:"nosave"`
+
+ // monotonicMU protects maxMonotonic.
+ monotonicMU sync.Mutex `state:"nosave"`
+ maxMonotonic int64
+}
+
+// NewStdClock returns an instance of a clock that uses the time package.
+func NewStdClock() Clock {
+ return &stdClock{
+ baseTime: time.Now(),
+ }
+}
+
+var _ Clock = (*stdClock)(nil)
+
+// NowNanoseconds implements Clock.NowNanoseconds.
+func (*stdClock) NowNanoseconds() int64 {
+ return time.Now().UnixNano()
+}
+
+// NowMonotonic implements Clock.NowMonotonic.
+func (s *stdClock) NowMonotonic() int64 {
+ sinceBase := time.Since(s.baseTime)
+ if sinceBase < 0 {
+ panic(fmt.Sprintf("got negative duration = %s since base time = %s", sinceBase, s.baseTime))
+ }
+
+ monotonicValue := sinceBase.Nanoseconds() + s.monotonicOffset
+
+ s.monotonicMU.Lock()
+ defer s.monotonicMU.Unlock()
+
+ // Monotonic time values must never decrease.
+ if monotonicValue > s.maxMonotonic {
+ s.maxMonotonic = monotonicValue
+ }
+
+ return s.maxMonotonic
+}
+
+// AfterFunc implements Clock.AfterFunc.
+func (*stdClock) AfterFunc(d time.Duration, f func()) Timer {
+ return &stdTimer{
+ t: time.AfterFunc(d, f),
+ }
+}
+
+type stdTimer struct {
+ t *time.Timer
+}
+
+var _ Timer = (*stdTimer)(nil)
+
+// Stop implements Timer.Stop.
+func (st *stdTimer) Stop() bool {
+ return st.t.Stop()
+}
+
+// Reset implements Timer.Reset.
+func (st *stdTimer) Reset(d time.Duration) {
+ st.t.Reset(d)
+}
+
+// NewStdTimer returns a Timer implemented with the time package.
+func NewStdTimer(t *time.Timer) Timer {
+ return &stdTimer{t: t}
+}
diff --git a/pkg/tcpip/stdclock_state.go b/pkg/tcpip/stdclock_state.go
new file mode 100644
index 000000000..795db9181
--- /dev/null
+++ b/pkg/tcpip/stdclock_state.go
@@ -0,0 +1,26 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpip
+
+import "time"
+
+// afterLoad is invoked by stateify.
+func (s *stdClock) afterLoad() {
+ s.baseTime = time.Now()
+
+ s.monotonicMU.Lock()
+ defer s.monotonicMU.Unlock()
+ s.monotonicOffset = s.maxMonotonic
+}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 0ba71b62e..797778e08 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -37,9 +37,9 @@ import (
"reflect"
"strconv"
"strings"
- "sync/atomic"
"time"
+ "gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -73,7 +73,7 @@ type Clock interface {
// nanoseconds since the Unix epoch.
NowNanoseconds() int64
- // NowMonotonic returns a monotonic time value.
+ // NowMonotonic returns a monotonic time value at nanosecond resolution.
NowMonotonic() int64
// AfterFunc waits for the duration to elapse and then calls f in its own
@@ -1107,6 +1107,7 @@ const (
// LingerOption is used by SetSockOpt/GetSockOpt to set/get the
// duration for which a socket lingers before returning from Close.
//
+// +marshal
// +stateify savable
type LingerOption struct {
Enabled bool
@@ -1219,7 +1220,7 @@ type NetworkProtocolNumber uint32
// A StatCounter keeps track of a statistic.
type StatCounter struct {
- count uint64
+ count atomicbitops.AlignedAtomicUint64
}
// Increment adds one to the counter.
@@ -1234,12 +1235,12 @@ func (s *StatCounter) Decrement() {
// Value returns the current value of the counter.
func (s *StatCounter) Value(name ...string) uint64 {
- return atomic.LoadUint64(&s.count)
+ return s.count.Load()
}
// IncrementBy increments the counter by v.
func (s *StatCounter) IncrementBy(v uint64) {
- atomic.AddUint64(&s.count, v)
+ s.count.Add(v)
}
func (s *StatCounter) String() string {
@@ -1527,6 +1528,42 @@ type IGMPStats struct {
// LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPStats)
}
+// IPForwardingStats collects stats related to IP forwarding (both v4 and v6).
+type IPForwardingStats struct {
+ // LINT.IfChange(IPForwardingStats)
+
+ // Unrouteable is the number of IP packets received which were dropped
+ // because a route to their destination could not be constructed.
+ Unrouteable *StatCounter
+
+ // ExhaustedTTL is the number of IP packets received which were dropped
+ // because their TTL was exhausted.
+ ExhaustedTTL *StatCounter
+
+ // LinkLocalSource is the number of IP packets which were dropped
+ // because they contained a link-local source address.
+ LinkLocalSource *StatCounter
+
+ // LinkLocalDestination is the number of IP packets which were dropped
+ // because they contained a link-local destination address.
+ LinkLocalDestination *StatCounter
+
+ // PacketTooBig is the number of IP packets which were dropped because they
+ // were too big for the outgoing MTU.
+ PacketTooBig *StatCounter
+
+ // ExtensionHeaderProblem is the number of IP packets which were dropped
+ // because of a problem encountered when processing an IPv6 extension
+ // header.
+ ExtensionHeaderProblem *StatCounter
+
+ // Errors is the number of IP packets received which could not be
+ // successfully forwarded.
+ Errors *StatCounter
+
+ // LINT.ThenChange(network/internal/ip/stats.go:multiCounterIPForwardingStats)
+}
+
// IPStats collects IP-specific stats (both v4 and v6).
type IPStats struct {
// LINT.IfChange(IPStats)
@@ -1534,6 +1571,10 @@ type IPStats struct {
// PacketsReceived is the number of IP packets received from the link layer.
PacketsReceived *StatCounter
+ // ValidPacketsReceived is the number of valid IP packets that reached the IP
+ // layer.
+ ValidPacketsReceived *StatCounter
+
// DisabledPacketsReceived is the number of IP packets received from the link
// layer when the IP layer is disabled.
DisabledPacketsReceived *StatCounter
@@ -1573,6 +1614,10 @@ type IPStats struct {
// chain.
IPTablesInputDropped *StatCounter
+ // IPTablesForwardDropped is the number of IP packets dropped in the Forward
+ // chain.
+ IPTablesForwardDropped *StatCounter
+
// IPTablesOutputDropped is the number of IP packets dropped in the Output
// chain.
IPTablesOutputDropped *StatCounter
@@ -1595,6 +1640,9 @@ type IPStats struct {
// OptionUnknownReceived is the number of unknown IP options seen.
OptionUnknownReceived *StatCounter
+ // Forwarding collects stats related to IP forwarding.
+ Forwarding IPForwardingStats
+
// LINT.ThenChange(network/internal/ip/stats.go:MultiCounterIPStats)
}
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index d4f7bb5ff..ab2dab60c 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -31,12 +31,14 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/tests/utils",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/udp",
],
)
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index dbd279c94..92fa6257d 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -16,6 +16,7 @@ package forward_test
import (
"bytes"
+ "fmt"
"testing"
"github.com/google/go-cmp/cmp"
@@ -34,6 +35,39 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+const ttl = 64
+
+var (
+ ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10")
+ ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a")
+)
+
+func rxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
+ utils.RxICMPv4EchoRequest(e, src, dst, ttl)
+}
+
+func rxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
+ utils.RxICMPv6EchoRequest(e, src, dst, ttl)
+}
+
+func forwardedICMPv4EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
+ checker.IPv4(t, b,
+ checker.SrcAddr(src),
+ checker.DstAddr(dst),
+ checker.TTL(ttl-1),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4Echo)))
+}
+
+func forwardedICMPv6EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
+ checker.IPv6(t, b,
+ checker.SrcAddr(src),
+ checker.DstAddr(dst),
+ checker.TTL(ttl-1),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoRequest)))
+}
+
func TestForwarding(t *testing.T) {
const listenPort = 8080
@@ -320,45 +354,16 @@ func TestMulticastForwarding(t *testing.T) {
const (
nicID1 = 1
nicID2 = 2
- ttl = 64
)
var (
ipv4LinkLocalUnicastAddr = testutil.MustParse4("169.254.0.10")
ipv4LinkLocalMulticastAddr = testutil.MustParse4("224.0.0.10")
- ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10")
ipv6LinkLocalUnicastAddr = testutil.MustParse6("fe80::a")
ipv6LinkLocalMulticastAddr = testutil.MustParse6("ff02::a")
- ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a")
)
- rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
- utils.RxICMPv4EchoRequest(e, src, dst, ttl)
- }
-
- rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
- utils.RxICMPv6EchoRequest(e, src, dst, ttl)
- }
-
- v4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
- checker.IPv4(t, b,
- checker.SrcAddr(src),
- checker.DstAddr(dst),
- checker.TTL(ttl-1),
- checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4Echo)))
- }
-
- v6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
- checker.IPv6(t, b,
- checker.SrcAddr(src),
- checker.DstAddr(dst),
- checker.TTL(ttl-1),
- checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6EchoRequest)))
- }
-
tests := []struct {
name string
srcAddr, dstAddr tcpip.Address
@@ -394,7 +399,7 @@ func TestMulticastForwarding(t *testing.T) {
rx: rxICMPv4EchoRequest,
expectForward: true,
checker: func(t *testing.T, b []byte) {
- v4Checker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
+ forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
},
},
{
@@ -404,7 +409,7 @@ func TestMulticastForwarding(t *testing.T) {
rx: rxICMPv4EchoRequest,
expectForward: true,
checker: func(t *testing.T, b []byte) {
- v4Checker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr)
+ forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr)
},
},
@@ -436,7 +441,7 @@ func TestMulticastForwarding(t *testing.T) {
rx: rxICMPv6EchoRequest,
expectForward: true,
checker: func(t *testing.T, b []byte) {
- v6Checker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
+ forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
},
},
{
@@ -446,7 +451,7 @@ func TestMulticastForwarding(t *testing.T) {
rx: rxICMPv6EchoRequest,
expectForward: true,
checker: func(t *testing.T, b []byte) {
- v6Checker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr)
+ forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr)
},
},
}
@@ -475,11 +480,11 @@ func TestMulticastForwarding(t *testing.T) {
t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err)
}
- if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil {
- t.Fatalf("s.SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err)
+ if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
+ t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err)
}
- if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil {
- t.Fatalf("s.SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err)
+ if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -506,3 +511,180 @@ func TestMulticastForwarding(t *testing.T) {
})
}
}
+
+func TestPerInterfaceForwarding(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+ )
+
+ tests := []struct {
+ name string
+ srcAddr, dstAddr tcpip.Address
+ rx func(*channel.Endpoint, tcpip.Address, tcpip.Address)
+ checker func(*testing.T, []byte)
+ }{
+ {
+ name: "IPv4 unicast",
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
+ rx: rxICMPv4EchoRequest,
+ checker: func(t *testing.T, b []byte) {
+ forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
+ },
+ },
+ {
+ name: "IPv4 multicast",
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: ipv4GlobalMulticastAddr,
+ rx: rxICMPv4EchoRequest,
+ checker: func(t *testing.T, b []byte) {
+ forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr)
+ },
+ },
+
+ {
+ name: "IPv6 unicast",
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
+ rx: rxICMPv6EchoRequest,
+ checker: func(t *testing.T, b []byte) {
+ forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
+ },
+ },
+ {
+ name: "IPv6 multicast",
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: ipv6GlobalMulticastAddr,
+ rx: rxICMPv6EchoRequest,
+ checker: func(t *testing.T, b []byte) {
+ forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr)
+ },
+ },
+ }
+
+ netProtos := [...]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber, ipv6.ProtocolNumber}
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ // ARP is not used in this test but it is a network protocol that does
+ // not support forwarding. We install the protocol to make sure that
+ // forwarding information for a NIC is only reported for network
+ // protocols that support forwarding.
+ arp.NewProtocol,
+
+ ipv4.NewProtocol,
+ ipv6.NewProtocol,
+ },
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+
+ e1 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID1, e1); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err)
+ }
+
+ e2 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID2, e2); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err)
+ }
+
+ for _, add := range [...]struct {
+ nicID tcpip.NICID
+ addr tcpip.ProtocolAddress
+ }{
+ {
+ nicID: nicID1,
+ addr: utils.RouterNIC1IPv4Addr,
+ },
+ {
+ nicID: nicID1,
+ addr: utils.RouterNIC1IPv6Addr,
+ },
+ {
+ nicID: nicID2,
+ addr: utils.RouterNIC2IPv4Addr,
+ },
+ {
+ nicID: nicID2,
+ addr: utils.RouterNIC2IPv6Addr,
+ },
+ } {
+ if err := s.AddProtocolAddress(add.nicID, add.addr); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", add.nicID, add.addr, err)
+ }
+ }
+
+ // Only enable forwarding on NIC1 and make sure that only packets arriving
+ // on NIC1 are forwarded.
+ for _, netProto := range netProtos {
+ if err := s.SetNICForwarding(nicID1, netProto, true); err != nil {
+ t.Fatalf("s.SetNICForwarding(%d, %d, true): %s", nicID1, netProtos, err)
+ }
+ }
+
+ nicsInfo := s.NICInfo()
+ for _, subTest := range [...]struct {
+ nicID tcpip.NICID
+ nicEP *channel.Endpoint
+ otherNICID tcpip.NICID
+ otherNICEP *channel.Endpoint
+ expectForwarding bool
+ }{
+ {
+ nicID: nicID1,
+ nicEP: e1,
+ otherNICID: nicID2,
+ otherNICEP: e2,
+ expectForwarding: true,
+ },
+ {
+ nicID: nicID2,
+ nicEP: e2,
+ otherNICID: nicID2,
+ otherNICEP: e1,
+ expectForwarding: false,
+ },
+ } {
+ t.Run(fmt.Sprintf("Packet arriving at NIC%d", subTest.nicID), func(t *testing.T) {
+ nicInfo, ok := nicsInfo[subTest.nicID]
+ if !ok {
+ t.Errorf("expected NIC info for NIC %d; got = %#v", subTest.nicID, nicsInfo)
+ } else {
+ forwarding := make(map[tcpip.NetworkProtocolNumber]bool)
+ for _, netProto := range netProtos {
+ forwarding[netProto] = subTest.expectForwarding
+ }
+
+ if diff := cmp.Diff(forwarding, nicInfo.Forwarding); diff != "" {
+ t.Errorf("nicsInfo[%d].Forwarding mismatch (-want +got):\n%s", subTest.nicID, diff)
+ }
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: subTest.otherNICID,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: subTest.otherNICID,
+ },
+ })
+
+ test.rx(subTest.nicEP, test.srcAddr, test.dstAddr)
+ if p, ok := subTest.nicEP.Read(); ok {
+ t.Errorf("unexpectedly got a response from the interface the packet arrived on: %#v", p)
+ }
+ if p, ok := subTest.otherNICEP.Read(); ok != subTest.expectForwarding {
+ t.Errorf("got otherNICEP.Read() = (%#v, %t), want = (_, %t)", p, ok, subTest.expectForwarding)
+ } else if subTest.expectForwarding {
+ test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index c61d4e788..07ba2b837 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -19,12 +19,14 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
@@ -645,3 +647,297 @@ func TestIPTableWritePackets(t *testing.T) {
})
}
}
+
+const ttl = 64
+
+var (
+ ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10")
+ ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a")
+)
+
+func rxICMPv4EchoReply(e *channel.Endpoint, src, dst tcpip.Address) {
+ utils.RxICMPv4EchoReply(e, src, dst, ttl)
+}
+
+func rxICMPv6EchoReply(e *channel.Endpoint, src, dst tcpip.Address) {
+ utils.RxICMPv6EchoReply(e, src, dst, ttl)
+}
+
+func forwardedICMPv4EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
+ checker.IPv4(t, b,
+ checker.SrcAddr(src),
+ checker.DstAddr(dst),
+ checker.TTL(ttl-1),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4EchoReply)))
+}
+
+func forwardedICMPv6EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
+ checker.IPv6(t, b,
+ checker.SrcAddr(src),
+ checker.DstAddr(dst),
+ checker.TTL(ttl-1),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoReply)))
+}
+
+func TestForwardingHook(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+
+ nic1Name = "nic1"
+ nic2Name = "nic2"
+
+ otherNICName = "otherNIC"
+ )
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ local bool
+ srcAddr, dstAddr tcpip.Address
+ rx func(*channel.Endpoint, tcpip.Address, tcpip.Address)
+ checker func(*testing.T, []byte)
+ }{
+ {
+ name: "IPv4 remote",
+ netProto: ipv4.ProtocolNumber,
+ local: false,
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
+ rx: rxICMPv4EchoReply,
+ checker: func(t *testing.T, b []byte) {
+ forwardedICMPv4EchoReplyChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
+ },
+ },
+ {
+ name: "IPv4 local",
+ netProto: ipv4.ProtocolNumber,
+ local: true,
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: utils.Ipv4Addr.Address,
+ rx: rxICMPv4EchoReply,
+ },
+ {
+ name: "IPv6 remote",
+ netProto: ipv6.ProtocolNumber,
+ local: false,
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
+ rx: rxICMPv6EchoReply,
+ checker: func(t *testing.T, b []byte) {
+ forwardedICMPv6EchoReplyChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
+ },
+ },
+ {
+ name: "IPv6 local",
+ netProto: ipv6.ProtocolNumber,
+ local: true,
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: utils.Ipv6Addr.Address,
+ rx: rxICMPv6EchoReply,
+ },
+ }
+
+ setupDropFilter := func(f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) {
+ return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) {
+ t.Helper()
+
+ ipv6 := netProto == ipv6.ProtocolNumber
+
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, ipv6)
+ ruleIdx := filter.BuiltinChains[stack.Forward]
+ filter.Rules[ruleIdx].Filter = f
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto}
+ // Make sure the packet is not dropped by the next rule.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err)
+ }
+ }
+ }
+
+ boolToInt := func(v bool) uint64 {
+ if v {
+ return 1
+ }
+ return 0
+ }
+
+ subTests := []struct {
+ name string
+ setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber)
+ expectForward bool
+ }{
+ {
+ name: "Accept",
+ setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ },
+ expectForward: true,
+ },
+
+ {
+ name: "Drop",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{}),
+ expectForward: false,
+ },
+ {
+ name: "Drop with input NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name}),
+ expectForward: false,
+ },
+ {
+ name: "Drop with output NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name}),
+ expectForward: false,
+ },
+ {
+ name: "Drop with input and output NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}),
+ expectForward: false,
+ },
+
+ {
+ name: "Drop with other input NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName}),
+ expectForward: true,
+ },
+ {
+ name: "Drop with other output NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: otherNICName}),
+ expectForward: true,
+ },
+ {
+ name: "Drop with other input and output NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}),
+ expectForward: true,
+ },
+ {
+ name: "Drop with input and other output NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}),
+ expectForward: true,
+ },
+ {
+ name: "Drop with other input and other output NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}),
+ expectForward: true,
+ },
+
+ {
+ name: "Drop with inverted input NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}),
+ expectForward: true,
+ },
+ {
+ name: "Drop with inverted output NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}),
+ expectForward: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ })
+
+ subTest.setupFilter(t, s, test.netProto)
+
+ e1 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil {
+ t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err)
+ }
+
+ e2 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil {
+ t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
+ }
+
+ if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err)
+ }
+ if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err)
+ }
+
+ if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
+ t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err)
+ }
+ if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID2,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID2,
+ },
+ })
+
+ test.rx(e1, test.srcAddr, test.dstAddr)
+
+ expectTransmitPacket := subTest.expectForward && !test.local
+
+ ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err)
+ }
+ ep1Stats := ep1.Stats()
+ ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats)
+ if !ok {
+ t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats)
+ }
+ ip1Stats := ipEP1Stats.IPStats()
+
+ if got := ip1Stats.PacketsReceived.Value(); got != 1 {
+ t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got)
+ }
+ if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 {
+ t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got)
+ }
+ if got, want := ip1Stats.IPTablesForwardDropped.Value(), boolToInt(!subTest.expectForward); got != want {
+ t.Errorf("got ip1Stats.IPTablesForwardDropped.Value() = %d, want = %d", got, want)
+ }
+ if got := ip1Stats.PacketsSent.Value(); got != 0 {
+ t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = 0", got)
+ }
+
+ ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err)
+ }
+ ep2Stats := ep2.Stats()
+ ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats)
+ if !ok {
+ t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats)
+ }
+ ip2Stats := ipEP2Stats.IPStats()
+ if got := ip2Stats.PacketsReceived.Value(); got != 0 {
+ t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got)
+ }
+ if got, want := ip2Stats.ValidPacketsReceived.Value(), boolToInt(subTest.expectForward && test.local); got != want {
+ t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = %d", got, want)
+ }
+ if got, want := ip2Stats.PacketsSent.Value(), boolToInt(expectTransmitPacket); got != want {
+ t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = %d", got, want)
+ }
+
+ p, ok := e2.Read()
+ if ok != expectTransmitPacket {
+ t.Fatalf("got e2.Read() = (%#v, %t), want = (_, %t)", p, ok, expectTransmitPacket)
+ }
+ if expectTransmitPacket {
+ test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index 3df1bbd68..87d36e1dd 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -714,11 +714,11 @@ func TestExternalLoopbackTraffic(t *testing.T) {
}
if test.forwarding {
- if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil {
- t.Fatalf("SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err)
+ if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err)
}
- if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil {
- t.Fatalf("SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err)
+ if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
}
}
diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go
index 8fd9be32b..2e6ae55ea 100644
--- a/pkg/tcpip/tests/utils/utils.go
+++ b/pkg/tcpip/tests/utils/utils.go
@@ -224,11 +224,11 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.
t.Fatalf("host2Stack.CreateNIC(%d, _): %s", Host2NICID, err)
}
- if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil {
- t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err)
+ if err := routerStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
+ t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv4.ProtocolNumber, err)
}
- if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil {
- t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err)
+ if err := routerStack.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv6.ProtocolNumber, err)
}
if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr); err != nil {
@@ -316,13 +316,11 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.
})
}
-// RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on
-// the provided endpoint.
-func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
+func rxICMPv4Echo(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8, ty header.ICMPv4Type) {
totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
hdr := buffer.NewPrependable(totalLen)
pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
- pkt.SetType(header.ICMPv4Echo)
+ pkt.SetType(ty)
pkt.SetCode(header.ICMPv4UnusedCode)
pkt.SetChecksum(0)
pkt.SetChecksum(^header.Checksum(pkt, 0))
@@ -341,13 +339,23 @@ func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8)
}))
}
-// RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on
+// RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on
// the provided endpoint.
-func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
+func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
+ rxICMPv4Echo(e, src, dst, ttl, header.ICMPv4Echo)
+}
+
+// RxICMPv4EchoReply constructs and injects an ICMPv4 echo reply packet on
+// the provided endpoint.
+func RxICMPv4EchoReply(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
+ rxICMPv4Echo(e, src, dst, ttl, header.ICMPv4EchoReply)
+}
+
+func rxICMPv6Echo(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8, ty header.ICMPv6Type) {
totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
hdr := buffer.NewPrependable(totalLen)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
- pkt.SetType(header.ICMPv6EchoRequest)
+ pkt.SetType(ty)
pkt.SetCode(header.ICMPv6UnusedCode)
pkt.SetChecksum(0)
pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
@@ -368,3 +376,15 @@ func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8)
Data: hdr.View().ToVectorisedView(),
}))
}
+
+// RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on
+// the provided endpoint.
+func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
+ rxICMPv6Echo(e, src, dst, ttl, header.ICMPv6EchoRequest)
+}
+
+// RxICMPv6EchoReply constructs and injects an ICMPv6 echo reply packet on
+// the provided endpoint.
+func RxICMPv6EchoReply(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
+ rxICMPv6Echo(e, src, dst, ttl, header.ICMPv6EchoReply)
+}
diff --git a/pkg/tcpip/testutil/BUILD b/pkg/tcpip/testutil/BUILD
index 472545a5d..02ee86ff1 100644
--- a/pkg/tcpip/testutil/BUILD
+++ b/pkg/tcpip/testutil/BUILD
@@ -5,7 +5,10 @@ package(licenses = ["notice"])
go_library(
name = "testutil",
testonly = True,
- srcs = ["testutil.go"],
+ srcs = [
+ "testutil.go",
+ "testutil_unsafe.go",
+ ],
visibility = ["//visibility:public"],
deps = ["//pkg/tcpip"],
)
diff --git a/pkg/tcpip/testutil/testutil.go b/pkg/tcpip/testutil/testutil.go
index 1aaed590f..f84d399fb 100644
--- a/pkg/tcpip/testutil/testutil.go
+++ b/pkg/tcpip/testutil/testutil.go
@@ -18,6 +18,8 @@ package testutil
import (
"fmt"
"net"
+ "reflect"
+ "strings"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -41,3 +43,69 @@ func MustParse6(addr string) tcpip.Address {
}
return tcpip.Address(ip)
}
+
+func checkFieldCounts(ref, multi reflect.Value) error {
+ refTypeName := ref.Type().Name()
+ multiTypeName := multi.Type().Name()
+ refNumField := ref.NumField()
+ multiNumField := multi.NumField()
+
+ if refNumField != multiNumField {
+ return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName)
+ }
+
+ return nil
+}
+
+func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error {
+ s, ok := ref.Addr().Interface().(**tcpip.StatCounter)
+ if !ok {
+ return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name())
+ }
+
+ // The field names are expected to match (case insensitive).
+ if !strings.EqualFold(refName, multiName) {
+ return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName)
+ }
+
+ base := (*s).Value()
+ m.Increment()
+ if (*s).Value() != base+1 {
+ return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName)
+ }
+
+ return nil
+}
+
+// ValidateMultiCounterStats verifies that every counter stored in multi is
+// correctly tracking its counterpart in the given counters.
+func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error {
+ for _, c := range counters {
+ if err := checkFieldCounts(c, multi); err != nil {
+ return err
+ }
+ }
+
+ for i := 0; i < multi.NumField(); i++ {
+ multiName := multi.Type().Field(i).Name
+ multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i))
+
+ if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok {
+ for _, c := range counters {
+ if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil {
+ return err
+ }
+ }
+ } else {
+ var countersNextField []reflect.Value
+ for _, c := range counters {
+ countersNextField = append(countersNextField, c.Field(i))
+ }
+ if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
diff --git a/pkg/tcpip/network/internal/testutil/testutil_unsafe.go b/pkg/tcpip/testutil/testutil_unsafe.go
index 5ff764800..5ff764800 100644
--- a/pkg/tcpip/network/internal/testutil/testutil_unsafe.go
+++ b/pkg/tcpip/testutil/testutil_unsafe.go
diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go
deleted file mode 100644
index eeea97b12..000000000
--- a/pkg/tcpip/time_unsafe.go
+++ /dev/null
@@ -1,75 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// +build go1.9
-// +build !go1.18
-
-// Check go:linkname function signatures when updating Go version.
-
-package tcpip
-
-import (
- "time" // Used with go:linkname.
- _ "unsafe" // Required for go:linkname.
-)
-
-// StdClock implements Clock with the time package.
-//
-// +stateify savable
-type StdClock struct{}
-
-var _ Clock = (*StdClock)(nil)
-
-//go:linkname now time.now
-func now() (sec int64, nsec int32, mono int64)
-
-// NowNanoseconds implements Clock.NowNanoseconds.
-func (*StdClock) NowNanoseconds() int64 {
- sec, nsec, _ := now()
- return sec*1e9 + int64(nsec)
-}
-
-// NowMonotonic implements Clock.NowMonotonic.
-func (*StdClock) NowMonotonic() int64 {
- _, _, mono := now()
- return mono
-}
-
-// AfterFunc implements Clock.AfterFunc.
-func (*StdClock) AfterFunc(d time.Duration, f func()) Timer {
- return &stdTimer{
- t: time.AfterFunc(d, f),
- }
-}
-
-type stdTimer struct {
- t *time.Timer
-}
-
-var _ Timer = (*stdTimer)(nil)
-
-// Stop implements Timer.Stop.
-func (st *stdTimer) Stop() bool {
- return st.t.Stop()
-}
-
-// Reset implements Timer.Reset.
-func (st *stdTimer) Reset(d time.Duration) {
- st.t.Reset(d)
-}
-
-// NewStdTimer returns a Timer implemented with the time package.
-func NewStdTimer(t *time.Timer) Timer {
- return &stdTimer{t: t}
-}
diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go
index a82384c49..1633d0aeb 100644
--- a/pkg/tcpip/timer_test.go
+++ b/pkg/tcpip/timer_test.go
@@ -29,7 +29,7 @@ const (
)
func TestJobReschedule(t *testing.T) {
- var clock tcpip.StdClock
+ clock := tcpip.NewStdClock()
var wg sync.WaitGroup
var lock sync.Mutex
@@ -43,7 +43,7 @@ func TestJobReschedule(t *testing.T) {
// that has an active timer (even if it has been stopped as a stopped
// timer may be blocked on a lock before it can check if it has been
// stopped while another goroutine holds the same lock).
- job := tcpip.NewJob(&clock, &lock, func() {
+ job := tcpip.NewJob(clock, &lock, func() {
wg.Done()
})
job.Schedule(shortDuration)
@@ -56,11 +56,11 @@ func TestJobReschedule(t *testing.T) {
func TestJobExecution(t *testing.T) {
t.Parallel()
- var clock tcpip.StdClock
+ clock := tcpip.NewStdClock()
var lock sync.Mutex
ch := make(chan struct{})
- job := tcpip.NewJob(&clock, &lock, func() {
+ job := tcpip.NewJob(clock, &lock, func() {
ch <- struct{}{}
})
job.Schedule(shortDuration)
@@ -83,11 +83,11 @@ func TestJobExecution(t *testing.T) {
func TestCancellableTimerResetFromLongDuration(t *testing.T) {
t.Parallel()
- var clock tcpip.StdClock
+ clock := tcpip.NewStdClock()
var lock sync.Mutex
ch := make(chan struct{})
- job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} })
job.Schedule(middleDuration)
lock.Lock()
@@ -114,12 +114,12 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) {
func TestJobRescheduleFromShortDuration(t *testing.T) {
t.Parallel()
- var clock tcpip.StdClock
+ clock := tcpip.NewStdClock()
var lock sync.Mutex
ch := make(chan struct{})
lock.Lock()
- job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} })
job.Schedule(shortDuration)
job.Cancel()
lock.Unlock()
@@ -151,13 +151,13 @@ func TestJobRescheduleFromShortDuration(t *testing.T) {
func TestJobImmediatelyCancel(t *testing.T) {
t.Parallel()
- var clock tcpip.StdClock
+ clock := tcpip.NewStdClock()
var lock sync.Mutex
ch := make(chan struct{})
for i := 0; i < 1000; i++ {
lock.Lock()
- job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} })
job.Schedule(shortDuration)
job.Cancel()
lock.Unlock()
@@ -174,12 +174,12 @@ func TestJobImmediatelyCancel(t *testing.T) {
func TestJobCancelledRescheduleWithoutLock(t *testing.T) {
t.Parallel()
- var clock tcpip.StdClock
+ clock := tcpip.NewStdClock()
var lock sync.Mutex
ch := make(chan struct{})
lock.Lock()
- job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} })
job.Schedule(shortDuration)
job.Cancel()
lock.Unlock()
@@ -206,12 +206,12 @@ func TestJobCancelledRescheduleWithoutLock(t *testing.T) {
func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
t.Parallel()
- var clock tcpip.StdClock
+ clock := tcpip.NewStdClock()
var lock sync.Mutex
ch := make(chan struct{})
lock.Lock()
- job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} })
job.Schedule(shortDuration)
for i := 0; i < 10; i++ {
// Sleep until the timer fires and gets blocked trying to take the lock.
@@ -239,12 +239,12 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
func TestManyJobReschedulesUnderLock(t *testing.T) {
t.Parallel()
- var clock tcpip.StdClock
+ clock := tcpip.NewStdClock()
var lock sync.Mutex
ch := make(chan struct{})
lock.Lock()
- job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} })
job.Schedule(shortDuration)
for i := 0; i < 10; i++ {
job.Cancel()
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 9948f305b..8afde7fca 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -747,8 +747,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
switch e.NetProto {
case header.IPv4ProtocolNumber:
h := header.ICMPv4(pkt.TransportHeader().View())
- // TODO(b/129292233): Determine if len(h) check is still needed after early
- // parsing.
+ // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed
+ // after early parsing.
if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
@@ -756,8 +756,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
}
case header.IPv6ProtocolNumber:
h := header.ICMPv6(pkt.TransportHeader().View())
- // TODO(b/129292233): Determine if len(h) check is still needed after early
- // parsing.
+ // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed
+ // after early parsing.
if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 48417f192..0f20d3856 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -126,7 +126,15 @@ go_test(
go_test(
name = "tcp_test",
size = "small",
- srcs = ["timer_test.go"],
+ srcs = [
+ "segment_test.go",
+ "timer_test.go",
+ ],
library = ":tcp",
- deps = ["//pkg/sleep"],
+ deps = [
+ "//pkg/sleep",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/stack",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ ],
)
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 524d5cabf..5e03e7715 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -586,8 +586,14 @@ func (h *handshake) complete() tcpip.Error {
<-h.ep.undrain
h.ep.mu.Lock()
}
+ // Check for any ICMP errors notified to us.
if n&notifyError != 0 {
- return h.ep.lastErrorLocked()
+ if err := h.ep.lastErrorLocked(); err != nil {
+ return err
+ }
+ // Flag the handshake failure as aborted if the lastError is
+ // cleared because of a socket layer call.
+ return &tcpip.ErrConnectionAborted{}
}
case wakerForNewSegment:
if err := h.processSegments(); err != nil {
@@ -1362,8 +1368,24 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
// Reaching this point means that we successfully completed the 3-way
- // handshake with our peer.
- //
+ // handshake with our peer. The current endpoint state could be any state
+ // post ESTABLISHED, including CLOSED or ERROR if the endpoint processes a
+ // RST from the peer via the dispatcher fast path, before the loop is
+ // started.
+ if s := e.EndpointState(); !s.connected() {
+ switch s {
+ case StateClose, StateError:
+ // If the endpoint is in CLOSED/ERROR state, sender state has to be
+ // initialized if the endpoint was previously established.
+ if e.snd != nil {
+ break
+ }
+ fallthrough
+ default:
+ panic("endpoint was not established, current state " + s.String())
+ }
+ }
+
// Completing the 3-way handshake is an indication that the route is valid
// and the remote is reachable as the only way we can complete a handshake
// is if our SYN reached the remote and their ACK reached us.
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index d6d68f128..f148d505d 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -19,6 +19,7 @@ import (
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -37,8 +38,8 @@ func TestV4MappedConnectOnV6Only(t *testing.T) {
// Start connection attempt, it must fail.
err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort})
- if _, ok := err.(*tcpip.ErrNoRoute); !ok {
- t.Fatalf("Unexpected return value from Connect: %v", err)
+ if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" {
+ t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
}
}
@@ -49,8 +50,8 @@ func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.Network
defer c.WQ.EventUnregister(&we)
err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort})
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- t.Fatalf("Unexpected return value from Connect: %v", err)
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
+ t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
}
// Receive SYN packet.
@@ -156,8 +157,8 @@ func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.Network
defer c.WQ.EventUnregister(&we)
err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort})
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- t.Fatalf("Unexpected return value from Connect: %v", err)
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
+ t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d)
}
// Receive SYN packet.
@@ -391,7 +392,7 @@ func testV4Accept(t *testing.T, c *context.Context) {
defer c.WQ.EventUnregister(&we)
nep, _, err := c.EP.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -525,7 +526,7 @@ func TestV6AcceptOnV6(t *testing.T) {
defer c.WQ.EventUnregister(&we)
var addr tcpip.FullAddress
_, _, err := c.EP.Accept(&addr)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -611,7 +612,7 @@ func testV4ListenClose(t *testing.T, c *context.Context) {
c.WQ.EventRegister(&we, waiter.ReadableEvents)
defer c.WQ.EventUnregister(&we)
nep, _, err := c.EP.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 3a7b2d166..50d39cbad 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1280,6 +1280,12 @@ func (e *endpoint) LastError() tcpip.Error {
return e.lastErrorLocked()
}
+// LastErrorLocked reads and clears lastError with e.mu held.
+// Only to be used in tests.
+func (e *endpoint) LastErrorLocked() tcpip.Error {
+ return e.lastErrorLocked()
+}
+
// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
func (e *endpoint) UpdateLastError(err tcpip.Error) {
e.LockUser()
@@ -1595,7 +1601,7 @@ func (e *endpoint) selectWindow() (wnd seqnum.Size) {
//
// For large receive buffers, the threshold is aMSS - once reader reads more
// than aMSS we'll send ACK. For tiny receive buffers, the threshold is half of
-// receive buffer size. This is chosen arbitrairly.
+// receive buffer size. This is chosen arbitrarily.
// crossed will be true if the window size crossed the ACK threshold.
// above will be true if the new window is >= ACK threshold and false
// otherwise.
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index ee2c08cd6..133371455 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -148,6 +148,18 @@ func (r *receiver) getSendParams() (RcvNxt seqnum.Value, rcvWnd seqnum.Size) {
}
newWnd = curWnd
}
+
+ // Apply silly-window avoidance when recovering from zero-window situation.
+ // Keep advertising zero receive window up until the new window reaches a
+ // threshold.
+ if r.rcvWnd == 0 && newWnd != 0 {
+ r.ep.rcvQueueInfo.rcvQueueMu.Lock()
+ if crossed, above := r.ep.windowCrossedACKThresholdLocked(int(newWnd), int(r.ep.ops.GetReceiveBufferSize())); !crossed && !above {
+ newWnd = 0
+ }
+ r.ep.rcvQueueInfo.rcvQueueMu.Unlock()
+ }
+
// Stash away the non-scaled receive window as we use it for measuring
// receiver's estimated RTT.
r.rcvWnd = newWnd
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index c28641be3..7e5ba6ef7 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -140,6 +140,15 @@ func (s *segment) clone() *segment {
return t
}
+// merge merges data in oth and clears oth.
+func (s *segment) merge(oth *segment) {
+ s.data.Append(oth.data)
+ s.dataMemSize = s.data.Size()
+
+ oth.data = buffer.VectorisedView{}
+ oth.dataMemSize = oth.data.Size()
+}
+
// flagIsSet checks if at least one flag in flags is set in s.flags.
func (s *segment) flagIsSet(flags header.TCPFlags) bool {
return s.flags&flags != 0
diff --git a/pkg/tcpip/transport/tcp/segment_test.go b/pkg/tcpip/transport/tcp/segment_test.go
new file mode 100644
index 000000000..486016fc0
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment_test.go
@@ -0,0 +1,67 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type segmentSizeWants struct {
+ DataSize int
+ SegMemSize int
+}
+
+func checkSegmentSize(t *testing.T, name string, seg *segment, want segmentSizeWants) {
+ t.Helper()
+ got := segmentSizeWants{
+ DataSize: seg.data.Size(),
+ SegMemSize: seg.segMemSize(),
+ }
+ if diff := cmp.Diff(got, want); diff != "" {
+ t.Errorf("%s differs (-want +got):\n%s", name, diff)
+ }
+}
+
+func TestSegmentMerge(t *testing.T) {
+ id := stack.TransportEndpointID{}
+ seg1 := newOutgoingSegment(id, buffer.NewView(10))
+ defer seg1.decRef()
+ seg2 := newOutgoingSegment(id, buffer.NewView(20))
+ defer seg2.decRef()
+
+ checkSegmentSize(t, "seg1", seg1, segmentSizeWants{
+ DataSize: 10,
+ SegMemSize: SegSize + 10,
+ })
+ checkSegmentSize(t, "seg2", seg2, segmentSizeWants{
+ DataSize: 20,
+ SegMemSize: SegSize + 20,
+ })
+
+ seg1.merge(seg2)
+
+ checkSegmentSize(t, "seg1", seg1, segmentSizeWants{
+ DataSize: 30,
+ SegMemSize: SegSize + 30,
+ })
+ checkSegmentSize(t, "seg2", seg2, segmentSizeWants{
+ DataSize: 0,
+ SegMemSize: SegSize,
+ })
+}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 2b32cb7b2..f43e86677 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -716,15 +716,14 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
// triggering bugs in poorly written DNS
// implementations.
var nextTooBig bool
- for seg.Next() != nil && seg.Next().data.Size() != 0 {
- if seg.data.Size()+seg.Next().data.Size() > available {
+ for nSeg := seg.Next(); nSeg != nil && nSeg.data.Size() != 0; nSeg = seg.Next() {
+ if seg.data.Size()+nSeg.data.Size() > available {
nextTooBig = true
break
}
- seg.data.Append(seg.Next().data)
-
- // Consume the segment that we just merged in.
- s.writeList.Remove(seg.Next())
+ seg.merge(nSeg)
+ s.writeList.Remove(nSeg)
+ nSeg.decRef()
}
if !nextTooBig && seg.data.Size() < available {
// Segment is not full.
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 3750b0691..9916182e3 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -87,7 +87,7 @@ func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-cha
}
for w.N != 0 {
_, err := e.ep.Read(&w, tcpip.ReadOptions{})
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for receive to be notified.
select {
case <-notifyRead:
@@ -130,8 +130,8 @@ func TestGiveUpConnect(t *testing.T) {
{
err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{})
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
+ t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d)
}
}
@@ -145,8 +145,8 @@ func TestGiveUpConnect(t *testing.T) {
// and stats updates.
{
err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if _, ok := err.(*tcpip.ErrAborted); !ok {
- t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrAborted{})
+ if d := cmp.Diff(&tcpip.ErrAborted{}, err); d != "" {
+ t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d)
}
}
@@ -159,6 +159,76 @@ func TestGiveUpConnect(t *testing.T) {
}
}
+// Test for ICMP error handling without completing handshake.
+func TestConnectICMPError(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ var wq waiter.Queue
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&waitEntry, waiter.EventHUp)
+ defer wq.EventUnregister(&waitEntry)
+
+ {
+ err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
+ t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d)
+ }
+ }
+
+ syn := c.GetPacket()
+ checker.IPv4(t, syn, checker.TCP(checker.TCPFlags(header.TCPFlagSyn)))
+
+ wep := ep.(interface {
+ StopWork()
+ ResumeWork()
+ LastErrorLocked() tcpip.Error
+ })
+
+ // Stop the protocol loop, ensure that the ICMP error is processed and
+ // the last ICMP error is read before the loop is resumed. This sanity
+ // tests the handshake completion logic on ICMP errors.
+ wep.StopWork()
+
+ c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable, nil, syn, defaultMTU)
+
+ for {
+ if err := wep.LastErrorLocked(); err != nil {
+ if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" {
+ t.Errorf("ep.LastErrorLocked() mismatch (-want +got):\n%s", d)
+ }
+ break
+ }
+ time.Sleep(time.Millisecond)
+ }
+
+ wep.ResumeWork()
+
+ <-notifyCh
+
+ // The stack would have unregistered the endpoint because of the ICMP error.
+ // Expect a RST for any subsequent packets sent to the endpoint.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(context.TestInitialSequenceNumber) + 1,
+ AckNum: c.IRS + 1,
+ })
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(0),
+ checker.TCPFlags(header.TCPFlagRst)))
+}
+
func TestConnectIncrementActiveConnection(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -202,8 +272,8 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) {
{
err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort})
- if _, ok := err.(*tcpip.ErrNoRoute); !ok {
- t.Errorf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrNoRoute{})
+ if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" {
+ t.Errorf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
}
}
@@ -393,7 +463,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -936,8 +1006,8 @@ func TestUserSuppliedMSSOnConnect(t *testing.T) {
connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort}
{
err := c.EP.Connect(connectAddr)
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- t.Fatalf("Connect(%+v): %s", connectAddr, err)
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
+ t.Fatalf("Connect(%+v) mismatch (-want +got):\n%s", connectAddr, d)
}
}
@@ -1543,8 +1613,8 @@ func TestConnectBindToDevice(t *testing.T) {
defer c.WQ.EventUnregister(&waitEntry)
err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- t.Fatalf("unexpected return value from Connect: %s", err)
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
+ t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
}
// Receive SYN packet.
@@ -1604,8 +1674,8 @@ func TestSynSent(t *testing.T) {
addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
err := c.EP.Connect(addr)
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- t.Fatalf("got Connect(%+v) = %v, want %s", addr, err, &tcpip.ErrConnectStarted{})
+ if d := cmp.Diff(err, &tcpip.ErrConnectStarted{}); d != "" {
+ t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d)
}
// Receive SYN packet.
@@ -2473,7 +2543,7 @@ func TestScaledWindowAccept(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -2545,7 +2615,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -3077,8 +3147,8 @@ func TestSetTTL(t *testing.T) {
{
err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- t.Fatalf("unexpected return value from Connect: %s", err)
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
+ t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
}
}
@@ -3137,7 +3207,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -3191,7 +3261,7 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -3266,8 +3336,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
{
err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- t.Fatalf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{})
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
+ t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
}
}
@@ -3385,8 +3455,8 @@ loop:
case <-ch:
// Expect the state to be StateError and subsequent Reads to fail with HardError.
_, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
- if _, ok := err.(*tcpip.ErrConnectionReset); !ok {
- t.Fatalf("got c.EP.Read() = %v, want = %s", err, &tcpip.ErrConnectionReset{})
+ if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" {
+ t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d)
}
break loop
case <-time.After(1 * time.Second):
@@ -3436,8 +3506,8 @@ func TestSendOnResetConnection(t *testing.T) {
var r bytes.Reader
r.Reset(make([]byte, 10))
_, err := c.EP.Write(&r, tcpip.WriteOptions{})
- if _, ok := err.(*tcpip.ErrConnectionReset); !ok {
- t.Fatalf("got c.EP.Write(...) = %v, want = %s", err, &tcpip.ErrConnectionReset{})
+ if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" {
+ t.Fatalf("c.EP.Write(...) mismatch (-want +got):\n%s", d)
}
}
@@ -4390,8 +4460,8 @@ func TestReadAfterClosedState(t *testing.T) {
var buf bytes.Buffer
{
_, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true})
- if _, ok := err.(*tcpip.ErrClosedForReceive); !ok {
- t.Fatalf("c.EP.Read(_, {Peek: true}) = %v, %s; want _, %s", res, err, &tcpip.ErrClosedForReceive{})
+ if d := cmp.Diff(&tcpip.ErrClosedForReceive{}, err); d != "" {
+ t.Fatalf("c.EP.Read(_, {Peek: true}) mismatch (-want +got):\n%s", d)
}
}
}
@@ -4435,8 +4505,8 @@ func TestReusePort(t *testing.T) {
}
{
err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- t.Fatalf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{})
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
+ t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
}
}
c.EP.Close()
@@ -4724,8 +4794,8 @@ func TestSelfConnect(t *testing.T) {
{
err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort})
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{})
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
+ t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d)
}
}
@@ -5428,7 +5498,7 @@ func TestListenBacklogFull(t *testing.T) {
}
lastPortOffset := uint16(0)
- for ; int(lastPortOffset) < listenBacklog+1; lastPortOffset++ {
+ for ; int(lastPortOffset) < listenBacklog; lastPortOffset++ {
executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */)
}
@@ -5452,7 +5522,7 @@ func TestListenBacklogFull(t *testing.T) {
for i := 0; i < listenBacklog; i++ {
_, _, err = c.EP.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -5469,7 +5539,7 @@ func TestListenBacklogFull(t *testing.T) {
// Now verify that there are no more connections that can be accepted.
_, _, err = c.EP.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
select {
case <-ch:
t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP)
@@ -5481,7 +5551,7 @@ func TestListenBacklogFull(t *testing.T) {
executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */)
newEP, _, err := c.EP.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -5794,7 +5864,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
// Try to accept the connections in the backlog.
newEP, _, err := c.EP.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -5865,7 +5935,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
defer c.WQ.EventUnregister(&we)
_, _, err = c.EP.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -5881,7 +5951,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
// Now verify that there are no more connections that can be accepted.
_, _, err = c.EP.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
select {
case <-ch:
t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP)
@@ -6020,7 +6090,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
t.Fatalf("Accept failed: %s", err)
}
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Try to accept the connections in the backlog.
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.ReadableEvents)
@@ -6088,7 +6158,7 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) {
// Verify that there is only one acceptable connection at this point.
_, _, err = c.EP.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -6158,7 +6228,7 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
// Now check that there is one acceptable connections.
_, _, err = c.EP.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -6210,7 +6280,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
defer wq.EventUnregister(&we)
aep, _, err := ep.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -6228,8 +6298,8 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
}
{
err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if _, ok := err.(*tcpip.ErrAlreadyConnected); !ok {
- t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %v, want: %s", err, &tcpip.ErrAlreadyConnected{})
+ if d := cmp.Diff(&tcpip.ErrAlreadyConnected{}, err); d != "" {
+ t.Errorf("Connect(...) mismatch (-want +got):\n%s", d)
}
}
// Listening endpoint remains in listen state.
@@ -6349,7 +6419,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// window increases to the full available buffer size.
for {
_, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
break
}
}
@@ -6480,7 +6550,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
totalCopied := 0
for {
res, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
break
}
totalCopied += res.Count
@@ -6672,7 +6742,7 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -6791,7 +6861,7 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -6898,7 +6968,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -6988,7 +7058,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
// Try to accept the connection.
c.EP, _, err = ep.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -7062,7 +7132,7 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -7212,7 +7282,7 @@ func TestTCPCloseWithData(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
select {
case <-ch:
@@ -7643,8 +7713,8 @@ func TestTCPDeferAccept(t *testing.T) {
irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
_, _, err := c.EP.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got c.EP.Accept(nil) = %v, want: %s", err, &tcpip.ErrWouldBlock{})
+ if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" {
+ t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d)
}
// Send data. This should result in an acceptable endpoint.
@@ -7702,8 +7772,8 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
_, _, err := c.EP.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got c.EP.Accept(nil) = %v, want: %s", err, &tcpip.ErrWouldBlock{})
+ if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" {
+ t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d)
}
// Sleep for a little of the tcpDeferAccept timeout.
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 16f8c5212..53efecc5a 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -1214,9 +1214,9 @@ func (c *Context) SACKEnabled() bool {
// SetGSOEnabled enables or disables generic segmentation offload.
func (c *Context) SetGSOEnabled(enable bool) {
if enable {
- c.linkEP.LinkEPCapabilities |= stack.CapabilityHardwareGSO
+ c.linkEP.SupportedGSOKind = stack.HWGSOSupported
} else {
- c.linkEP.LinkEPCapabilities &^= stack.CapabilityHardwareGSO
+ c.linkEP.SupportedGSOKind = stack.GSONotSupported
}
}