summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/BUILD3
-rw-r--r--pkg/abi/linux/context.go36
-rw-r--r--pkg/abi/linux/ioctl_tun.go3
-rw-r--r--pkg/abi/linux/signal.go344
-rw-r--r--pkg/abi/linux/xattr.go6
-rw-r--r--pkg/buffer/BUILD5
-rw-r--r--pkg/buffer/safemem.go133
-rw-r--r--pkg/buffer/safemem_test.go172
-rw-r--r--pkg/linuxerr/BUILD20
-rw-r--r--pkg/linuxerr/linuxerr.go184
-rw-r--r--pkg/linuxerr/linuxerr_test.go (renamed from pkg/syserror/syserror_test.go)42
-rw-r--r--pkg/marshal/primitive/BUILD2
-rw-r--r--pkg/marshal/primitive/primitive.go25
-rw-r--r--pkg/metric/metric.go69
-rw-r--r--pkg/metric/metric_test.go18
-rw-r--r--pkg/refs/refcounter.go6
-rw-r--r--pkg/ring0/kernel_amd64.go22
-rw-r--r--pkg/safecopy/BUILD1
-rw-r--r--pkg/safecopy/safecopy_unsafe.go14
-rw-r--r--pkg/seccomp/seccomp.go12
-rw-r--r--pkg/sentry/arch/BUILD4
-rw-r--r--pkg/sentry/arch/arch.go14
-rw-r--r--pkg/sentry/arch/signal.go276
-rw-r--r--pkg/sentry/arch/signal_act.go83
-rw-r--r--pkg/sentry/arch/signal_amd64.go26
-rw-r--r--pkg/sentry/arch/signal_arm64.go26
-rw-r--r--pkg/sentry/arch/signal_info.go66
-rw-r--r--pkg/sentry/arch/signal_stack.go68
-rw-r--r--pkg/sentry/arch/stack.go2
-rw-r--r--pkg/sentry/control/proc.go9
-rw-r--r--pkg/sentry/fs/attr.go14
-rw-r--r--pkg/sentry/fs/dev/dev.go8
-rw-r--r--pkg/sentry/fs/fsutil/host_mappable.go21
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached.go27
-rw-r--r--pkg/sentry/fs/gofer/BUILD1
-rw-r--r--pkg/sentry/fs/gofer/cache_policy.go2
-rw-r--r--pkg/sentry/fs/gofer/file.go19
-rw-r--r--pkg/sentry/fs/gofer/inode.go18
-rw-r--r--pkg/sentry/fs/gofer/path.go20
-rw-r--r--pkg/sentry/fs/proc/sys.go22
-rw-r--r--pkg/sentry/fs/proc/sys_net.go16
-rw-r--r--pkg/sentry/fs/proc/sys_net_state.go6
-rw-r--r--pkg/sentry/fs/tmpfs/fs.go2
-rw-r--r--pkg/sentry/fs/tmpfs/inode_file.go15
-rw-r--r--pkg/sentry/fs/tmpfs/tmpfs.go31
-rw-r--r--pkg/sentry/fs/user/user_test.go5
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/base.go27
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/cgroupfs.go16
-rw-r--r--pkg/sentry/fsimpl/gofer/BUILD1
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go24
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go23
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go3
-rw-r--r--pkg/sentry/fsimpl/gofer/special_file.go3
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go13
-rw-r--r--pkg/sentry/fsimpl/overlay/filesystem.go16
-rw-r--r--pkg/sentry/fsimpl/overlay/regular_file.go13
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go19
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys_test.go2
-rw-r--r--pkg/sentry/fsimpl/testutil/BUILD1
-rw-r--r--pkg/sentry/fsimpl/testutil/kernel.go7
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go23
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go57
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go13
-rw-r--r--pkg/sentry/fsmetric/fsmetric.go1
-rw-r--r--pkg/sentry/inet/inet.go3
-rw-r--r--pkg/sentry/inet/test_stack.go5
-rw-r--r--pkg/sentry/kernel/BUILD2
-rw-r--r--pkg/sentry/kernel/auth/BUILD2
-rw-r--r--pkg/sentry/kernel/auth/credentials.go2
-rw-r--r--pkg/sentry/kernel/cgroup.go71
-rw-r--r--pkg/sentry/kernel/fasync/BUILD1
-rw-r--r--pkg/sentry/kernel/fasync/fasync.go5
-rw-r--r--pkg/sentry/kernel/fd_table.go8
-rw-r--r--pkg/sentry/kernel/futex/BUILD2
-rw-r--r--pkg/sentry/kernel/kernel.go69
-rw-r--r--pkg/sentry/kernel/pending_signals.go9
-rw-r--r--pkg/sentry/kernel/pending_signals_state.go6
-rw-r--r--pkg/sentry/kernel/pipe/pipe_util.go8
-rw-r--r--pkg/sentry/kernel/posixtimer.go7
-rw-r--r--pkg/sentry/kernel/ptrace.go23
-rw-r--r--pkg/sentry/kernel/seccomp.go6
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore.go14
-rw-r--r--pkg/sentry/kernel/sessions.go3
-rw-r--r--pkg/sentry/kernel/signal.go15
-rw-r--r--pkg/sentry/kernel/signal_handlers.go17
-rw-r--r--pkg/sentry/kernel/task.go15
-rw-r--r--pkg/sentry/kernel/task_cgroup.go8
-rw-r--r--pkg/sentry/kernel/task_context.go5
-rw-r--r--pkg/sentry/kernel/task_exec.go3
-rw-r--r--pkg/sentry/kernel/task_exit.go43
-rw-r--r--pkg/sentry/kernel/task_sched.go10
-rw-r--r--pkg/sentry/kernel/task_signals.go124
-rw-r--r--pkg/sentry/kernel/task_start.go3
-rw-r--r--pkg/sentry/kernel/thread_group.go13
-rw-r--r--pkg/sentry/kernel/time/time.go19
-rw-r--r--pkg/sentry/loader/interpreter.go2
-rw-r--r--pkg/sentry/pgalloc/pgalloc.go18
-rw-r--r--pkg/sentry/platform/kvm/BUILD3
-rw-r--r--pkg/sentry/platform/kvm/address_space.go9
-rw-r--r--pkg/sentry/platform/kvm/address_space_amd64.go (renamed from pkg/tcpip/transport/tcp/rcv_state.go)23
-rw-r--r--pkg/sentry/platform/kvm/address_space_arm64.go (renamed from pkg/tcpip/time.s)14
-rw-r--r--pkg/sentry/platform/kvm/context.go7
-rw-r--r--pkg/sentry/platform/kvm/kvm_amd64_test.go5
-rw-r--r--pkg/sentry/platform/kvm/kvm_test.go33
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go32
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64.go12
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go8
-rw-r--r--pkg/sentry/platform/platform.go6
-rw-r--r--pkg/sentry/platform/ptrace/ptrace.go4
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_unsafe.go2
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go3
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_amd64.go2
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_arm64.go2
-rw-r--r--pkg/sentry/sighandling/sighandling_unsafe.go13
-rw-r--r--pkg/sentry/socket/hostinet/BUILD2
-rw-r--r--pkg/sentry/socket/hostinet/stack.go24
-rw-r--r--pkg/sentry/socket/netfilter/BUILD1
-rw-r--r--pkg/sentry/socket/netfilter/extensions.go16
-rw-r--r--pkg/sentry/socket/netfilter/ipv4.go9
-rw-r--r--pkg/sentry/socket/netfilter/ipv6.go9
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go22
-rw-r--r--pkg/sentry/socket/netfilter/owner_matcher.go24
-rw-r--r--pkg/sentry/socket/netfilter/targets.go4
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go5
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go5
-rw-r--r--pkg/sentry/socket/netstack/netstack.go184
-rw-r--r--pkg/sentry/socket/netstack/stack.go10
-rw-r--r--pkg/sentry/socket/netstack/tun.go2
-rw-r--r--pkg/sentry/strace/BUILD2
-rw-r--r--pkg/sentry/strace/clone.go46
-rw-r--r--pkg/sentry/strace/linux64_amd64.go2
-rw-r--r--pkg/sentry/strace/linux64_arm64.go2
-rw-r--r--pkg/sentry/strace/mmap.go92
-rw-r--r--pkg/sentry/strace/open.go42
-rw-r--r--pkg/sentry/strace/signal.go4
-rw-r--r--pkg/sentry/strace/strace.go4
-rw-r--r--pkg/sentry/strace/syscalls.go6
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go34
-rw-r--r--pkg/sentry/syscalls/linux/sys_signal.go41
-rw-r--r--pkg/sentry/syscalls/linux/sys_thread.go16
-rw-r--r--pkg/sentry/syscalls/linux/sys_time.go64
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/setstat.go4
-rw-r--r--pkg/sentry/time/BUILD2
-rw-r--r--pkg/sentry/time/sampler_amd64.go2
-rw-r--r--pkg/sentry/time/sampler_arm64.go2
-rw-r--r--pkg/sentry/vfs/file_description.go8
-rw-r--r--pkg/sentry/vfs/file_description_impl_util.go66
-rw-r--r--pkg/sentry/vfs/memxattr/BUILD1
-rw-r--r--pkg/sentry/vfs/memxattr/xattr.go33
-rw-r--r--pkg/sentry/vfs/mount.go2
-rw-r--r--pkg/sentry/vfs/opath.go38
-rw-r--r--pkg/sentry/watchdog/watchdog.go10
-rw-r--r--pkg/shim/BUILD14
-rw-r--r--pkg/shim/service.go50
-rw-r--r--pkg/shim/service_test.go121
-rw-r--r--pkg/shim/utils/annotations.go6
-rw-r--r--pkg/shim/utils/utils.go20
-rw-r--r--pkg/shim/utils/volumes.go20
-rw-r--r--pkg/shim/utils/volumes_test.go56
-rw-r--r--pkg/sync/BUILD40
-rw-r--r--pkg/sync/README.md2
-rw-r--r--pkg/sync/atomicptr/BUILD (renamed from pkg/sync/atomicptrtest/BUILD)13
-rw-r--r--pkg/sync/atomicptr/atomicptr_test.go (renamed from pkg/sync/atomicptrtest/atomicptr_test.go)0
-rw-r--r--pkg/sync/atomicptr/generic_atomicptr_unsafe.go (renamed from pkg/sync/generic_atomicptr_unsafe.go)0
-rw-r--r--pkg/sync/atomicptrmap/BUILD (renamed from pkg/sync/atomicptrmaptest/BUILD)25
-rw-r--r--pkg/sync/atomicptrmap/atomicptrmap.go (renamed from pkg/sync/atomicptrmaptest/atomicptrmap.go)0
-rw-r--r--pkg/sync/atomicptrmap/atomicptrmap_test.go (renamed from pkg/sync/atomicptrmaptest/atomicptrmap_test.go)0
-rw-r--r--pkg/sync/atomicptrmap/generic_atomicptrmap_unsafe.go (renamed from pkg/sync/generic_atomicptrmap_unsafe.go)0
-rw-r--r--pkg/sync/seqatomic/BUILD (renamed from pkg/sync/seqatomictest/BUILD)17
-rw-r--r--pkg/sync/seqatomic/generic_seqatomic_unsafe.go (renamed from pkg/sync/generic_seqatomic_unsafe.go)0
-rw-r--r--pkg/sync/seqatomic/seqatomic_test.go (renamed from pkg/sync/seqatomictest/seqatomic_test.go)0
-rw-r--r--pkg/syserror/BUILD11
-rw-r--r--pkg/tcpip/BUILD20
-rw-r--r--pkg/tcpip/checker/checker.go14
-rw-r--r--pkg/tcpip/faketime/faketime.go27
-rw-r--r--pkg/tcpip/faketime/faketime_test.go4
-rw-r--r--pkg/tcpip/header/ipv4.go24
-rw-r--r--pkg/tcpip/header/ndp_options.go6
-rw-r--r--pkg/tcpip/header/ndp_router_advert.go75
-rw-r--r--pkg/tcpip/header/ndp_test.go91
-rw-r--r--pkg/tcpip/header/tcp.go10
-rw-r--r--pkg/tcpip/link/channel/channel.go2
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go14
-rw-r--r--pkg/tcpip/network/arp/BUILD2
-rw-r--r--pkg/tcpip/network/arp/arp.go2
-rw-r--r--pkg/tcpip/network/arp/arp_test.go129
-rw-r--r--pkg/tcpip/network/internal/fragmentation/fragmentation.go2
-rw-r--r--pkg/tcpip/network/internal/fragmentation/fragmentation_test.go24
-rw-r--r--pkg/tcpip/network/internal/fragmentation/reassembler.go22
-rw-r--r--pkg/tcpip/network/internal/ip/duplicate_address_detection.go2
-rw-r--r--pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go16
-rw-r--r--pkg/tcpip/network/internal/ip/generic_multicast_protocol.go2
-rw-r--r--pkg/tcpip/network/internal/ip/stats.go5
-rw-r--r--pkg/tcpip/network/internal/testutil/BUILD1
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go33
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go164
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go158
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go41
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go37
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go164
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go205
-rw-r--r--pkg/tcpip/network/ipv6/mld_test.go4
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go83
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go154
-rw-r--r--pkg/tcpip/ports/ports.go10
-rw-r--r--pkg/tcpip/ports/ports_test.go7
-rw-r--r--pkg/tcpip/socketops.go15
-rw-r--r--pkg/tcpip/stack/BUILD1
-rw-r--r--pkg/tcpip/stack/conntrack.go14
-rw-r--r--pkg/tcpip/stack/forwarding_test.go80
-rw-r--r--pkg/tcpip/stack/iptables.go9
-rw-r--r--pkg/tcpip/stack/iptables_targets.go8
-rw-r--r--pkg/tcpip/stack/iptables_types.go2
-rw-r--r--pkg/tcpip/stack/ndp_test.go558
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go2
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go330
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go42
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go231
-rw-r--r--pkg/tcpip/stack/nic.go93
-rw-r--r--pkg/tcpip/stack/nic_stats.go74
-rw-r--r--pkg/tcpip/stack/nic_test.go32
-rw-r--r--pkg/tcpip/stack/nud.go80
-rw-r--r--pkg/tcpip/stack/nud_test.go37
-rw-r--r--pkg/tcpip/stack/packet_buffer.go8
-rw-r--r--pkg/tcpip/stack/packet_buffer_test.go31
-rw-r--r--pkg/tcpip/stack/rand.go4
-rw-r--r--pkg/tcpip/stack/registration.go16
-rw-r--r--pkg/tcpip/stack/route.go2
-rw-r--r--pkg/tcpip/stack/stack.go232
-rw-r--r--pkg/tcpip/stack/stack_global_state.go72
-rw-r--r--pkg/tcpip/stack/stack_test.go215
-rw-r--r--pkg/tcpip/stack/tcp.go14
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go8
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go73
-rw-r--r--pkg/tcpip/stdclock.go16
-rw-r--r--pkg/tcpip/tcpip.go157
-rw-r--r--pkg/tcpip/tcpip_test.go21
-rw-r--r--pkg/tcpip/tests/integration/BUILD2
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go248
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go285
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go453
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go11
-rw-r--r--pkg/tcpip/testutil/testutil.go12
-rw-r--r--pkg/tcpip/timer_test.go131
-rw-r--r--pkg/tcpip/transport/icmp/BUILD21
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go41
-rw-r--r--pkg/tcpip/transport/icmp/endpoint_state.go12
-rw-r--r--pkg/tcpip/transport/icmp/icmp_test.go235
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go2
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go24
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go12
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go22
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go12
-rw-r--r--pkg/tcpip/transport/tcp/BUILD3
-rw-r--r--pkg/tcpip/transport/tcp/accept.go28
-rw-r--r--pkg/tcpip/transport/tcp/connect.go86
-rw-r--r--pkg/tcpip/transport/tcp/cubic.go19
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go26
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go2
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go110
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go35
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go12
-rw-r--r--pkg/tcpip/transport/tcp/rack.go14
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go37
-rw-r--r--pkg/tcpip/transport/tcp/segment.go31
-rw-r--r--pkg/tcpip/transport/tcp/segment_state.go22
-rw-r--r--pkg/tcpip/transport/tcp/segment_test.go6
-rw-r--r--pkg/tcpip/transport/tcp/snd.go38
-rw-r--r--pkg/tcpip/transport/tcp/snd_state.go42
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go6
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go102
-rw-r--r--pkg/tcpip/transport/tcp/timer.go51
-rw-r--r--pkg/tcpip/transport/tcp/timer_test.go7
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go71
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go26
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go115
-rw-r--r--pkg/test/testutil/BUILD1
-rw-r--r--pkg/test/testutil/testutil.go2
-rw-r--r--pkg/usermem/BUILD1
-rw-r--r--pkg/usermem/marshal.go43
283 files changed, 6152 insertions, 4477 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index a461bb65e..38288bdb7 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -15,6 +15,7 @@ go_library(
"bpf.go",
"capability.go",
"clone.go",
+ "context.go",
"dev.go",
"elf.go",
"epoll.go",
@@ -77,6 +78,8 @@ go_library(
deps = [
"//pkg/abi",
"//pkg/bits",
+ "//pkg/context",
+ "//pkg/hostarch",
"//pkg/marshal",
"//pkg/marshal/primitive",
],
diff --git a/pkg/abi/linux/context.go b/pkg/abi/linux/context.go
new file mode 100644
index 000000000..d2dbba183
--- /dev/null
+++ b/pkg/abi/linux/context.go
@@ -0,0 +1,36 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+)
+
+// contextID is the linux package's type for context.Context.Value keys.
+type contextID int
+
+const (
+ // CtxSignalNoInfoFunc is a Context.Value key for a function to send signals.
+ CtxSignalNoInfoFunc contextID = iota
+)
+
+// SignalNoInfoFuncFromContext returns a callback function that can be used to send a
+// signal to the given context.
+func SignalNoInfoFuncFromContext(ctx context.Context) func(Signal) error {
+ if f := ctx.Value(CtxSignalNoInfoFunc); f != nil {
+ return f.(func(Signal) error)
+ }
+ return nil
+}
diff --git a/pkg/abi/linux/ioctl_tun.go b/pkg/abi/linux/ioctl_tun.go
index c59c9c136..ea4fdca0f 100644
--- a/pkg/abi/linux/ioctl_tun.go
+++ b/pkg/abi/linux/ioctl_tun.go
@@ -26,4 +26,7 @@ const (
IFF_TAP = 0x0002
IFF_NO_PI = 0x1000
IFF_NOFILTER = 0x1000
+
+ // According to linux/if_tun.h "This flag has no real effect"
+ IFF_ONE_QUEUE = 0x2000
)
diff --git a/pkg/abi/linux/signal.go b/pkg/abi/linux/signal.go
index 6ca57ffbb..06a4c6401 100644
--- a/pkg/abi/linux/signal.go
+++ b/pkg/abi/linux/signal.go
@@ -16,6 +16,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/bits"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
const (
@@ -165,7 +166,7 @@ const (
SIG_IGN = 1
)
-// Signal action flags for rt_sigaction(2), from uapi/asm-generic/signal.h
+// Signal action flags for rt_sigaction(2), from uapi/asm-generic/signal.h.
const (
SA_NOCLDSTOP = 0x00000001
SA_NOCLDWAIT = 0x00000002
@@ -179,21 +180,17 @@ const (
SA_ONESHOT = SA_RESETHAND
)
-// Signal info types.
+// Signal stack flags for signalstack(2), from include/uapi/linux/signal.h.
const (
- SI_MASK = 0xffff0000
- SI_KILL = 0 << 16
- SI_TIMER = 1 << 16
- SI_POLL = 2 << 16
- SI_FAULT = 3 << 16
- SI_CHLD = 4 << 16
- SI_RT = 5 << 16
- SI_MESGQ = 6 << 16
- SI_SYS = 7 << 16
+ SS_ONSTACK = 1
+ SS_DISABLE = 2
)
// SIGPOLL si_codes.
const (
+ // SI_POLL is defined as __SI_POLL in Linux 2.6.
+ SI_POLL = 2 << 16
+
// POLL_IN indicates that data input available.
POLL_IN = SI_POLL | 1
@@ -213,6 +210,75 @@ const (
POLL_HUP = SI_POLL | 6
)
+// Possible values for si_code.
+const (
+ // SI_USER is sent by kill, sigsend, raise.
+ SI_USER = 0
+
+ // SI_KERNEL is sent by the kernel from somewhere.
+ SI_KERNEL = 0x80
+
+ // SI_QUEUE is sent by sigqueue.
+ SI_QUEUE = -1
+
+ // SI_TIMER is sent by timer expiration.
+ SI_TIMER = -2
+
+ // SI_MESGQ is sent by real time mesq state change.
+ SI_MESGQ = -3
+
+ // SI_ASYNCIO is sent by AIO completion.
+ SI_ASYNCIO = -4
+
+ // SI_SIGIO is sent by queued SIGIO.
+ SI_SIGIO = -5
+
+ // SI_TKILL is sent by tkill system call.
+ SI_TKILL = -6
+
+ // SI_DETHREAD is sent by execve() killing subsidiary threads.
+ SI_DETHREAD = -7
+
+ // SI_ASYNCNL is sent by glibc async name lookup completion.
+ SI_ASYNCNL = -60
+)
+
+// CLD_* codes are only meaningful for SIGCHLD.
+const (
+ // CLD_EXITED indicates that a task exited.
+ CLD_EXITED = 1
+
+ // CLD_KILLED indicates that a task was killed by a signal.
+ CLD_KILLED = 2
+
+ // CLD_DUMPED indicates that a task was killed by a signal and then dumped
+ // core.
+ CLD_DUMPED = 3
+
+ // CLD_TRAPPED indicates that a task was stopped by ptrace.
+ CLD_TRAPPED = 4
+
+ // CLD_STOPPED indicates that a thread group completed a group stop.
+ CLD_STOPPED = 5
+
+ // CLD_CONTINUED indicates that a group-stopped thread group was continued.
+ CLD_CONTINUED = 6
+)
+
+// SYS_* codes are only meaningful for SIGSYS.
+const (
+ // SYS_SECCOMP indicates that a signal originates from seccomp.
+ SYS_SECCOMP = 1
+)
+
+// Possible values for Sigevent.Notify, aka struct sigevent::sigev_notify.
+const (
+ SIGEV_SIGNAL = 0
+ SIGEV_NONE = 1
+ SIGEV_THREAD = 2
+ SIGEV_THREAD_ID = 4
+)
+
// Sigevent represents struct sigevent.
//
// +marshal
@@ -227,10 +293,252 @@ type Sigevent struct {
UnRemainder [44]byte
}
-// Possible values for Sigevent.Notify, aka struct sigevent::sigev_notify.
-const (
- SIGEV_SIGNAL = 0
- SIGEV_NONE = 1
- SIGEV_THREAD = 2
- SIGEV_THREAD_ID = 4
-)
+// SigAction represents struct sigaction.
+//
+// +marshal
+// +stateify savable
+type SigAction struct {
+ Handler uint64
+ Flags uint64
+ Restorer uint64
+ Mask SignalSet
+}
+
+// SignalStack represents information about a user stack, and is equivalent to
+// stack_t.
+//
+// +marshal
+// +stateify savable
+type SignalStack struct {
+ Addr uint64
+ Flags uint32
+ _ uint32
+ Size uint64
+}
+
+// Contains checks if the stack pointer is within this stack.
+func (s *SignalStack) Contains(sp hostarch.Addr) bool {
+ return hostarch.Addr(s.Addr) < sp && sp <= hostarch.Addr(s.Addr+s.Size)
+}
+
+// Top returns the stack's top address.
+func (s *SignalStack) Top() hostarch.Addr {
+ return hostarch.Addr(s.Addr + s.Size)
+}
+
+// IsEnabled returns true iff this signal stack is marked as enabled.
+func (s *SignalStack) IsEnabled() bool {
+ return s.Flags&SS_DISABLE == 0
+}
+
+// SignalInfo represents information about a signal being delivered, and is
+// equivalent to struct siginfo in linux kernel(linux/include/uapi/asm-generic/siginfo.h).
+//
+// +marshal
+// +stateify savable
+type SignalInfo struct {
+ Signo int32 // Signal number
+ Errno int32 // Errno value
+ Code int32 // Signal code
+ _ uint32
+
+ // struct siginfo::_sifields is a union. In SignalInfo, fields in the union
+ // are accessed through methods.
+ //
+ // For reference, here is the definition of _sifields: (_sigfault._trapno,
+ // which does not exist on x86, omitted for clarity)
+ //
+ // union {
+ // int _pad[SI_PAD_SIZE];
+ //
+ // /* kill() */
+ // struct {
+ // __kernel_pid_t _pid; /* sender's pid */
+ // __ARCH_SI_UID_T _uid; /* sender's uid */
+ // } _kill;
+ //
+ // /* POSIX.1b timers */
+ // struct {
+ // __kernel_timer_t _tid; /* timer id */
+ // int _overrun; /* overrun count */
+ // char _pad[sizeof( __ARCH_SI_UID_T) - sizeof(int)];
+ // sigval_t _sigval; /* same as below */
+ // int _sys_private; /* not to be passed to user */
+ // } _timer;
+ //
+ // /* POSIX.1b signals */
+ // struct {
+ // __kernel_pid_t _pid; /* sender's pid */
+ // __ARCH_SI_UID_T _uid; /* sender's uid */
+ // sigval_t _sigval;
+ // } _rt;
+ //
+ // /* SIGCHLD */
+ // struct {
+ // __kernel_pid_t _pid; /* which child */
+ // __ARCH_SI_UID_T _uid; /* sender's uid */
+ // int _status; /* exit code */
+ // __ARCH_SI_CLOCK_T _utime;
+ // __ARCH_SI_CLOCK_T _stime;
+ // } _sigchld;
+ //
+ // /* SIGILL, SIGFPE, SIGSEGV, SIGBUS */
+ // struct {
+ // void *_addr; /* faulting insn/memory ref. */
+ // short _addr_lsb; /* LSB of the reported address */
+ // } _sigfault;
+ //
+ // /* SIGPOLL */
+ // struct {
+ // __ARCH_SI_BAND_T _band; /* POLL_IN, POLL_OUT, POLL_MSG */
+ // int _fd;
+ // } _sigpoll;
+ //
+ // /* SIGSYS */
+ // struct {
+ // void *_call_addr; /* calling user insn */
+ // int _syscall; /* triggering system call number */
+ // unsigned int _arch; /* AUDIT_ARCH_* of syscall */
+ // } _sigsys;
+ // } _sifields;
+ //
+ // _sifields is padded so that the size of siginfo is SI_MAX_SIZE = 128
+ // bytes.
+ Fields [128 - 16]byte
+}
+
+// FixSignalCodeForUser fixes up si_code.
+//
+// The si_code we get from Linux may contain the kernel-specific code in the
+// top 16 bits if it's positive (e.g., from ptrace). Linux's
+// copy_siginfo_to_user does
+// err |= __put_user((short)from->si_code, &to->si_code);
+// to mask out those bits and we need to do the same.
+func (s *SignalInfo) FixSignalCodeForUser() {
+ if s.Code > 0 {
+ s.Code &= 0x0000ffff
+ }
+}
+
+// PID returns the si_pid field.
+func (s *SignalInfo) PID() int32 {
+ return int32(hostarch.ByteOrder.Uint32(s.Fields[0:4]))
+}
+
+// SetPID mutates the si_pid field.
+func (s *SignalInfo) SetPID(val int32) {
+ hostarch.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
+}
+
+// UID returns the si_uid field.
+func (s *SignalInfo) UID() int32 {
+ return int32(hostarch.ByteOrder.Uint32(s.Fields[4:8]))
+}
+
+// SetUID mutates the si_uid field.
+func (s *SignalInfo) SetUID(val int32) {
+ hostarch.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
+}
+
+// Sigval returns the sigval field, which is aliased to both si_int and si_ptr.
+func (s *SignalInfo) Sigval() uint64 {
+ return hostarch.ByteOrder.Uint64(s.Fields[8:16])
+}
+
+// SetSigval mutates the sigval field.
+func (s *SignalInfo) SetSigval(val uint64) {
+ hostarch.ByteOrder.PutUint64(s.Fields[8:16], val)
+}
+
+// TimerID returns the si_timerid field.
+func (s *SignalInfo) TimerID() TimerID {
+ return TimerID(hostarch.ByteOrder.Uint32(s.Fields[0:4]))
+}
+
+// SetTimerID sets the si_timerid field.
+func (s *SignalInfo) SetTimerID(val TimerID) {
+ hostarch.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
+}
+
+// Overrun returns the si_overrun field.
+func (s *SignalInfo) Overrun() int32 {
+ return int32(hostarch.ByteOrder.Uint32(s.Fields[4:8]))
+}
+
+// SetOverrun sets the si_overrun field.
+func (s *SignalInfo) SetOverrun(val int32) {
+ hostarch.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
+}
+
+// Addr returns the si_addr field.
+func (s *SignalInfo) Addr() uint64 {
+ return hostarch.ByteOrder.Uint64(s.Fields[0:8])
+}
+
+// SetAddr sets the si_addr field.
+func (s *SignalInfo) SetAddr(val uint64) {
+ hostarch.ByteOrder.PutUint64(s.Fields[0:8], val)
+}
+
+// Status returns the si_status field.
+func (s *SignalInfo) Status() int32 {
+ return int32(hostarch.ByteOrder.Uint32(s.Fields[8:12]))
+}
+
+// SetStatus mutates the si_status field.
+func (s *SignalInfo) SetStatus(val int32) {
+ hostarch.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
+}
+
+// CallAddr returns the si_call_addr field.
+func (s *SignalInfo) CallAddr() uint64 {
+ return hostarch.ByteOrder.Uint64(s.Fields[0:8])
+}
+
+// SetCallAddr mutates the si_call_addr field.
+func (s *SignalInfo) SetCallAddr(val uint64) {
+ hostarch.ByteOrder.PutUint64(s.Fields[0:8], val)
+}
+
+// Syscall returns the si_syscall field.
+func (s *SignalInfo) Syscall() int32 {
+ return int32(hostarch.ByteOrder.Uint32(s.Fields[8:12]))
+}
+
+// SetSyscall mutates the si_syscall field.
+func (s *SignalInfo) SetSyscall(val int32) {
+ hostarch.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
+}
+
+// Arch returns the si_arch field.
+func (s *SignalInfo) Arch() uint32 {
+ return hostarch.ByteOrder.Uint32(s.Fields[12:16])
+}
+
+// SetArch mutates the si_arch field.
+func (s *SignalInfo) SetArch(val uint32) {
+ hostarch.ByteOrder.PutUint32(s.Fields[12:16], val)
+}
+
+// Band returns the si_band field.
+func (s *SignalInfo) Band() int64 {
+ return int64(hostarch.ByteOrder.Uint64(s.Fields[0:8]))
+}
+
+// SetBand mutates the si_band field.
+func (s *SignalInfo) SetBand(val int64) {
+ // Note: this assumes the platform uses `long` as `__ARCH_SI_BAND_T`.
+ // On some platforms, which gVisor doesn't support, `__ARCH_SI_BAND_T` is
+ // `int`. See siginfo.h.
+ hostarch.ByteOrder.PutUint64(s.Fields[0:8], uint64(val))
+}
+
+// FD returns the si_fd field.
+func (s *SignalInfo) FD() uint32 {
+ return hostarch.ByteOrder.Uint32(s.Fields[8:12])
+}
+
+// SetFD mutates the si_fd field.
+func (s *SignalInfo) SetFD(val uint32) {
+ hostarch.ByteOrder.PutUint32(s.Fields[8:12], val)
+}
diff --git a/pkg/abi/linux/xattr.go b/pkg/abi/linux/xattr.go
index 8ef837f27..1fa7a4f4f 100644
--- a/pkg/abi/linux/xattr.go
+++ b/pkg/abi/linux/xattr.go
@@ -23,6 +23,12 @@ const (
XATTR_CREATE = 1
XATTR_REPLACE = 2
+ XATTR_SECURITY_PREFIX = "security."
+ XATTR_SECURITY_PREFIX_LEN = len(XATTR_SECURITY_PREFIX)
+
+ XATTR_SYSTEM_PREFIX = "system."
+ XATTR_SYSTEM_PREFIX_LEN = len(XATTR_SYSTEM_PREFIX)
+
XATTR_TRUSTED_PREFIX = "trusted."
XATTR_TRUSTED_PREFIX_LEN = len(XATTR_TRUSTED_PREFIX)
diff --git a/pkg/buffer/BUILD b/pkg/buffer/BUILD
index 2a2e3d1aa..19cd28a32 100644
--- a/pkg/buffer/BUILD
+++ b/pkg/buffer/BUILD
@@ -21,7 +21,6 @@ go_library(
"buffer.go",
"buffer_list.go",
"pool.go",
- "safemem.go",
"view.go",
"view_unsafe.go",
],
@@ -29,8 +28,6 @@ go_library(
deps = [
"//pkg/context",
"//pkg/log",
- "//pkg/safemem",
- "//pkg/usermem",
],
)
@@ -40,12 +37,10 @@ go_test(
srcs = [
"buffer_test.go",
"pool_test.go",
- "safemem_test.go",
"view_test.go",
],
library = ":buffer",
deps = [
- "//pkg/safemem",
"//pkg/state",
],
)
diff --git a/pkg/buffer/safemem.go b/pkg/buffer/safemem.go
deleted file mode 100644
index 8b42575b4..000000000
--- a/pkg/buffer/safemem.go
+++ /dev/null
@@ -1,133 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package buffer
-
-import (
- "gvisor.dev/gvisor/pkg/safemem"
-)
-
-// WriteBlock returns this buffer as a write Block.
-func (b *buffer) WriteBlock() safemem.Block {
- return safemem.BlockFromSafeSlice(b.WriteSlice())
-}
-
-// ReadBlock returns this buffer as a read Block.
-func (b *buffer) ReadBlock() safemem.Block {
- return safemem.BlockFromSafeSlice(b.ReadSlice())
-}
-
-// WriteFromSafememReader writes up to count bytes from r to v and advances the
-// write index by the number of bytes written. It calls r.ReadToBlocks() at
-// most once.
-func (v *View) WriteFromSafememReader(r safemem.Reader, count uint64) (uint64, error) {
- if count == 0 {
- return 0, nil
- }
-
- var (
- dst safemem.BlockSeq
- blocks []safemem.Block
- )
-
- // Need at least one buffer.
- firstBuf := v.data.Back()
- if firstBuf == nil {
- firstBuf = v.pool.get()
- v.data.PushBack(firstBuf)
- }
-
- // Does the last block have sufficient capacity alone?
- if l := uint64(firstBuf.WriteSize()); l >= count {
- dst = safemem.BlockSeqOf(firstBuf.WriteBlock().TakeFirst64(count))
- } else {
- // Append blocks until sufficient.
- count -= l
- blocks = append(blocks, firstBuf.WriteBlock())
- for count > 0 {
- emptyBuf := v.pool.get()
- v.data.PushBack(emptyBuf)
- block := emptyBuf.WriteBlock().TakeFirst64(count)
- count -= uint64(block.Len())
- blocks = append(blocks, block)
- }
- dst = safemem.BlockSeqFromSlice(blocks)
- }
-
- // Perform I/O.
- n, err := r.ReadToBlocks(dst)
- v.size += int64(n)
-
- // Update all indices.
- for left := n; left > 0; firstBuf = firstBuf.Next() {
- if l := firstBuf.WriteSize(); left >= uint64(l) {
- firstBuf.WriteMove(l) // Whole block.
- left -= uint64(l)
- } else {
- firstBuf.WriteMove(int(left)) // Partial block.
- left = 0
- }
- }
-
- return n, err
-}
-
-// WriteFromBlocks implements safemem.Writer.WriteFromBlocks. It advances the
-// write index by the number of bytes written.
-func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
- return v.WriteFromSafememReader(&safemem.BlockSeqReader{srcs}, srcs.NumBytes())
-}
-
-// ReadToSafememWriter reads up to count bytes from v to w. It does not advance
-// the read index. It calls w.WriteFromBlocks() at most once.
-func (v *View) ReadToSafememWriter(w safemem.Writer, count uint64) (uint64, error) {
- if count == 0 {
- return 0, nil
- }
-
- var (
- src safemem.BlockSeq
- blocks []safemem.Block
- )
-
- firstBuf := v.data.Front()
- if firstBuf == nil {
- return 0, nil // No EOF.
- }
-
- // Is all the data in a single block?
- if l := uint64(firstBuf.ReadSize()); l >= count {
- src = safemem.BlockSeqOf(firstBuf.ReadBlock().TakeFirst64(count))
- } else {
- // Build a list of all the buffers.
- count -= l
- blocks = append(blocks, firstBuf.ReadBlock())
- for buf := firstBuf.Next(); buf != nil && count > 0; buf = buf.Next() {
- block := buf.ReadBlock().TakeFirst64(count)
- count -= uint64(block.Len())
- blocks = append(blocks, block)
- }
- src = safemem.BlockSeqFromSlice(blocks)
- }
-
- // Perform I/O. As documented, we don't advance the read index.
- return w.WriteFromBlocks(src)
-}
-
-// ReadToBlocks implements safemem.Reader.ReadToBlocks. It does not advance the
-// read index by the number of bytes read, such that it's only safe to call if
-// the caller guarantees that ReadToBlocks will only be called once.
-func (v *View) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
- return v.ReadToSafememWriter(&safemem.BlockSeqWriter{dsts}, dsts.NumBytes())
-}
diff --git a/pkg/buffer/safemem_test.go b/pkg/buffer/safemem_test.go
deleted file mode 100644
index 721cc5934..000000000
--- a/pkg/buffer/safemem_test.go
+++ /dev/null
@@ -1,172 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package buffer
-
-import (
- "bytes"
- "strings"
- "testing"
-
- "gvisor.dev/gvisor/pkg/safemem"
-)
-
-func TestSafemem(t *testing.T) {
- const bufferSize = defaultBufferSize
-
- testCases := []struct {
- name string
- input string
- output string
- readLen int
- op func(*View)
- }{
- // Basic coverage.
- {
- name: "short",
- input: "010",
- output: "010",
- },
- {
- name: "long",
- input: "0" + strings.Repeat("1", bufferSize) + "0",
- output: "0" + strings.Repeat("1", bufferSize) + "0",
- },
- {
- name: "short-read",
- input: "0",
- readLen: 100, // > size.
- output: "0",
- },
- {
- name: "zero-read",
- input: "0",
- output: "",
- },
- {
- name: "read-empty",
- input: "",
- readLen: 1, // > size.
- output: "",
- },
-
- // Ensure offsets work.
- {
- name: "offsets-short",
- input: "012",
- output: "2",
- op: func(v *View) {
- v.TrimFront(2)
- },
- },
- {
- name: "offsets-long0",
- input: "0" + strings.Repeat("1", bufferSize) + "0",
- output: strings.Repeat("1", bufferSize) + "0",
- op: func(v *View) {
- v.TrimFront(1)
- },
- },
- {
- name: "offsets-long1",
- input: "0" + strings.Repeat("1", bufferSize) + "0",
- output: strings.Repeat("1", bufferSize-1) + "0",
- op: func(v *View) {
- v.TrimFront(2)
- },
- },
- {
- name: "offsets-long2",
- input: "0" + strings.Repeat("1", bufferSize) + "0",
- output: "10",
- op: func(v *View) {
- v.TrimFront(bufferSize)
- },
- },
-
- // Ensure truncation works.
- {
- name: "truncate-short",
- input: "012",
- output: "01",
- op: func(v *View) {
- v.Truncate(2)
- },
- },
- {
- name: "truncate-long0",
- input: "0" + strings.Repeat("1", bufferSize) + "0",
- output: "0" + strings.Repeat("1", bufferSize),
- op: func(v *View) {
- v.Truncate(bufferSize + 1)
- },
- },
- {
- name: "truncate-long1",
- input: "0" + strings.Repeat("1", bufferSize) + "0",
- output: "0" + strings.Repeat("1", bufferSize-1),
- op: func(v *View) {
- v.Truncate(bufferSize)
- },
- },
- {
- name: "truncate-long2",
- input: "0" + strings.Repeat("1", bufferSize) + "0",
- output: "01",
- op: func(v *View) {
- v.Truncate(2)
- },
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- // Construct the new view.
- var view View
- bs := safemem.BlockSeqOf(safemem.BlockFromSafeSlice([]byte(tc.input)))
- n, err := view.WriteFromBlocks(bs)
- if err != nil {
- t.Errorf("expected err nil, got %v", err)
- }
- if n != uint64(len(tc.input)) {
- t.Errorf("expected %d bytes, got %d", len(tc.input), n)
- }
-
- // Run the operation.
- if tc.op != nil {
- tc.op(&view)
- }
-
- // Read and validate.
- readLen := tc.readLen
- if readLen == 0 {
- readLen = len(tc.output) // Default.
- }
- out := make([]byte, readLen)
- bs = safemem.BlockSeqOf(safemem.BlockFromSafeSlice(out))
- n, err = view.ReadToBlocks(bs)
- if err != nil {
- t.Errorf("expected nil, got %v", err)
- }
- if n != uint64(len(tc.output)) {
- t.Errorf("expected %d bytes, got %d", len(tc.output), n)
- }
-
- // Ensure the contents are correct.
- if !bytes.Equal(out[:n], []byte(tc.output[:n])) {
- t.Errorf("contents are wrong: expected %q, got %q", tc.output, string(out))
- }
- })
- }
-}
diff --git a/pkg/linuxerr/BUILD b/pkg/linuxerr/BUILD
new file mode 100644
index 000000000..c5abbd34f
--- /dev/null
+++ b/pkg/linuxerr/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "linuxerr",
+ srcs = ["linuxerr.go"],
+ visibility = ["//visibility:public"],
+ deps = ["//pkg/abi/linux"],
+)
+
+go_test(
+ name = "linuxerr_test",
+ srcs = ["linuxerr_test.go"],
+ deps = [
+ ":linuxerr",
+ "//pkg/syserror",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/linuxerr/linuxerr.go b/pkg/linuxerr/linuxerr.go
new file mode 100644
index 000000000..f45caaadf
--- /dev/null
+++ b/pkg/linuxerr/linuxerr.go
@@ -0,0 +1,184 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package linuxerr contains syscall error codes exported as an error interface
+// pointers. This allows for fast comparison and return operations comperable
+// to unix.Errno constants.
+package linuxerr
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+// Error represents a syscall errno with a descriptive message.
+type Error struct {
+ errno linux.Errno
+ message string
+}
+
+func new(err linux.Errno, message string) *Error {
+ return &Error{
+ errno: err,
+ message: message,
+ }
+}
+
+// Error implements error.Error.
+func (e *Error) Error() string { return e.message }
+
+// Errno returns the underlying linux.Errno value.
+func (e *Error) Errno() linux.Errno { return e.errno }
+
+// The following varables have the same meaning as their errno equivalent.
+
+// Errno values from include/uapi/asm-generic/errno-base.h.
+var (
+ EPERM = new(linux.EPERM, "operation not permitted")
+ ENOENT = new(linux.ENOENT, "no such file or directory")
+ ESRCH = new(linux.ESRCH, "no such process")
+ EINTR = new(linux.EINTR, "interrupted system call")
+ EIO = new(linux.EIO, "I/O error")
+ ENXIO = new(linux.ENXIO, "no such device or address")
+ E2BIG = new(linux.E2BIG, "argument list too long")
+ ENOEXEC = new(linux.ENOEXEC, "exec format error")
+ EBADF = new(linux.EBADF, "bad file number")
+ ECHILD = new(linux.ECHILD, "no child processes")
+ EAGAIN = new(linux.EAGAIN, "try again")
+ ENOMEM = new(linux.ENOMEM, "out of memory")
+ EACCES = new(linux.EACCES, "permission denied")
+ EFAULT = new(linux.EFAULT, "bad address")
+ ENOTBLK = new(linux.ENOTBLK, "block device required")
+ EBUSY = new(linux.EBUSY, "device or resource busy")
+ EEXIST = new(linux.EEXIST, "file exists")
+ EXDEV = new(linux.EXDEV, "cross-device link")
+ ENODEV = new(linux.ENODEV, "no such device")
+ ENOTDIR = new(linux.ENOTDIR, "not a directory")
+ EISDIR = new(linux.EISDIR, "is a directory")
+ EINVAL = new(linux.EINVAL, "invalid argument")
+ ENFILE = new(linux.ENFILE, "file table overflow")
+ EMFILE = new(linux.EMFILE, "too many open files")
+ ENOTTY = new(linux.ENOTTY, "not a typewriter")
+ ETXTBSY = new(linux.ETXTBSY, "text file busy")
+ EFBIG = new(linux.EFBIG, "file too large")
+ ENOSPC = new(linux.ENOSPC, "no space left on device")
+ ESPIPE = new(linux.ESPIPE, "illegal seek")
+ EROFS = new(linux.EROFS, "read-only file system")
+ EMLINK = new(linux.EMLINK, "too many links")
+ EPIPE = new(linux.EPIPE, "broken pipe")
+ EDOM = new(linux.EDOM, "math argument out of domain of func")
+ ERANGE = new(linux.ERANGE, "math result not representable")
+)
+
+// Errno values from include/uapi/asm-generic/errno.h.
+var (
+ EDEADLK = new(linux.EDEADLK, "resource deadlock would occur")
+ ENAMETOOLONG = new(linux.ENAMETOOLONG, "file name too long")
+ ENOLCK = new(linux.ENOLCK, "no record locks available")
+ ENOSYS = new(linux.ENOSYS, "invalid system call number")
+ ENOTEMPTY = new(linux.ENOTEMPTY, "directory not empty")
+ ELOOP = new(linux.ELOOP, "too many symbolic links encountered")
+ EWOULDBLOCK = new(linux.EWOULDBLOCK, "operation would block")
+ ENOMSG = new(linux.ENOMSG, "no message of desired type")
+ EIDRM = new(linux.EIDRM, "identifier removed")
+ ECHRNG = new(linux.ECHRNG, "channel number out of range")
+ EL2NSYNC = new(linux.EL2NSYNC, "level 2 not synchronized")
+ EL3HLT = new(linux.EL3HLT, "level 3 halted")
+ EL3RST = new(linux.EL3RST, "level 3 reset")
+ ELNRNG = new(linux.ELNRNG, "link number out of range")
+ EUNATCH = new(linux.EUNATCH, "protocol driver not attached")
+ ENOCSI = new(linux.ENOCSI, "no CSI structure available")
+ EL2HLT = new(linux.EL2HLT, "level 2 halted")
+ EBADE = new(linux.EBADE, "invalid exchange")
+ EBADR = new(linux.EBADR, "invalid request descriptor")
+ EXFULL = new(linux.EXFULL, "exchange full")
+ ENOANO = new(linux.ENOANO, "no anode")
+ EBADRQC = new(linux.EBADRQC, "invalid request code")
+ EBADSLT = new(linux.EBADSLT, "invalid slot")
+ EDEADLOCK = new(linux.EDEADLOCK, EDEADLK.message)
+ EBFONT = new(linux.EBFONT, "bad font file format")
+ ENOSTR = new(linux.ENOSTR, "device not a stream")
+ ENODATA = new(linux.ENODATA, "no data available")
+ ETIME = new(linux.ETIME, "timer expired")
+ ENOSR = new(linux.ENOSR, "out of streams resources")
+ ENONET = new(linux.ENOENT, "machine is not on the network")
+ ENOPKG = new(linux.ENOPKG, "package not installed")
+ EREMOTE = new(linux.EREMOTE, "object is remote")
+ ENOLINK = new(linux.ENOLINK, "link has been severed")
+ EADV = new(linux.EADV, "advertise error")
+ ESRMNT = new(linux.ESRMNT, "srmount error")
+ ECOMM = new(linux.ECOMM, "communication error on send")
+ EPROTO = new(linux.EPROTO, "protocol error")
+ EMULTIHOP = new(linux.EMULTIHOP, "multihop attempted")
+ EDOTDOT = new(linux.EDOTDOT, "RFS specific error")
+ EBADMSG = new(linux.EBADMSG, "not a data message")
+ EOVERFLOW = new(linux.EOVERFLOW, "value too large for defined data type")
+ ENOTUNIQ = new(linux.ENOTUNIQ, "name not unique on network")
+ EBADFD = new(linux.EBADFD, "file descriptor in bad state")
+ EREMCHG = new(linux.EREMCHG, "remote address changed")
+ ELIBACC = new(linux.ELIBACC, "can not access a needed shared library")
+ ELIBBAD = new(linux.ELIBBAD, "accessing a corrupted shared library")
+ ELIBSCN = new(linux.ELIBSCN, ".lib section in a.out corrupted")
+ ELIBMAX = new(linux.ELIBMAX, "attempting to link in too many shared libraries")
+ ELIBEXEC = new(linux.ELIBEXEC, "cannot exec a shared library directly")
+ EILSEQ = new(linux.EILSEQ, "illegal byte sequence")
+ ERESTART = new(linux.ERESTART, "interrupted system call should be restarted")
+ ESTRPIPE = new(linux.ESTRPIPE, "streams pipe error")
+ EUSERS = new(linux.EUSERS, "too many users")
+ ENOTSOCK = new(linux.ENOTSOCK, "socket operation on non-socket")
+ EDESTADDRREQ = new(linux.EDESTADDRREQ, "destination address required")
+ EMSGSIZE = new(linux.EMSGSIZE, "message too long")
+ EPROTOTYPE = new(linux.EPROTOTYPE, "protocol wrong type for socket")
+ ENOPROTOOPT = new(linux.ENOPROTOOPT, "protocol not available")
+ EPROTONOSUPPORT = new(linux.EPROTONOSUPPORT, "protocol not supported")
+ ESOCKTNOSUPPORT = new(linux.ESOCKTNOSUPPORT, "socket type not supported")
+ EOPNOTSUPP = new(linux.EOPNOTSUPP, "operation not supported on transport endpoint")
+ EPFNOSUPPORT = new(linux.EPFNOSUPPORT, "protocol family not supported")
+ EAFNOSUPPORT = new(linux.EAFNOSUPPORT, "address family not supported by protocol")
+ EADDRINUSE = new(linux.EADDRINUSE, "address already in use")
+ EADDRNOTAVAIL = new(linux.EADDRNOTAVAIL, "cannot assign requested address")
+ ENETDOWN = new(linux.ENETDOWN, "network is down")
+ ENETUNREACH = new(linux.ENETUNREACH, "network is unreachable")
+ ENETRESET = new(linux.ENETRESET, "network dropped connection because of reset")
+ ECONNABORTED = new(linux.ECONNABORTED, "software caused connection abort")
+ ECONNRESET = new(linux.ECONNRESET, "connection reset by peer")
+ ENOBUFS = new(linux.ENOBUFS, "no buffer space available")
+ EISCONN = new(linux.EISCONN, "transport endpoint is already connected")
+ ENOTCONN = new(linux.ENOTCONN, "transport endpoint is not connected")
+ ESHUTDOWN = new(linux.ESHUTDOWN, "cannot send after transport endpoint shutdown")
+ ETOOMANYREFS = new(linux.ETOOMANYREFS, "too many references: cannot splice")
+ ETIMEDOUT = new(linux.ETIMEDOUT, "connection timed out")
+ ECONNREFUSED = new(linux.ECONNREFUSED, "connection refused")
+ EHOSTDOWN = new(linux.EHOSTDOWN, "host is down")
+ EHOSTUNREACH = new(linux.EHOSTUNREACH, "no route to host")
+ EALREADY = new(linux.EALREADY, "operation already in progress")
+ EINPROGRESS = new(linux.EINPROGRESS, "operation now in progress")
+ ESTALE = new(linux.ESTALE, "stale file handle")
+ EUCLEAN = new(linux.EUCLEAN, "structure needs cleaning")
+ ENOTNAM = new(linux.ENOTNAM, "not a XENIX named type file")
+ ENAVAIL = new(linux.ENAVAIL, "no XENIX semaphores available")
+ EISNAM = new(linux.EISNAM, "is a named type file")
+ EREMOTEIO = new(linux.EREMOTEIO, "remote I/O error")
+ EDQUOT = new(linux.EDQUOT, "quota exceeded")
+ ENOMEDIUM = new(linux.ENOMEDIUM, "no medium found")
+ EMEDIUMTYPE = new(linux.EMEDIUMTYPE, "wrong medium type")
+ ECANCELED = new(linux.ECANCELED, "operation Canceled")
+ ENOKEY = new(linux.ENOKEY, "required key not available")
+ EKEYEXPIRED = new(linux.EKEYEXPIRED, "key has expired")
+ EKEYREVOKED = new(linux.EKEYREVOKED, "key has been revoked")
+ EKEYREJECTED = new(linux.EKEYREJECTED, "key was rejected by service")
+ EOWNERDEAD = new(linux.EOWNERDEAD, "owner died")
+ ENOTRECOVERABLE = new(linux.ENOTRECOVERABLE, "state not recoverable")
+ ERFKILL = new(linux.ERFKILL, "operation not possible due to RF-kill")
+ EHWPOISON = new(linux.EHWPOISON, "memory page has hardware error")
+)
diff --git a/pkg/syserror/syserror_test.go b/pkg/linuxerr/linuxerr_test.go
index c141e5f6e..d34937e93 100644
--- a/pkg/syserror/syserror_test.go
+++ b/pkg/linuxerr/linuxerr_test.go
@@ -19,6 +19,7 @@ import (
"testing"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/linuxerr"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -30,7 +31,13 @@ func BenchmarkAssignErrno(b *testing.B) {
}
}
-func BenchmarkAssignError(b *testing.B) {
+func BenchmarkLinuxerrAssignError(b *testing.B) {
+ for i := b.N; i > 0; i-- {
+ globalError = linuxerr.EINVAL
+ }
+}
+
+func BenchmarkAssignSyserrorError(b *testing.B) {
for i := b.N; i > 0; i-- {
globalError = syserror.EINVAL
}
@@ -46,7 +53,17 @@ func BenchmarkCompareErrno(b *testing.B) {
}
}
-func BenchmarkCompareError(b *testing.B) {
+func BenchmarkCompareLinuxerrError(b *testing.B) {
+ globalError = linuxerr.E2BIG
+ j := 0
+ for i := b.N; i > 0; i-- {
+ if globalError == linuxerr.EINVAL {
+ j++
+ }
+ }
+}
+
+func BenchmarkCompareSyserrorError(b *testing.B) {
globalError = syserror.EAGAIN
j := 0
for i := b.N; i > 0; i-- {
@@ -62,7 +79,7 @@ func BenchmarkSwitchErrno(b *testing.B) {
for i := b.N; i > 0; i-- {
switch globalError {
case unix.EINVAL:
- j += 1
+ j++
case unix.EINTR:
j += 2
case unix.EAGAIN:
@@ -71,13 +88,28 @@ func BenchmarkSwitchErrno(b *testing.B) {
}
}
-func BenchmarkSwitchError(b *testing.B) {
+func BenchmarkSwitchLinuxerrError(b *testing.B) {
+ globalError = linuxerr.EPERM
+ j := 0
+ for i := b.N; i > 0; i-- {
+ switch globalError {
+ case linuxerr.EINVAL:
+ j++
+ case linuxerr.EINTR:
+ j += 2
+ case linuxerr.EAGAIN:
+ j += 3
+ }
+ }
+}
+
+func BenchmarkSwitchSyserrorError(b *testing.B) {
globalError = syserror.EPERM
j := 0
for i := b.N; i > 0; i-- {
switch globalError {
case syserror.EINVAL:
- j += 1
+ j++
case syserror.EINTR:
j += 2
case syserror.EAGAIN:
diff --git a/pkg/marshal/primitive/BUILD b/pkg/marshal/primitive/BUILD
index 190b57c29..6e5ce136d 100644
--- a/pkg/marshal/primitive/BUILD
+++ b/pkg/marshal/primitive/BUILD
@@ -12,9 +12,7 @@ go_library(
"//:sandbox",
],
deps = [
- "//pkg/context",
"//pkg/hostarch",
"//pkg/marshal",
- "//pkg/usermem",
],
)
diff --git a/pkg/marshal/primitive/primitive.go b/pkg/marshal/primitive/primitive.go
index 6f38992b7..1c49cf082 100644
--- a/pkg/marshal/primitive/primitive.go
+++ b/pkg/marshal/primitive/primitive.go
@@ -19,10 +19,8 @@ package primitive
import (
"io"
- "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
- "gvisor.dev/gvisor/pkg/usermem"
)
// Int8 is a marshal.Marshallable implementation for int8.
@@ -400,26 +398,3 @@ func CopyStringOut(cc marshal.CopyContext, addr hostarch.Addr, src string) (int,
srcP := ByteSlice(src)
return srcP.CopyOut(cc, addr)
}
-
-// IOCopyContext wraps an object implementing hostarch.IO to implement
-// marshal.CopyContext.
-type IOCopyContext struct {
- Ctx context.Context
- IO usermem.IO
- Opts usermem.IOOpts
-}
-
-// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer.
-func (i *IOCopyContext) CopyScratchBuffer(size int) []byte {
- return make([]byte, size)
-}
-
-// CopyOutBytes implements marshal.CopyContext.CopyOutBytes.
-func (i *IOCopyContext) CopyOutBytes(addr hostarch.Addr, b []byte) (int, error) {
- return i.IO.CopyOut(i.Ctx, addr, b, i.Opts)
-}
-
-// CopyInBytes implements marshal.CopyContext.CopyInBytes.
-func (i *IOCopyContext) CopyInBytes(addr hostarch.Addr, b []byte) (int, error) {
- return i.IO.CopyIn(i.Ctx, addr, b, i.Opts)
-}
diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go
index e822fe77d..4829ae7ce 100644
--- a/pkg/metric/metric.go
+++ b/pkg/metric/metric.go
@@ -37,9 +37,21 @@ var (
ErrInitializationDone = errors.New("metric cannot be created after initialization is complete")
// WeirdnessMetric is a metric with fields created to track the number
- // of weird occurrences such as time fallback, partial_result and
- // vsyscall count.
- WeirdnessMetric *Uint64Metric
+ // of weird occurrences such as time fallback, partial_result, vsyscall
+ // count, watchdog startup timeouts and stuck tasks.
+ WeirdnessMetric = MustCreateNewUint64Metric("/weirdness", true /* sync */, "Increment for weird occurrences of problems such as time fallback, partial result, vsyscalls invoked in the sandbox, watchdog startup timeouts and stuck tasks.",
+ Field{
+ name: "weirdness_type",
+ allowedValues: []string{"time_fallback", "partial_result", "vsyscall_count", "watchdog_stuck_startup", "watchdog_stuck_tasks"},
+ })
+
+ // SuspiciousOperationsMetric is a metric with fields created to detect
+ // operations such as opening an executable file to write from a gofer.
+ SuspiciousOperationsMetric = MustCreateNewUint64Metric("/suspicious_operations", true /* sync */, "Increment for suspicious operations such as opening an executable file to write from a gofer.",
+ Field{
+ name: "operation_type",
+ allowedValues: []string{"opened_write_execute_file"},
+ })
)
// Uint64Metric encapsulates a uint64 that represents some kind of metric to be
@@ -77,17 +89,21 @@ var (
// Precondition:
// * All metrics are registered.
// * Initialize/Disable has not been called.
-func Initialize() {
+func Initialize() error {
if initialized {
- panic("Initialize/Disable called more than once")
+ return errors.New("metric.Initialize called after metric.Initialize or metric.Disable")
}
- initialized = true
m := pb.MetricRegistration{}
for _, v := range allMetrics.m {
m.Metrics = append(m.Metrics, v.metadata)
}
- eventchannel.Emit(&m)
+ if err := eventchannel.Emit(&m); err != nil {
+ return fmt.Errorf("unable to emit metric initialize event: %w", err)
+ }
+
+ initialized = true
+ return nil
}
// Disable sends an empty metric registration event over the event channel,
@@ -96,16 +112,18 @@ func Initialize() {
// Precondition:
// * All metrics are registered.
// * Initialize/Disable has not been called.
-func Disable() {
+func Disable() error {
if initialized {
- panic("Initialize/Disable called more than once")
+ return errors.New("metric.Disable called after metric.Initialize or metric.Disable")
}
- initialized = true
m := pb.MetricRegistration{}
if err := eventchannel.Emit(&m); err != nil {
- panic("unable to emit metric disable event: " + err.Error())
+ return fmt.Errorf("unable to emit metric disable event: %w", err)
}
+
+ initialized = true
+ return nil
}
type customUint64Metric struct {
@@ -158,8 +176,8 @@ func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.Met
}
// Metrics can exist without fields.
- if len(fields) > 1 {
- panic("Sentry metrics support at most one field")
+ if l := len(fields); l > 1 {
+ return fmt.Errorf("%d fields provided, must be <= 1", l)
}
for _, field := range fields {
@@ -175,7 +193,7 @@ func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.Met
// without fields and panics if it returns an error.
func MustRegisterCustomUint64Metric(name string, cumulative, sync bool, description string, value func(...string) uint64, fields ...Field) {
if err := RegisterCustomUint64Metric(name, cumulative, sync, pb.MetricMetadata_UNITS_NONE, description, value, fields...); err != nil {
- panic(fmt.Sprintf("Unable to register metric %q: %v", name, err))
+ panic(fmt.Sprintf("Unable to register metric %q: %s", name, err))
}
}
@@ -202,7 +220,7 @@ func NewUint64Metric(name string, sync bool, units pb.MetricMetadata_Units, desc
func MustCreateNewUint64Metric(name string, sync bool, description string, fields ...Field) *Uint64Metric {
m, err := NewUint64Metric(name, sync, pb.MetricMetadata_UNITS_NONE, description, fields...)
if err != nil {
- panic(fmt.Sprintf("Unable to create metric %q: %v", name, err))
+ panic(fmt.Sprintf("Unable to create metric %q: %s", name, err))
}
return m
}
@@ -212,7 +230,7 @@ func MustCreateNewUint64Metric(name string, sync bool, description string, field
func MustCreateNewUint64NanosecondsMetric(name string, sync bool, description string) *Uint64Metric {
m, err := NewUint64Metric(name, sync, pb.MetricMetadata_UNITS_NANOSECONDS, description)
if err != nil {
- panic(fmt.Sprintf("Unable to create metric %q: %v", name, err))
+ panic(fmt.Sprintf("Unable to create metric %q: %s", name, err))
}
return m
}
@@ -347,7 +365,7 @@ func EmitMetricUpdate() {
m.Metrics = append(m.Metrics, &pb.MetricValue{
Name: k,
- Value: &pb.MetricValue_Uint64Value{t},
+ Value: &pb.MetricValue_Uint64Value{Uint64Value: t},
})
case map[string]uint64:
for fieldValue, metricValue := range t {
@@ -362,7 +380,7 @@ func EmitMetricUpdate() {
m.Metrics = append(m.Metrics, &pb.MetricValue{
Name: k,
FieldValues: []string{fieldValue},
- Value: &pb.MetricValue_Uint64Value{metricValue},
+ Value: &pb.MetricValue_Uint64Value{Uint64Value: metricValue},
})
}
}
@@ -383,18 +401,7 @@ func EmitMetricUpdate() {
}
}
- eventchannel.Emit(&m)
-}
-
-// CreateSentryMetrics creates the sentry metrics during kernel initialization.
-func CreateSentryMetrics() {
- if WeirdnessMetric != nil {
- return
+ if err := eventchannel.Emit(&m); err != nil {
+ log.Warningf("Unable to emit metrics: %s", err)
}
-
- WeirdnessMetric = MustCreateNewUint64Metric("/weirdness", true /* sync */, "Increment for weird occurrences of problems such as time fallback, partial result and vsyscalls invoked in the sandbox",
- Field{
- name: "weirdness_type",
- allowedValues: []string{"time_fallback", "partial_result", "vsyscall_count"},
- })
}
diff --git a/pkg/metric/metric_test.go b/pkg/metric/metric_test.go
index c71dfd460..1b4a9e73a 100644
--- a/pkg/metric/metric_test.go
+++ b/pkg/metric/metric_test.go
@@ -48,6 +48,8 @@ func (s *sliceEmitter) Reset() {
var emitter sliceEmitter
func init() {
+ reset()
+
eventchannel.AddEmitter(&emitter)
}
@@ -77,7 +79,9 @@ func TestInitialize(t *testing.T) {
t.Fatalf("NewUint64Metric got err %v want nil", err)
}
- Initialize()
+ if err := Initialize(); err != nil {
+ t.Fatalf("Initialize(): %s", err)
+ }
if len(emitter) != 1 {
t.Fatalf("Initialize emitted %d events want 1", len(emitter))
@@ -149,7 +153,9 @@ func TestDisable(t *testing.T) {
t.Fatalf("NewUint64Metric got err %v want nil", err)
}
- Disable()
+ if err := Disable(); err != nil {
+ t.Fatalf("Disable(): %s", err)
+ }
if len(emitter) != 1 {
t.Fatalf("Initialize emitted %d events want 1", len(emitter))
@@ -178,7 +184,9 @@ func TestEmitMetricUpdate(t *testing.T) {
t.Fatalf("NewUint64Metric got err %v want nil", err)
}
- Initialize()
+ if err := Initialize(); err != nil {
+ t.Fatalf("Initialize(): %s", err)
+ }
// Don't care about the registration metrics.
emitter.Reset()
@@ -270,7 +278,9 @@ func TestEmitMetricUpdateWithFields(t *testing.T) {
t.Fatalf("NewUint64Metric got err %v want nil", err)
}
- Initialize()
+ if err := Initialize(); err != nil {
+ t.Fatalf("Initialize(): %s", err)
+ }
// Don't care about the registration metrics.
emitter.Reset()
diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go
index 4aecb8007..1bbcae045 100644
--- a/pkg/refs/refcounter.go
+++ b/pkg/refs/refcounter.go
@@ -261,8 +261,8 @@ func (l *LeakMode) Get() interface{} {
}
// String implements flag.Value.
-func (l *LeakMode) String() string {
- switch *l {
+func (l LeakMode) String() string {
+ switch l {
case UninitializedLeakChecking:
return "uninitialized"
case NoLeakChecking:
@@ -272,7 +272,7 @@ func (l *LeakMode) String() string {
case LeaksLogTraces:
return "log-traces"
}
- panic(fmt.Sprintf("invalid ref leak mode %d", *l))
+ panic(fmt.Sprintf("invalid ref leak mode %d", l))
}
// leakMode stores the current mode for the reference leak checker.
diff --git a/pkg/ring0/kernel_amd64.go b/pkg/ring0/kernel_amd64.go
index 41dfd0bf9..f63af8b76 100644
--- a/pkg/ring0/kernel_amd64.go
+++ b/pkg/ring0/kernel_amd64.go
@@ -254,6 +254,8 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
return
}
+var sentryXCR0 = xgetbv(0)
+
// start is the CPU entrypoint.
//
// This is called from the Start asm stub (see entry_amd64.go); on return the
@@ -265,24 +267,10 @@ func start(c *CPU) {
WriteGS(kernelAddr(c.kernelEntry))
WriteFS(uintptr(c.registers.Fs_base))
- // Initialize floating point.
- //
- // Note that on skylake, the valid XCR0 mask reported seems to be 0xff.
- // This breaks down as:
- //
- // bit0 - x87
- // bit1 - SSE
- // bit2 - AVX
- // bit3-4 - MPX
- // bit5-7 - AVX512
- //
- // For some reason, enabled MPX & AVX512 on platforms that report them
- // seems to be cause a general protection fault. (Maybe there are some
- // virtualization issues and these aren't exported to the guest cpuid.)
- // This needs further investigation, but we can limit the floating
- // point operations to x87, SSE & AVX for now.
fninit()
- xsetbv(0, validXCR0Mask&0x7)
+ // Need to sync XCR0 with the host, because xsave and xrstor can be
+ // called from different contexts.
+ xsetbv(0, sentryXCR0)
// Set the syscall target.
wrmsr(_MSR_LSTAR, kernelFunc(sysenter))
diff --git a/pkg/safecopy/BUILD b/pkg/safecopy/BUILD
index b77c40279..db5787302 100644
--- a/pkg/safecopy/BUILD
+++ b/pkg/safecopy/BUILD
@@ -18,6 +18,7 @@ go_library(
],
visibility = ["//:sandbox"],
deps = [
+ "//pkg/abi/linux",
"//pkg/syserror",
"@org_golang_x_sys//unix:go_default_library",
],
diff --git a/pkg/safecopy/safecopy_unsafe.go b/pkg/safecopy/safecopy_unsafe.go
index efbc2ddc1..2365b2c0d 100644
--- a/pkg/safecopy/safecopy_unsafe.go
+++ b/pkg/safecopy/safecopy_unsafe.go
@@ -20,6 +20,7 @@ import (
"unsafe"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
)
// maxRegisterSize is the maximum register size used in memcpy and memclr. It
@@ -342,12 +343,7 @@ func errorFromFaultSignal(addr uintptr, sig int32) error {
// handler however, and if this is function is being used externally then the
// same courtesy is expected.
func ReplaceSignalHandler(sig unix.Signal, handler uintptr, previous *uintptr) error {
- var sa struct {
- handler uintptr
- flags uint64
- restorer uintptr
- mask uint64
- }
+ var sa linux.SigAction
const maskLen = 8
// Get the existing signal handler information, and save the current
@@ -358,14 +354,14 @@ func ReplaceSignalHandler(sig unix.Signal, handler uintptr, previous *uintptr) e
}
// Fail if there isn't a previous handler.
- if sa.handler == 0 {
+ if sa.Handler == 0 {
return fmt.Errorf("previous handler for signal %x isn't set", sig)
}
- *previous = sa.handler
+ *previous = uintptr(sa.Handler)
// Install our own handler.
- sa.handler = handler
+ sa.Handler = uint64(handler)
if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, maskLen, 0, 0); e != 0 {
return e
}
diff --git a/pkg/seccomp/seccomp.go b/pkg/seccomp/seccomp.go
index daea51c4d..8ffa1db37 100644
--- a/pkg/seccomp/seccomp.go
+++ b/pkg/seccomp/seccomp.go
@@ -36,14 +36,10 @@ const (
// Install generates BPF code based on the set of syscalls provided. It only
// allows syscalls that conform to the specification. Syscalls that violate the
-// specification will trigger RET_KILL_PROCESS, except for the cases below.
-//
-// RET_TRAP is used in violations, instead of RET_KILL_PROCESS, in the
-// following cases:
-// 1. Kernel doesn't support RET_KILL_PROCESS: RET_KILL_THREAD only kills the
-// offending thread and often keeps the sentry hanging.
-// 2. Debug: RET_TRAP generates a panic followed by a stack trace which is
-// much easier to debug then RET_KILL_PROCESS which can't be caught.
+// specification will trigger RET_KILL_PROCESS. If RET_KILL_PROCESS is not
+// supported, violations will trigger RET_TRAP instead. RET_KILL_THREAD is not
+// used because it only kills the offending thread and often keeps the sentry
+// hanging.
//
// Be aware that RET_TRAP sends SIGSYS to the process and it may be ignored,
// making it possible for the process to continue running after a violation.
diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD
index c9c52530d..61dacd2fb 100644
--- a/pkg/sentry/arch/BUILD
+++ b/pkg/sentry/arch/BUILD
@@ -14,12 +14,8 @@ go_library(
"arch_x86.go",
"arch_x86_impl.go",
"auxv.go",
- "signal.go",
- "signal_act.go",
"signal_amd64.go",
"signal_arm64.go",
- "signal_info.go",
- "signal_stack.go",
"stack.go",
"stack_unsafe.go",
"syscalls_amd64.go",
diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go
index 290863ee6..c9393b091 100644
--- a/pkg/sentry/arch/arch.go
+++ b/pkg/sentry/arch/arch.go
@@ -134,21 +134,13 @@ type Context interface {
// RegisterMap returns a map of all registers.
RegisterMap() (map[string]uintptr, error)
- // NewSignalAct returns a new object that is equivalent to struct sigaction
- // in the guest architecture.
- NewSignalAct() NativeSignalAct
-
- // NewSignalStack returns a new object that is equivalent to stack_t in the
- // guest architecture.
- NewSignalStack() NativeSignalStack
-
// SignalSetup modifies the context in preparation for handling the
// given signal.
//
// st is the stack where the signal handler frame should be
// constructed.
//
- // act is the SignalAct that specifies how this signal is being
+ // act is the SigAction that specifies how this signal is being
// handled.
//
// info is the SignalInfo of the signal being delivered.
@@ -157,7 +149,7 @@ type Context interface {
// stack is not going to be used).
//
// sigset is the signal mask before entering the signal handler.
- SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt *SignalStack, sigset linux.SignalSet) error
+ SignalSetup(st *Stack, act *linux.SigAction, info *linux.SignalInfo, alt *linux.SignalStack, sigset linux.SignalSet) error
// SignalRestore restores context after returning from a signal
// handler.
@@ -167,7 +159,7 @@ type Context interface {
// rt is true if SignalRestore is being entered from rt_sigreturn and
// false if SignalRestore is being entered from sigreturn.
// SignalRestore returns the thread's new signal mask.
- SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error)
+ SignalRestore(st *Stack, rt bool) (linux.SignalSet, linux.SignalStack, error)
// CPUIDEmulate emulates a CPUID instruction according to current register state.
CPUIDEmulate(l log.Logger)
diff --git a/pkg/sentry/arch/signal.go b/pkg/sentry/arch/signal.go
deleted file mode 100644
index 67d7edf68..000000000
--- a/pkg/sentry/arch/signal.go
+++ /dev/null
@@ -1,276 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package arch
-
-import (
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/hostarch"
-)
-
-// SignalAct represents the action that should be taken when a signal is
-// delivered, and is equivalent to struct sigaction.
-//
-// +marshal
-// +stateify savable
-type SignalAct struct {
- Handler uint64
- Flags uint64
- Restorer uint64 // Only used on amd64.
- Mask linux.SignalSet
-}
-
-// SerializeFrom implements NativeSignalAct.SerializeFrom.
-func (s *SignalAct) SerializeFrom(other *SignalAct) {
- *s = *other
-}
-
-// DeserializeTo implements NativeSignalAct.DeserializeTo.
-func (s *SignalAct) DeserializeTo(other *SignalAct) {
- *other = *s
-}
-
-// SignalStack represents information about a user stack, and is equivalent to
-// stack_t.
-//
-// +marshal
-// +stateify savable
-type SignalStack struct {
- Addr uint64
- Flags uint32
- _ uint32
- Size uint64
-}
-
-// SerializeFrom implements NativeSignalStack.SerializeFrom.
-func (s *SignalStack) SerializeFrom(other *SignalStack) {
- *s = *other
-}
-
-// DeserializeTo implements NativeSignalStack.DeserializeTo.
-func (s *SignalStack) DeserializeTo(other *SignalStack) {
- *other = *s
-}
-
-// SignalInfo represents information about a signal being delivered, and is
-// equivalent to struct siginfo in linux kernel(linux/include/uapi/asm-generic/siginfo.h).
-//
-// +marshal
-// +stateify savable
-type SignalInfo struct {
- Signo int32 // Signal number
- Errno int32 // Errno value
- Code int32 // Signal code
- _ uint32
-
- // struct siginfo::_sifields is a union. In SignalInfo, fields in the union
- // are accessed through methods.
- //
- // For reference, here is the definition of _sifields: (_sigfault._trapno,
- // which does not exist on x86, omitted for clarity)
- //
- // union {
- // int _pad[SI_PAD_SIZE];
- //
- // /* kill() */
- // struct {
- // __kernel_pid_t _pid; /* sender's pid */
- // __ARCH_SI_UID_T _uid; /* sender's uid */
- // } _kill;
- //
- // /* POSIX.1b timers */
- // struct {
- // __kernel_timer_t _tid; /* timer id */
- // int _overrun; /* overrun count */
- // char _pad[sizeof( __ARCH_SI_UID_T) - sizeof(int)];
- // sigval_t _sigval; /* same as below */
- // int _sys_private; /* not to be passed to user */
- // } _timer;
- //
- // /* POSIX.1b signals */
- // struct {
- // __kernel_pid_t _pid; /* sender's pid */
- // __ARCH_SI_UID_T _uid; /* sender's uid */
- // sigval_t _sigval;
- // } _rt;
- //
- // /* SIGCHLD */
- // struct {
- // __kernel_pid_t _pid; /* which child */
- // __ARCH_SI_UID_T _uid; /* sender's uid */
- // int _status; /* exit code */
- // __ARCH_SI_CLOCK_T _utime;
- // __ARCH_SI_CLOCK_T _stime;
- // } _sigchld;
- //
- // /* SIGILL, SIGFPE, SIGSEGV, SIGBUS */
- // struct {
- // void *_addr; /* faulting insn/memory ref. */
- // short _addr_lsb; /* LSB of the reported address */
- // } _sigfault;
- //
- // /* SIGPOLL */
- // struct {
- // __ARCH_SI_BAND_T _band; /* POLL_IN, POLL_OUT, POLL_MSG */
- // int _fd;
- // } _sigpoll;
- //
- // /* SIGSYS */
- // struct {
- // void *_call_addr; /* calling user insn */
- // int _syscall; /* triggering system call number */
- // unsigned int _arch; /* AUDIT_ARCH_* of syscall */
- // } _sigsys;
- // } _sifields;
- //
- // _sifields is padded so that the size of siginfo is SI_MAX_SIZE = 128
- // bytes.
- Fields [128 - 16]byte
-}
-
-// FixSignalCodeForUser fixes up si_code.
-//
-// The si_code we get from Linux may contain the kernel-specific code in the
-// top 16 bits if it's positive (e.g., from ptrace). Linux's
-// copy_siginfo_to_user does
-// err |= __put_user((short)from->si_code, &to->si_code);
-// to mask out those bits and we need to do the same.
-func (s *SignalInfo) FixSignalCodeForUser() {
- if s.Code > 0 {
- s.Code &= 0x0000ffff
- }
-}
-
-// PID returns the si_pid field.
-func (s *SignalInfo) PID() int32 {
- return int32(hostarch.ByteOrder.Uint32(s.Fields[0:4]))
-}
-
-// SetPID mutates the si_pid field.
-func (s *SignalInfo) SetPID(val int32) {
- hostarch.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
-}
-
-// UID returns the si_uid field.
-func (s *SignalInfo) UID() int32 {
- return int32(hostarch.ByteOrder.Uint32(s.Fields[4:8]))
-}
-
-// SetUID mutates the si_uid field.
-func (s *SignalInfo) SetUID(val int32) {
- hostarch.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
-}
-
-// Sigval returns the sigval field, which is aliased to both si_int and si_ptr.
-func (s *SignalInfo) Sigval() uint64 {
- return hostarch.ByteOrder.Uint64(s.Fields[8:16])
-}
-
-// SetSigval mutates the sigval field.
-func (s *SignalInfo) SetSigval(val uint64) {
- hostarch.ByteOrder.PutUint64(s.Fields[8:16], val)
-}
-
-// TimerID returns the si_timerid field.
-func (s *SignalInfo) TimerID() linux.TimerID {
- return linux.TimerID(hostarch.ByteOrder.Uint32(s.Fields[0:4]))
-}
-
-// SetTimerID sets the si_timerid field.
-func (s *SignalInfo) SetTimerID(val linux.TimerID) {
- hostarch.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
-}
-
-// Overrun returns the si_overrun field.
-func (s *SignalInfo) Overrun() int32 {
- return int32(hostarch.ByteOrder.Uint32(s.Fields[4:8]))
-}
-
-// SetOverrun sets the si_overrun field.
-func (s *SignalInfo) SetOverrun(val int32) {
- hostarch.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
-}
-
-// Addr returns the si_addr field.
-func (s *SignalInfo) Addr() uint64 {
- return hostarch.ByteOrder.Uint64(s.Fields[0:8])
-}
-
-// SetAddr sets the si_addr field.
-func (s *SignalInfo) SetAddr(val uint64) {
- hostarch.ByteOrder.PutUint64(s.Fields[0:8], val)
-}
-
-// Status returns the si_status field.
-func (s *SignalInfo) Status() int32 {
- return int32(hostarch.ByteOrder.Uint32(s.Fields[8:12]))
-}
-
-// SetStatus mutates the si_status field.
-func (s *SignalInfo) SetStatus(val int32) {
- hostarch.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
-}
-
-// CallAddr returns the si_call_addr field.
-func (s *SignalInfo) CallAddr() uint64 {
- return hostarch.ByteOrder.Uint64(s.Fields[0:8])
-}
-
-// SetCallAddr mutates the si_call_addr field.
-func (s *SignalInfo) SetCallAddr(val uint64) {
- hostarch.ByteOrder.PutUint64(s.Fields[0:8], val)
-}
-
-// Syscall returns the si_syscall field.
-func (s *SignalInfo) Syscall() int32 {
- return int32(hostarch.ByteOrder.Uint32(s.Fields[8:12]))
-}
-
-// SetSyscall mutates the si_syscall field.
-func (s *SignalInfo) SetSyscall(val int32) {
- hostarch.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
-}
-
-// Arch returns the si_arch field.
-func (s *SignalInfo) Arch() uint32 {
- return hostarch.ByteOrder.Uint32(s.Fields[12:16])
-}
-
-// SetArch mutates the si_arch field.
-func (s *SignalInfo) SetArch(val uint32) {
- hostarch.ByteOrder.PutUint32(s.Fields[12:16], val)
-}
-
-// Band returns the si_band field.
-func (s *SignalInfo) Band() int64 {
- return int64(hostarch.ByteOrder.Uint64(s.Fields[0:8]))
-}
-
-// SetBand mutates the si_band field.
-func (s *SignalInfo) SetBand(val int64) {
- // Note: this assumes the platform uses `long` as `__ARCH_SI_BAND_T`.
- // On some platforms, which gVisor doesn't support, `__ARCH_SI_BAND_T` is
- // `int`. See siginfo.h.
- hostarch.ByteOrder.PutUint64(s.Fields[0:8], uint64(val))
-}
-
-// FD returns the si_fd field.
-func (s *SignalInfo) FD() uint32 {
- return hostarch.ByteOrder.Uint32(s.Fields[8:12])
-}
-
-// SetFD mutates the si_fd field.
-func (s *SignalInfo) SetFD(val uint32) {
- hostarch.ByteOrder.PutUint32(s.Fields[8:12], val)
-}
diff --git a/pkg/sentry/arch/signal_act.go b/pkg/sentry/arch/signal_act.go
deleted file mode 100644
index d3e2324a8..000000000
--- a/pkg/sentry/arch/signal_act.go
+++ /dev/null
@@ -1,83 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package arch
-
-import "gvisor.dev/gvisor/pkg/marshal"
-
-// Special values for SignalAct.Handler.
-const (
- // SignalActDefault is SIG_DFL and specifies that the default behavior for
- // a signal should be taken.
- SignalActDefault = 0
-
- // SignalActIgnore is SIG_IGN and specifies that a signal should be
- // ignored.
- SignalActIgnore = 1
-)
-
-// Available signal flags.
-const (
- SignalFlagNoCldStop = 0x00000001
- SignalFlagNoCldWait = 0x00000002
- SignalFlagSigInfo = 0x00000004
- SignalFlagRestorer = 0x04000000
- SignalFlagOnStack = 0x08000000
- SignalFlagRestart = 0x10000000
- SignalFlagInterrupt = 0x20000000
- SignalFlagNoDefer = 0x40000000
- SignalFlagResetHandler = 0x80000000
-)
-
-// IsSigInfo returns true iff this handle expects siginfo.
-func (s SignalAct) IsSigInfo() bool {
- return s.Flags&SignalFlagSigInfo != 0
-}
-
-// IsNoDefer returns true iff this SignalAct has the NoDefer flag set.
-func (s SignalAct) IsNoDefer() bool {
- return s.Flags&SignalFlagNoDefer != 0
-}
-
-// IsRestart returns true iff this SignalAct has the Restart flag set.
-func (s SignalAct) IsRestart() bool {
- return s.Flags&SignalFlagRestart != 0
-}
-
-// IsResetHandler returns true iff this SignalAct has the ResetHandler flag set.
-func (s SignalAct) IsResetHandler() bool {
- return s.Flags&SignalFlagResetHandler != 0
-}
-
-// IsOnStack returns true iff this SignalAct has the OnStack flag set.
-func (s SignalAct) IsOnStack() bool {
- return s.Flags&SignalFlagOnStack != 0
-}
-
-// HasRestorer returns true iff this SignalAct has the Restorer flag set.
-func (s SignalAct) HasRestorer() bool {
- return s.Flags&SignalFlagRestorer != 0
-}
-
-// NativeSignalAct is a type that is equivalent to struct sigaction in the
-// guest architecture.
-type NativeSignalAct interface {
- marshal.Marshallable
-
- // SerializeFrom copies the data in the host SignalAct s into this object.
- SerializeFrom(s *SignalAct)
-
- // DeserializeTo copies the data in this object into the host SignalAct s.
- DeserializeTo(s *SignalAct)
-}
diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go
index 082ed92b1..58e28dbba 100644
--- a/pkg/sentry/arch/signal_amd64.go
+++ b/pkg/sentry/arch/signal_amd64.go
@@ -76,21 +76,11 @@ const (
type UContext64 struct {
Flags uint64
Link uint64
- Stack SignalStack
+ Stack linux.SignalStack
MContext SignalContext64
Sigset linux.SignalSet
}
-// NewSignalAct implements Context.NewSignalAct.
-func (c *context64) NewSignalAct() NativeSignalAct {
- return &SignalAct{}
-}
-
-// NewSignalStack implements Context.NewSignalStack.
-func (c *context64) NewSignalStack() NativeSignalStack {
- return &SignalStack{}
-}
-
// From Linux 'arch/x86/include/uapi/asm/sigcontext.h' the following is the
// size of the magic cookie at the end of the xsave frame.
//
@@ -110,7 +100,7 @@ func (c *context64) fpuFrameSize() (size int, useXsave bool) {
// SignalSetup implements Context.SignalSetup. (Compare to Linux's
// arch/x86/kernel/signal.c:__setup_rt_frame().)
-func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt *SignalStack, sigset linux.SignalSet) error {
+func (c *context64) SignalSetup(st *Stack, act *linux.SigAction, info *linux.SignalInfo, alt *linux.SignalStack, sigset linux.SignalSet) error {
sp := st.Bottom
// "The 128-byte area beyond the location pointed to by %rsp is considered
@@ -187,7 +177,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
// Prior to proceeding, figure out if the frame will exhaust the range
// for the signal stack. This is not allowed, and should immediately
// force signal delivery (reverting to the default handler).
- if act.IsOnStack() && alt.IsEnabled() && !alt.Contains(frameBottom) {
+ if act.Flags&linux.SA_ONSTACK != 0 && alt.IsEnabled() && !alt.Contains(frameBottom) {
return unix.EFAULT
}
@@ -203,7 +193,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
return err
}
ucAddr := st.Bottom
- if act.HasRestorer() {
+ if act.Flags&linux.SA_RESTORER != 0 {
// Push the restorer return address.
// Note that this doesn't need to be popped.
if _, err := primitive.CopyUint64Out(st, StackBottomMagic, act.Restorer); err != nil {
@@ -237,15 +227,15 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
// SignalRestore implements Context.SignalRestore. (Compare to Linux's
// arch/x86/kernel/signal.c:sys_rt_sigreturn().)
-func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) {
+func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, linux.SignalStack, error) {
// Copy out the stack frame.
var uc UContext64
if _, err := uc.CopyIn(st, StackBottomMagic); err != nil {
- return 0, SignalStack{}, err
+ return 0, linux.SignalStack{}, err
}
- var info SignalInfo
+ var info linux.SignalInfo
if _, err := info.CopyIn(st, StackBottomMagic); err != nil {
- return 0, SignalStack{}, err
+ return 0, linux.SignalStack{}, err
}
// Restore registers.
diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go
index da71fb873..80df90076 100644
--- a/pkg/sentry/arch/signal_arm64.go
+++ b/pkg/sentry/arch/signal_arm64.go
@@ -61,7 +61,7 @@ type FpsimdContext struct {
type UContext64 struct {
Flags uint64
Link uint64
- Stack SignalStack
+ Stack linux.SignalStack
Sigset linux.SignalSet
// glibc uses a 1024-bit sigset_t
_pad [120]byte // (1024 - 64) / 8 = 120
@@ -71,18 +71,8 @@ type UContext64 struct {
MContext SignalContext64
}
-// NewSignalAct implements Context.NewSignalAct.
-func (c *context64) NewSignalAct() NativeSignalAct {
- return &SignalAct{}
-}
-
-// NewSignalStack implements Context.NewSignalStack.
-func (c *context64) NewSignalStack() NativeSignalStack {
- return &SignalStack{}
-}
-
// SignalSetup implements Context.SignalSetup.
-func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt *SignalStack, sigset linux.SignalSet) error {
+func (c *context64) SignalSetup(st *Stack, act *linux.SigAction, info *linux.SignalInfo, alt *linux.SignalStack, sigset linux.SignalSet) error {
sp := st.Bottom
// Construct the UContext64 now since we need its size.
@@ -114,7 +104,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
// Prior to proceeding, figure out if the frame will exhaust the range
// for the signal stack. This is not allowed, and should immediately
// force signal delivery (reverting to the default handler).
- if act.IsOnStack() && alt.IsEnabled() && !alt.Contains(frameBottom) {
+ if act.Flags&linux.SA_ONSTACK != 0 && alt.IsEnabled() && !alt.Contains(frameBottom) {
return unix.EFAULT
}
@@ -137,7 +127,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
c.Regs.Regs[0] = uint64(info.Signo)
c.Regs.Regs[1] = uint64(infoAddr)
c.Regs.Regs[2] = uint64(ucAddr)
- c.Regs.Regs[30] = uint64(act.Restorer)
+ c.Regs.Regs[30] = act.Restorer
// Save the thread's floating point state.
c.sigFPState = append(c.sigFPState, c.fpState)
@@ -147,15 +137,15 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
}
// SignalRestore implements Context.SignalRestore.
-func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) {
+func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, linux.SignalStack, error) {
// Copy out the stack frame.
var uc UContext64
if _, err := uc.CopyIn(st, StackBottomMagic); err != nil {
- return 0, SignalStack{}, err
+ return 0, linux.SignalStack{}, err
}
- var info SignalInfo
+ var info linux.SignalInfo
if _, err := info.CopyIn(st, StackBottomMagic); err != nil {
- return 0, SignalStack{}, err
+ return 0, linux.SignalStack{}, err
}
// Restore registers.
diff --git a/pkg/sentry/arch/signal_info.go b/pkg/sentry/arch/signal_info.go
deleted file mode 100644
index f93ee8b46..000000000
--- a/pkg/sentry/arch/signal_info.go
+++ /dev/null
@@ -1,66 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package arch
-
-// Possible values for SignalInfo.Code. These values originate from the Linux
-// kernel's include/uapi/asm-generic/siginfo.h.
-const (
- // SignalInfoUser (properly SI_USER) indicates that a signal was sent from
- // a kill() or raise() syscall.
- SignalInfoUser = 0
-
- // SignalInfoKernel (properly SI_KERNEL) indicates that the signal was sent
- // by the kernel.
- SignalInfoKernel = 0x80
-
- // SignalInfoTimer (properly SI_TIMER) indicates that the signal was sent
- // by an expired timer.
- SignalInfoTimer = -2
-
- // SignalInfoTkill (properly SI_TKILL) indicates that the signal was sent
- // from a tkill() or tgkill() syscall.
- SignalInfoTkill = -6
-
- // CLD_* codes are only meaningful for SIGCHLD.
-
- // CLD_EXITED indicates that a task exited.
- CLD_EXITED = 1
-
- // CLD_KILLED indicates that a task was killed by a signal.
- CLD_KILLED = 2
-
- // CLD_DUMPED indicates that a task was killed by a signal and then dumped
- // core.
- CLD_DUMPED = 3
-
- // CLD_TRAPPED indicates that a task was stopped by ptrace.
- CLD_TRAPPED = 4
-
- // CLD_STOPPED indicates that a thread group completed a group stop.
- CLD_STOPPED = 5
-
- // CLD_CONTINUED indicates that a group-stopped thread group was continued.
- CLD_CONTINUED = 6
-
- // SYS_* codes are only meaningful for SIGSYS.
-
- // SYS_SECCOMP indicates that a signal originates from seccomp.
- SYS_SECCOMP = 1
-
- // TRAP_* codes are only meaningful for SIGTRAP.
-
- // TRAP_BRKPT indicates a breakpoint trap.
- TRAP_BRKPT = 1
-)
diff --git a/pkg/sentry/arch/signal_stack.go b/pkg/sentry/arch/signal_stack.go
deleted file mode 100644
index c732c7503..000000000
--- a/pkg/sentry/arch/signal_stack.go
+++ /dev/null
@@ -1,68 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// +build 386 amd64 arm64
-
-package arch
-
-import (
- "gvisor.dev/gvisor/pkg/hostarch"
- "gvisor.dev/gvisor/pkg/marshal"
-)
-
-const (
- // SignalStackFlagOnStack is possible set on return from getaltstack,
- // in order to indicate that the thread is currently on the alt stack.
- SignalStackFlagOnStack = 1
-
- // SignalStackFlagDisable is a flag to indicate the stack is disabled.
- SignalStackFlagDisable = 2
-)
-
-// IsEnabled returns true iff this signal stack is marked as enabled.
-func (s SignalStack) IsEnabled() bool {
- return s.Flags&SignalStackFlagDisable == 0
-}
-
-// Top returns the stack's top address.
-func (s SignalStack) Top() hostarch.Addr {
- return hostarch.Addr(s.Addr + s.Size)
-}
-
-// SetOnStack marks this signal stack as in use.
-//
-// Note that there is no corresponding ClearOnStack, and that this should only
-// be called on copies that are serialized to userspace.
-func (s *SignalStack) SetOnStack() {
- s.Flags |= SignalStackFlagOnStack
-}
-
-// Contains checks if the stack pointer is within this stack.
-func (s *SignalStack) Contains(sp hostarch.Addr) bool {
- return hostarch.Addr(s.Addr) < sp && sp <= hostarch.Addr(s.Addr+s.Size)
-}
-
-// NativeSignalStack is a type that is equivalent to stack_t in the guest
-// architecture.
-type NativeSignalStack interface {
- marshal.Marshallable
-
- // SerializeFrom copies the data in the host SignalStack s into this
- // object.
- SerializeFrom(s *SignalStack)
-
- // DeserializeTo copies the data in this object into the host SignalStack
- // s.
- DeserializeTo(s *SignalStack)
-}
diff --git a/pkg/sentry/arch/stack.go b/pkg/sentry/arch/stack.go
index 65a794c7c..85e3515af 100644
--- a/pkg/sentry/arch/stack.go
+++ b/pkg/sentry/arch/stack.go
@@ -45,7 +45,7 @@ type Stack struct {
}
// scratchBufLen is the default length of Stack.scratchBuf. The
-// largest structs the stack regularly serializes are arch.SignalInfo
+// largest structs the stack regularly serializes are linux.SignalInfo
// and arch.UContext64. We'll set the default size as the larger of
// the two, arch.UContext64.
var scratchBufLen = (*UContext64)(nil).SizeBytes()
diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go
index 367849e75..221e98a01 100644
--- a/pkg/sentry/control/proc.go
+++ b/pkg/sentry/control/proc.go
@@ -99,6 +99,9 @@ type ExecArgs struct {
// PIDNamespace is the pid namespace for the process being executed.
PIDNamespace *kernel.PIDNamespace
+
+ // Limits is the limit set for the process being executed.
+ Limits *limits.LimitSet
}
// String prints the arguments as a string.
@@ -151,6 +154,10 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI
if pidns == nil {
pidns = proc.Kernel.RootPIDNamespace()
}
+ limitSet := args.Limits
+ if limitSet == nil {
+ limitSet = limits.NewLimitSet()
+ }
initArgs := kernel.CreateProcessArgs{
Filename: args.Filename,
Argv: args.Argv,
@@ -161,7 +168,7 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI
Credentials: creds,
FDTable: fdTable,
Umask: 0022,
- Limits: limits.NewLimitSet(),
+ Limits: limitSet,
MaxSymlinkTraversals: linux.MaxSymlinkTraversals,
UTSNamespace: proc.Kernel.RootUTSNamespace(),
IPCNamespace: proc.Kernel.RootIPCNamespace(),
diff --git a/pkg/sentry/fs/attr.go b/pkg/sentry/fs/attr.go
index b90f7c1be..4c99944e7 100644
--- a/pkg/sentry/fs/attr.go
+++ b/pkg/sentry/fs/attr.go
@@ -478,6 +478,20 @@ func (f FilePermissions) AnyRead() bool {
return f.User.Read || f.Group.Read || f.Other.Read
}
+// HasSetUIDOrGID returns true if either the setuid or setgid bit is set.
+func (f FilePermissions) HasSetUIDOrGID() bool {
+ return f.SetUID || f.SetGID
+}
+
+// DropSetUIDAndMaybeGID turns off setuid, and turns off setgid if f allows
+// group execution.
+func (f *FilePermissions) DropSetUIDAndMaybeGID() {
+ f.SetUID = false
+ if f.Group.Execute {
+ f.SetGID = false
+ }
+}
+
// FileOwner represents ownership of a file.
//
// +stateify savable
diff --git a/pkg/sentry/fs/dev/dev.go b/pkg/sentry/fs/dev/dev.go
index e84ba7a5d..c62effd52 100644
--- a/pkg/sentry/fs/dev/dev.go
+++ b/pkg/sentry/fs/dev/dev.go
@@ -16,6 +16,7 @@
package dev
import (
+ "fmt"
"math"
"gvisor.dev/gvisor/pkg/context"
@@ -90,6 +91,11 @@ func newSymlink(ctx context.Context, target string, msrc *fs.MountSource) *fs.In
// New returns the root node of a device filesystem.
func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ shm, err := tmpfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0777), msrc, nil /* parent */)
+ if err != nil {
+ panic(fmt.Sprintf("tmpfs.NewDir failed: %v", err))
+ }
+
contents := map[string]*fs.Inode{
"fd": newSymlink(ctx, "/proc/self/fd", msrc),
"stdin": newSymlink(ctx, "/proc/self/fd/0", msrc),
@@ -108,7 +114,7 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
"random": newMemDevice(ctx, newRandomDevice(ctx, fs.RootOwner, 0444), msrc, randomDevMinor),
"urandom": newMemDevice(ctx, newRandomDevice(ctx, fs.RootOwner, 0444), msrc, urandomDevMinor),
- "shm": tmpfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0777), msrc),
+ "shm": shm,
// A devpts is typically mounted at /dev/pts to provide
// pseudoterminal support. Place an empty directory there for
diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go
index e1e38b498..8ac3738e9 100644
--- a/pkg/sentry/fs/fsutil/host_mappable.go
+++ b/pkg/sentry/fs/fsutil/host_mappable.go
@@ -155,12 +155,20 @@ func (h *HostMappable) DecRef(fr memmap.FileRange) {
// T2: Appends to file causing it to grow
// T2: Writes to mapped pages and COW happens
// T1: Continues and wronly invalidates the page mapped in step above.
-func (h *HostMappable) Truncate(ctx context.Context, newSize int64) error {
+func (h *HostMappable) Truncate(ctx context.Context, newSize int64, uattr fs.UnstableAttr) error {
h.truncateMu.Lock()
defer h.truncateMu.Unlock()
mask := fs.AttrMask{Size: true}
attr := fs.UnstableAttr{Size: newSize}
+
+ // Truncating a file clears privilege bits.
+ if uattr.Perms.HasSetUIDOrGID() {
+ mask.Perms = true
+ attr.Perms = uattr.Perms
+ attr.Perms.DropSetUIDAndMaybeGID()
+ }
+
if err := h.backingFile.SetMaskedAttributes(ctx, mask, attr, false); err != nil {
return err
}
@@ -193,10 +201,17 @@ func (h *HostMappable) Allocate(ctx context.Context, offset int64, length int64)
}
// Write writes to the file backing this mappable.
-func (h *HostMappable) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+func (h *HostMappable) Write(ctx context.Context, src usermem.IOSequence, offset int64, uattr fs.UnstableAttr) (int64, error) {
h.truncateMu.RLock()
+ defer h.truncateMu.RUnlock()
n, err := src.CopyInTo(ctx, &writer{ctx: ctx, hostMappable: h, off: offset})
- h.truncateMu.RUnlock()
+ if n > 0 && uattr.Perms.HasSetUIDOrGID() {
+ mask := fs.AttrMask{Perms: true}
+ uattr.Perms.DropSetUIDAndMaybeGID()
+ if err := h.backingFile.SetMaskedAttributes(ctx, mask, uattr, false); err != nil {
+ return n, err
+ }
+ }
return n, err
}
diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go
index 7856b354b..855029b84 100644
--- a/pkg/sentry/fs/fsutil/inode_cached.go
+++ b/pkg/sentry/fs/fsutil/inode_cached.go
@@ -310,6 +310,12 @@ func (c *CachingInodeOperations) Truncate(ctx context.Context, inode *fs.Inode,
now := ktime.NowFromContext(ctx)
masked := fs.AttrMask{Size: true}
attr := fs.UnstableAttr{Size: size}
+ if c.attr.Perms.HasSetUIDOrGID() {
+ masked.Perms = true
+ attr.Perms = c.attr.Perms
+ attr.Perms.DropSetUIDAndMaybeGID()
+ c.attr.Perms = attr.Perms
+ }
if err := c.backingFile.SetMaskedAttributes(ctx, masked, attr, false); err != nil {
c.dataMu.Unlock()
return err
@@ -685,13 +691,14 @@ func (rw *inodeReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
return done, nil
}
-// maybeGrowFile grows the file's size if data has been written past the old
-// size.
+// maybeUpdateAttrs updates the file's attributes after a write. It updates
+// size if data has been written past the old size, and setuid/setgid if any
+// bytes were written.
//
// Preconditions:
// * rw.c.attrMu must be locked.
// * rw.c.dataMu must be locked.
-func (rw *inodeReadWriter) maybeGrowFile() {
+func (rw *inodeReadWriter) maybeUpdateAttrs(nwritten uint64) {
// If the write ends beyond the file's previous size, it causes the
// file to grow.
if rw.offset > rw.c.attr.Size {
@@ -705,6 +712,12 @@ func (rw *inodeReadWriter) maybeGrowFile() {
rw.c.attr.Usage = rw.offset
rw.c.dirtyAttr.Usage = true
}
+
+ // If bytes were written, ensure setuid and setgid are cleared.
+ if nwritten > 0 && rw.c.attr.Perms.HasSetUIDOrGID() {
+ rw.c.dirtyAttr.Perms = true
+ rw.c.attr.Perms.DropSetUIDAndMaybeGID()
+ }
}
// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
@@ -732,7 +745,7 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error
segMR := seg.Range().Intersect(mr)
ims, err := mf.MapInternal(seg.FileRangeOf(segMR), hostarch.Write)
if err != nil {
- rw.maybeGrowFile()
+ rw.maybeUpdateAttrs(done)
rw.c.dataMu.Unlock()
return done, err
}
@@ -744,7 +757,7 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error
srcs = srcs.DropFirst64(n)
rw.c.dirty.MarkDirty(segMR)
if err != nil {
- rw.maybeGrowFile()
+ rw.maybeUpdateAttrs(done)
rw.c.dataMu.Unlock()
return done, err
}
@@ -765,7 +778,7 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error
srcs = srcs.DropFirst64(n)
// Partial writes are fine. But we must stop writing.
if n != src.NumBytes() || err != nil {
- rw.maybeGrowFile()
+ rw.maybeUpdateAttrs(done)
rw.c.dataMu.Unlock()
return done, err
}
@@ -774,7 +787,7 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error
seg, gap = gap.NextSegment(), FileRangeGapIterator{}
}
}
- rw.maybeGrowFile()
+ rw.maybeUpdateAttrs(done)
rw.c.dataMu.Unlock()
return done, nil
}
diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD
index c4a069832..94cb05246 100644
--- a/pkg/sentry/fs/gofer/BUILD
+++ b/pkg/sentry/fs/gofer/BUILD
@@ -29,6 +29,7 @@ go_library(
"//pkg/fd",
"//pkg/hostarch",
"//pkg/log",
+ "//pkg/metric",
"//pkg/p9",
"//pkg/refs",
"//pkg/safemem",
diff --git a/pkg/sentry/fs/gofer/cache_policy.go b/pkg/sentry/fs/gofer/cache_policy.go
index 07a564e92..f8b7a60fc 100644
--- a/pkg/sentry/fs/gofer/cache_policy.go
+++ b/pkg/sentry/fs/gofer/cache_policy.go
@@ -139,7 +139,7 @@ func (cp cachePolicy) revalidate(ctx context.Context, name string, parent, child
// Walk from parent to child again.
//
- // TODO(b/112031682): If we have a directory FD in the parent
+ // NOTE(b/112031682): If we have a directory FD in the parent
// inodeOperations, then we can use fstatat(2) to get the inode
// attributes instead of making this RPC.
qids, f, mask, attr, err := parentIops.fileState.file.walkGetAttr(ctx, []string{name})
diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go
index 8f5a87120..73d80d9b5 100644
--- a/pkg/sentry/fs/gofer/file.go
+++ b/pkg/sentry/fs/gofer/file.go
@@ -21,6 +21,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -91,7 +92,7 @@ func NewFile(ctx context.Context, dirent *fs.Dirent, name string, flags fs.FileF
}
if flags.Write {
if err := dirent.Inode.CheckPermission(ctx, fs.PermMask{Execute: true}); err == nil {
- fsmetric.GoferOpensWX.Increment()
+ metric.SuspiciousOperationsMetric.Increment("opened_write_execute_file")
log.Warningf("Opened a writable executable: %q", name)
}
}
@@ -236,10 +237,20 @@ func (f *fileOperations) Write(ctx context.Context, file *fs.File, src usermem.I
// and availability of a host-mappable FD.
if f.inodeOperations.session().cachePolicy.useCachingInodeOps(file.Dirent.Inode) {
n, err = f.inodeOperations.cachingInodeOps.Write(ctx, src, offset)
- } else if f.inodeOperations.fileState.hostMappable != nil {
- n, err = f.inodeOperations.fileState.hostMappable.Write(ctx, src, offset)
} else {
- n, err = src.CopyInTo(ctx, f.handles.readWriterAt(ctx, offset))
+ uattr, e := f.UnstableAttr(ctx, file)
+ if e != nil {
+ return 0, e
+ }
+ if f.inodeOperations.fileState.hostMappable != nil {
+ n, err = f.inodeOperations.fileState.hostMappable.Write(ctx, src, offset, uattr)
+ } else {
+ n, err = src.CopyInTo(ctx, f.handles.readWriterAt(ctx, offset))
+ if n > 0 && uattr.Perms.HasSetUIDOrGID() {
+ uattr.Perms.DropSetUIDAndMaybeGID()
+ f.inodeOperations.SetPermissions(ctx, file.Dirent.Inode, uattr.Perms)
+ }
+ }
}
if n == 0 {
diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go
index b97635ec4..da3178527 100644
--- a/pkg/sentry/fs/gofer/inode.go
+++ b/pkg/sentry/fs/gofer/inode.go
@@ -600,11 +600,25 @@ func (i *inodeOperations) Truncate(ctx context.Context, inode *fs.Inode, length
if i.session().cachePolicy.useCachingInodeOps(inode) {
return i.cachingInodeOps.Truncate(ctx, inode, length)
}
+
+ uattr, err := i.fileState.unstableAttr(ctx)
+ if err != nil {
+ return err
+ }
+
if i.session().cachePolicy == cacheRemoteRevalidating {
- return i.fileState.hostMappable.Truncate(ctx, length)
+ return i.fileState.hostMappable.Truncate(ctx, length, uattr)
+ }
+
+ mask := p9.SetAttrMask{Size: true}
+ attr := p9.SetAttr{Size: uint64(length)}
+ if uattr.Perms.HasSetUIDOrGID() {
+ mask.Permissions = true
+ uattr.Perms.DropSetUIDAndMaybeGID()
+ attr.Permissions = p9.FileMode(uattr.Perms.LinuxMode())
}
- return i.fileState.file.setAttr(ctx, p9.SetAttrMask{Size: true}, p9.SetAttr{Size: uint64(length)})
+ return i.fileState.file.setAttr(ctx, mask, attr)
}
// GetXattr implements fs.InodeOperations.GetXattr.
diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go
index 6b3627813..940838a44 100644
--- a/pkg/sentry/fs/gofer/path.go
+++ b/pkg/sentry/fs/gofer/path.go
@@ -130,7 +130,16 @@ func (i *inodeOperations) Create(ctx context.Context, dir *fs.Inode, name string
panic(fmt.Sprintf("Create called with unknown or unset open flags: %v", flags))
}
+ // If the parent directory has setgid enabled, change the new file's owner.
owner := fs.FileOwnerFromContext(ctx)
+ parentUattr, err := dir.UnstableAttr(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if parentUattr.Perms.SetGID {
+ owner.GID = parentUattr.Owner.GID
+ }
+
hostFile, err := newFile.create(ctx, name, openFlags, p9.FileMode(perm.LinuxMode()), p9.UID(owner.UID), p9.GID(owner.GID))
if err != nil {
// Could not create the file.
@@ -225,7 +234,18 @@ func (i *inodeOperations) CreateDirectory(ctx context.Context, dir *fs.Inode, s
return syserror.ENAMETOOLONG
}
+ // If the parent directory has setgid enabled, change the new directory's
+ // owner and enable setgid.
owner := fs.FileOwnerFromContext(ctx)
+ parentUattr, err := dir.UnstableAttr(ctx)
+ if err != nil {
+ return err
+ }
+ if parentUattr.Perms.SetGID {
+ owner.GID = parentUattr.Owner.GID
+ perm.SetGID = true
+ }
+
if _, err := i.fileState.file.mkdir(ctx, s, p9.FileMode(perm.LinuxMode()), p9.UID(owner.UID), p9.GID(owner.GID)); err != nil {
return err
}
diff --git a/pkg/sentry/fs/proc/sys.go b/pkg/sentry/fs/proc/sys.go
index b998fb75d..085aa6d61 100644
--- a/pkg/sentry/fs/proc/sys.go
+++ b/pkg/sentry/fs/proc/sys.go
@@ -77,6 +77,27 @@ func (*overcommitMemory) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandl
}, 0
}
+// +stateify savable
+type maxMapCount struct{}
+
+// NeedsUpdate implements seqfile.SeqSource.
+func (*maxMapCount) NeedsUpdate(int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.
+func (*maxMapCount) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+ return []seqfile.SeqData{
+ {
+ Buf: []byte("2147483647\n"),
+ Handle: (*maxMapCount)(nil),
+ },
+ }, 0
+}
+
func (p *proc) newKernelDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
h := hostname{
SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC),
@@ -96,6 +117,7 @@ func (p *proc) newKernelDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode
func (p *proc) newVMDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
children := map[string]*fs.Inode{
+ "max_map_count": seqfile.NewSeqFileInode(ctx, &maxMapCount{}, msrc),
"mmap_min_addr": seqfile.NewSeqFileInode(ctx, &mmapMinAddrData{p.k}, msrc),
"overcommit_memory": seqfile.NewSeqFileInode(ctx, &overcommitMemory{}, msrc),
}
diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go
index 1d09afdd7..4893af56b 100644
--- a/pkg/sentry/fs/proc/sys_net.go
+++ b/pkg/sentry/fs/proc/sys_net.go
@@ -403,7 +403,7 @@ type ipForwarding struct {
// enabled stores the IPv4 forwarding state on save.
// We must save/restore this here, since a netstack instance
// is created on restore.
- enabled *bool
+ enabled bool
}
func newIPForwardingInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode {
@@ -461,13 +461,8 @@ func (f *ipForwardingFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
return 0, io.EOF
}
- if f.ipf.enabled == nil {
- enabled := f.stack.Forwarding(ipv4.ProtocolNumber)
- f.ipf.enabled = &enabled
- }
-
val := "0\n"
- if *f.ipf.enabled {
+ if f.ipf.enabled {
// Technically, this is not quite compatible with Linux. Linux
// stores these as an integer, so if you write "2" into
// ip_forward, you should get 2 back.
@@ -494,11 +489,8 @@ func (f *ipForwardingFile) Write(ctx context.Context, _ *fs.File, src usermem.IO
if err != nil {
return n, err
}
- if f.ipf.enabled == nil {
- f.ipf.enabled = new(bool)
- }
- *f.ipf.enabled = v != 0
- return n, f.stack.SetForwarding(ipv4.ProtocolNumber, *f.ipf.enabled)
+ f.ipf.enabled = v != 0
+ return n, f.stack.SetForwarding(ipv4.ProtocolNumber, f.ipf.enabled)
}
// portRangeInode implements fs.InodeOperations. It provides and allows
diff --git a/pkg/sentry/fs/proc/sys_net_state.go b/pkg/sentry/fs/proc/sys_net_state.go
index 4cb4741af..51d2be647 100644
--- a/pkg/sentry/fs/proc/sys_net_state.go
+++ b/pkg/sentry/fs/proc/sys_net_state.go
@@ -47,9 +47,7 @@ func (s *tcpSack) afterLoad() {
// afterLoad is invoked by stateify.
func (ipf *ipForwarding) afterLoad() {
- if ipf.enabled != nil {
- if err := ipf.stack.SetForwarding(ipv4.ProtocolNumber, *ipf.enabled); err != nil {
- panic(fmt.Sprintf("failed to set IPv4 forwarding [%v]: %v", *ipf.enabled, err))
- }
+ if err := ipf.stack.SetForwarding(ipv4.ProtocolNumber, ipf.enabled); err != nil {
+ panic(fmt.Sprintf("ipf.stack.SetForwarding(%d, %t): %s", ipv4.ProtocolNumber, ipf.enabled, err))
}
}
diff --git a/pkg/sentry/fs/tmpfs/fs.go b/pkg/sentry/fs/tmpfs/fs.go
index bc117ca6a..b48d475ed 100644
--- a/pkg/sentry/fs/tmpfs/fs.go
+++ b/pkg/sentry/fs/tmpfs/fs.go
@@ -151,5 +151,5 @@ func (f *Filesystem) Mount(ctx context.Context, device string, flags fs.MountSou
}
// Construct the tmpfs root.
- return NewDir(ctx, nil, owner, perms, msrc), nil
+ return NewDir(ctx, nil, owner, perms, msrc, nil /* parent */)
}
diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go
index f4de8c968..7faa822f0 100644
--- a/pkg/sentry/fs/tmpfs/inode_file.go
+++ b/pkg/sentry/fs/tmpfs/inode_file.go
@@ -226,6 +226,12 @@ func (f *fileInodeOperations) Truncate(ctx context.Context, _ *fs.Inode, size in
now := ktime.NowFromContext(ctx)
f.attr.ModificationTime = now
f.attr.StatusChangeTime = now
+
+ // Truncating clears privilege bits.
+ f.attr.Perms.SetUID = false
+ if f.attr.Perms.Group.Execute {
+ f.attr.Perms.SetGID = false
+ }
}
f.dataMu.Unlock()
@@ -363,7 +369,14 @@ func (f *fileInodeOperations) write(ctx context.Context, src usermem.IOSequence,
now := ktime.NowFromContext(ctx)
f.attr.ModificationTime = now
f.attr.StatusChangeTime = now
- return src.CopyInTo(ctx, &fileReadWriter{f, offset})
+ nwritten, err := src.CopyInTo(ctx, &fileReadWriter{f, offset})
+
+ // Writing clears privilege bits.
+ if nwritten > 0 {
+ f.attr.Perms.DropSetUIDAndMaybeGID()
+ }
+
+ return nwritten, err
}
type fileReadWriter struct {
diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go
index 577052888..6aa8ff331 100644
--- a/pkg/sentry/fs/tmpfs/tmpfs.go
+++ b/pkg/sentry/fs/tmpfs/tmpfs.go
@@ -87,7 +87,20 @@ type Dir struct {
var _ fs.InodeOperations = (*Dir)(nil)
// NewDir returns a new directory.
-func NewDir(ctx context.Context, contents map[string]*fs.Inode, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource) *fs.Inode {
+func NewDir(ctx context.Context, contents map[string]*fs.Inode, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource, parent *fs.Inode) (*fs.Inode, error) {
+ // If the parent has setgid enabled, the new directory enables it and changes
+ // its GID.
+ if parent != nil {
+ parentUattr, err := parent.UnstableAttr(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if parentUattr.Perms.SetGID {
+ owner.GID = parentUattr.Owner.GID
+ perms.SetGID = true
+ }
+ }
+
d := &Dir{
ramfsDir: ramfs.NewDir(ctx, contents, owner, perms),
kernel: kernel.KernelFromContext(ctx),
@@ -101,7 +114,7 @@ func NewDir(ctx context.Context, contents map[string]*fs.Inode, owner fs.FileOwn
InodeID: tmpfsDevice.NextIno(),
BlockSize: hostarch.PageSize,
Type: fs.Directory,
- })
+ }), nil
}
// afterLoad is invoked by stateify.
@@ -219,11 +232,21 @@ func (d *Dir) SetTimestamps(ctx context.Context, i *fs.Inode, ts fs.TimeSpec) er
func (d *Dir) newCreateOps() *ramfs.CreateOps {
return &ramfs.CreateOps{
NewDir: func(ctx context.Context, dir *fs.Inode, perms fs.FilePermissions) (*fs.Inode, error) {
- return NewDir(ctx, nil, fs.FileOwnerFromContext(ctx), perms, dir.MountSource), nil
+ return NewDir(ctx, nil, fs.FileOwnerFromContext(ctx), perms, dir.MountSource, dir)
},
NewFile: func(ctx context.Context, dir *fs.Inode, perms fs.FilePermissions) (*fs.Inode, error) {
+ // If the parent has setgid enabled, change the GID of the new file.
+ owner := fs.FileOwnerFromContext(ctx)
+ parentUattr, err := dir.UnstableAttr(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if parentUattr.Perms.SetGID {
+ owner.GID = parentUattr.Owner.GID
+ }
+
uattr := fs.WithCurrentTime(ctx, fs.UnstableAttr{
- Owner: fs.FileOwnerFromContext(ctx),
+ Owner: owner,
Perms: perms,
// Always start unlinked.
Links: 0,
diff --git a/pkg/sentry/fs/user/user_test.go b/pkg/sentry/fs/user/user_test.go
index 12b786224..7f8fa8038 100644
--- a/pkg/sentry/fs/user/user_test.go
+++ b/pkg/sentry/fs/user/user_test.go
@@ -104,7 +104,10 @@ func TestGetExecUserHome(t *testing.T) {
t.Run(name, func(t *testing.T) {
ctx := contexttest.Context(t)
msrc := fs.NewPseudoMountSource(ctx)
- rootInode := tmpfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0777), msrc)
+ rootInode, err := tmpfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0777), msrc, nil /* parent */)
+ if err != nil {
+ t.Fatalf("tmpfs.NewDir failed: %v", err)
+ }
mns, err := fs.NewMountNamespace(ctx, rootInode)
if err != nil {
diff --git a/pkg/sentry/fsimpl/cgroupfs/base.go b/pkg/sentry/fsimpl/cgroupfs/base.go
index 0f54888d8..fe9871bdd 100644
--- a/pkg/sentry/fsimpl/cgroupfs/base.go
+++ b/pkg/sentry/fsimpl/cgroupfs/base.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -68,11 +67,6 @@ func (c *controllerCommon) Enabled() bool {
return true
}
-// Filesystem implements kernel.CgroupController.Filesystem.
-func (c *controllerCommon) Filesystem() *vfs.Filesystem {
- return c.fs.VFSFilesystem()
-}
-
// RootCgroup implements kernel.CgroupController.RootCgroup.
func (c *controllerCommon) RootCgroup() kernel.Cgroup {
return c.fs.rootCgroup()
@@ -139,6 +133,17 @@ func (c *cgroupInode) Controllers() []kernel.CgroupController {
return c.fs.kcontrollers
}
+// tasks returns a snapshot of the tasks inside the cgroup.
+func (c *cgroupInode) tasks() []*kernel.Task {
+ c.fs.tasksMu.RLock()
+ defer c.fs.tasksMu.RUnlock()
+ ts := make([]*kernel.Task, 0, len(c.ts))
+ for t := range c.ts {
+ ts = append(ts, t)
+ }
+ return ts
+}
+
// Enter implements kernel.CgroupImpl.Enter.
func (c *cgroupInode) Enter(t *kernel.Task) {
c.fs.tasksMu.Lock()
@@ -169,10 +174,7 @@ func (d *cgroupProcsData) Generate(ctx context.Context, buf *bytes.Buffer) error
pgids := make(map[kernel.ThreadID]struct{})
- d.fs.tasksMu.RLock()
- defer d.fs.tasksMu.RUnlock()
-
- for task := range d.ts {
+ for _, task := range d.tasks() {
// Map dedups pgid, since iterating over all tasks produces multiple
// entries for the group leaders.
if pgid := currPidns.IDOfThreadGroup(task.ThreadGroup()); pgid != 0 {
@@ -211,10 +213,7 @@ func (d *tasksData) Generate(ctx context.Context, buf *bytes.Buffer) error {
var pids []kernel.ThreadID
- d.fs.tasksMu.RLock()
- defer d.fs.tasksMu.RUnlock()
-
- for task := range d.ts {
+ for _, task := range d.tasks() {
if pid := currPidns.IDOfTask(task); pid != 0 {
pids = append(pids, pid)
}
diff --git a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go
index bd3e69757..05d7eb4ce 100644
--- a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go
+++ b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go
@@ -49,8 +49,9 @@
//
// kernel.CgroupRegistry.mu
// cgroupfs.filesystem.mu
-// Task.mu
-// cgroupfs.filesystem.tasksMu.
+// kernel.TaskSet.mu
+// kernel.Task.mu
+// cgroupfs.filesystem.tasksMu.
package cgroupfs
import (
@@ -109,7 +110,7 @@ type InternalData struct {
DefaultControlValues map[string]int64
}
-// filesystem implements vfs.FilesystemImpl.
+// filesystem implements vfs.FilesystemImpl and kernel.cgroupFS.
//
// +stateify savable
type filesystem struct {
@@ -139,6 +140,11 @@ type filesystem struct {
tasksMu sync.RWMutex `state:"nosave"`
}
+// InitializeHierarchyID implements kernel.cgroupFS.InitializeHierarchyID.
+func (fs *filesystem) InitializeHierarchyID(hid uint32) {
+ fs.hierarchyID = hid
+}
+
// Name implements vfs.FilesystemType.Name.
func (FilesystemType) Name() string {
return Name
@@ -284,14 +290,12 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// Register controllers. The registry may be modified concurrently, so if we
// get an error, we raced with someone else who registered the same
// controllers first.
- hid, err := r.Register(fs.kcontrollers)
- if err != nil {
+ if err := r.Register(fs.kcontrollers, fs); err != nil {
ctx.Infof("cgroupfs.FilesystemType.GetFilesystem: failed to register new hierarchy with controllers %v: %v", wantControllers, err)
rootD.DecRef(ctx)
fs.VFSFilesystem().DecRef(ctx)
return nil, nil, syserror.EBUSY
}
- fs.hierarchyID = hid
// Move all existing tasks to the root of the new hierarchy.
k.PopulateNewCgroupHierarchy(fs.rootCgroup())
diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD
index 52879f871..368272f12 100644
--- a/pkg/sentry/fsimpl/gofer/BUILD
+++ b/pkg/sentry/fsimpl/gofer/BUILD
@@ -54,6 +54,7 @@ go_library(
"//pkg/fspath",
"//pkg/hostarch",
"//pkg/log",
+ "//pkg/metric",
"//pkg/p9",
"//pkg/refs",
"//pkg/refsvfs2",
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 97ce80853..eb09d54c3 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -961,7 +961,7 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open
}
return &fd.vfsfd, nil
case linux.S_IFLNK:
- // Can't open symlinks without O_PATH (which is unimplemented).
+ // Can't open symlinks without O_PATH, which is handled at the VFS layer.
return nil, syserror.ELOOP
case linux.S_IFSOCK:
if d.isSynthetic() {
@@ -1194,11 +1194,7 @@ func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st
// RenameAt implements vfs.FilesystemImpl.RenameAt.
func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error {
- if opts.Flags != 0 {
- // Requires 9P support.
- return syserror.EINVAL
- }
-
+ // Resolve newParent first to verify that it's on this Mount.
var ds *[]*dentry
fs.renameMu.Lock()
defer fs.renameMuUnlockAndCheckCaching(ctx, &ds)
@@ -1206,8 +1202,21 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if err != nil {
return err
}
+
+ if opts.Flags&^linux.RENAME_NOREPLACE != 0 {
+ return syserror.EINVAL
+ }
+ if fs.opts.interop == InteropModeShared && opts.Flags&linux.RENAME_NOREPLACE != 0 {
+ // Requires 9P support to synchronize with other remote filesystem
+ // users.
+ return syserror.EINVAL
+ }
+
newName := rp.Component()
if newName == "." || newName == ".." {
+ if opts.Flags&linux.RENAME_NOREPLACE != 0 {
+ return syserror.EEXIST
+ }
return syserror.EBUSY
}
mnt := rp.Mount()
@@ -1280,6 +1289,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
}
var replacedVFSD *vfs.Dentry
if replaced != nil {
+ if opts.Flags&linux.RENAME_NOREPLACE != 0 {
+ return syserror.EEXIST
+ }
replacedVFSD = &replaced.vfsd
if replaced.isDir() {
if !renamed.isDir() {
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 21692d2ac..cf69e1b7a 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -1282,9 +1282,12 @@ func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes)
}
func (d *dentry) checkXattrPermissions(creds *auth.Credentials, name string, ats vfs.AccessTypes) error {
- // We only support xattrs prefixed with "user." (see b/148380782). Currently,
- // there is no need to expose any other xattrs through a gofer.
- if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) {
+ // Deny access to the "security" and "system" namespaces since applications
+ // may expect these to affect kernel behavior in unimplemented ways
+ // (b/148380782). Allow all other extended attributes to be passed through
+ // to the remote filesystem. This is inconsistent with Linux's 9p client,
+ // but consistent with other filesystems (e.g. FUSE).
+ if strings.HasPrefix(name, linux.XATTR_SECURITY_PREFIX) || strings.HasPrefix(name, linux.XATTR_SYSTEM_PREFIX) {
return syserror.EOPNOTSUPP
}
mode := linux.FileMode(atomic.LoadUint32(&d.mode))
@@ -1684,7 +1687,7 @@ func (d *dentry) setDeleted() {
}
func (d *dentry) listXattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) {
- if d.file.isNil() || !d.userXattrSupported() {
+ if d.file.isNil() {
return nil, nil
}
xattrMap, err := d.file.listXattr(ctx, size)
@@ -1693,10 +1696,7 @@ func (d *dentry) listXattr(ctx context.Context, creds *auth.Credentials, size ui
}
xattrs := make([]string, 0, len(xattrMap))
for x := range xattrMap {
- // We only support xattrs in the user.* namespace.
- if strings.HasPrefix(x, linux.XATTR_USER_PREFIX) {
- xattrs = append(xattrs, x)
- }
+ xattrs = append(xattrs, x)
}
return xattrs, nil
}
@@ -1731,13 +1731,6 @@ func (d *dentry) removeXattr(ctx context.Context, creds *auth.Credentials, name
return d.file.removeXattr(ctx, name)
}
-// Extended attributes in the user.* namespace are only supported for regular
-// files and directories.
-func (d *dentry) userXattrSupported() bool {
- filetype := linux.FileMode(atomic.LoadUint32(&d.mode)).FileType()
- return filetype == linux.ModeRegular || filetype == linux.ModeDirectory
-}
-
// Preconditions:
// * !d.isSynthetic().
// * d.isRegularFile() || d.isDir().
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index f0e7bbaf7..eed05e369 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
@@ -59,7 +60,7 @@ func newRegularFileFD(mnt *vfs.Mount, d *dentry, flags uint32) (*regularFileFD,
return nil, err
}
if fd.vfsfd.IsWritable() && (atomic.LoadUint32(&d.mode)&0111 != 0) {
- fsmetric.GoferOpensWX.Increment()
+ metric.SuspiciousOperationsMetric.Increment("opened_write_execute_file")
}
if atomic.LoadInt32(&d.mmapFD) >= 0 {
fsmetric.GoferOpensHost.Increment()
diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go
index ac3b5b621..c12444b7e 100644
--- a/pkg/sentry/fsimpl/gofer/special_file.go
+++ b/pkg/sentry/fsimpl/gofer/special_file.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fsmetric"
@@ -100,7 +101,7 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, flags uint32) (*speci
d.fs.specialFileFDs[fd] = struct{}{}
d.fs.syncMu.Unlock()
if fd.vfsfd.IsWritable() && (atomic.LoadUint32(&d.mode)&0111 != 0) {
- fsmetric.GoferOpensWX.Increment()
+ metric.SuspiciousOperationsMetric.Increment("opened_write_execute_file")
}
if h.fd >= 0 {
fsmetric.GoferOpensHost.Increment()
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index f50b0fb08..8fac53c60 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -635,12 +635,6 @@ func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st
// RenameAt implements vfs.FilesystemImpl.RenameAt.
func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error {
- // Only RENAME_NOREPLACE is supported.
- if opts.Flags&^linux.RENAME_NOREPLACE != 0 {
- return syserror.EINVAL
- }
- noReplace := opts.Flags&linux.RENAME_NOREPLACE != 0
-
fs.mu.Lock()
defer fs.processDeferredDecRefs(ctx)
defer fs.mu.Unlock()
@@ -651,6 +645,13 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if err != nil {
return err
}
+
+ // Only RENAME_NOREPLACE is supported.
+ if opts.Flags&^linux.RENAME_NOREPLACE != 0 {
+ return syserror.EINVAL
+ }
+ noReplace := opts.Flags&linux.RENAME_NOREPLACE != 0
+
mnt := rp.Mount()
if mnt != oldParentVD.Mount() {
return syserror.EXDEV
diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go
index 46c500427..6b6fa0bd5 100644
--- a/pkg/sentry/fsimpl/overlay/filesystem.go
+++ b/pkg/sentry/fsimpl/overlay/filesystem.go
@@ -1017,10 +1017,7 @@ func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st
// RenameAt implements vfs.FilesystemImpl.RenameAt.
func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error {
- if opts.Flags != 0 {
- return syserror.EINVAL
- }
-
+ // Resolve newParent first to verify that it's on this Mount.
var ds *[]*dentry
fs.renameMu.Lock()
defer fs.renameMuUnlockAndCheckDrop(ctx, &ds)
@@ -1028,8 +1025,16 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if err != nil {
return err
}
+
+ if opts.Flags&^linux.RENAME_NOREPLACE != 0 {
+ return syserror.EINVAL
+ }
+
newName := rp.Component()
if newName == "." || newName == ".." {
+ if opts.Flags&linux.RENAME_NOREPLACE != 0 {
+ return syserror.EEXIST
+ }
return syserror.EBUSY
}
mnt := rp.Mount()
@@ -1093,6 +1098,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
return err
}
if replaced != nil {
+ if opts.Flags&linux.RENAME_NOREPLACE != 0 {
+ return syserror.EEXIST
+ }
replacedVFSD = &replaced.vfsd
if replaced.isDir() {
if !renamed.isDir() {
diff --git a/pkg/sentry/fsimpl/overlay/regular_file.go b/pkg/sentry/fsimpl/overlay/regular_file.go
index 43bfd69a3..82491a0f8 100644
--- a/pkg/sentry/fsimpl/overlay/regular_file.go
+++ b/pkg/sentry/fsimpl/overlay/regular_file.go
@@ -207,9 +207,10 @@ func (fd *regularFileFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) e
return err
}
- // Changing owners may clear one or both of the setuid and setgid bits,
- // so we may have to update opts before setting d.mode.
- if opts.Stat.Mask&(linux.STATX_UID|linux.STATX_GID) != 0 {
+ // Changing owners or truncating may clear one or both of the setuid and
+ // setgid bits, so we may have to update opts before setting d.mode.
+ inotifyMask := opts.Stat.Mask
+ if opts.Stat.Mask&(linux.STATX_UID|linux.STATX_GID|linux.STATX_SIZE) != 0 {
stat, err := wrappedFD.Stat(ctx, vfs.StatOptions{
Mask: linux.STATX_MODE,
})
@@ -218,10 +219,14 @@ func (fd *regularFileFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) e
}
opts.Stat.Mode = stat.Mode
opts.Stat.Mask |= linux.STATX_MODE
+ // Don't generate inotify IN_ATTRIB for size-only changes (truncations).
+ if opts.Stat.Mask&(linux.STATX_UID|linux.STATX_GID) != 0 {
+ inotifyMask |= linux.STATX_MODE
+ }
}
d.updateAfterSetStatLocked(&opts)
- if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
+ if ev := vfs.InotifyEventFromStatMask(inotifyMask); ev != 0 {
d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent)
}
return nil
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
index 9b14dd6b9..2bc98a94f 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -55,6 +55,7 @@ func (fs *filesystem) newSysDir(ctx context.Context, root *auth.Credentials, k *
}),
}),
"vm": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
+ "max_map_count": fs.newInode(ctx, root, 0444, newStaticFile("2147483647\n")),
"mmap_min_addr": fs.newInode(ctx, root, 0444, &mmapMinAddrData{k: k}),
"overcommit_memory": fs.newInode(ctx, root, 0444, newStaticFile("0\n")),
}),
@@ -365,27 +366,22 @@ func (d *tcpMemData) writeSizeLocked(size inet.TCPBufferSize) error {
}
// ipForwarding implements vfs.WritableDynamicBytesSource for
-// /proc/sys/net/ipv4/ip_forwarding.
+// /proc/sys/net/ipv4/ip_forward.
//
// +stateify savable
type ipForwarding struct {
kernfs.DynamicBytesFile
stack inet.Stack `state:"wait"`
- enabled *bool
+ enabled bool
}
var _ vfs.WritableDynamicBytesSource = (*ipForwarding)(nil)
// Generate implements vfs.DynamicBytesSource.Generate.
func (ipf *ipForwarding) Generate(ctx context.Context, buf *bytes.Buffer) error {
- if ipf.enabled == nil {
- enabled := ipf.stack.Forwarding(ipv4.ProtocolNumber)
- ipf.enabled = &enabled
- }
-
val := "0\n"
- if *ipf.enabled {
+ if ipf.enabled {
// Technically, this is not quite compatible with Linux. Linux stores these
// as an integer, so if you write "2" into tcp_sack, you should get 2 back.
// Tough luck.
@@ -414,11 +410,8 @@ func (ipf *ipForwarding) Write(ctx context.Context, src usermem.IOSequence, offs
if err != nil {
return 0, err
}
- if ipf.enabled == nil {
- ipf.enabled = new(bool)
- }
- *ipf.enabled = v != 0
- if err := ipf.stack.SetForwarding(ipv4.ProtocolNumber, *ipf.enabled); err != nil {
+ ipf.enabled = v != 0
+ if err := ipf.stack.SetForwarding(ipv4.ProtocolNumber, ipf.enabled); err != nil {
return 0, err
}
return n, nil
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys_test.go b/pkg/sentry/fsimpl/proc/tasks_sys_test.go
index 6cee22823..19b012f7d 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys_test.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys_test.go
@@ -132,7 +132,7 @@ func TestConfigureIPForwarding(t *testing.T) {
t.Run(c.comment, func(t *testing.T) {
s.IPForwarding = c.initial
- file := &ipForwarding{stack: s, enabled: &c.initial}
+ file := &ipForwarding{stack: s, enabled: c.initial}
// Write the values.
src := usermem.BytesIOSequence([]byte(c.str))
diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD
index c766164c7..b3f9d1010 100644
--- a/pkg/sentry/fsimpl/testutil/BUILD
+++ b/pkg/sentry/fsimpl/testutil/BUILD
@@ -17,7 +17,6 @@ go_library(
"//pkg/fspath",
"//pkg/hostarch",
"//pkg/memutil",
- "//pkg/metric",
"//pkg/sentry/fsbridge",
"//pkg/sentry/fsimpl/tmpfs",
"//pkg/sentry/kernel",
diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go
index 33e52ce64..97aa20cd1 100644
--- a/pkg/sentry/fsimpl/testutil/kernel.go
+++ b/pkg/sentry/fsimpl/testutil/kernel.go
@@ -25,7 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/memutil"
- "gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -63,8 +62,6 @@ func Boot() (*kernel.Kernel, error) {
return nil, fmt.Errorf("creating platform: %v", err)
}
- metric.CreateSentryMetrics()
-
kernel.VFS2Enabled = true
k := &kernel.Kernel{
Platform: plat,
@@ -88,6 +85,7 @@ func Boot() (*kernel.Kernel, error) {
return nil, fmt.Errorf("creating timekeeper: %v", err)
}
tk.SetClocks(time.NewCalibratedClocks())
+ k.SetTimekeeper(tk)
creds := auth.NewRootCredentials(auth.NewRootUserNamespace())
@@ -96,7 +94,6 @@ func Boot() (*kernel.Kernel, error) {
if err = k.Init(kernel.InitKernelArgs{
ApplicationCores: uint(runtime.GOMAXPROCS(-1)),
FeatureSet: cpuid.HostFeatureSet(),
- Timekeeper: tk,
RootUserNamespace: creds.UserNamespace,
Vdso: vdso,
RootUTSNamespace: kernel.NewUTSNamespace("hostname", "domain", creds.UserNamespace),
@@ -181,7 +178,7 @@ func createMemoryFile() (*pgalloc.MemoryFile, error) {
memfile := os.NewFile(uintptr(memfd), memfileName)
mf, err := pgalloc.NewMemoryFile(memfile, pgalloc.MemoryFileOpts{})
if err != nil {
- memfile.Close()
+ _ = memfile.Close()
return nil, fmt.Errorf("error creating pgalloc.MemoryFile: %v", err)
}
return mf, nil
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index 5fdca1d46..f0f4297ef 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -465,7 +465,7 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open
}
return &fd.vfsfd, nil
case *symlink:
- // TODO(gvisor.dev/issue/2782): Can't open symlinks without O_PATH.
+ // Can't open symlinks without O_PATH, which is handled at the VFS layer.
return nil, syserror.ELOOP
case *namedPipe:
return impl.pipe.Open(ctx, rp.Mount(), &d.vfsd, opts.Flags, &d.inode.locks)
@@ -496,20 +496,24 @@ func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st
// RenameAt implements vfs.FilesystemImpl.RenameAt.
func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error {
- if opts.Flags != 0 {
- // TODO(b/145974740): Support renameat2 flags.
- return syserror.EINVAL
- }
-
- // Resolve newParent first to verify that it's on this Mount.
+ // Resolve newParentDir first to verify that it's on this Mount.
fs.mu.Lock()
defer fs.mu.Unlock()
newParentDir, err := walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry))
if err != nil {
return err
}
+
+ if opts.Flags&^linux.RENAME_NOREPLACE != 0 {
+ // TODO(b/145974740): Support other renameat2 flags.
+ return syserror.EINVAL
+ }
+
newName := rp.Component()
if newName == "." || newName == ".." {
+ if opts.Flags&linux.RENAME_NOREPLACE != 0 {
+ return syserror.EEXIST
+ }
return syserror.EBUSY
}
mnt := rp.Mount()
@@ -556,6 +560,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
}
replaced, ok := newParentDir.childMap[newName]
if ok {
+ if opts.Flags&linux.RENAME_NOREPLACE != 0 {
+ return syserror.EEXIST
+ }
replacedDir, ok := replaced.inode.impl.(*directory)
if ok {
if !renamed.inode.isDir() {
@@ -815,7 +822,7 @@ func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, si
if err != nil {
return nil, err
}
- return d.inode.listXattr(size)
+ return d.inode.listXattr(rp.Credentials(), size)
}
// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt.
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index 9ae25ce9e..6b4367c42 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -717,44 +717,63 @@ func (i *inode) touchCMtimeLocked() {
atomic.StoreInt64(&i.ctime, now)
}
-func (i *inode) listXattr(size uint64) ([]string, error) {
- return i.xattrs.ListXattr(size)
+func checkXattrName(name string) error {
+ // Linux's tmpfs supports "security" and "trusted" xattr namespaces, and
+ // (depending on build configuration) POSIX ACL xattr namespaces
+ // ("system.posix_acl_access" and "system.posix_acl_default"). We don't
+ // support POSIX ACLs or the "security" namespace (b/148380782).
+ if strings.HasPrefix(name, linux.XATTR_TRUSTED_PREFIX) {
+ return nil
+ }
+ // We support the "user" namespace because we have tests that depend on
+ // this feature.
+ if strings.HasPrefix(name, linux.XATTR_USER_PREFIX) {
+ return nil
+ }
+ return syserror.EOPNOTSUPP
+}
+
+func (i *inode) listXattr(creds *auth.Credentials, size uint64) ([]string, error) {
+ return i.xattrs.ListXattr(creds, size)
}
func (i *inode) getXattr(creds *auth.Credentials, opts *vfs.GetXattrOptions) (string, error) {
- if err := i.checkXattrPermissions(creds, opts.Name, vfs.MayRead); err != nil {
+ if err := checkXattrName(opts.Name); err != nil {
return "", err
}
- return i.xattrs.GetXattr(opts)
+ mode := linux.FileMode(atomic.LoadUint32(&i.mode))
+ kuid := auth.KUID(atomic.LoadUint32(&i.uid))
+ kgid := auth.KGID(atomic.LoadUint32(&i.gid))
+ if err := vfs.GenericCheckPermissions(creds, vfs.MayRead, mode, kuid, kgid); err != nil {
+ return "", err
+ }
+ return i.xattrs.GetXattr(creds, mode, kuid, opts)
}
func (i *inode) setXattr(creds *auth.Credentials, opts *vfs.SetXattrOptions) error {
- if err := i.checkXattrPermissions(creds, opts.Name, vfs.MayWrite); err != nil {
+ if err := checkXattrName(opts.Name); err != nil {
return err
}
- return i.xattrs.SetXattr(opts)
-}
-
-func (i *inode) removeXattr(creds *auth.Credentials, name string) error {
- if err := i.checkXattrPermissions(creds, name, vfs.MayWrite); err != nil {
+ mode := linux.FileMode(atomic.LoadUint32(&i.mode))
+ kuid := auth.KUID(atomic.LoadUint32(&i.uid))
+ kgid := auth.KGID(atomic.LoadUint32(&i.gid))
+ if err := vfs.GenericCheckPermissions(creds, vfs.MayWrite, mode, kuid, kgid); err != nil {
return err
}
- return i.xattrs.RemoveXattr(name)
+ return i.xattrs.SetXattr(creds, mode, kuid, opts)
}
-func (i *inode) checkXattrPermissions(creds *auth.Credentials, name string, ats vfs.AccessTypes) error {
- // We currently only support extended attributes in the user.* and
- // trusted.* namespaces. See b/148380782.
- if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) && !strings.HasPrefix(name, linux.XATTR_TRUSTED_PREFIX) {
- return syserror.EOPNOTSUPP
+func (i *inode) removeXattr(creds *auth.Credentials, name string) error {
+ if err := checkXattrName(name); err != nil {
+ return err
}
mode := linux.FileMode(atomic.LoadUint32(&i.mode))
kuid := auth.KUID(atomic.LoadUint32(&i.uid))
kgid := auth.KGID(atomic.LoadUint32(&i.gid))
- if err := vfs.GenericCheckPermissions(creds, ats, mode, kuid, kgid); err != nil {
+ if err := vfs.GenericCheckPermissions(creds, vfs.MayWrite, mode, kuid, kgid); err != nil {
return err
}
- return vfs.CheckXattrPermissions(creds, ats, mode, kuid, name)
+ return i.xattrs.RemoveXattr(creds, mode, kuid, name)
}
// fileDescription is embedded by tmpfs implementations of
@@ -807,7 +826,7 @@ func (fd *fileDescription) StatFS(ctx context.Context) (linux.Statfs, error) {
// ListXattr implements vfs.FileDescriptionImpl.ListXattr.
func (fd *fileDescription) ListXattr(ctx context.Context, size uint64) ([]string, error) {
- return fd.inode().listXattr(size)
+ return fd.inode().listXattr(auth.CredentialsFromContext(ctx), size)
}
// GetXattr implements vfs.FileDescriptionImpl.GetXattr.
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
index 31d34ef60..969003613 100644
--- a/pkg/sentry/fsimpl/verity/verity.go
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -868,6 +868,10 @@ func (fd *fileDescription) IterDirents(ctx context.Context, cb vfs.IterDirentsCa
fd.mu.Lock()
defer fd.mu.Unlock()
+ if _, err := fd.lowerFD.Seek(ctx, fd.off, linux.SEEK_SET); err != nil {
+ return err
+ }
+
var ds []vfs.Dirent
err := fd.lowerFD.IterDirents(ctx, vfs.IterDirentsCallbackFunc(func(dirent vfs.Dirent) error {
// Do not include the Merkle tree files.
@@ -890,8 +894,8 @@ func (fd *fileDescription) IterDirents(ctx context.Context, cb vfs.IterDirentsCa
return err
}
- // The result should contain all children plus "." and "..".
- if fd.d.verityEnabled() && len(ds) != len(fd.d.childrenNames)+2 {
+ // The result should be a part of all children plus "." and "..", counting from fd.off.
+ if fd.d.verityEnabled() && len(ds) != len(fd.d.childrenNames)+2-int(fd.off) {
return fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Unexpected children number %d", len(ds)))
}
@@ -1299,6 +1303,11 @@ func (fd *fileDescription) ConfigureMMap(ctx context.Context, opts *memmap.MMapO
return vfs.GenericConfigureMMap(&fd.vfsfd, fd, opts)
}
+// SupportsLocks implements vfs.FileDescriptionImpl.SupportsLocks.
+func (fd *fileDescription) SupportsLocks() bool {
+ return fd.lowerFD.SupportsLocks()
+}
+
// LockBSD implements vfs.FileDescriptionImpl.LockBSD.
func (fd *fileDescription) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error {
return fd.lowerFD.LockBSD(ctx, ownerPID, t, block)
diff --git a/pkg/sentry/fsmetric/fsmetric.go b/pkg/sentry/fsmetric/fsmetric.go
index 7e535b527..17d0d5025 100644
--- a/pkg/sentry/fsmetric/fsmetric.go
+++ b/pkg/sentry/fsmetric/fsmetric.go
@@ -42,7 +42,6 @@ var (
// Metrics that only apply to fs/gofer and fsimpl/gofer.
var (
- GoferOpensWX = metric.MustCreateNewUint64Metric("/gofer/opened_write_execute_file", true /* sync */, "Number of times a executable file was opened writably from a gofer.")
GoferOpens9P = metric.MustCreateNewUint64Metric("/gofer/opens_9p", false /* sync */, "Number of times a file was opened from a gofer and did not have a host file descriptor.")
GoferOpensHost = metric.MustCreateNewUint64Metric("/gofer/opens_host", false /* sync */, "Number of times a file was opened from a gofer and did have a host file descriptor.")
GoferReads9P = metric.MustCreateNewUint64Metric("/gofer/reads_9p", false /* sync */, "Number of 9P file reads from a gofer.")
diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go
index 6b71bd3a9..80dda1559 100644
--- a/pkg/sentry/inet/inet.go
+++ b/pkg/sentry/inet/inet.go
@@ -88,9 +88,6 @@ type Stack interface {
// for restoring a stack after a save.
RestoreCleanupEndpoints([]stack.TransportEndpoint)
- // Forwarding returns if packet forwarding between NICs is enabled.
- Forwarding(protocol tcpip.NetworkProtocolNumber) bool
-
// SetForwarding enables or disables packet forwarding between NICs.
SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error
diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go
index 03e2608c2..218d9dafc 100644
--- a/pkg/sentry/inet/test_stack.go
+++ b/pkg/sentry/inet/test_stack.go
@@ -154,11 +154,6 @@ func (s *TestStack) CleanupEndpoints() []stack.TransportEndpoint {
// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints.
func (s *TestStack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {}
-// Forwarding implements inet.Stack.Forwarding.
-func (s *TestStack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool {
- return s.IPForwarding
-}
-
// SetForwarding implements inet.Stack.SetForwarding.
func (s *TestStack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error {
s.IPForwarding = enable
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index a1ec6daab..188c0ebff 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -32,7 +32,7 @@ go_template_instance(
out = "seqatomic_taskgoroutineschedinfo_unsafe.go",
package = "kernel",
suffix = "TaskGoroutineSchedInfo",
- template = "//pkg/sync:generic_seqatomic",
+ template = "//pkg/sync/seqatomic:generic_seqatomic",
types = {
"Value": "TaskGoroutineSchedInfo",
},
diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD
index 869e49ebc..12180351d 100644
--- a/pkg/sentry/kernel/auth/BUILD
+++ b/pkg/sentry/kernel/auth/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "atomicptr_credentials_unsafe.go",
package = "auth",
suffix = "Credentials",
- template = "//pkg/sync:generic_atomicptr",
+ template = "//pkg/sync/atomicptr:generic_atomicptr",
types = {
"Value": "Credentials",
},
diff --git a/pkg/sentry/kernel/auth/credentials.go b/pkg/sentry/kernel/auth/credentials.go
index 6862f2ef5..3325fedcb 100644
--- a/pkg/sentry/kernel/auth/credentials.go
+++ b/pkg/sentry/kernel/auth/credentials.go
@@ -125,7 +125,7 @@ func NewUserCredentials(kuid KUID, kgid KGID, extraKGIDs []KGID, capabilities *T
creds.EffectiveCaps = capabilities.EffectiveCaps
creds.BoundingCaps = capabilities.BoundingCaps
creds.InheritableCaps = capabilities.InheritableCaps
- // TODO(nlacasse): Support ambient capabilities.
+ // TODO(gvisor.dev/issue/3166): Support ambient capabilities.
} else {
// If no capabilities are specified, grant capabilities consistent with
// setresuid + setresgid from NewRootCredentials to the given uid and
diff --git a/pkg/sentry/kernel/cgroup.go b/pkg/sentry/kernel/cgroup.go
index 1f1c63f37..c93ef6ac1 100644
--- a/pkg/sentry/kernel/cgroup.go
+++ b/pkg/sentry/kernel/cgroup.go
@@ -48,10 +48,6 @@ type CgroupController interface {
// attached to. Returned value is valid for the lifetime of the controller.
HierarchyID() uint32
- // Filesystem returns the filesystem this controller is attached to.
- // Returned value is valid for the lifetime of the controller.
- Filesystem() *vfs.Filesystem
-
// RootCgroup returns the root cgroup for this controller. Returned value is
// valid for the lifetime of the controller.
RootCgroup() Cgroup
@@ -124,6 +120,19 @@ func (h *hierarchy) match(ctypes []CgroupControllerType) bool {
return true
}
+// cgroupFS is the public interface to cgroupfs. This lets the kernel package
+// refer to cgroupfs.filesystem methods without directly depending on the
+// cgroupfs package, which would lead to a circular dependency.
+type cgroupFS interface {
+ // Returns the vfs.Filesystem for the cgroupfs.
+ VFSFilesystem() *vfs.Filesystem
+
+ // InitializeHierarchyID sets the hierarchy ID for this filesystem during
+ // filesystem creation. May only be called before the filesystem is visible
+ // to the vfs layer.
+ InitializeHierarchyID(hid uint32)
+}
+
// CgroupRegistry tracks the active set of cgroup controllers on the system.
//
// +stateify savable
@@ -172,7 +181,23 @@ func (r *CgroupRegistry) FindHierarchy(ctypes []CgroupControllerType) *vfs.Files
for _, h := range r.hierarchies {
if h.match(ctypes) {
- h.fs.IncRef()
+ if !h.fs.TryIncRef() {
+ // Racing with filesystem destruction, namely h.fs.Release.
+ // Since we hold r.mu, we know the hierarchy hasn't been
+ // unregistered yet, but its associated filesystem is tearing
+ // down.
+ //
+ // If we simply indicate the hierarchy wasn't found without
+ // cleaning up the registry, the caller can race with the
+ // unregister and find itself temporarily unable to create a new
+ // hierarchy with a subset of the relevant controllers.
+ //
+ // To keep the result of FindHierarchy consistent with the
+ // uniqueness of controllers enforced by Register, drop the
+ // dying hierarchy now. The eventual unregister by the FS
+ // teardown will become a no-op.
+ return nil
+ }
return h.fs
}
}
@@ -182,31 +207,35 @@ func (r *CgroupRegistry) FindHierarchy(ctypes []CgroupControllerType) *vfs.Files
// Register registers the provided set of controllers with the registry as a new
// hierarchy. If any controller is already registered, the function returns an
-// error without modifying the registry. The hierarchy can be later referenced
-// by the returned id.
-func (r *CgroupRegistry) Register(cs []CgroupController) (uint32, error) {
+// error without modifying the registry. Register sets the hierarchy ID for the
+// filesystem on success.
+func (r *CgroupRegistry) Register(cs []CgroupController, fs cgroupFS) error {
r.mu.Lock()
defer r.mu.Unlock()
if len(cs) == 0 {
- return InvalidCgroupHierarchyID, fmt.Errorf("can't register hierarchy with no controllers")
+ return fmt.Errorf("can't register hierarchy with no controllers")
}
for _, c := range cs {
if _, ok := r.controllers[c.Type()]; ok {
- return InvalidCgroupHierarchyID, fmt.Errorf("controllers may only be mounted on a single hierarchy")
+ return fmt.Errorf("controllers may only be mounted on a single hierarchy")
}
}
hid, err := r.nextHierarchyID()
if err != nil {
- return hid, err
+ return err
}
+ // Must not fail below here, once we publish the hierarchy ID.
+
+ fs.InitializeHierarchyID(hid)
+
h := hierarchy{
id: hid,
controllers: make(map[CgroupControllerType]CgroupController),
- fs: cs[0].Filesystem(),
+ fs: fs.VFSFilesystem(),
}
for _, c := range cs {
n := c.Type()
@@ -214,15 +243,20 @@ func (r *CgroupRegistry) Register(cs []CgroupController) (uint32, error) {
h.controllers[n] = c
}
r.hierarchies[hid] = h
- return hid, nil
+ return nil
}
-// Unregister removes a previously registered hierarchy from the registry. If
-// the controller was not previously registered, Unregister is a no-op.
+// Unregister removes a previously registered hierarchy from the registry. If no
+// such hierarchy is registered, Unregister is a no-op.
func (r *CgroupRegistry) Unregister(hid uint32) {
r.mu.Lock()
- defer r.mu.Unlock()
+ r.unregisterLocked(hid)
+ r.mu.Unlock()
+}
+// Precondition: Caller must hold r.mu.
+// +checklocks:r.mu
+func (r *CgroupRegistry) unregisterLocked(hid uint32) {
if h, ok := r.hierarchies[hid]; ok {
for name, _ := range h.controllers {
delete(r.controllers, name)
@@ -253,6 +287,11 @@ func (r *CgroupRegistry) computeInitialGroups(inherit map[Cgroup]struct{}) map[C
for name, ctl := range r.controllers {
if _, ok := ctlSet[name]; !ok {
cg := ctl.RootCgroup()
+ // Multiple controllers may share the same hierarchy, so may have
+ // the same root cgroup. Grab a single ref per hierarchy root.
+ if _, ok := cgset[cg]; ok {
+ continue
+ }
cg.IncRef() // Ref transferred to caller.
cgset[cg] = struct{}{}
}
diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD
index f855f038b..6224a0cbd 100644
--- a/pkg/sentry/kernel/fasync/BUILD
+++ b/pkg/sentry/kernel/fasync/BUILD
@@ -8,7 +8,6 @@ go_library(
visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
- "//pkg/sentry/arch",
"//pkg/sentry/fs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/kernel/fasync/fasync.go b/pkg/sentry/kernel/fasync/fasync.go
index dbbbaeeb0..5d584dc45 100644
--- a/pkg/sentry/kernel/fasync/fasync.go
+++ b/pkg/sentry/kernel/fasync/fasync.go
@@ -17,7 +17,6 @@ package fasync
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -125,9 +124,9 @@ func (a *FileAsync) Callback(e *waiter.Entry, mask waiter.EventMask) {
if !permCheck {
return
}
- signalInfo := &arch.SignalInfo{
+ signalInfo := &linux.SignalInfo{
Signo: int32(linux.SIGIO),
- Code: arch.SignalInfoKernel,
+ Code: linux.SI_KERNEL,
}
if a.signal != 0 {
signalInfo.Signo = int32(a.signal)
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index 10885688c..62777faa8 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -154,9 +154,11 @@ func (f *FDTable) drop(ctx context.Context, file *fs.File) {
// dropVFS2 drops the table reference.
func (f *FDTable) dropVFS2(ctx context.Context, file *vfs.FileDescription) {
// Release any POSIX lock possibly held by the FDTable.
- err := file.UnlockPOSIX(ctx, f, lock.LockRange{0, lock.LockEOF})
- if err != nil && err != syserror.ENOLCK {
- panic(fmt.Sprintf("UnlockPOSIX failed: %v", err))
+ if file.SupportsLocks() {
+ err := file.UnlockPOSIX(ctx, f, lock.LockRange{0, lock.LockEOF})
+ if err != nil && err != syserror.ENOLCK {
+ panic(fmt.Sprintf("UnlockPOSIX failed: %v", err))
+ }
}
// Drop the table's reference.
diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD
index a75686cf3..6c31e082c 100644
--- a/pkg/sentry/kernel/futex/BUILD
+++ b/pkg/sentry/kernel/futex/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "atomicptr_bucket_unsafe.go",
package = "futex",
suffix = "Bucket",
- template = "//pkg/sync:generic_atomicptr",
+ template = "//pkg/sync/atomicptr:generic_atomicptr",
types = {
"Value": "bucket",
},
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index e6e9da898..d537e608a 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -306,9 +306,6 @@ type InitKernelArgs struct {
// FeatureSet is the emulated CPU feature set.
FeatureSet *cpuid.FeatureSet
- // Timekeeper manages time for all tasks in the system.
- Timekeeper *Timekeeper
-
// RootUserNamespace is the root user namespace.
RootUserNamespace *auth.UserNamespace
@@ -348,29 +345,34 @@ type InitKernelArgs struct {
PIDNamespace *PIDNamespace
}
+// SetTimekeeper sets Kernel.timekeeper. SetTimekeeper must be called before
+// Init.
+func (k *Kernel) SetTimekeeper(tk *Timekeeper) {
+ k.timekeeper = tk
+}
+
// Init initialize the Kernel with no tasks.
//
// Callers must manually set Kernel.Platform and call Kernel.SetMemoryFile
-// before calling Init.
+// and Kernel.SetTimekeeper before calling Init.
func (k *Kernel) Init(args InitKernelArgs) error {
if args.FeatureSet == nil {
- return fmt.Errorf("FeatureSet is nil")
+ return fmt.Errorf("args.FeatureSet is nil")
}
- if args.Timekeeper == nil {
- return fmt.Errorf("Timekeeper is nil")
+ if k.timekeeper == nil {
+ return fmt.Errorf("timekeeper is nil")
}
- if args.Timekeeper.clocks == nil {
+ if k.timekeeper.clocks == nil {
return fmt.Errorf("must call Timekeeper.SetClocks() before Kernel.Init()")
}
if args.RootUserNamespace == nil {
- return fmt.Errorf("RootUserNamespace is nil")
+ return fmt.Errorf("args.RootUserNamespace is nil")
}
if args.ApplicationCores == 0 {
- return fmt.Errorf("ApplicationCores is 0")
+ return fmt.Errorf("args.ApplicationCores is 0")
}
k.featureSet = args.FeatureSet
- k.timekeeper = args.Timekeeper
k.tasks = newTaskSet(args.PIDNamespace)
k.rootUserNamespace = args.RootUserNamespace
k.rootUTSNamespace = args.RootUTSNamespace
@@ -395,8 +397,8 @@ func (k *Kernel) Init(args InitKernelArgs) error {
}
k.extraAuxv = args.ExtraAuxv
k.vdso = args.Vdso
- k.realtimeClock = &timekeeperClock{tk: args.Timekeeper, c: sentrytime.Realtime}
- k.monotonicClock = &timekeeperClock{tk: args.Timekeeper, c: sentrytime.Monotonic}
+ k.realtimeClock = &timekeeperClock{tk: k.timekeeper, c: sentrytime.Realtime}
+ k.monotonicClock = &timekeeperClock{tk: k.timekeeper, c: sentrytime.Monotonic}
k.futexes = futex.NewManager()
k.netlinkPorts = port.New()
k.ptraceExceptions = make(map[*Task]*Task)
@@ -654,12 +656,12 @@ func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error {
defer k.tasks.mu.RUnlock()
for t := range k.tasks.Root.tids {
// We can skip locking Task.mu here since the kernel is paused.
- if mm := t.image.MemoryManager; mm != nil {
- if _, ok := invalidated[mm]; !ok {
- if err := mm.InvalidateUnsavable(ctx); err != nil {
+ if memMgr := t.image.MemoryManager; memMgr != nil {
+ if _, ok := invalidated[memMgr]; !ok {
+ if err := memMgr.InvalidateUnsavable(ctx); err != nil {
return err
}
- invalidated[mm] = struct{}{}
+ invalidated[memMgr] = struct{}{}
}
}
// I really wish we just had a sync.Map of all MMs...
@@ -1339,7 +1341,7 @@ func (k *Kernel) Unpause() {
// context is used only for debugging to describe how the signal was received.
//
// Preconditions: Kernel must have an init process.
-func (k *Kernel) SendExternalSignal(info *arch.SignalInfo, context string) {
+func (k *Kernel) SendExternalSignal(info *linux.SignalInfo, context string) {
k.extMu.Lock()
defer k.extMu.Unlock()
k.sendExternalSignal(info, context)
@@ -1347,7 +1349,7 @@ func (k *Kernel) SendExternalSignal(info *arch.SignalInfo, context string) {
// SendExternalSignalThreadGroup injects a signal into an specific ThreadGroup.
// This function doesn't skip signals like SendExternalSignal does.
-func (k *Kernel) SendExternalSignalThreadGroup(tg *ThreadGroup, info *arch.SignalInfo) error {
+func (k *Kernel) SendExternalSignalThreadGroup(tg *ThreadGroup, info *linux.SignalInfo) error {
k.extMu.Lock()
defer k.extMu.Unlock()
return tg.SendSignal(info)
@@ -1355,7 +1357,7 @@ func (k *Kernel) SendExternalSignalThreadGroup(tg *ThreadGroup, info *arch.Signa
// SendContainerSignal sends the given signal to all processes inside the
// namespace that match the given container ID.
-func (k *Kernel) SendContainerSignal(cid string, info *arch.SignalInfo) error {
+func (k *Kernel) SendContainerSignal(cid string, info *linux.SignalInfo) error {
k.extMu.Lock()
defer k.extMu.Unlock()
k.tasks.mu.RLock()
@@ -1553,22 +1555,23 @@ func (k *Kernel) SetSaveError(err error) {
var _ tcpip.Clock = (*Kernel)(nil)
-// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
-func (k *Kernel) NowNanoseconds() int64 {
- now, err := k.timekeeper.GetTime(sentrytime.Realtime)
+// Now implements tcpip.Clock.NowNanoseconds.
+func (k *Kernel) Now() time.Time {
+ nsec, err := k.timekeeper.GetTime(sentrytime.Realtime)
if err != nil {
- panic("Kernel.NowNanoseconds: " + err.Error())
+ panic("timekeeper.GetTime(sentrytime.Realtime): " + err.Error())
}
- return now
+ return time.Unix(0, nsec)
}
// NowMonotonic implements tcpip.Clock.NowMonotonic.
-func (k *Kernel) NowMonotonic() int64 {
- now, err := k.timekeeper.GetTime(sentrytime.Monotonic)
+func (k *Kernel) NowMonotonic() tcpip.MonotonicTime {
+ nsec, err := k.timekeeper.GetTime(sentrytime.Monotonic)
if err != nil {
- panic("Kernel.NowMonotonic: " + err.Error())
+ panic("timekeeper.GetTime(sentrytime.Monotonic): " + err.Error())
}
- return now
+ var mt tcpip.MonotonicTime
+ return mt.Add(time.Duration(nsec) * time.Nanosecond)
}
// AfterFunc implements tcpip.Clock.AfterFunc.
@@ -1783,7 +1786,7 @@ func (k *Kernel) EmitUnimplementedEvent(ctx context.Context) {
})
t := TaskFromContext(ctx)
- k.unimplementedSyscallEmitter.Emit(&uspb.UnimplementedSyscall{
+ _, _ = k.unimplementedSyscallEmitter.Emit(&uspb.UnimplementedSyscall{
Tid: int32(t.ThreadID()),
Registers: t.Arch().StateData().Proto(),
})
@@ -1858,7 +1861,9 @@ func (k *Kernel) PopulateNewCgroupHierarchy(root Cgroup) {
return
}
t.mu.Lock()
- t.enterCgroupLocked(root)
+ // A task can be in the cgroup if it has been created after the
+ // cgroup hierarchy was registered.
+ t.enterCgroupIfNotYetLocked(root)
t.mu.Unlock()
})
k.tasks.mu.RUnlock()
@@ -1874,7 +1879,7 @@ func (k *Kernel) ReleaseCgroupHierarchy(hid uint32) {
return
}
t.mu.Lock()
- for cg, _ := range t.cgroups {
+ for cg := range t.cgroups {
if cg.HierarchyID() == hid {
t.leaveCgroupLocked(cg)
}
diff --git a/pkg/sentry/kernel/pending_signals.go b/pkg/sentry/kernel/pending_signals.go
index 77a35b788..af455c434 100644
--- a/pkg/sentry/kernel/pending_signals.go
+++ b/pkg/sentry/kernel/pending_signals.go
@@ -17,7 +17,6 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bits"
- "gvisor.dev/gvisor/pkg/sentry/arch"
)
const (
@@ -65,7 +64,7 @@ type pendingSignalQueue struct {
type pendingSignal struct {
// pendingSignalEntry links into a pendingSignalList.
pendingSignalEntry
- *arch.SignalInfo
+ *linux.SignalInfo
// If timer is not nil, it is the IntervalTimer which sent this signal.
timer *IntervalTimer
@@ -75,7 +74,7 @@ type pendingSignal struct {
// on failure (if the given signal's queue is full).
//
// Preconditions: info represents a valid signal.
-func (p *pendingSignals) enqueue(info *arch.SignalInfo, timer *IntervalTimer) bool {
+func (p *pendingSignals) enqueue(info *linux.SignalInfo, timer *IntervalTimer) bool {
sig := linux.Signal(info.Signo)
q := &p.signals[sig.Index()]
if sig.IsStandard() {
@@ -93,7 +92,7 @@ func (p *pendingSignals) enqueue(info *arch.SignalInfo, timer *IntervalTimer) bo
// dequeue dequeues and returns any pending signal not masked by mask. If no
// unmasked signals are pending, dequeue returns nil.
-func (p *pendingSignals) dequeue(mask linux.SignalSet) *arch.SignalInfo {
+func (p *pendingSignals) dequeue(mask linux.SignalSet) *linux.SignalInfo {
// "Real-time signals are delivered in a guaranteed order. Multiple
// real-time signals of the same type are delivered in the order they were
// sent. If different real-time signals are sent to a process, they are
@@ -111,7 +110,7 @@ func (p *pendingSignals) dequeue(mask linux.SignalSet) *arch.SignalInfo {
return p.dequeueSpecific(linux.Signal(lowestPendingUnblockedBit + 1))
}
-func (p *pendingSignals) dequeueSpecific(sig linux.Signal) *arch.SignalInfo {
+func (p *pendingSignals) dequeueSpecific(sig linux.Signal) *linux.SignalInfo {
q := &p.signals[sig.Index()]
ps := q.pendingSignalList.Front()
if ps == nil {
diff --git a/pkg/sentry/kernel/pending_signals_state.go b/pkg/sentry/kernel/pending_signals_state.go
index ca8b4e164..e77f1a254 100644
--- a/pkg/sentry/kernel/pending_signals_state.go
+++ b/pkg/sentry/kernel/pending_signals_state.go
@@ -14,13 +14,11 @@
package kernel
-import (
- "gvisor.dev/gvisor/pkg/sentry/arch"
-)
+import "gvisor.dev/gvisor/pkg/abi/linux"
// +stateify savable
type savedPendingSignal struct {
- si *arch.SignalInfo
+ si *linux.SignalInfo
timer *IntervalTimer
}
diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go
index 2d89b9ccd..3fa5d1d2f 100644
--- a/pkg/sentry/kernel/pipe/pipe_util.go
+++ b/pkg/sentry/kernel/pipe/pipe_util.go
@@ -86,6 +86,12 @@ func (p *Pipe) Write(ctx context.Context, src usermem.IOSequence) (int64, error)
if n > 0 {
p.Notify(waiter.ReadableEvents)
}
+ if err == unix.EPIPE {
+ // If we are returning EPIPE send SIGPIPE to the task.
+ if sendSig := linux.SignalNoInfoFuncFromContext(ctx); sendSig != nil {
+ sendSig(linux.SIGPIPE)
+ }
+ }
return n, err
}
@@ -129,7 +135,7 @@ func (p *Pipe) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArgume
v = math.MaxInt32 // Silently truncate.
}
// Copy result to userspace.
- iocc := primitive.IOCopyContext{
+ iocc := usermem.IOCopyContext{
IO: io,
Ctx: ctx,
Opts: usermem.IOOpts{
diff --git a/pkg/sentry/kernel/posixtimer.go b/pkg/sentry/kernel/posixtimer.go
index 2e861a5a8..d801a3d83 100644
--- a/pkg/sentry/kernel/posixtimer.go
+++ b/pkg/sentry/kernel/posixtimer.go
@@ -18,7 +18,6 @@ import (
"math"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/arch"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -97,7 +96,7 @@ func (it *IntervalTimer) ResumeTimer() {
}
// Preconditions: it.target's signal mutex must be locked.
-func (it *IntervalTimer) updateDequeuedSignalLocked(si *arch.SignalInfo) {
+func (it *IntervalTimer) updateDequeuedSignalLocked(si *linux.SignalInfo) {
it.sigpending = false
if it.sigorphan {
return
@@ -138,9 +137,9 @@ func (it *IntervalTimer) Notify(exp uint64, setting ktime.Setting) (ktime.Settin
it.sigpending = true
it.sigorphan = false
it.overrunCur += exp - 1
- si := &arch.SignalInfo{
+ si := &linux.SignalInfo{
Signo: int32(it.signo),
- Code: arch.SignalInfoTimer,
+ Code: linux.SI_TIMER,
}
si.SetTimerID(it.id)
si.SetSigval(it.sigval)
diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go
index 57c7659e7..a6287fd6a 100644
--- a/pkg/sentry/kernel/ptrace.go
+++ b/pkg/sentry/kernel/ptrace.go
@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
- "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -394,7 +393,7 @@ func (t *Task) ptraceTrapLocked(code int32) {
t.trapStopPending = false
t.tg.signalHandlers.mu.Unlock()
t.ptraceCode = code
- t.ptraceSiginfo = &arch.SignalInfo{
+ t.ptraceSiginfo = &linux.SignalInfo{
Signo: int32(linux.SIGTRAP),
Code: code,
}
@@ -402,7 +401,7 @@ func (t *Task) ptraceTrapLocked(code int32) {
t.ptraceSiginfo.SetUID(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
if t.beginPtraceStopLocked() {
tracer := t.Tracer()
- tracer.signalStop(t, arch.CLD_TRAPPED, int32(linux.SIGTRAP))
+ tracer.signalStop(t, linux.CLD_TRAPPED, int32(linux.SIGTRAP))
tracer.tg.eventQueue.Notify(EventTraceeStop)
}
}
@@ -542,9 +541,9 @@ func (t *Task) ptraceAttach(target *Task, seize bool, opts uintptr) error {
// "Unlike PTRACE_ATTACH, PTRACE_SEIZE does not stop the process." -
// ptrace(2)
if !seize {
- target.sendSignalLocked(&arch.SignalInfo{
+ target.sendSignalLocked(&linux.SignalInfo{
Signo: int32(linux.SIGSTOP),
- Code: arch.SignalInfoUser,
+ Code: linux.SI_USER,
}, false /* group */)
}
// Undocumented Linux feature: If the tracee is already group-stopped (and
@@ -586,7 +585,7 @@ func (t *Task) exitPtrace() {
for target := range t.ptraceTracees {
if target.ptraceOpts.ExitKill {
target.tg.signalHandlers.mu.Lock()
- target.sendSignalLocked(&arch.SignalInfo{
+ target.sendSignalLocked(&linux.SignalInfo{
Signo: int32(linux.SIGKILL),
}, false /* group */)
target.tg.signalHandlers.mu.Unlock()
@@ -652,7 +651,7 @@ func (t *Task) forgetTracerLocked() {
// Preconditions:
// * The signal mutex must be locked.
// * The caller must be running on the task goroutine.
-func (t *Task) ptraceSignalLocked(info *arch.SignalInfo) bool {
+func (t *Task) ptraceSignalLocked(info *linux.SignalInfo) bool {
if linux.Signal(info.Signo) == linux.SIGKILL {
return false
}
@@ -678,7 +677,7 @@ func (t *Task) ptraceSignalLocked(info *arch.SignalInfo) bool {
t.ptraceSiginfo = info
t.Debugf("Entering signal-delivery-stop for signal %d", info.Signo)
if t.beginPtraceStopLocked() {
- tracer.signalStop(t, arch.CLD_TRAPPED, info.Signo)
+ tracer.signalStop(t, linux.CLD_TRAPPED, info.Signo)
tracer.tg.eventQueue.Notify(EventTraceeStop)
}
return true
@@ -829,7 +828,7 @@ func (t *Task) ptraceClone(kind ptraceCloneKind, child *Task, opts *CloneOptions
if child.ptraceSeized {
child.trapStopPending = true
} else {
- child.pendingSignals.enqueue(&arch.SignalInfo{
+ child.pendingSignals.enqueue(&linux.SignalInfo{
Signo: int32(linux.SIGSTOP),
}, nil)
}
@@ -893,9 +892,9 @@ func (t *Task) ptraceExec(oldTID ThreadID) {
}
t.tg.signalHandlers.mu.Lock()
defer t.tg.signalHandlers.mu.Unlock()
- t.sendSignalLocked(&arch.SignalInfo{
+ t.sendSignalLocked(&linux.SignalInfo{
Signo: int32(linux.SIGTRAP),
- Code: arch.SignalInfoUser,
+ Code: linux.SI_USER,
}, false /* group */)
}
@@ -1228,7 +1227,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data hostarch.Addr) error {
return err
case linux.PTRACE_SETSIGINFO:
- var info arch.SignalInfo
+ var info linux.SignalInfo
if _, err := info.CopyIn(t, data); err != nil {
return err
}
diff --git a/pkg/sentry/kernel/seccomp.go b/pkg/sentry/kernel/seccomp.go
index a95e174a2..54ca43c2e 100644
--- a/pkg/sentry/kernel/seccomp.go
+++ b/pkg/sentry/kernel/seccomp.go
@@ -39,11 +39,11 @@ func dataAsBPFInput(t *Task, d *linux.SeccompData) bpf.Input {
}
}
-func seccompSiginfo(t *Task, errno, sysno int32, ip hostarch.Addr) *arch.SignalInfo {
- si := &arch.SignalInfo{
+func seccompSiginfo(t *Task, errno, sysno int32, ip hostarch.Addr) *linux.SignalInfo {
+ si := &linux.SignalInfo{
Signo: int32(linux.SIGSYS),
Errno: errno,
- Code: arch.SYS_SECCOMP,
+ Code: linux.SYS_SECCOMP,
}
si.SetCallAddr(uint64(ip))
si.SetSyscall(sysno)
diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go
index fe2ab1662..47bb66b42 100644
--- a/pkg/sentry/kernel/semaphore/semaphore.go
+++ b/pkg/sentry/kernel/semaphore/semaphore.go
@@ -35,10 +35,10 @@ const (
// Maximum number of semaphore sets.
setsMax = linux.SEMMNI
- // Maximum number of semaphroes in a semaphore set.
+ // Maximum number of semaphores in a semaphore set.
semsMax = linux.SEMMSL
- // Maximum number of semaphores in all semaphroe sets.
+ // Maximum number of semaphores in all semaphore sets.
semsTotalMax = linux.SEMMNS
)
@@ -171,10 +171,10 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu
// Map semaphores and map indexes in a registry are of the same size,
// check map semaphores only here for the system limit.
if len(r.semaphores) >= setsMax {
- return nil, syserror.EINVAL
+ return nil, syserror.ENOSPC
}
if r.totalSems() > int(semsTotalMax-nsems) {
- return nil, syserror.EINVAL
+ return nil, syserror.ENOSPC
}
// Finally create a new set.
@@ -220,7 +220,7 @@ func (r *Registry) HighestIndex() int32 {
defer r.mu.Unlock()
// By default, highest used index is 0 even though
- // there is no semaphroe set.
+ // there is no semaphore set.
var highestIndex int32
for index := range r.indexes {
if index > highestIndex {
@@ -702,7 +702,9 @@ func (s *Set) checkPerms(creds *auth.Credentials, reqPerms fs.PermMask) bool {
return s.checkCapability(creds)
}
-// destroy destroys the set. Caller must hold 's.mu'.
+// destroy destroys the set.
+//
+// Preconditions: Caller must hold 's.mu'.
func (s *Set) destroy() {
// Notify all waiters. They will fail on the next attempt to execute
// operations and return error.
diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go
index 0cd9e2533..ca9076406 100644
--- a/pkg/sentry/kernel/sessions.go
+++ b/pkg/sentry/kernel/sessions.go
@@ -16,7 +16,6 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -233,7 +232,7 @@ func (pg *ProcessGroup) Session() *Session {
// SendSignal sends a signal to all processes inside the process group. It is
// analagous to kernel/signal.c:kill_pgrp.
-func (pg *ProcessGroup) SendSignal(info *arch.SignalInfo) error {
+func (pg *ProcessGroup) SendSignal(info *linux.SignalInfo) error {
tasks := pg.originator.TaskSet()
tasks.mu.RLock()
defer tasks.mu.RUnlock()
diff --git a/pkg/sentry/kernel/signal.go b/pkg/sentry/kernel/signal.go
index 2488ae7d5..e08474d25 100644
--- a/pkg/sentry/kernel/signal.go
+++ b/pkg/sentry/kernel/signal.go
@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/platform"
)
@@ -36,7 +35,7 @@ const SignalPanic = linux.SIGUSR2
// context is used only for debugging to differentiate these cases.
//
// Preconditions: Kernel must have an init process.
-func (k *Kernel) sendExternalSignal(info *arch.SignalInfo, context string) {
+func (k *Kernel) sendExternalSignal(info *linux.SignalInfo, context string) {
switch linux.Signal(info.Signo) {
case linux.SIGURG:
// Sent by the Go 1.14+ runtime for asynchronous goroutine preemption.
@@ -60,18 +59,18 @@ func (k *Kernel) sendExternalSignal(info *arch.SignalInfo, context string) {
}
// SignalInfoPriv returns a SignalInfo equivalent to Linux's SEND_SIG_PRIV.
-func SignalInfoPriv(sig linux.Signal) *arch.SignalInfo {
- return &arch.SignalInfo{
+func SignalInfoPriv(sig linux.Signal) *linux.SignalInfo {
+ return &linux.SignalInfo{
Signo: int32(sig),
- Code: arch.SignalInfoKernel,
+ Code: linux.SI_KERNEL,
}
}
// SignalInfoNoInfo returns a SignalInfo equivalent to Linux's SEND_SIG_NOINFO.
-func SignalInfoNoInfo(sig linux.Signal, sender, receiver *Task) *arch.SignalInfo {
- info := &arch.SignalInfo{
+func SignalInfoNoInfo(sig linux.Signal, sender, receiver *Task) *linux.SignalInfo {
+ info := &linux.SignalInfo{
Signo: int32(sig),
- Code: arch.SignalInfoUser,
+ Code: linux.SI_USER,
}
info.SetPID(int32(receiver.tg.pidns.IDOfThreadGroup(sender.tg)))
info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
diff --git a/pkg/sentry/kernel/signal_handlers.go b/pkg/sentry/kernel/signal_handlers.go
index 768fda220..147cc41bb 100644
--- a/pkg/sentry/kernel/signal_handlers.go
+++ b/pkg/sentry/kernel/signal_handlers.go
@@ -16,7 +16,6 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -30,14 +29,14 @@ type SignalHandlers struct {
mu sync.Mutex `state:"nosave"`
// actions is the action to be taken upon receiving each signal.
- actions map[linux.Signal]arch.SignalAct
+ actions map[linux.Signal]linux.SigAction
}
// NewSignalHandlers returns a new SignalHandlers specifying all default
// actions.
func NewSignalHandlers() *SignalHandlers {
return &SignalHandlers{
- actions: make(map[linux.Signal]arch.SignalAct),
+ actions: make(map[linux.Signal]linux.SigAction),
}
}
@@ -59,9 +58,9 @@ func (sh *SignalHandlers) CopyForExec() *SignalHandlers {
sh.mu.Lock()
defer sh.mu.Unlock()
for sig, act := range sh.actions {
- if act.Handler == arch.SignalActIgnore {
- sh2.actions[sig] = arch.SignalAct{
- Handler: arch.SignalActIgnore,
+ if act.Handler == linux.SIG_IGN {
+ sh2.actions[sig] = linux.SigAction{
+ Handler: linux.SIG_IGN,
}
}
}
@@ -73,15 +72,15 @@ func (sh *SignalHandlers) IsIgnored(sig linux.Signal) bool {
sh.mu.Lock()
defer sh.mu.Unlock()
sa, ok := sh.actions[sig]
- return ok && sa.Handler == arch.SignalActIgnore
+ return ok && sa.Handler == linux.SIG_IGN
}
// dequeueActionLocked returns the SignalAct that should be used to handle sig.
//
// Preconditions: sh.mu must be locked.
-func (sh *SignalHandlers) dequeueAction(sig linux.Signal) arch.SignalAct {
+func (sh *SignalHandlers) dequeueAction(sig linux.Signal) linux.SigAction {
act := sh.actions[sig]
- if act.IsResetHandler() {
+ if act.Flags&linux.SA_RESETHAND != 0 {
delete(sh.actions, sig)
}
return act
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index be1371855..2e3b4488a 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -22,7 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bpf"
"gvisor.dev/gvisor/pkg/hostarch"
- "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -151,7 +150,7 @@ type Task struct {
// which the SA_ONSTACK flag is set.
//
// signalStack is exclusive to the task goroutine.
- signalStack arch.SignalStack
+ signalStack linux.SignalStack
// signalQueue is a set of registered waiters for signal-related events.
//
@@ -395,7 +394,7 @@ type Task struct {
// ptraceSiginfo is analogous to Linux's task_struct::last_siginfo.
//
// ptraceSiginfo is protected by the TaskSet mutex.
- ptraceSiginfo *arch.SignalInfo
+ ptraceSiginfo *linux.SignalInfo
// ptraceEventMsg is the value set by PTRACE_EVENT stops and returned to
// the tracer by ptrace(PTRACE_GETEVENTMSG).
@@ -853,15 +852,13 @@ func (t *Task) SetOOMScoreAdj(adj int32) error {
return nil
}
-// UID returns t's uid.
-// TODO(gvisor.dev/issue/170): This method is not namespaced yet.
-func (t *Task) UID() uint32 {
+// KUID returns t's kuid.
+func (t *Task) KUID() uint32 {
return uint32(t.Credentials().EffectiveKUID)
}
-// GID returns t's gid.
-// TODO(gvisor.dev/issue/170): This method is not namespaced yet.
-func (t *Task) GID() uint32 {
+// KGID returns t's kgid.
+func (t *Task) KGID() uint32 {
return uint32(t.Credentials().EffectiveKGID)
}
diff --git a/pkg/sentry/kernel/task_cgroup.go b/pkg/sentry/kernel/task_cgroup.go
index 25d2504fa..7c138e80f 100644
--- a/pkg/sentry/kernel/task_cgroup.go
+++ b/pkg/sentry/kernel/task_cgroup.go
@@ -85,6 +85,14 @@ func (t *Task) enterCgroupLocked(c Cgroup) {
c.Enter(t)
}
+// +checklocks:t.mu
+func (t *Task) enterCgroupIfNotYetLocked(c Cgroup) {
+ if _, ok := t.cgroups[c]; ok {
+ return
+ }
+ t.enterCgroupLocked(c)
+}
+
// LeaveCgroups removes t out from all its cgroups.
func (t *Task) LeaveCgroups() {
t.mu.Lock()
diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go
index 70b0699dc..c82d9e82b 100644
--- a/pkg/sentry/kernel/task_context.go
+++ b/pkg/sentry/kernel/task_context.go
@@ -17,6 +17,7 @@ package kernel
import (
"time"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -113,6 +114,10 @@ func (t *Task) contextValue(key interface{}, isTaskGoroutine bool) interface{} {
return t.k.RealtimeClock()
case limits.CtxLimits:
return t.tg.limits
+ case linux.CtxSignalNoInfoFunc:
+ return func(sig linux.Signal) error {
+ return t.SendSignal(SignalInfoNoInfo(sig, t, t))
+ }
case pgalloc.CtxMemoryFile:
return t.k.mf
case pgalloc.CtxMemoryFileProvider:
diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go
index d9897e802..cf8571262 100644
--- a/pkg/sentry/kernel/task_exec.go
+++ b/pkg/sentry/kernel/task_exec.go
@@ -66,7 +66,6 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -181,7 +180,7 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState {
t.tg.signalHandlers = t.tg.signalHandlers.CopyForExec()
t.endStopCond.L = &t.tg.signalHandlers.mu
// "Any alternate signal stack is not preserved (sigaltstack(2))." - execve(2)
- t.signalStack = arch.SignalStack{Flags: arch.SignalStackFlagDisable}
+ t.signalStack = linux.SignalStack{Flags: linux.SS_DISABLE}
// "The termination signal is reset to SIGCHLD (see clone(2))."
t.tg.terminationSignal = linux.SIGCHLD
// execed indicates that the process can no longer join a process group
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
index b1af1a7ef..d115b8783 100644
--- a/pkg/sentry/kernel/task_exit.go
+++ b/pkg/sentry/kernel/task_exit.go
@@ -28,9 +28,9 @@ import (
"errors"
"fmt"
"strconv"
+ "strings"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
@@ -50,6 +50,23 @@ type ExitStatus struct {
Signo int
}
+func (es ExitStatus) String() string {
+ var b strings.Builder
+ if code := es.Code; code != 0 {
+ if b.Len() != 0 {
+ b.WriteByte(' ')
+ }
+ _, _ = fmt.Fprintf(&b, "Code=%d", code)
+ }
+ if signal := es.Signo; signal != 0 {
+ if b.Len() != 0 {
+ b.WriteByte(' ')
+ }
+ _, _ = fmt.Fprintf(&b, "Signal=%d", signal)
+ }
+ return b.String()
+}
+
// Signaled returns true if the ExitStatus indicates that the exiting task or
// thread group was killed by a signal.
func (es ExitStatus) Signaled() bool {
@@ -122,12 +139,12 @@ func (t *Task) killLocked() {
if t.stop != nil && t.stop.Killable() {
t.endInternalStopLocked()
}
- t.pendingSignals.enqueue(&arch.SignalInfo{
+ t.pendingSignals.enqueue(&linux.SignalInfo{
Signo: int32(linux.SIGKILL),
// Linux just sets SIGKILL in the pending signal bitmask without
// enqueueing an actual siginfo, such that
// kernel/signal.c:collect_signal() initializes si_code to SI_USER.
- Code: arch.SignalInfoUser,
+ Code: linux.SI_USER,
}, nil)
t.interrupt()
}
@@ -332,7 +349,7 @@ func (t *Task) exitThreadGroup() bool {
// signalStop must be called with t's signal mutex unlocked.
t.tg.signalHandlers.mu.Unlock()
if notifyParent && t.tg.leader.parent != nil {
- t.tg.leader.parent.signalStop(t, arch.CLD_STOPPED, int32(sig))
+ t.tg.leader.parent.signalStop(t, linux.CLD_STOPPED, int32(sig))
t.tg.leader.parent.tg.eventQueue.Notify(EventChildGroupStop)
}
return last
@@ -353,7 +370,7 @@ func (t *Task) exitChildren() {
continue
}
other.signalHandlers.mu.Lock()
- other.leader.sendSignalLocked(&arch.SignalInfo{
+ other.leader.sendSignalLocked(&linux.SignalInfo{
Signo: int32(linux.SIGKILL),
}, true /* group */)
other.signalHandlers.mu.Unlock()
@@ -368,9 +385,9 @@ func (t *Task) exitChildren() {
// wait for a parent to reap them.)
for c := range t.children {
if sig := c.ParentDeathSignal(); sig != 0 {
- siginfo := &arch.SignalInfo{
+ siginfo := &linux.SignalInfo{
Signo: int32(sig),
- Code: arch.SignalInfoUser,
+ Code: linux.SI_USER,
}
siginfo.SetPID(int32(c.tg.pidns.tids[t]))
siginfo.SetUID(int32(t.Credentials().RealKUID.In(c.UserNamespace()).OrOverflow()))
@@ -652,10 +669,10 @@ func (t *Task) exitNotifyLocked(fromPtraceDetach bool) {
t.parent.tg.signalHandlers.mu.Lock()
if t.tg.terminationSignal == linux.SIGCHLD || fromPtraceDetach {
if act, ok := t.parent.tg.signalHandlers.actions[linux.SIGCHLD]; ok {
- if act.Handler == arch.SignalActIgnore {
+ if act.Handler == linux.SIG_IGN {
t.exitParentAcked = true
signalParent = false
- } else if act.Flags&arch.SignalFlagNoCldWait != 0 {
+ } else if act.Flags&linux.SA_NOCLDWAIT != 0 {
t.exitParentAcked = true
}
}
@@ -705,17 +722,17 @@ func (t *Task) exitNotifyLocked(fromPtraceDetach bool) {
}
// Preconditions: The TaskSet mutex must be locked.
-func (t *Task) exitNotificationSignal(sig linux.Signal, receiver *Task) *arch.SignalInfo {
- info := &arch.SignalInfo{
+func (t *Task) exitNotificationSignal(sig linux.Signal, receiver *Task) *linux.SignalInfo {
+ info := &linux.SignalInfo{
Signo: int32(sig),
}
info.SetPID(int32(receiver.tg.pidns.tids[t]))
info.SetUID(int32(t.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
if t.exitStatus.Signaled() {
- info.Code = arch.CLD_KILLED
+ info.Code = linux.CLD_KILLED
info.SetStatus(int32(t.exitStatus.Signo))
} else {
- info.Code = arch.CLD_EXITED
+ info.Code = linux.CLD_EXITED
info.SetStatus(int32(t.exitStatus.Code))
}
// TODO(b/72102453): Set utime, stime.
diff --git a/pkg/sentry/kernel/task_sched.go b/pkg/sentry/kernel/task_sched.go
index 9ba5f8d78..f142feab4 100644
--- a/pkg/sentry/kernel/task_sched.go
+++ b/pkg/sentry/kernel/task_sched.go
@@ -536,7 +536,7 @@ func (tg *ThreadGroup) updateCPUTimersEnabledLocked() {
// appropriate for /proc/[pid]/status.
func (t *Task) StateStatus() string {
switch s := t.TaskGoroutineSchedInfo().State; s {
- case TaskGoroutineNonexistent:
+ case TaskGoroutineNonexistent, TaskGoroutineRunningSys:
t.tg.pidns.owner.mu.RLock()
defer t.tg.pidns.owner.mu.RUnlock()
switch t.exitState {
@@ -546,16 +546,16 @@ func (t *Task) StateStatus() string {
return "X (dead)"
default:
// The task goroutine can't exit before passing through
- // runExitNotify, so this indicates that the task has been created,
- // but the task goroutine hasn't yet started. The Linux equivalent
- // is struct task_struct::state == TASK_NEW
+ // runExitNotify, so if s == TaskGoroutineNonexistent, the task has
+ // been created but the task goroutine hasn't yet started. The
+ // Linux equivalent is struct task_struct::state == TASK_NEW
// (kernel/fork.c:copy_process() =>
// kernel/sched/core.c:sched_fork()), but the TASK_NEW bit is
// masked out by TASK_REPORT for /proc/[pid]/status, leaving only
// TASK_RUNNING.
return "R (running)"
}
- case TaskGoroutineRunningSys, TaskGoroutineRunningApp:
+ case TaskGoroutineRunningApp:
return "R (running)"
case TaskGoroutineBlockedInterruptible:
return "S (sleeping)"
diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go
index c2b9fc08f..8ca61ed48 100644
--- a/pkg/sentry/kernel/task_signals.go
+++ b/pkg/sentry/kernel/task_signals.go
@@ -86,7 +86,7 @@ var defaultActions = map[linux.Signal]SignalAction{
}
// computeAction figures out what to do given a signal number
-// and an arch.SignalAct. SIGSTOP always results in a SignalActionStop,
+// and an linux.SigAction. SIGSTOP always results in a SignalActionStop,
// and SIGKILL always results in a SignalActionTerm.
// Signal 0 is always ignored as many programs use it for various internal functions
// and don't expect it to do anything.
@@ -97,7 +97,7 @@ var defaultActions = map[linux.Signal]SignalAction{
// 0, the default action is taken;
// 1, the signal is ignored;
// anything else, the function returns SignalActionHandler.
-func computeAction(sig linux.Signal, act arch.SignalAct) SignalAction {
+func computeAction(sig linux.Signal, act linux.SigAction) SignalAction {
switch sig {
case linux.SIGSTOP:
return SignalActionStop
@@ -108,9 +108,9 @@ func computeAction(sig linux.Signal, act arch.SignalAct) SignalAction {
}
switch act.Handler {
- case arch.SignalActDefault:
+ case linux.SIG_DFL:
return defaultActions[sig]
- case arch.SignalActIgnore:
+ case linux.SIG_IGN:
return SignalActionIgnore
default:
return SignalActionHandler
@@ -127,7 +127,7 @@ var StopSignals = linux.MakeSignalSet(linux.SIGSTOP, linux.SIGTSTP, linux.SIGTTI
// If there are no pending unmasked signals, dequeueSignalLocked returns nil.
//
// Preconditions: t.tg.signalHandlers.mu must be locked.
-func (t *Task) dequeueSignalLocked(mask linux.SignalSet) *arch.SignalInfo {
+func (t *Task) dequeueSignalLocked(mask linux.SignalSet) *linux.SignalInfo {
if info := t.pendingSignals.dequeue(mask); info != nil {
return info
}
@@ -155,7 +155,7 @@ func (t *Task) PendingSignals() linux.SignalSet {
}
// deliverSignal delivers the given signal and returns the following run state.
-func (t *Task) deliverSignal(info *arch.SignalInfo, act arch.SignalAct) taskRunState {
+func (t *Task) deliverSignal(info *linux.SignalInfo, act linux.SigAction) taskRunState {
sigact := computeAction(linux.Signal(info.Signo), act)
if t.haveSyscallReturn {
@@ -172,7 +172,7 @@ func (t *Task) deliverSignal(info *arch.SignalInfo, act arch.SignalAct) taskRunS
fallthrough
case sre == syserror.ERESTART_RESTARTBLOCK:
fallthrough
- case (sre == syserror.ERESTARTSYS && !act.IsRestart()):
+ case (sre == syserror.ERESTARTSYS && act.Flags&linux.SA_RESTART == 0):
t.Debugf("Not restarting syscall %d after errno %d: interrupted by signal %d", t.Arch().SyscallNo(), sre, info.Signo)
t.Arch().SetReturn(uintptr(-ExtractErrno(syserror.EINTR, -1)))
default:
@@ -236,7 +236,7 @@ func (t *Task) deliverSignal(info *arch.SignalInfo, act arch.SignalAct) taskRunS
// deliverSignalToHandler changes the task's userspace state to enter the given
// user-configured handler for the given signal.
-func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct) error {
+func (t *Task) deliverSignalToHandler(info *linux.SignalInfo, act linux.SigAction) error {
// Signal delivery to an application handler interrupts restartable
// sequences.
t.rseqInterrupt()
@@ -248,8 +248,8 @@ func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct)
// N.B. This is a *copy* of the alternate stack that the user's signal
// handler expects to see in its ucontext (even if it's not in use).
alt := t.signalStack
- if act.IsOnStack() && alt.IsEnabled() {
- alt.SetOnStack()
+ if act.Flags&linux.SA_ONSTACK != 0 && alt.IsEnabled() {
+ alt.Flags |= linux.SS_ONSTACK
if !alt.Contains(sp) {
sp = hostarch.Addr(alt.Top())
}
@@ -289,7 +289,7 @@ func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct)
// Add our signal mask.
newMask := t.signalMask | act.Mask
- if !act.IsNoDefer() {
+ if act.Flags&linux.SA_NODEFER == 0 {
newMask |= linux.SignalSetOf(linux.Signal(info.Signo))
}
t.SetSignalMask(newMask)
@@ -326,7 +326,7 @@ func (t *Task) SignalReturn(rt bool) (*SyscallControl, error) {
// Preconditions:
// * The caller must be running on the task goroutine.
// * t.exitState < TaskExitZombie.
-func (t *Task) Sigtimedwait(set linux.SignalSet, timeout time.Duration) (*arch.SignalInfo, error) {
+func (t *Task) Sigtimedwait(set linux.SignalSet, timeout time.Duration) (*linux.SignalInfo, error) {
// set is the set of signals we're interested in; invert it to get the set
// of signals to block.
mask := ^(set &^ UnblockableSignals)
@@ -373,7 +373,7 @@ func (t *Task) Sigtimedwait(set linux.SignalSet, timeout time.Duration) (*arch.S
// syserror.EINVAL - The signal is not valid.
// syserror.EAGAIN - THe signal is realtime, and cannot be queued.
//
-func (t *Task) SendSignal(info *arch.SignalInfo) error {
+func (t *Task) SendSignal(info *linux.SignalInfo) error {
t.tg.pidns.owner.mu.RLock()
defer t.tg.pidns.owner.mu.RUnlock()
t.tg.signalHandlers.mu.Lock()
@@ -382,7 +382,7 @@ func (t *Task) SendSignal(info *arch.SignalInfo) error {
}
// SendGroupSignal sends the given signal to t's thread group.
-func (t *Task) SendGroupSignal(info *arch.SignalInfo) error {
+func (t *Task) SendGroupSignal(info *linux.SignalInfo) error {
t.tg.pidns.owner.mu.RLock()
defer t.tg.pidns.owner.mu.RUnlock()
t.tg.signalHandlers.mu.Lock()
@@ -392,7 +392,7 @@ func (t *Task) SendGroupSignal(info *arch.SignalInfo) error {
// SendSignal sends the given signal to tg, using tg's leader to determine if
// the signal is blocked.
-func (tg *ThreadGroup) SendSignal(info *arch.SignalInfo) error {
+func (tg *ThreadGroup) SendSignal(info *linux.SignalInfo) error {
tg.pidns.owner.mu.RLock()
defer tg.pidns.owner.mu.RUnlock()
tg.signalHandlers.mu.Lock()
@@ -400,11 +400,11 @@ func (tg *ThreadGroup) SendSignal(info *arch.SignalInfo) error {
return tg.leader.sendSignalLocked(info, true /* group */)
}
-func (t *Task) sendSignalLocked(info *arch.SignalInfo, group bool) error {
+func (t *Task) sendSignalLocked(info *linux.SignalInfo, group bool) error {
return t.sendSignalTimerLocked(info, group, nil)
}
-func (t *Task) sendSignalTimerLocked(info *arch.SignalInfo, group bool, timer *IntervalTimer) error {
+func (t *Task) sendSignalTimerLocked(info *linux.SignalInfo, group bool, timer *IntervalTimer) error {
if t.exitState == TaskExitDead {
return syserror.ESRCH
}
@@ -572,9 +572,9 @@ func (t *Task) forceSignal(sig linux.Signal, unconditional bool) {
func (t *Task) forceSignalLocked(sig linux.Signal, unconditional bool) {
blocked := linux.SignalSetOf(sig)&t.signalMask != 0
act := t.tg.signalHandlers.actions[sig]
- ignored := act.Handler == arch.SignalActIgnore
+ ignored := act.Handler == linux.SIG_IGN
if blocked || ignored || unconditional {
- act.Handler = arch.SignalActDefault
+ act.Handler = linux.SIG_DFL
t.tg.signalHandlers.actions[sig] = act
if blocked {
t.setSignalMaskLocked(t.signalMask &^ linux.SignalSetOf(sig))
@@ -641,17 +641,17 @@ func (t *Task) SetSavedSignalMask(mask linux.SignalSet) {
}
// SignalStack returns the task-private signal stack.
-func (t *Task) SignalStack() arch.SignalStack {
+func (t *Task) SignalStack() linux.SignalStack {
t.p.PullFullState(t.MemoryManager().AddressSpace(), t.Arch())
alt := t.signalStack
if t.onSignalStack(alt) {
- alt.Flags |= arch.SignalStackFlagOnStack
+ alt.Flags |= linux.SS_ONSTACK
}
return alt
}
// onSignalStack returns true if the task is executing on the given signal stack.
-func (t *Task) onSignalStack(alt arch.SignalStack) bool {
+func (t *Task) onSignalStack(alt linux.SignalStack) bool {
sp := hostarch.Addr(t.Arch().Stack())
return alt.Contains(sp)
}
@@ -661,30 +661,30 @@ func (t *Task) onSignalStack(alt arch.SignalStack) bool {
// This value may not be changed if the task is currently executing on the
// signal stack, i.e. if t.onSignalStack returns true. In this case, this
// function will return false. Otherwise, true is returned.
-func (t *Task) SetSignalStack(alt arch.SignalStack) bool {
+func (t *Task) SetSignalStack(alt linux.SignalStack) bool {
// Check that we're not executing on the stack.
if t.onSignalStack(t.signalStack) {
return false
}
- if alt.Flags&arch.SignalStackFlagDisable != 0 {
+ if alt.Flags&linux.SS_DISABLE != 0 {
// Don't record anything beyond the flags.
- t.signalStack = arch.SignalStack{
- Flags: arch.SignalStackFlagDisable,
+ t.signalStack = linux.SignalStack{
+ Flags: linux.SS_DISABLE,
}
} else {
// Mask out irrelevant parts: only disable matters.
- alt.Flags &= arch.SignalStackFlagDisable
+ alt.Flags &= linux.SS_DISABLE
t.signalStack = alt
}
return true
}
-// SetSignalAct atomically sets the thread group's signal action for signal sig
+// SetSigAction atomically sets the thread group's signal action for signal sig
// to *actptr (if actptr is not nil) and returns the old signal action.
-func (tg *ThreadGroup) SetSignalAct(sig linux.Signal, actptr *arch.SignalAct) (arch.SignalAct, error) {
+func (tg *ThreadGroup) SetSigAction(sig linux.Signal, actptr *linux.SigAction) (linux.SigAction, error) {
if !sig.IsValid() {
- return arch.SignalAct{}, syserror.EINVAL
+ return linux.SigAction{}, syserror.EINVAL
}
tg.pidns.owner.mu.RLock()
@@ -718,48 +718,6 @@ func (tg *ThreadGroup) SetSignalAct(sig linux.Signal, actptr *arch.SignalAct) (a
return oldact, nil
}
-// CopyOutSignalAct converts the given SignalAct into an architecture-specific
-// type and then copies it out to task memory.
-func (t *Task) CopyOutSignalAct(addr hostarch.Addr, s *arch.SignalAct) error {
- n := t.Arch().NewSignalAct()
- n.SerializeFrom(s)
- _, err := n.CopyOut(t, addr)
- return err
-}
-
-// CopyInSignalAct copies an architecture-specific sigaction type from task
-// memory and then converts it into a SignalAct.
-func (t *Task) CopyInSignalAct(addr hostarch.Addr) (arch.SignalAct, error) {
- n := t.Arch().NewSignalAct()
- var s arch.SignalAct
- if _, err := n.CopyIn(t, addr); err != nil {
- return s, err
- }
- n.DeserializeTo(&s)
- return s, nil
-}
-
-// CopyOutSignalStack converts the given SignalStack into an
-// architecture-specific type and then copies it out to task memory.
-func (t *Task) CopyOutSignalStack(addr hostarch.Addr, s *arch.SignalStack) error {
- n := t.Arch().NewSignalStack()
- n.SerializeFrom(s)
- _, err := n.CopyOut(t, addr)
- return err
-}
-
-// CopyInSignalStack copies an architecture-specific stack_t from task memory
-// and then converts it into a SignalStack.
-func (t *Task) CopyInSignalStack(addr hostarch.Addr) (arch.SignalStack, error) {
- n := t.Arch().NewSignalStack()
- var s arch.SignalStack
- if _, err := n.CopyIn(t, addr); err != nil {
- return s, err
- }
- n.DeserializeTo(&s)
- return s, nil
-}
-
// groupStop is a TaskStop placed on tasks that have received a stop signal
// (SIGSTOP, SIGTSTP, SIGTTIN, SIGTTOU). (The term "group-stop" originates from
// the ptrace man page.)
@@ -774,7 +732,7 @@ func (*groupStop) Killable() bool { return true }
// previously-dequeued stop signal.
//
// Preconditions: The caller must be running on the task goroutine.
-func (t *Task) initiateGroupStop(info *arch.SignalInfo) {
+func (t *Task) initiateGroupStop(info *linux.SignalInfo) {
t.tg.pidns.owner.mu.RLock()
defer t.tg.pidns.owner.mu.RUnlock()
t.tg.signalHandlers.mu.Lock()
@@ -909,8 +867,8 @@ func (t *Task) signalStop(target *Task, code int32, status int32) {
t.tg.signalHandlers.mu.Lock()
defer t.tg.signalHandlers.mu.Unlock()
act, ok := t.tg.signalHandlers.actions[linux.SIGCHLD]
- if !ok || (act.Handler != arch.SignalActIgnore && act.Flags&arch.SignalFlagNoCldStop == 0) {
- sigchld := &arch.SignalInfo{
+ if !ok || (act.Handler != linux.SIG_IGN && act.Flags&linux.SA_NOCLDSTOP == 0) {
+ sigchld := &linux.SignalInfo{
Signo: int32(linux.SIGCHLD),
Code: code,
}
@@ -955,14 +913,14 @@ func (*runInterrupt) execute(t *Task) taskRunState {
// notified its tracer accordingly. But it's consistent with
// Linux...
if intr {
- tracer.signalStop(t.tg.leader, arch.CLD_STOPPED, int32(sig))
+ tracer.signalStop(t.tg.leader, linux.CLD_STOPPED, int32(sig))
if !notifyParent {
tracer.tg.eventQueue.Notify(EventGroupContinue | EventTraceeStop | EventChildGroupStop)
} else {
tracer.tg.eventQueue.Notify(EventGroupContinue | EventTraceeStop)
}
} else {
- tracer.signalStop(t.tg.leader, arch.CLD_CONTINUED, int32(sig))
+ tracer.signalStop(t.tg.leader, linux.CLD_CONTINUED, int32(sig))
tracer.tg.eventQueue.Notify(EventGroupContinue)
}
}
@@ -974,10 +932,10 @@ func (*runInterrupt) execute(t *Task) taskRunState {
// SIGCHLD is a standard signal, so the latter would always be
// dropped. Hence sending only the former is equivalent.
if intr {
- t.tg.leader.parent.signalStop(t.tg.leader, arch.CLD_STOPPED, int32(sig))
+ t.tg.leader.parent.signalStop(t.tg.leader, linux.CLD_STOPPED, int32(sig))
t.tg.leader.parent.tg.eventQueue.Notify(EventGroupContinue | EventChildGroupStop)
} else {
- t.tg.leader.parent.signalStop(t.tg.leader, arch.CLD_CONTINUED, int32(sig))
+ t.tg.leader.parent.signalStop(t.tg.leader, linux.CLD_CONTINUED, int32(sig))
t.tg.leader.parent.tg.eventQueue.Notify(EventGroupContinue)
}
}
@@ -1018,7 +976,7 @@ func (*runInterrupt) execute(t *Task) taskRunState {
// without requiring an extra PTRACE_GETSIGINFO call." -
// "Group-stop", ptrace(2)
t.ptraceCode = int32(sig) | linux.PTRACE_EVENT_STOP<<8
- t.ptraceSiginfo = &arch.SignalInfo{
+ t.ptraceSiginfo = &linux.SignalInfo{
Signo: int32(sig),
Code: t.ptraceCode,
}
@@ -1029,7 +987,7 @@ func (*runInterrupt) execute(t *Task) taskRunState {
t.ptraceSiginfo = nil
}
if t.beginPtraceStopLocked() {
- tracer.signalStop(t, arch.CLD_STOPPED, int32(sig))
+ tracer.signalStop(t, linux.CLD_STOPPED, int32(sig))
// For consistency with Linux, if the parent and tracer are in the
// same thread group, deduplicate notification signals.
if notifyParent && tracer.tg == t.tg.leader.parent.tg {
@@ -1047,7 +1005,7 @@ func (*runInterrupt) execute(t *Task) taskRunState {
t.tg.signalHandlers.mu.Unlock()
}
if notifyParent {
- t.tg.leader.parent.signalStop(t.tg.leader, arch.CLD_STOPPED, int32(sig))
+ t.tg.leader.parent.signalStop(t.tg.leader, linux.CLD_STOPPED, int32(sig))
t.tg.leader.parent.tg.eventQueue.Notify(EventChildGroupStop)
}
t.tg.pidns.owner.mu.RUnlock()
@@ -1101,7 +1059,7 @@ func (*runInterruptAfterSignalDeliveryStop) execute(t *Task) taskRunState {
if sig != linux.Signal(info.Signo) {
info.Signo = int32(sig)
info.Errno = 0
- info.Code = arch.SignalInfoUser
+ info.Code = linux.SI_USER
// pid isn't a valid field for all signal numbers, but Linux
// doesn't care (kernel/signal.c:ptrace_signal()).
//
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
index 32031cd70..41fd2d471 100644
--- a/pkg/sentry/kernel/task_start.go
+++ b/pkg/sentry/kernel/task_start.go
@@ -18,7 +18,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/hostarch"
- "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
@@ -131,7 +130,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
runState: (*runApp)(nil),
interruptChan: make(chan struct{}, 1),
signalMask: cfg.SignalMask,
- signalStack: arch.SignalStack{Flags: arch.SignalStackFlagDisable},
+ signalStack: linux.SignalStack{Flags: linux.SS_DISABLE},
image: *image,
fsContext: cfg.FSContext,
fdTable: cfg.FDTable,
diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go
index b92e98fa1..891e2201d 100644
--- a/pkg/sentry/kernel/thread_group.go
+++ b/pkg/sentry/kernel/thread_group.go
@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
@@ -446,10 +445,10 @@ func (tg *ThreadGroup) ReleaseControllingTTY(tty *TTY) error {
othertg.signalHandlers.mu.Lock()
othertg.tty = nil
if othertg.processGroup == tg.processGroup.session.foreground {
- if err := othertg.leader.sendSignalLocked(&arch.SignalInfo{Signo: int32(linux.SIGHUP)}, true /* group */); err != nil {
+ if err := othertg.leader.sendSignalLocked(&linux.SignalInfo{Signo: int32(linux.SIGHUP)}, true /* group */); err != nil {
lastErr = err
}
- if err := othertg.leader.sendSignalLocked(&arch.SignalInfo{Signo: int32(linux.SIGCONT)}, true /* group */); err != nil {
+ if err := othertg.leader.sendSignalLocked(&linux.SignalInfo{Signo: int32(linux.SIGCONT)}, true /* group */); err != nil {
lastErr = err
}
}
@@ -490,10 +489,10 @@ func (tg *ThreadGroup) SetForegroundProcessGroup(tty *TTY, pgid ProcessGroupID)
tg.signalHandlers.mu.Lock()
defer tg.signalHandlers.mu.Unlock()
- // TODO(b/129283598): "If tcsetpgrp() is called by a member of a
- // background process group in its session, and the calling process is
- // not blocking or ignoring SIGTTOU, a SIGTTOU signal is sent to all
- // members of this background process group."
+ // TODO(gvisor.dev/issue/6148): "If tcsetpgrp() is called by a member of a
+ // background process group in its session, and the calling process is not
+ // blocking or ignoring SIGTTOU, a SIGTTOU signal is sent to all members of
+ // this background process group."
// tty must be the controlling terminal.
if tg.tty != tty {
diff --git a/pkg/sentry/kernel/time/time.go b/pkg/sentry/kernel/time/time.go
index f61a8e164..26aa34aa6 100644
--- a/pkg/sentry/kernel/time/time.go
+++ b/pkg/sentry/kernel/time/time.go
@@ -458,25 +458,6 @@ func NewTimer(clock Clock, listener TimerListener) *Timer {
return t
}
-// After waits for the duration to elapse according to clock and then sends a
-// notification on the returned channel. The timer is started immediately and
-// will fire exactly once. The second return value is the start time used with
-// the duration.
-//
-// Callers must call Timer.Destroy.
-func After(clock Clock, duration time.Duration) (*Timer, Time, <-chan struct{}) {
- notifier, tchan := NewChannelNotifier()
- t := NewTimer(clock, notifier)
- now := clock.Now()
-
- t.Swap(Setting{
- Enabled: true,
- Period: 0,
- Next: now.Add(duration),
- })
- return t, now, tchan
-}
-
// init initializes Timer state that is not preserved across save/restore. If
// init has already been called, calling it again is a no-op.
//
diff --git a/pkg/sentry/loader/interpreter.go b/pkg/sentry/loader/interpreter.go
index 3886b4d33..3e302d92c 100644
--- a/pkg/sentry/loader/interpreter.go
+++ b/pkg/sentry/loader/interpreter.go
@@ -59,7 +59,7 @@ func parseInterpreterScript(ctx context.Context, filename string, f fsbridge.Fil
// Linux silently truncates the remainder of the line if it exceeds
// interpMaxLineLength.
i := bytes.IndexByte(line, '\n')
- if i > 0 {
+ if i >= 0 {
line = line[:i]
}
diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go
index b81292c46..d1a883da4 100644
--- a/pkg/sentry/pgalloc/pgalloc.go
+++ b/pkg/sentry/pgalloc/pgalloc.go
@@ -1062,10 +1062,20 @@ func (f *MemoryFile) runReclaim() {
break
}
- // If ManualZeroing is in effect, pages will be zeroed on allocation
- // and may not be freed by decommitFile, so calling decommitFile is
- // unnecessary.
- if !f.opts.ManualZeroing {
+ if f.opts.ManualZeroing {
+ // If ManualZeroing is in effect, only hugepage-aligned regions may
+ // be safely passed to decommitFile. Pages will be zeroed on
+ // reallocation, so we don't need to perform any manual zeroing
+ // here, whether or not decommitFile succeeds.
+ if startAddr, ok := hostarch.Addr(fr.Start).HugeRoundUp(); ok {
+ if endAddr := hostarch.Addr(fr.End).HugeRoundDown(); startAddr < endAddr {
+ decommitFR := memmap.FileRange{uint64(startAddr), uint64(endAddr)}
+ if err := f.decommitFile(decommitFR); err != nil {
+ log.Warningf("Reclaim failed to decommit %v: %v", decommitFR, err)
+ }
+ }
+ }
+ } else {
if err := f.decommitFile(fr); err != nil {
log.Warningf("Reclaim failed to decommit %v: %v", fr, err)
// Zero the pages manually. This won't reduce memory usage, but at
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index b307832fd..8a490b3de 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -6,6 +6,8 @@ go_library(
name = "kvm",
srcs = [
"address_space.go",
+ "address_space_amd64.go",
+ "address_space_arm64.go",
"bluepill.go",
"bluepill_allocator.go",
"bluepill_amd64.go",
@@ -77,6 +79,7 @@ go_test(
"requires-kvm",
],
deps = [
+ "//pkg/abi/linux",
"//pkg/hostarch",
"//pkg/ring0",
"//pkg/ring0/pagetables",
diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go
index 5524e8727..9929caebb 100644
--- a/pkg/sentry/platform/kvm/address_space.go
+++ b/pkg/sentry/platform/kvm/address_space.go
@@ -85,15 +85,6 @@ type addressSpace struct {
dirtySet *dirtySet
}
-// invalidate is the implementation for Invalidate.
-func (as *addressSpace) invalidate() {
- as.dirtySet.forEach(as.machine, func(c *vCPU) {
- if c.active.get() == as { // If this happens to be active,
- c.BounceToKernel() // ... force a kernel transition.
- }
- })
-}
-
// Invalidate interrupts all dirty contexts.
func (as *addressSpace) Invalidate() {
as.mu.Lock()
diff --git a/pkg/tcpip/transport/tcp/rcv_state.go b/pkg/sentry/platform/kvm/address_space_amd64.go
index 2bf21a2e7..d11d38679 100644
--- a/pkg/tcpip/transport/tcp/rcv_state.go
+++ b/pkg/sentry/platform/kvm/address_space_amd64.go
@@ -1,4 +1,4 @@
-// Copyright 2019 The gVisor Authors.
+// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,18 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package tcp
+package kvm
-import (
- "time"
-)
-
-// saveLastRcvdAckTime is invoked by stateify.
-func (r *receiver) saveLastRcvdAckTime() unixTime {
- return unixTime{r.lastRcvdAckTime.Unix(), r.lastRcvdAckTime.UnixNano()}
-}
-
-// loadLastRcvdAckTime is invoked by stateify.
-func (r *receiver) loadLastRcvdAckTime(unix unixTime) {
- r.lastRcvdAckTime = time.Unix(unix.second, unix.nano)
+// invalidate is the implementation for Invalidate.
+func (as *addressSpace) invalidate() {
+ as.dirtySet.forEach(as.machine, func(c *vCPU) {
+ if c.active.get() == as { // If this happens to be active,
+ c.BounceToKernel() // ... force a kernel transition.
+ }
+ })
}
diff --git a/pkg/tcpip/time.s b/pkg/sentry/platform/kvm/address_space_arm64.go
index fb37360ac..fb954418b 100644
--- a/pkg/tcpip/time.s
+++ b/pkg/sentry/platform/kvm/address_space_arm64.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,4 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Empty assembly file so empty func definitions work.
+package kvm
+
+import (
+ "gvisor.dev/gvisor/pkg/ring0"
+)
+
+// invalidate is the implementation for Invalidate.
+func (as *addressSpace) invalidate() {
+ bluepill(as.pageTables.Allocator.(*allocator).cpu)
+ ring0.FlushTlbAll()
+}
diff --git a/pkg/sentry/platform/kvm/context.go b/pkg/sentry/platform/kvm/context.go
index f4d4473a8..183e741ea 100644
--- a/pkg/sentry/platform/kvm/context.go
+++ b/pkg/sentry/platform/kvm/context.go
@@ -17,6 +17,7 @@ package kvm
import (
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/abi/linux"
pkgcontext "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0"
@@ -32,15 +33,15 @@ type context struct {
// machine is the parent machine, and is immutable.
machine *machine
- // info is the arch.SignalInfo cached for this context.
- info arch.SignalInfo
+ // info is the linux.SignalInfo cached for this context.
+ info linux.SignalInfo
// interrupt is the interrupt context.
interrupt interrupt.Forwarder
}
// Switch runs the provided context in the given address space.
-func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, _ int32) (*arch.SignalInfo, hostarch.AccessType, error) {
+func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, _ int32) (*linux.SignalInfo, hostarch.AccessType, error) {
as := mm.AddressSpace()
localAS := as.(*addressSpace)
diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.go b/pkg/sentry/platform/kvm/kvm_amd64_test.go
index b8dd1e4a5..b1cab89a0 100644
--- a/pkg/sentry/platform/kvm/kvm_amd64_test.go
+++ b/pkg/sentry/platform/kvm/kvm_amd64_test.go
@@ -19,6 +19,7 @@ package kvm
import (
"testing"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -30,7 +31,7 @@ func TestSegments(t *testing.T) {
applicationTest(t, true, testutil.TwiddleSegments, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
testutil.SetTestSegments(regs)
for {
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -55,7 +56,7 @@ func stmxcsr(addr *uint32)
func TestMXCSR(t *testing.T) {
applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
- var si arch.SignalInfo
+ var si linux.SignalInfo
switchOpts := ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go
index ceff09a60..fe570aff9 100644
--- a/pkg/sentry/platform/kvm/kvm_test.go
+++ b/pkg/sentry/platform/kvm/kvm_test.go
@@ -22,6 +22,7 @@ import (
"time"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
@@ -157,7 +158,7 @@ func applicationTest(t testHarness, useHostMappings bool, target func(), fn func
func TestApplicationSyscall(t *testing.T) {
applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -171,7 +172,7 @@ func TestApplicationSyscall(t *testing.T) {
return false
})
applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -188,7 +189,7 @@ func TestApplicationSyscall(t *testing.T) {
func TestApplicationFault(t *testing.T) {
applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
testutil.SetTouchTarget(regs, nil) // Cause fault.
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -203,7 +204,7 @@ func TestApplicationFault(t *testing.T) {
})
applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
testutil.SetTouchTarget(regs, nil) // Cause fault.
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -221,7 +222,7 @@ func TestRegistersSyscall(t *testing.T) {
applicationTest(t, true, testutil.TwiddleRegsSyscall, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
testutil.SetTestRegs(regs) // Fill values for all registers.
for {
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -244,7 +245,7 @@ func TestRegistersFault(t *testing.T) {
applicationTest(t, true, testutil.TwiddleRegsFault, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
testutil.SetTestRegs(regs) // Fill values for all registers.
for {
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -270,7 +271,7 @@ func TestBounce(t *testing.T) {
time.Sleep(time.Millisecond)
c.BounceToKernel()
}()
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -285,7 +286,7 @@ func TestBounce(t *testing.T) {
time.Sleep(time.Millisecond)
c.BounceToKernel()
}()
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -317,7 +318,7 @@ func TestBounceStress(t *testing.T) {
c.BounceToKernel()
}()
randomSleep()
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -338,7 +339,7 @@ func TestInvalidate(t *testing.T) {
applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
testutil.SetTouchTarget(regs, &data) // Read legitimate value.
for {
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -353,7 +354,7 @@ func TestInvalidate(t *testing.T) {
// Unmap the page containing data & invalidate.
pt.Unmap(hostarch.Addr(reflect.ValueOf(&data).Pointer() & ^uintptr(hostarch.PageSize-1)), hostarch.PageSize)
for {
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -371,13 +372,13 @@ func TestInvalidate(t *testing.T) {
}
// IsFault returns true iff the given signal represents a fault.
-func IsFault(err error, si *arch.SignalInfo) bool {
+func IsFault(err error, si *linux.SignalInfo) bool {
return err == platform.ErrContextSignal && si.Signo == int32(unix.SIGSEGV)
}
func TestEmptyAddressSpace(t *testing.T) {
applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -391,7 +392,7 @@ func TestEmptyAddressSpace(t *testing.T) {
return false
})
applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -467,7 +468,7 @@ func BenchmarkApplicationSyscall(b *testing.B) {
a int // Count for ErrContextInterrupt.
)
applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -504,7 +505,7 @@ func BenchmarkWorldSwitchToUserRoundtrip(b *testing.B) {
a int
)
applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index 9a2337654..7c063c7f5 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -23,11 +23,11 @@ import (
"runtime/debug"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
- "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
ktime "gvisor.dev/gvisor/pkg/sentry/time"
@@ -264,10 +264,10 @@ func (c *vCPU) setSystemTime() error {
// nonCanonical generates a canonical address return.
//
//go:nosplit
-func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (hostarch.AccessType, error) {
- *info = arch.SignalInfo{
+func nonCanonical(addr uint64, signal int32, info *linux.SignalInfo) (hostarch.AccessType, error) {
+ *info = linux.SignalInfo{
Signo: signal,
- Code: arch.SignalInfoKernel,
+ Code: linux.SI_KERNEL,
}
info.SetAddr(addr) // Include address.
return hostarch.NoAccess, platform.ErrContextSignal
@@ -276,7 +276,7 @@ func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (hostarch.Ac
// fault generates an appropriate fault return.
//
//go:nosplit
-func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (hostarch.AccessType, error) {
+func (c *vCPU) fault(signal int32, info *linux.SignalInfo) (hostarch.AccessType, error) {
bluepill(c) // Probably no-op, but may not be.
faultAddr := ring0.ReadCR2()
code, user := c.ErrorCode()
@@ -287,7 +287,7 @@ func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (hostarch.AccessType,
return hostarch.NoAccess, platform.ErrContextInterrupt
}
// Reset the pointed SignalInfo.
- *info = arch.SignalInfo{Signo: signal}
+ *info = linux.SignalInfo{Signo: signal}
info.SetAddr(uint64(faultAddr))
accessType := hostarch.AccessType{
Read: code&(1<<1) == 0,
@@ -325,7 +325,7 @@ func prefaultFloatingPointState(data *fpu.State) {
}
// SwitchToUser unpacks architectural-details.
-func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (hostarch.AccessType, error) {
+func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo) (hostarch.AccessType, error) {
// Check for canonical addresses.
if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Rip) {
return nonCanonical(regs.Rip, int32(unix.SIGSEGV), info)
@@ -371,7 +371,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
return c.fault(int32(unix.SIGSEGV), info)
case ring0.Debug, ring0.Breakpoint:
- *info = arch.SignalInfo{
+ *info = linux.SignalInfo{
Signo: int32(unix.SIGTRAP),
Code: 1, // TRAP_BRKPT (breakpoint).
}
@@ -383,9 +383,9 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
ring0.BoundRangeExceeded,
ring0.InvalidTSS,
ring0.StackSegmentFault:
- *info = arch.SignalInfo{
+ *info = linux.SignalInfo{
Signo: int32(unix.SIGSEGV),
- Code: arch.SignalInfoKernel,
+ Code: linux.SI_KERNEL,
}
info.SetAddr(switchOpts.Registers.Rip) // Include address.
if vector == ring0.GeneralProtectionFault {
@@ -397,7 +397,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
return hostarch.AccessType{}, platform.ErrContextSignal
case ring0.InvalidOpcode:
- *info = arch.SignalInfo{
+ *info = linux.SignalInfo{
Signo: int32(unix.SIGILL),
Code: 1, // ILL_ILLOPC (illegal opcode).
}
@@ -405,7 +405,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
return hostarch.AccessType{}, platform.ErrContextSignal
case ring0.DivideByZero:
- *info = arch.SignalInfo{
+ *info = linux.SignalInfo{
Signo: int32(unix.SIGFPE),
Code: 1, // FPE_INTDIV (divide by zero).
}
@@ -413,7 +413,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
return hostarch.AccessType{}, platform.ErrContextSignal
case ring0.Overflow:
- *info = arch.SignalInfo{
+ *info = linux.SignalInfo{
Signo: int32(unix.SIGFPE),
Code: 2, // FPE_INTOVF (integer overflow).
}
@@ -422,7 +422,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
case ring0.X87FloatingPointException,
ring0.SIMDFloatingPointException:
- *info = arch.SignalInfo{
+ *info = linux.SignalInfo{
Signo: int32(unix.SIGFPE),
Code: 7, // FPE_FLTINV (invalid operation).
}
@@ -433,7 +433,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
return hostarch.NoAccess, platform.ErrContextInterrupt
case ring0.AlignmentCheck:
- *info = arch.SignalInfo{
+ *info = linux.SignalInfo{
Signo: int32(unix.SIGBUS),
Code: 2, // BUS_ADRERR (physical address does not exist).
}
@@ -469,7 +469,7 @@ func availableRegionsForSetMem() (phyRegions []physicalRegion) {
}
func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) {
- // Map all the executible regions so that all the entry functions
+ // Map all the executable regions so that all the entry functions
// are mapped in the upper half.
applyVirtualRegions(func(vr virtualRegion) {
if excludeVirtualRegion(vr) || vr.filename == "[vsyscall]" {
diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go
index 8926b1d9f..edaccf9bc 100644
--- a/pkg/sentry/platform/kvm/machine_arm64.go
+++ b/pkg/sentry/platform/kvm/machine_arm64.go
@@ -21,10 +21,10 @@ import (
"sync/atomic"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
- "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
)
@@ -126,10 +126,10 @@ func availableRegionsForSetMem() (phyRegions []physicalRegion) {
// nonCanonical generates a canonical address return.
//
//go:nosplit
-func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (hostarch.AccessType, error) {
- *info = arch.SignalInfo{
+func nonCanonical(addr uint64, signal int32, info *linux.SignalInfo) (hostarch.AccessType, error) {
+ *info = linux.SignalInfo{
Signo: signal,
- Code: arch.SignalInfoKernel,
+ Code: linux.SI_KERNEL,
}
info.SetAddr(addr) // Include address.
return hostarch.NoAccess, platform.ErrContextSignal
@@ -157,7 +157,7 @@ func isWriteFault(code uint64) bool {
// fault generates an appropriate fault return.
//
//go:nosplit
-func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (hostarch.AccessType, error) {
+func (c *vCPU) fault(signal int32, info *linux.SignalInfo) (hostarch.AccessType, error) {
bluepill(c) // Probably no-op, but may not be.
faultAddr := c.GetFaultAddr()
code, user := c.ErrorCode()
@@ -170,7 +170,7 @@ func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (hostarch.AccessType,
}
// Reset the pointed SignalInfo.
- *info = arch.SignalInfo{Signo: signal}
+ *info = linux.SignalInfo{Signo: signal}
info.SetAddr(uint64(faultAddr))
ret := code & _ESR_ELx_FSC
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index 92edc992b..1b0a6e0a7 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -23,10 +23,10 @@ import (
"unsafe"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
- "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
ktime "gvisor.dev/gvisor/pkg/sentry/time"
@@ -272,7 +272,7 @@ func (c *vCPU) getOneRegister(reg *kvmOneReg) error {
}
// SwitchToUser unpacks architectural-details.
-func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (hostarch.AccessType, error) {
+func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo) (hostarch.AccessType, error) {
// Check for canonical addresses.
if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Pc) {
return nonCanonical(regs.Pc, int32(unix.SIGSEGV), info)
@@ -319,14 +319,14 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
case ring0.El0SyncUndef:
return c.fault(int32(unix.SIGILL), info)
case ring0.El0SyncDbg:
- *info = arch.SignalInfo{
+ *info = linux.SignalInfo{
Signo: int32(unix.SIGTRAP),
Code: 1, // TRAP_BRKPT (breakpoint).
}
info.SetAddr(switchOpts.Registers.Pc) // Include address.
return hostarch.AccessType{}, platform.ErrContextSignal
case ring0.El0SyncSpPc:
- *info = arch.SignalInfo{
+ *info = linux.SignalInfo{
Signo: int32(unix.SIGBUS),
Code: 2, // BUS_ADRERR (physical address does not exist).
}
diff --git a/pkg/sentry/platform/platform.go b/pkg/sentry/platform/platform.go
index ef7814a6f..a26bc2316 100644
--- a/pkg/sentry/platform/platform.go
+++ b/pkg/sentry/platform/platform.go
@@ -195,8 +195,8 @@ type Context interface {
// - nil: The Context invoked a system call.
//
// - ErrContextSignal: The Context was interrupted by a signal. The
- // returned *arch.SignalInfo contains information about the signal. If
- // arch.SignalInfo.Signo == SIGSEGV, the returned hostarch.AccessType
+ // returned *linux.SignalInfo contains information about the signal. If
+ // linux.SignalInfo.Signo == SIGSEGV, the returned hostarch.AccessType
// contains the access type of the triggering fault. The caller owns
// the returned SignalInfo.
//
@@ -207,7 +207,7 @@ type Context interface {
// concurrent call to Switch().
//
// - ErrContextCPUPreempted: See the definition of that error for details.
- Switch(ctx context.Context, mm MemoryManager, ac arch.Context, cpu int32) (*arch.SignalInfo, hostarch.AccessType, error)
+ Switch(ctx context.Context, mm MemoryManager, ac arch.Context, cpu int32) (*linux.SignalInfo, hostarch.AccessType, error)
// PullFullState() pulls a full state of the application thread.
//
diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go
index 828458ce2..319b0cf1d 100644
--- a/pkg/sentry/platform/ptrace/ptrace.go
+++ b/pkg/sentry/platform/ptrace/ptrace.go
@@ -73,7 +73,7 @@ var (
type context struct {
// signalInfo is the signal info, if and when a signal is received.
- signalInfo arch.SignalInfo
+ signalInfo linux.SignalInfo
// interrupt is the interrupt context.
interrupt interrupt.Forwarder
@@ -96,7 +96,7 @@ type context struct {
}
// Switch runs the provided context in the given address space.
-func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, cpu int32) (*arch.SignalInfo, hostarch.AccessType, error) {
+func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, cpu int32) (*linux.SignalInfo, hostarch.AccessType, error) {
as := mm.AddressSpace()
s := as.(*subprocess)
isSyscall := s.switchToApp(c, ac)
diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
index facb96011..cc93396a9 100644
--- a/pkg/sentry/platform/ptrace/ptrace_unsafe.go
+++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
@@ -101,7 +101,7 @@ func (t *thread) setFPRegs(fpState *fpu.State, fpLen uint64, useXsave bool) erro
}
// getSignalInfo retrieves information about the signal that caused the stop.
-func (t *thread) getSignalInfo(si *arch.SignalInfo) error {
+func (t *thread) getSignalInfo(si *linux.SignalInfo) error {
_, _, errno := unix.RawSyscall6(
unix.SYS_PTRACE,
unix.PTRACE_GETSIGINFO,
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index 9c73a725a..0931795c5 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -20,6 +20,7 @@ import (
"runtime"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/procid"
@@ -524,7 +525,7 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
// Check for interrupts, and ensure that future interrupts will signal t.
if !c.interrupt.Enable(t) {
// Pending interrupt; simulate.
- c.signalInfo = arch.SignalInfo{Signo: int32(platform.SignalInterrupt)}
+ c.signalInfo = linux.SignalInfo{Signo: int32(platform.SignalInterrupt)}
return false
}
defer c.interrupt.Disable()
diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go
index 9252c0bd7..90b1ead56 100644
--- a/pkg/sentry/platform/ptrace/subprocess_amd64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go
@@ -155,7 +155,7 @@ func initChildProcessPPID(initregs *arch.Registers, ppid int32) {
//
// Note that this should only be called after verifying that the signalInfo has
// been generated by the kernel.
-func patchSignalInfo(regs *arch.Registers, signalInfo *arch.SignalInfo) {
+func patchSignalInfo(regs *arch.Registers, signalInfo *linux.SignalInfo) {
if linux.Signal(signalInfo.Signo) == linux.SIGSYS {
signalInfo.Signo = int32(linux.SIGSEGV)
diff --git a/pkg/sentry/platform/ptrace/subprocess_arm64.go b/pkg/sentry/platform/ptrace/subprocess_arm64.go
index c0cbc0686..e4257e3bf 100644
--- a/pkg/sentry/platform/ptrace/subprocess_arm64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_arm64.go
@@ -138,7 +138,7 @@ func initChildProcessPPID(initregs *arch.Registers, ppid int32) {
//
// Note that this should only be called after verifying that the signalInfo has
// been generated by the kernel.
-func patchSignalInfo(regs *arch.Registers, signalInfo *arch.SignalInfo) {
+func patchSignalInfo(regs *arch.Registers, signalInfo *linux.SignalInfo) {
if linux.Signal(signalInfo.Signo) == linux.SIGSYS {
signalInfo.Signo = int32(linux.SIGSEGV)
diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sentry/sighandling/sighandling_unsafe.go
index d6a2fbe34..3fe5c6770 100644
--- a/pkg/sentry/sighandling/sighandling_unsafe.go
+++ b/pkg/sentry/sighandling/sighandling_unsafe.go
@@ -21,25 +21,16 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
)
-// FIXME(gvisor.dev/issue/214): Move to pkg/abi/linux along with definitions in
-// pkg/sentry/arch.
-type sigaction struct {
- handler uintptr
- flags uint64
- restorer uintptr
- mask uint64
-}
-
// IgnoreChildStop sets the SA_NOCLDSTOP flag, causing child processes to not
// generate SIGCHLD when they stop.
func IgnoreChildStop() error {
- var sa sigaction
+ var sa linux.SigAction
// Get the existing signal handler information, and set the flag.
if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(unix.SIGCHLD), 0, uintptr(unsafe.Pointer(&sa)), linux.SignalSetSize, 0, 0); e != 0 {
return e
}
- sa.flags |= linux.SA_NOCLDSTOP
+ sa.Flags |= linux.SA_NOCLDSTOP
if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(unix.SIGCHLD), uintptr(unsafe.Pointer(&sa)), 0, linux.SignalSetSize, 0, 0); e != 0 {
return e
}
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
index 2e3064565..3c6511ead 100644
--- a/pkg/sentry/socket/hostinet/BUILD
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -39,8 +39,6 @@ go_library(
"//pkg/syserr",
"//pkg/syserror",
"//pkg/tcpip",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/usermem",
"//pkg/waiter",
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
index 393a1ab3a..cbb1e905d 100644
--- a/pkg/sentry/socket/hostinet/stack.go
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -35,8 +35,6 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -66,8 +64,6 @@ type Stack struct {
tcpSACKEnabled bool
netDevFile *os.File
netSNMPFile *os.File
- ipv4Forwarding bool
- ipv6Forwarding bool
}
// NewStack returns an empty Stack containing no configuration.
@@ -127,13 +123,6 @@ func (s *Stack) Configure() error {
s.netSNMPFile = f
}
- s.ipv6Forwarding = false
- if ipForwarding, err := ioutil.ReadFile("/proc/sys/net/ipv6/conf/all/forwarding"); err == nil {
- s.ipv6Forwarding = strings.TrimSpace(string(ipForwarding)) != "0"
- } else {
- log.Warningf("Failed to read if ipv6 forwarding is enabled, setting to false")
- }
-
return nil
}
@@ -492,19 +481,6 @@ func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil }
// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints.
func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {}
-// Forwarding implements inet.Stack.Forwarding.
-func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool {
- switch protocol {
- case ipv4.ProtocolNumber:
- return s.ipv4Forwarding
- case ipv6.ProtocolNumber:
- return s.ipv6Forwarding
- default:
- log.Warningf("Forwarding(%v) failed: unsupported protocol", protocol)
- return false
- }
-}
-
// SetForwarding implements inet.Stack.SetForwarding.
func (s *Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error {
return syserror.EACCES
diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD
index 61b2c9755..608474fa1 100644
--- a/pkg/sentry/socket/netfilter/BUILD
+++ b/pkg/sentry/socket/netfilter/BUILD
@@ -25,6 +25,7 @@ go_library(
"//pkg/log",
"//pkg/marshal",
"//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
"//pkg/syserr",
"//pkg/tcpip",
"//pkg/tcpip/header",
diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go
index 6fc7781ad..3f1b4a17b 100644
--- a/pkg/sentry/socket/netfilter/extensions.go
+++ b/pkg/sentry/socket/netfilter/extensions.go
@@ -19,20 +19,12 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bits"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
-// TODO(gvisor.dev/issue/170): The following per-matcher params should be
-// supported:
-// - Table name
-// - Match size
-// - User size
-// - Hooks
-// - Proto
-// - Family
-
// matchMaker knows how to (un)marshal the matcher named name().
type matchMaker interface {
// name is the matcher name as stored in the xt_entry_match struct.
@@ -43,7 +35,7 @@ type matchMaker interface {
// unmarshal converts from the ABI matcher struct to an
// stack.Matcher.
- unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error)
+ unmarshal(task *kernel.Task, buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error)
}
type matcher interface {
@@ -94,12 +86,12 @@ func marshalEntryMatch(name string, data []byte) []byte {
return buf
}
-func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf []byte) (stack.Matcher, error) {
+func unmarshalMatcher(task *kernel.Task, match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf []byte) (stack.Matcher, error) {
matchMaker, ok := matchMakers[match.Name.String()]
if !ok {
return nil, fmt.Errorf("unsupported matcher with name %q", match.Name.String())
}
- return matchMaker.unmarshal(buf, filter)
+ return matchMaker.unmarshal(task, buf, filter)
}
// targetMaker knows how to (un)marshal a target. Once registered,
diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go
index cb78ef60b..d8bd86292 100644
--- a/pkg/sentry/socket/netfilter/ipv4.go
+++ b/pkg/sentry/socket/netfilter/ipv4.go
@@ -18,6 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -123,7 +124,7 @@ func getEntries4(table stack.Table, tablename linux.TableName) (linux.KernelIPTG
return entries, info
}
-func modifyEntries4(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, table *stack.Table) (map[uint32]int, *syserr.Error) {
+func modifyEntries4(task *kernel.Task, stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, table *stack.Table) (map[uint32]int, *syserr.Error) {
nflog("set entries: setting entries in table %q", replace.Name.String())
// Convert input into a list of rules and their offsets.
@@ -148,23 +149,19 @@ func modifyEntries4(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace,
return nil, syserr.ErrInvalidArgument
}
- // TODO(gvisor.dev/issue/170): We should support more IPTIP
- // filtering fields.
filter, err := filterFromIPTIP(entry.IP)
if err != nil {
nflog("bad iptip: %v", err)
return nil, syserr.ErrInvalidArgument
}
- // TODO(gvisor.dev/issue/170): Matchers and targets can specify
- // that they only work for certain protocols, hooks, tables.
// Get matchers.
matchersSize := entry.TargetOffset - linux.SizeOfIPTEntry
if len(optVal) < int(matchersSize) {
nflog("entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal))
return nil, syserr.ErrInvalidArgument
}
- matchers, err := parseMatchers(filter, optVal[:matchersSize])
+ matchers, err := parseMatchers(task, filter, optVal[:matchersSize])
if err != nil {
nflog("failed to parse matchers: %v", err)
return nil, syserr.ErrInvalidArgument
diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go
index 5cb7fe4aa..c68230847 100644
--- a/pkg/sentry/socket/netfilter/ipv6.go
+++ b/pkg/sentry/socket/netfilter/ipv6.go
@@ -18,6 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -126,7 +127,7 @@ func getEntries6(table stack.Table, tablename linux.TableName) (linux.KernelIP6T
return entries, info
}
-func modifyEntries6(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, table *stack.Table) (map[uint32]int, *syserr.Error) {
+func modifyEntries6(task *kernel.Task, stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, table *stack.Table) (map[uint32]int, *syserr.Error) {
nflog("set entries: setting entries in table %q", replace.Name.String())
// Convert input into a list of rules and their offsets.
@@ -151,23 +152,19 @@ func modifyEntries6(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace,
return nil, syserr.ErrInvalidArgument
}
- // TODO(gvisor.dev/issue/170): We should support more IPTIP
- // filtering fields.
filter, err := filterFromIP6TIP(entry.IPv6)
if err != nil {
nflog("bad iptip: %v", err)
return nil, syserr.ErrInvalidArgument
}
- // TODO(gvisor.dev/issue/170): Matchers and targets can specify
- // that they only work for certain protocols, hooks, tables.
// Get matchers.
matchersSize := entry.TargetOffset - linux.SizeOfIP6TEntry
if len(optVal) < int(matchersSize) {
nflog("entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal))
return nil, syserr.ErrInvalidArgument
}
- matchers, err := parseMatchers(filter, optVal[:matchersSize])
+ matchers, err := parseMatchers(task, filter, optVal[:matchersSize])
if err != nil {
nflog("failed to parse matchers: %v", err)
return nil, syserr.ErrInvalidArgument
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index f42d73178..e3eade180 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -58,8 +58,8 @@ var nameToID = map[string]stack.TableID{
// DefaultLinuxTables returns the rules of stack.DefaultTables() wrapped for
// compatibility with netfilter extensions.
-func DefaultLinuxTables() *stack.IPTables {
- tables := stack.DefaultTables()
+func DefaultLinuxTables(seed uint32) *stack.IPTables {
+ tables := stack.DefaultTables(seed)
tables.VisitTargets(func(oldTarget stack.Target) stack.Target {
switch val := oldTarget.(type) {
case *stack.AcceptTarget:
@@ -174,13 +174,12 @@ func setHooksAndUnderflow(info *linux.IPTGetinfo, table stack.Table, offset uint
// SetEntries sets iptables rules for a single table. See
// net/ipv4/netfilter/ip_tables.c:translate_table for reference.
-func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
+func SetEntries(task *kernel.Task, stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
var replace linux.IPTReplace
replaceBuf := optVal[:linux.SizeOfIPTReplace]
optVal = optVal[linux.SizeOfIPTReplace:]
replace.UnmarshalBytes(replaceBuf)
- // TODO(gvisor.dev/issue/170): Support other tables.
var table stack.Table
switch replace.Name.String() {
case filterTable:
@@ -188,16 +187,16 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
case natTable:
table = stack.EmptyNATTable()
default:
- nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
+ nflog("unknown iptables table %q", replace.Name.String())
return syserr.ErrInvalidArgument
}
var err *syserr.Error
var offsets map[uint32]int
if ipv6 {
- offsets, err = modifyEntries6(stk, optVal, &replace, &table)
+ offsets, err = modifyEntries6(task, stk, optVal, &replace, &table)
} else {
- offsets, err = modifyEntries4(stk, optVal, &replace, &table)
+ offsets, err = modifyEntries4(task, stk, optVal, &replace, &table)
}
if err != nil {
return err
@@ -272,7 +271,6 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
table.Rules[ruleIdx] = rule
}
- // TODO(gvisor.dev/issue/170): Support other chains.
// Since we don't support FORWARD, yet, make sure all other chains point to
// ACCEPT rules.
for hook, ruleIdx := range table.BuiltinChains {
@@ -287,7 +285,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
}
}
- // TODO(gvisor.dev/issue/170): Check the following conditions:
+ // TODO(gvisor.dev/issue/6167): Check the following conditions:
// - There are no loops.
// - There are no chains without an unconditional final rule.
// - There are no chains without an unconditional underflow rule.
@@ -297,7 +295,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
// parseMatchers parses 0 or more matchers from optVal. optVal should contain
// only the matchers.
-func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, error) {
+func parseMatchers(task *kernel.Task, filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, error) {
nflog("set entries: parsing matchers of size %d", len(optVal))
var matchers []stack.Matcher
for len(optVal) > 0 {
@@ -321,13 +319,13 @@ func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher,
}
// Parse the specific matcher.
- matcher, err := unmarshalMatcher(match, filter, optVal[linux.SizeOfXTEntryMatch:match.MatchSize])
+ matcher, err := unmarshalMatcher(task, match, filter, optVal[linux.SizeOfXTEntryMatch:match.MatchSize])
if err != nil {
return nil, fmt.Errorf("failed to create matcher: %v", err)
}
matchers = append(matchers, matcher)
- // TODO(gvisor.dev/issue/170): Check the revision field.
+ // TODO(gvisor.dev/issue/6167): Check the revision field.
optVal = optVal[match.MatchSize:]
}
diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go
index 60845cab3..6eff2ae65 100644
--- a/pkg/sentry/socket/netfilter/owner_matcher.go
+++ b/pkg/sentry/socket/netfilter/owner_matcher.go
@@ -19,6 +19,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -40,8 +42,8 @@ func (ownerMarshaler) name() string {
func (ownerMarshaler) marshal(mr matcher) []byte {
matcher := mr.(*OwnerMatcher)
iptOwnerInfo := linux.IPTOwnerInfo{
- UID: matcher.uid,
- GID: matcher.gid,
+ UID: uint32(matcher.uid),
+ GID: uint32(matcher.gid),
}
// Support for UID and GID match.
@@ -63,7 +65,7 @@ func (ownerMarshaler) marshal(mr matcher) []byte {
}
// unmarshal implements matchMaker.unmarshal.
-func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
+func (ownerMarshaler) unmarshal(task *kernel.Task, buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
if len(buf) < linux.SizeOfIPTOwnerInfo {
return nil, fmt.Errorf("buf has insufficient size for owner match: %d", len(buf))
}
@@ -72,11 +74,12 @@ func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.
// exceed what's strictly necessary to hold matchData.
var matchData linux.IPTOwnerInfo
matchData.UnmarshalUnsafe(buf[:linux.SizeOfIPTOwnerInfo])
- nflog("parseMatchers: parsed IPTOwnerInfo: %+v", matchData)
+ nflog("parsed IPTOwnerInfo: %+v", matchData)
var owner OwnerMatcher
- owner.uid = matchData.UID
- owner.gid = matchData.GID
+ creds := task.Credentials()
+ owner.uid = creds.UserNamespace.MapToKUID(auth.UID(matchData.UID))
+ owner.gid = creds.UserNamespace.MapToKGID(auth.GID(matchData.GID))
// Check flags.
if matchData.Match&linux.XT_OWNER_UID != 0 {
@@ -97,8 +100,8 @@ func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.
// OwnerMatcher matches against a UID and/or GID.
type OwnerMatcher struct {
- uid uint32
- gid uint32
+ uid auth.KUID
+ gid auth.KGID
matchUID bool
matchGID bool
invertUID bool
@@ -113,7 +116,6 @@ func (*OwnerMatcher) name() string {
// Match implements Matcher.Match.
func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) {
// Support only for OUTPUT chain.
- // TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also.
if hook != stack.Output {
return false, true
}
@@ -126,7 +128,7 @@ func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ str
var matches bool
// Check for UID match.
if om.matchUID {
- if pkt.Owner.UID() == om.uid {
+ if auth.KUID(pkt.Owner.KUID()) == om.uid {
matches = true
}
if matches == om.invertUID {
@@ -137,7 +139,7 @@ func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ str
// Check for GID match.
if om.matchGID {
matches = false
- if pkt.Owner.GID() == om.gid {
+ if auth.KGID(pkt.Owner.KGID()) == om.gid {
matches = true
}
if matches == om.invertGID {
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index fa5456eee..7d83e708f 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -331,7 +331,6 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (
return nil, syserr.ErrInvalidArgument
}
- // TODO(gvisor.dev/issue/170): Check if the flags are valid.
// Also check if we need to map ports or IP.
// For now, redirect target only supports destination port change.
// Port range and IP range are not supported yet.
@@ -340,7 +339,6 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (
return nil, syserr.ErrInvalidArgument
}
- // TODO(gvisor.dev/issue/170): Port range is not supported yet.
if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort {
nflog("redirectTargetMaker: MinPort != MaxPort (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort)
return nil, syserr.ErrInvalidArgument
@@ -502,7 +500,6 @@ func (*snatTargetMakerV4) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta
return nil, syserr.ErrInvalidArgument
}
- // TODO(gvisor.dev/issue/170): Port range is not supported yet.
if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort {
nflog("snatTargetMakerV4: MinPort != MaxPort (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort)
return nil, syserr.ErrInvalidArgument
@@ -594,7 +591,6 @@ func (*snatTargetMakerV6) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta
// translateToStandardTarget translates from the value in a
// linux.XTStandardTarget to an stack.Verdict.
func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (target, *syserr.Error) {
- // TODO(gvisor.dev/issue/170): Support other verdicts.
switch val {
case -linux.NF_ACCEPT - 1:
return &acceptTarget{stack.AcceptTarget{
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
index 95bb9826e..e5b73a976 100644
--- a/pkg/sentry/socket/netfilter/tcp_matcher.go
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -50,7 +51,7 @@ func (tcpMarshaler) marshal(mr matcher) []byte {
}
// unmarshal implements matchMaker.unmarshal.
-func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
+func (tcpMarshaler) unmarshal(_ *kernel.Task, buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
if len(buf) < linux.SizeOfXTTCP {
return nil, fmt.Errorf("buf has insufficient size for TCP match: %d", len(buf))
}
@@ -95,8 +96,6 @@ func (*TCPMatcher) name() string {
// Match implements Matcher.Match.
func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) {
- // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved
- // into the stack.Check codepath as matchers are added.
switch pkt.NetworkProtocolNumber {
case header.IPv4ProtocolNumber:
netHeader := header.IPv4(pkt.NetworkHeader().View())
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
index fb8be27e6..aa72ee70c 100644
--- a/pkg/sentry/socket/netfilter/udp_matcher.go
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -50,7 +51,7 @@ func (udpMarshaler) marshal(mr matcher) []byte {
}
// unmarshal implements matchMaker.unmarshal.
-func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
+func (udpMarshaler) unmarshal(_ *kernel.Task, buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
if len(buf) < linux.SizeOfXTUDP {
return nil, fmt.Errorf("buf has insufficient size for UDP match: %d", len(buf))
}
@@ -92,8 +93,6 @@ func (*UDPMatcher) name() string {
// Match implements Matcher.Match.
func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) {
- // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved
- // into the stack.Check codepath as matchers are added.
switch pkt.NetworkProtocolNumber {
case header.IPv4ProtocolNumber:
netHeader := header.IPv4(pkt.NetworkHeader().View())
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 0b64a24c3..11ba80497 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -77,41 +77,59 @@ func mustCreateGauge(name, description string) *tcpip.StatCounter {
// Metrics contains metrics exported by netstack.
var Metrics = tcpip.Stats{
- UnknownProtocolRcvdPackets: mustCreateMetric("/netstack/unknown_protocol_received_packets", "Number of packets received by netstack that were for an unknown or unsupported protocol."),
- MalformedRcvdPackets: mustCreateMetric("/netstack/malformed_received_packets", "Number of packets received by netstack that were deemed malformed."),
- DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped by netstack due to full queues."),
+ DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped at the transport layer."),
+ NICs: tcpip.NICStats{
+ UnknownL3ProtocolRcvdPackets: mustCreateMetric("/netstack/nic/unknown_l3_protocol_received_packets", "Number of packets received that were for an unknown or unsupported L3 protocol."),
+ UnknownL4ProtocolRcvdPackets: mustCreateMetric("/netstack/nic/unknown_l4_protocol_received_packets", "Number of packets received that were for an unknown or unsupported L4 protocol."),
+ MalformedL4RcvdPackets: mustCreateMetric("/netstack/nic/malformed_l4_received_packets", "Number of packets received that failed L4 header parsing."),
+ Tx: tcpip.NICPacketStats{
+ Packets: mustCreateMetric("/netstack/nic/tx/packets", "Number of packets transmitted."),
+ Bytes: mustCreateMetric("/netstack/nic/tx/bytes", "Number of bytes transmitted."),
+ },
+ Rx: tcpip.NICPacketStats{
+ Packets: mustCreateMetric("/netstack/nic/rx/packets", "Number of packets received."),
+ Bytes: mustCreateMetric("/netstack/nic/rx/bytes", "Number of bytes received."),
+ },
+ DisabledRx: tcpip.NICPacketStats{
+ Packets: mustCreateMetric("/netstack/nic/disabled_rx/packets", "Number of packets received on disabled NICs."),
+ Bytes: mustCreateMetric("/netstack/nic/disabled_rx/bytes", "Number of bytes received on disabled NICs."),
+ },
+ Neighbor: tcpip.NICNeighborStats{
+ UnreachableEntryLookups: mustCreateMetric("/netstack/nic/neighbor/unreachable_entry_loopups", "Number of lookups performed on a neighbor entry in Unreachable state."),
+ },
+ },
ICMP: tcpip.ICMPStats{
V4: tcpip.ICMPv4Stats{
PacketsSent: tcpip.ICMPv4SentPacketStats{
ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
- EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_request", "Number of ICMPv4 echo request packets sent by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Number of ICMPv4 echo reply packets sent by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Number of ICMPv4 destination unreachable packets sent by netstack."),
- SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Number of ICMPv4 source quench packets sent by netstack."),
- Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Number of ICMPv4 redirect packets sent by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Number of ICMPv4 time exceeded packets sent by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Number of ICMPv4 parameter problem packets sent by netstack."),
- Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Number of ICMPv4 timestamp packets sent by netstack."),
- TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Number of ICMPv4 timestamp reply packets sent by netstack."),
- InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Number of ICMPv4 information request packets sent by netstack."),
- InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Number of ICMPv4 information reply packets sent by netstack."),
+ EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_request", "Number of ICMPv4 echo request packets sent."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Number of ICMPv4 echo reply packets sent."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Number of ICMPv4 destination unreachable packets sent."),
+ SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Number of ICMPv4 source quench packets sent."),
+ Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Number of ICMPv4 redirect packets sent."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Number of ICMPv4 time exceeded packets sent."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Number of ICMPv4 parameter problem packets sent."),
+ Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Number of ICMPv4 timestamp packets sent."),
+ TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Number of ICMPv4 timestamp reply packets sent."),
+ InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Number of ICMPv4 information request packets sent."),
+ InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Number of ICMPv4 information reply packets sent."),
},
- Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Number of ICMPv4 packets dropped by netstack due to link layer errors."),
- RateLimited: mustCreateMetric("/netstack/icmp/v4/packets_sent/rate_limited", "Number of ICMPv4 packets dropped by netstack due to rate limit being exceeded."),
+ Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Number of ICMPv4 packets dropped due to link layer errors."),
+ RateLimited: mustCreateMetric("/netstack/icmp/v4/packets_sent/rate_limited", "Number of ICMPv4 packets dropped due to rate limit being exceeded."),
},
PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{
ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
- EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_request", "Number of ICMPv4 echo request packets received by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Number of ICMPv4 echo reply packets received by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Number of ICMPv4 destination unreachable packets received by netstack."),
- SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Number of ICMPv4 source quench packets received by netstack."),
- Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Number of ICMPv4 redirect packets received by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Number of ICMPv4 time exceeded packets received by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Number of ICMPv4 parameter problem packets received by netstack."),
- Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Number of ICMPv4 timestamp packets received by netstack."),
- TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Number of ICMPv4 timestamp reply packets received by netstack."),
- InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Number of ICMPv4 information request packets received by netstack."),
- InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Number of ICMPv4 information reply packets received by netstack."),
+ EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_request", "Number of ICMPv4 echo request packets received."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Number of ICMPv4 echo reply packets received."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Number of ICMPv4 destination unreachable packets received."),
+ SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Number of ICMPv4 source quench packets received."),
+ Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Number of ICMPv4 redirect packets received."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Number of ICMPv4 time exceeded packets received."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Number of ICMPv4 parameter problem packets received."),
+ Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Number of ICMPv4 timestamp packets received."),
+ TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Number of ICMPv4 timestamp reply packets received."),
+ InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Number of ICMPv4 information request packets received."),
+ InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Number of ICMPv4 information reply packets received."),
},
Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Number of ICMPv4 packets received that the transport layer could not parse."),
},
@@ -119,40 +137,40 @@ var Metrics = tcpip.Stats{
V6: tcpip.ICMPv6Stats{
PacketsSent: tcpip.ICMPv6SentPacketStats{
ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
- EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Number of ICMPv6 echo request packets sent by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Number of ICMPv6 echo reply packets sent by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Number of ICMPv6 destination unreachable packets sent by netstack."),
- PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Number of ICMPv6 packet too big packets sent by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Number of ICMPv6 time exceeded packets sent by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Number of ICMPv6 parameter problem packets sent by netstack."),
- RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Number of ICMPv6 router solicit packets sent by netstack."),
- RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Number of ICMPv6 router advert packets sent by netstack."),
- NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets sent by netstack."),
- NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Number of ICMPv6 neighbor advert packets sent by netstack."),
- RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Number of ICMPv6 redirect message packets sent by netstack."),
- MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_query", "Number of ICMPv6 multicast listener query packets sent by netstack."),
- MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent by netstack."),
- MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent by netstack."),
+ EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Number of ICMPv6 echo request packets sent."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Number of ICMPv6 echo reply packets sent."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Number of ICMPv6 destination unreachable packets sent."),
+ PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Number of ICMPv6 packet too big packets sent."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Number of ICMPv6 time exceeded packets sent."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Number of ICMPv6 parameter problem packets sent."),
+ RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Number of ICMPv6 router solicit packets sent."),
+ RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Number of ICMPv6 router advert packets sent."),
+ NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets sent."),
+ NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Number of ICMPv6 neighbor advert packets sent."),
+ RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Number of ICMPv6 redirect message packets sent."),
+ MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_query", "Number of ICMPv6 multicast listener query packets sent."),
+ MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent."),
+ MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent."),
},
- Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Number of ICMPv6 packets dropped by netstack due to link layer errors."),
- RateLimited: mustCreateMetric("/netstack/icmp/v6/packets_sent/rate_limited", "Number of ICMPv6 packets dropped by netstack due to rate limit being exceeded."),
+ Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Number of ICMPv6 packets dropped due to link layer errors."),
+ RateLimited: mustCreateMetric("/netstack/icmp/v6/packets_sent/rate_limited", "Number of ICMPv6 packets dropped due to rate limit being exceeded."),
},
PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{
ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
- EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Number of ICMPv6 echo request packets received by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Number of ICMPv6 echo reply packets received by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Number of ICMPv6 destination unreachable packets received by netstack."),
- PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Number of ICMPv6 packet too big packets received by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Number of ICMPv6 time exceeded packets received by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Number of ICMPv6 parameter problem packets received by netstack."),
- RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Number of ICMPv6 router solicit packets received by netstack."),
- RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Number of ICMPv6 router advert packets received by netstack."),
- NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets received by netstack."),
- NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Number of ICMPv6 neighbor advert packets received by netstack."),
- RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Number of ICMPv6 redirect message packets received by netstack."),
- MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_query", "Number of ICMPv6 multicast listener query packets received by netstack."),
- MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent by netstack."),
- MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent by netstack."),
+ EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Number of ICMPv6 echo request packets received."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Number of ICMPv6 echo reply packets received."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Number of ICMPv6 destination unreachable packets received."),
+ PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Number of ICMPv6 packet too big packets received."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Number of ICMPv6 time exceeded packets received."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Number of ICMPv6 parameter problem packets received."),
+ RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Number of ICMPv6 router solicit packets received."),
+ RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Number of ICMPv6 router advert packets received."),
+ NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets received."),
+ NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Number of ICMPv6 neighbor advert packets received."),
+ RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Number of ICMPv6 redirect message packets received."),
+ MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_query", "Number of ICMPv6 multicast listener query packets received."),
+ MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent."),
+ MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent."),
},
Unrecognized: mustCreateMetric("/netstack/icmp/v6/packets_received/unrecognized", "Number of ICMPv6 packets received that the transport layer does not know how to parse."),
Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Number of ICMPv6 packets received that the transport layer could not parse."),
@@ -163,23 +181,23 @@ var Metrics = tcpip.Stats{
IGMP: tcpip.IGMPStats{
PacketsSent: tcpip.IGMPSentPacketStats{
IGMPPacketStats: tcpip.IGMPPacketStats{
- MembershipQuery: mustCreateMetric("/netstack/igmp/packets_sent/membership_query", "Number of IGMP Membership Query messages sent by netstack."),
- V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v1_membership_report", "Number of IGMPv1 Membership Report messages sent by netstack."),
- V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v2_membership_report", "Number of IGMPv2 Membership Report messages sent by netstack."),
- LeaveGroup: mustCreateMetric("/netstack/igmp/packets_sent/leave_group", "Number of IGMP Leave Group messages sent by netstack."),
+ MembershipQuery: mustCreateMetric("/netstack/igmp/packets_sent/membership_query", "Number of IGMP Membership Query messages sent."),
+ V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v1_membership_report", "Number of IGMPv1 Membership Report messages sent."),
+ V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v2_membership_report", "Number of IGMPv2 Membership Report messages sent."),
+ LeaveGroup: mustCreateMetric("/netstack/igmp/packets_sent/leave_group", "Number of IGMP Leave Group messages sent."),
},
- Dropped: mustCreateMetric("/netstack/igmp/packets_sent/dropped", "Number of IGMP packets dropped by netstack due to link layer errors."),
+ Dropped: mustCreateMetric("/netstack/igmp/packets_sent/dropped", "Number of IGMP packets dropped due to link layer errors."),
},
PacketsReceived: tcpip.IGMPReceivedPacketStats{
IGMPPacketStats: tcpip.IGMPPacketStats{
- MembershipQuery: mustCreateMetric("/netstack/igmp/packets_received/membership_query", "Number of IGMP Membership Query messages received by netstack."),
- V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v1_membership_report", "Number of IGMPv1 Membership Report messages received by netstack."),
- V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v2_membership_report", "Number of IGMPv2 Membership Report messages received by netstack."),
- LeaveGroup: mustCreateMetric("/netstack/igmp/packets_received/leave_group", "Number of IGMP Leave Group messages received by netstack."),
+ MembershipQuery: mustCreateMetric("/netstack/igmp/packets_received/membership_query", "Number of IGMP Membership Query messages received."),
+ V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v1_membership_report", "Number of IGMPv1 Membership Report messages received."),
+ V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v2_membership_report", "Number of IGMPv2 Membership Report messages received."),
+ LeaveGroup: mustCreateMetric("/netstack/igmp/packets_received/leave_group", "Number of IGMP Leave Group messages received."),
},
- Invalid: mustCreateMetric("/netstack/igmp/packets_received/invalid", "Number of IGMP packets received by netstack that could not be parsed."),
+ Invalid: mustCreateMetric("/netstack/igmp/packets_received/invalid", "Number of IGMP packets received that could not be parsed."),
ChecksumErrors: mustCreateMetric("/netstack/igmp/packets_received/checksum_errors", "Number of received IGMP packets with bad checksums."),
- Unrecognized: mustCreateMetric("/netstack/igmp/packets_received/unrecognized", "Number of unrecognized IGMP packets received by netstack."),
+ Unrecognized: mustCreateMetric("/netstack/igmp/packets_received/unrecognized", "Number of unrecognized IGMP packets received."),
},
},
IP: tcpip.IPStats{
@@ -205,7 +223,8 @@ var Metrics = tcpip.Stats{
LinkLocalSource: mustCreateMetric("/netstack/ip/forwarding/link_local_source_address", "Number of IP packets received which could not be forwarded due to a link-local source address."),
LinkLocalDestination: mustCreateMetric("/netstack/ip/forwarding/link_local_destination_address", "Number of IP packets received which could not be forwarded due to a link-local destination address."),
ExtensionHeaderProblem: mustCreateMetric("/netstack/ip/forwarding/extension_header_problem", "Number of IP packets received which could not be forwarded due to a problem processing their IPv6 extension headers."),
- PacketTooBig: mustCreateMetric("/netstack/ip/forwarding/packet_too_big", "Number of IP packets received which could not fit within the outgoing MTU."),
+ PacketTooBig: mustCreateMetric("/netstack/ip/forwarding/packet_too_big", "Number of IP packets received which could not be forwarded because they could not fit within the outgoing MTU."),
+ HostUnreachable: mustCreateMetric("/netstack/ip/forwarding/host_unreachable", "Number of IP packets received which could not be forwarded due to unresolvable next hop."),
Errors: mustCreateMetric("/netstack/ip/forwarding/errors", "Number of IP packets which couldn't be forwarded."),
},
},
@@ -1126,7 +1145,14 @@ func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name,
// TODO(b/64800844): Translate fields once they are added to
// tcpip.TCPInfoOption.
- info := linux.TCPInfo{}
+ info := linux.TCPInfo{
+ State: uint8(v.State),
+ RTO: uint32(v.RTO / time.Microsecond),
+ RTT: uint32(v.RTT / time.Microsecond),
+ RTTVar: uint32(v.RTTVar / time.Microsecond),
+ SndSsthresh: v.SndSsthresh,
+ SndCwnd: v.SndCwnd,
+ }
switch v.CcState {
case tcpip.RTORecovery:
info.CaState = linux.TCP_CA_Loss
@@ -1137,11 +1163,6 @@ func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name,
case tcpip.Open:
info.CaState = linux.TCP_CA_Open
}
- info.RTO = uint32(v.RTO / time.Microsecond)
- info.RTT = uint32(v.RTT / time.Microsecond)
- info.RTTVar = uint32(v.RTTVar / time.Microsecond)
- info.SndSsthresh = v.SndSsthresh
- info.SndCwnd = v.SndCwnd
// In netstack reorderSeen is updated only when RACK is enabled.
// We only track whether the reordering is seen, which is
@@ -1777,11 +1798,6 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
v := hostarch.ByteOrder.Uint32(optVal)
-
- if v == 0 {
- socket.SetSockOptEmitUnimplementedEvent(t, name)
- }
-
ep.SocketOptions().SetOutOfBandInline(v != 0)
return nil
@@ -2089,10 +2105,10 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return syserr.ErrNoDevice
}
// Stack must be a netstack stack.
- return netfilter.SetEntries(stack.(*Stack).Stack, optVal, true)
+ return netfilter.SetEntries(t, stack.(*Stack).Stack, optVal, true)
case linux.IP6T_SO_SET_ADD_COUNTERS:
- // TODO(gvisor.dev/issue/170): Counter support.
+ log.Infof("IP6T_SO_SET_ADD_COUNTERS is not supported")
return nil
default:
@@ -2332,10 +2348,10 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return syserr.ErrNoDevice
}
// Stack must be a netstack stack.
- return netfilter.SetEntries(stack.(*Stack).Stack, optVal, false)
+ return netfilter.SetEntries(t, stack.(*Stack).Stack, optVal, false)
case linux.IPT_SO_SET_ADD_COUNTERS:
- // TODO(gvisor.dev/issue/170): Counter support.
+ log.Infof("IPT_SO_SET_ADD_COUNTERS is not supported")
return nil
case linux.IP_ADD_SOURCE_MEMBERSHIP,
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index 9cc1c57d7..eef5e6519 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -458,16 +458,6 @@ func (s *Stack) RestoreCleanupEndpoints(es []stack.TransportEndpoint) {
s.Stack.RestoreCleanupEndpoints(es)
}
-// Forwarding implements inet.Stack.Forwarding.
-func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool {
- switch protocol {
- case ipv4.ProtocolNumber, ipv6.ProtocolNumber:
- return s.Stack.Forwarding(protocol)
- default:
- panic(fmt.Sprintf("Forwarding(%v) failed: unsupported protocol", protocol))
- }
-}
-
// SetForwarding implements inet.Stack.SetForwarding.
func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error {
if err := s.Stack.SetForwardingDefaultAndAllNICs(protocol, enable); err != nil {
diff --git a/pkg/sentry/socket/netstack/tun.go b/pkg/sentry/socket/netstack/tun.go
index 288dd0c9e..c7ed52702 100644
--- a/pkg/sentry/socket/netstack/tun.go
+++ b/pkg/sentry/socket/netstack/tun.go
@@ -40,7 +40,7 @@ func LinuxToTUNFlags(flags uint16) (tun.Flags, error) {
// Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately)
// when there is no sk_filter. See __tun_chr_ioctl() in
// net/drivers/tun.c.
- if flags&^uint16(linux.IFF_TUN|linux.IFF_TAP|linux.IFF_NO_PI) != 0 {
+ if flags&^uint16(linux.IFF_TUN|linux.IFF_TAP|linux.IFF_NO_PI|linux.IFF_ONE_QUEUE) != 0 {
return tun.Flags{}, syserror.EINVAL
}
return tun.Flags{
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index 1fbbd133c..369541c7a 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -11,6 +11,7 @@ go_library(
"futex.go",
"linux64_amd64.go",
"linux64_arm64.go",
+ "mmap.go",
"open.go",
"poll.go",
"ptrace.go",
@@ -35,7 +36,6 @@ go_library(
"//pkg/sentry/socket",
"//pkg/sentry/socket/netlink",
"//pkg/sentry/syscalls/linux",
- "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/strace/clone.go b/pkg/sentry/strace/clone.go
index ab1060426..bfb4d7f5c 100644
--- a/pkg/sentry/strace/clone.go
+++ b/pkg/sentry/strace/clone.go
@@ -15,98 +15,98 @@
package strace
import (
- "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/abi/linux"
)
// CloneFlagSet is the set of clone(2) flags.
var CloneFlagSet = abi.FlagSet{
{
- Flag: unix.CLONE_VM,
+ Flag: linux.CLONE_VM,
Name: "CLONE_VM",
},
{
- Flag: unix.CLONE_FS,
+ Flag: linux.CLONE_FS,
Name: "CLONE_FS",
},
{
- Flag: unix.CLONE_FILES,
+ Flag: linux.CLONE_FILES,
Name: "CLONE_FILES",
},
{
- Flag: unix.CLONE_SIGHAND,
+ Flag: linux.CLONE_SIGHAND,
Name: "CLONE_SIGHAND",
},
{
- Flag: unix.CLONE_PTRACE,
+ Flag: linux.CLONE_PTRACE,
Name: "CLONE_PTRACE",
},
{
- Flag: unix.CLONE_VFORK,
+ Flag: linux.CLONE_VFORK,
Name: "CLONE_VFORK",
},
{
- Flag: unix.CLONE_PARENT,
+ Flag: linux.CLONE_PARENT,
Name: "CLONE_PARENT",
},
{
- Flag: unix.CLONE_THREAD,
+ Flag: linux.CLONE_THREAD,
Name: "CLONE_THREAD",
},
{
- Flag: unix.CLONE_NEWNS,
+ Flag: linux.CLONE_NEWNS,
Name: "CLONE_NEWNS",
},
{
- Flag: unix.CLONE_SYSVSEM,
+ Flag: linux.CLONE_SYSVSEM,
Name: "CLONE_SYSVSEM",
},
{
- Flag: unix.CLONE_SETTLS,
+ Flag: linux.CLONE_SETTLS,
Name: "CLONE_SETTLS",
},
{
- Flag: unix.CLONE_PARENT_SETTID,
+ Flag: linux.CLONE_PARENT_SETTID,
Name: "CLONE_PARENT_SETTID",
},
{
- Flag: unix.CLONE_CHILD_CLEARTID,
+ Flag: linux.CLONE_CHILD_CLEARTID,
Name: "CLONE_CHILD_CLEARTID",
},
{
- Flag: unix.CLONE_DETACHED,
+ Flag: linux.CLONE_DETACHED,
Name: "CLONE_DETACHED",
},
{
- Flag: unix.CLONE_UNTRACED,
+ Flag: linux.CLONE_UNTRACED,
Name: "CLONE_UNTRACED",
},
{
- Flag: unix.CLONE_CHILD_SETTID,
+ Flag: linux.CLONE_CHILD_SETTID,
Name: "CLONE_CHILD_SETTID",
},
{
- Flag: unix.CLONE_NEWUTS,
+ Flag: linux.CLONE_NEWUTS,
Name: "CLONE_NEWUTS",
},
{
- Flag: unix.CLONE_NEWIPC,
+ Flag: linux.CLONE_NEWIPC,
Name: "CLONE_NEWIPC",
},
{
- Flag: unix.CLONE_NEWUSER,
+ Flag: linux.CLONE_NEWUSER,
Name: "CLONE_NEWUSER",
},
{
- Flag: unix.CLONE_NEWPID,
+ Flag: linux.CLONE_NEWPID,
Name: "CLONE_NEWPID",
},
{
- Flag: unix.CLONE_NEWNET,
+ Flag: linux.CLONE_NEWNET,
Name: "CLONE_NEWNET",
},
{
- Flag: unix.CLONE_IO,
+ Flag: linux.CLONE_IO,
Name: "CLONE_IO",
},
}
diff --git a/pkg/sentry/strace/linux64_amd64.go b/pkg/sentry/strace/linux64_amd64.go
index d66befe81..6ce1bb592 100644
--- a/pkg/sentry/strace/linux64_amd64.go
+++ b/pkg/sentry/strace/linux64_amd64.go
@@ -33,7 +33,7 @@ var linuxAMD64 = SyscallMap{
6: makeSyscallInfo("lstat", Path, Stat),
7: makeSyscallInfo("poll", PollFDs, Hex, Hex),
8: makeSyscallInfo("lseek", Hex, Hex, Hex),
- 9: makeSyscallInfo("mmap", Hex, Hex, Hex, Hex, FD, Hex),
+ 9: makeSyscallInfo("mmap", Hex, Hex, MmapProt, MmapFlags, FD, Hex),
10: makeSyscallInfo("mprotect", Hex, Hex, Hex),
11: makeSyscallInfo("munmap", Hex, Hex),
12: makeSyscallInfo("brk", Hex),
diff --git a/pkg/sentry/strace/linux64_arm64.go b/pkg/sentry/strace/linux64_arm64.go
index 1a2d7d75f..ce5594301 100644
--- a/pkg/sentry/strace/linux64_arm64.go
+++ b/pkg/sentry/strace/linux64_arm64.go
@@ -246,7 +246,7 @@ var linuxARM64 = SyscallMap{
219: makeSyscallInfo("keyctl", Hex, Hex, Hex, Hex, Hex),
220: makeSyscallInfo("clone", CloneFlags, Hex, Hex, Hex, Hex),
221: makeSyscallInfo("execve", Path, ExecveStringVector, ExecveStringVector),
- 222: makeSyscallInfo("mmap", Hex, Hex, Hex, Hex, FD, Hex),
+ 222: makeSyscallInfo("mmap", Hex, Hex, MmapProt, MmapFlags, FD, Hex),
223: makeSyscallInfo("fadvise64", FD, Hex, Hex, Hex),
224: makeSyscallInfo("swapon", Hex, Hex),
225: makeSyscallInfo("swapoff", Hex),
diff --git a/pkg/sentry/strace/mmap.go b/pkg/sentry/strace/mmap.go
new file mode 100644
index 000000000..0035be586
--- /dev/null
+++ b/pkg/sentry/strace/mmap.go
@@ -0,0 +1,92 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+// ProtectionFlagSet represents the protection to mmap(2).
+var ProtectionFlagSet = abi.FlagSet{
+ {
+ Flag: linux.PROT_READ,
+ Name: "PROT_READ",
+ },
+ {
+ Flag: linux.PROT_WRITE,
+ Name: "PROT_WRITE",
+ },
+ {
+ Flag: linux.PROT_EXEC,
+ Name: "PROT_EXEC",
+ },
+}
+
+// MmapFlagSet is the set of mmap(2) flags.
+var MmapFlagSet = abi.FlagSet{
+ {
+ Flag: linux.MAP_SHARED,
+ Name: "MAP_SHARED",
+ },
+ {
+ Flag: linux.MAP_PRIVATE,
+ Name: "MAP_PRIVATE",
+ },
+ {
+ Flag: linux.MAP_FIXED,
+ Name: "MAP_FIXED",
+ },
+ {
+ Flag: linux.MAP_ANONYMOUS,
+ Name: "MAP_ANONYMOUS",
+ },
+ {
+ Flag: linux.MAP_GROWSDOWN,
+ Name: "MAP_GROWSDOWN",
+ },
+ {
+ Flag: linux.MAP_DENYWRITE,
+ Name: "MAP_DENYWRITE",
+ },
+ {
+ Flag: linux.MAP_EXECUTABLE,
+ Name: "MAP_EXECUTABLE",
+ },
+ {
+ Flag: linux.MAP_LOCKED,
+ Name: "MAP_LOCKED",
+ },
+ {
+ Flag: linux.MAP_NORESERVE,
+ Name: "MAP_NORESERVE",
+ },
+ {
+ Flag: linux.MAP_POPULATE,
+ Name: "MAP_POPULATE",
+ },
+ {
+ Flag: linux.MAP_NONBLOCK,
+ Name: "MAP_NONBLOCK",
+ },
+ {
+ Flag: linux.MAP_STACK,
+ Name: "MAP_STACK",
+ },
+ {
+ Flag: linux.MAP_HUGETLB,
+ Name: "MAP_HUGETLB",
+ },
+}
diff --git a/pkg/sentry/strace/open.go b/pkg/sentry/strace/open.go
index 5769360da..e7c7649f4 100644
--- a/pkg/sentry/strace/open.go
+++ b/pkg/sentry/strace/open.go
@@ -15,61 +15,61 @@
package strace
import (
- "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/abi/linux"
)
// OpenMode represents the mode to open(2) a file.
var OpenMode = abi.ValueSet{
- unix.O_RDWR: "O_RDWR",
- unix.O_WRONLY: "O_WRONLY",
- unix.O_RDONLY: "O_RDONLY",
+ linux.O_RDWR: "O_RDWR",
+ linux.O_WRONLY: "O_WRONLY",
+ linux.O_RDONLY: "O_RDONLY",
}
// OpenFlagSet is the set of open(2) flags.
var OpenFlagSet = abi.FlagSet{
{
- Flag: unix.O_APPEND,
+ Flag: linux.O_APPEND,
Name: "O_APPEND",
},
{
- Flag: unix.O_ASYNC,
+ Flag: linux.O_ASYNC,
Name: "O_ASYNC",
},
{
- Flag: unix.O_CLOEXEC,
+ Flag: linux.O_CLOEXEC,
Name: "O_CLOEXEC",
},
{
- Flag: unix.O_CREAT,
+ Flag: linux.O_CREAT,
Name: "O_CREAT",
},
{
- Flag: unix.O_DIRECT,
+ Flag: linux.O_DIRECT,
Name: "O_DIRECT",
},
{
- Flag: unix.O_DIRECTORY,
+ Flag: linux.O_DIRECTORY,
Name: "O_DIRECTORY",
},
{
- Flag: unix.O_EXCL,
+ Flag: linux.O_EXCL,
Name: "O_EXCL",
},
{
- Flag: unix.O_NOATIME,
+ Flag: linux.O_NOATIME,
Name: "O_NOATIME",
},
{
- Flag: unix.O_NOCTTY,
+ Flag: linux.O_NOCTTY,
Name: "O_NOCTTY",
},
{
- Flag: unix.O_NOFOLLOW,
+ Flag: linux.O_NOFOLLOW,
Name: "O_NOFOLLOW",
},
{
- Flag: unix.O_NONBLOCK,
+ Flag: linux.O_NONBLOCK,
Name: "O_NONBLOCK",
},
{
@@ -77,18 +77,22 @@ var OpenFlagSet = abi.FlagSet{
Name: "O_PATH",
},
{
- Flag: unix.O_SYNC,
+ Flag: linux.O_SYNC,
Name: "O_SYNC",
},
{
- Flag: unix.O_TRUNC,
+ Flag: linux.O_TMPFILE,
+ Name: "O_TMPFILE",
+ },
+ {
+ Flag: linux.O_TRUNC,
Name: "O_TRUNC",
},
}
func open(val uint64) string {
- s := OpenMode.Parse(val & unix.O_ACCMODE)
- if flags := OpenFlagSet.Parse(val &^ unix.O_ACCMODE); flags != "" {
+ s := OpenMode.Parse(val & linux.O_ACCMODE)
+ if flags := OpenFlagSet.Parse(val &^ linux.O_ACCMODE); flags != "" {
s += "|" + flags
}
return s
diff --git a/pkg/sentry/strace/signal.go b/pkg/sentry/strace/signal.go
index e5b379a20..5afc9525b 100644
--- a/pkg/sentry/strace/signal.go
+++ b/pkg/sentry/strace/signal.go
@@ -130,8 +130,8 @@ func sigAction(t *kernel.Task, addr hostarch.Addr) string {
return "null"
}
- sa, err := t.CopyInSignalAct(addr)
- if err != nil {
+ var sa linux.SigAction
+ if _, err := sa.CopyIn(t, addr); err != nil {
return fmt.Sprintf("%#x (error copying sigaction: %v)", addr, err)
}
diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go
index ec5d5f846..af7088847 100644
--- a/pkg/sentry/strace/strace.go
+++ b/pkg/sentry/strace/strace.go
@@ -489,6 +489,10 @@ func (i *SyscallInfo) pre(t *kernel.Task, args arch.SyscallArguments, maximumBlo
output = append(output, epollEvents(t, args[arg].Pointer(), 0 /* numEvents */, uint64(maximumBlobSize)))
case SelectFDSet:
output = append(output, fdSet(t, int(args[0].Int()), args[arg].Pointer()))
+ case MmapProt:
+ output = append(output, ProtectionFlagSet.Parse(uint64(args[arg].Uint())))
+ case MmapFlags:
+ output = append(output, MmapFlagSet.Parse(uint64(args[arg].Uint())))
case Oct:
output = append(output, "0o"+strconv.FormatUint(args[arg].Uint64(), 8))
case Hex:
diff --git a/pkg/sentry/strace/syscalls.go b/pkg/sentry/strace/syscalls.go
index 7e69b9279..5893443a7 100644
--- a/pkg/sentry/strace/syscalls.go
+++ b/pkg/sentry/strace/syscalls.go
@@ -238,6 +238,12 @@ const (
// EpollEvents is an array of struct epoll_event. It is the events
// argument in epoll_wait(2)/epoll_pwait(2).
EpollEvents
+
+ // MmapProt is the protection argument in mmap(2).
+ MmapProt
+
+ // MmapFlags is the flags argument in mmap(2).
+ MmapFlags
)
// defaultFormat is the syscall argument format to use if the actual format is
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index 9cd238efd..90a719ba2 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -1569,9 +1569,9 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
if uint64(length) >= t.ThreadGroup().Limits().Get(limits.FileSize).Cur {
- t.SendSignal(&arch.SignalInfo{
+ t.SendSignal(&linux.SignalInfo{
Signo: int32(linux.SIGXFSZ),
- Code: arch.SignalInfoUser,
+ Code: linux.SI_USER,
})
return 0, nil, syserror.EFBIG
}
@@ -1632,9 +1632,9 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
if uint64(length) >= t.ThreadGroup().Limits().Get(limits.FileSize).Cur {
- t.SendSignal(&arch.SignalInfo{
+ t.SendSignal(&linux.SignalInfo{
Signo: int32(linux.SIGXFSZ),
- Code: arch.SignalInfoUser,
+ Code: linux.SI_USER,
})
return 0, nil, syserror.EFBIG
}
@@ -1673,9 +1673,11 @@ func chown(t *kernel.Task, d *fs.Dirent, uid auth.UID, gid auth.GID) error {
if err != nil {
return err
}
+
c := t.Credentials()
hasCap := d.Inode.CheckCapability(t, linux.CAP_CHOWN)
isOwner := uattr.Owner.UID == c.EffectiveKUID
+ var clearPrivilege bool
if uid.Ok() {
kuid := c.UserNamespace.MapToKUID(uid)
// Valid UID must be supplied if UID is to be changed.
@@ -1693,6 +1695,11 @@ func chown(t *kernel.Task, d *fs.Dirent, uid auth.UID, gid auth.GID) error {
return syserror.EPERM
}
+ // The setuid and setgid bits are cleared during a chown.
+ if uattr.Owner.UID != kuid {
+ clearPrivilege = true
+ }
+
owner.UID = kuid
}
if gid.Ok() {
@@ -1711,6 +1718,11 @@ func chown(t *kernel.Task, d *fs.Dirent, uid auth.UID, gid auth.GID) error {
return syserror.EPERM
}
+ // The setuid and setgid bits are cleared during a chown.
+ if uattr.Owner.GID != kgid {
+ clearPrivilege = true
+ }
+
owner.GID = kgid
}
@@ -1721,10 +1733,14 @@ func chown(t *kernel.Task, d *fs.Dirent, uid auth.UID, gid auth.GID) error {
if err := d.Inode.SetOwner(t, d, owner); err != nil {
return err
}
+ // Clear privilege bits if needed and they are set.
+ if clearPrivilege && uattr.Perms.HasSetUIDOrGID() && !fs.IsDir(d.Inode.StableAttr) {
+ uattr.Perms.DropSetUIDAndMaybeGID()
+ if !d.Inode.SetPermissions(t, d, uattr.Perms) {
+ return syserror.EPERM
+ }
+ }
- // When the owner or group are changed by an unprivileged user,
- // chown(2) also clears the set-user-ID and set-group-ID bits, but
- // we do not support them.
return nil
}
@@ -2124,9 +2140,9 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, syserror.EFBIG
}
if uint64(size) >= t.ThreadGroup().Limits().Get(limits.FileSize).Cur {
- t.SendSignal(&arch.SignalInfo{
+ t.SendSignal(&linux.SignalInfo{
Signo: int32(linux.SIGXFSZ),
- Code: arch.SignalInfoUser,
+ Code: linux.SI_USER,
})
return 0, nil, syserror.EFBIG
}
diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go
index 53b12dc41..27a7f7fe1 100644
--- a/pkg/sentry/syscalls/linux/sys_signal.go
+++ b/pkg/sentry/syscalls/linux/sys_signal.go
@@ -84,9 +84,9 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
if !mayKill(t, target, sig) {
return 0, nil, syserror.EPERM
}
- info := &arch.SignalInfo{
+ info := &linux.SignalInfo{
Signo: int32(sig),
- Code: arch.SignalInfoUser,
+ Code: linux.SI_USER,
}
info.SetPID(int32(target.PIDNamespace().IDOfTask(t)))
info.SetUID(int32(t.Credentials().RealKUID.In(target.UserNamespace()).OrOverflow()))
@@ -123,9 +123,9 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
// depend on the iteration order. We at least implement the
// semantics documented by the man page: "On success (at least
// one signal was sent), zero is returned."
- info := &arch.SignalInfo{
+ info := &linux.SignalInfo{
Signo: int32(sig),
- Code: arch.SignalInfoUser,
+ Code: linux.SI_USER,
}
info.SetPID(int32(tg.PIDNamespace().IDOfTask(t)))
info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow()))
@@ -167,9 +167,9 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
continue
}
- info := &arch.SignalInfo{
+ info := &linux.SignalInfo{
Signo: int32(sig),
- Code: arch.SignalInfoUser,
+ Code: linux.SI_USER,
}
info.SetPID(int32(tg.PIDNamespace().IDOfTask(t)))
info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow()))
@@ -184,10 +184,10 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
}
}
-func tkillSigInfo(sender, receiver *kernel.Task, sig linux.Signal) *arch.SignalInfo {
- info := &arch.SignalInfo{
+func tkillSigInfo(sender, receiver *kernel.Task, sig linux.Signal) *linux.SignalInfo {
+ info := &linux.SignalInfo{
Signo: int32(sig),
- Code: arch.SignalInfoTkill,
+ Code: linux.SI_TKILL,
}
info.SetPID(int32(receiver.PIDNamespace().IDOfThreadGroup(sender.ThreadGroup())))
info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
@@ -251,20 +251,20 @@ func RtSigaction(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
return 0, nil, syserror.EINVAL
}
- var newactptr *arch.SignalAct
+ var newactptr *linux.SigAction
if newactarg != 0 {
- newact, err := t.CopyInSignalAct(newactarg)
- if err != nil {
+ var newact linux.SigAction
+ if _, err := newact.CopyIn(t, newactarg); err != nil {
return 0, nil, err
}
newactptr = &newact
}
- oldact, err := t.ThreadGroup().SetSignalAct(sig, newactptr)
+ oldact, err := t.ThreadGroup().SetSigAction(sig, newactptr)
if err != nil {
return 0, nil, err
}
if oldactarg != 0 {
- if err := t.CopyOutSignalAct(oldactarg, &oldact); err != nil {
+ if _, err := oldact.CopyOut(t, oldactarg); err != nil {
return 0, nil, err
}
}
@@ -325,13 +325,12 @@ func Sigaltstack(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
alt := t.SignalStack()
if oldaddr != 0 {
- if err := t.CopyOutSignalStack(oldaddr, &alt); err != nil {
+ if _, err := alt.CopyOut(t, oldaddr); err != nil {
return 0, nil, err
}
}
if setaddr != 0 {
- alt, err := t.CopyInSignalStack(setaddr)
- if err != nil {
+ if _, err := alt.CopyIn(t, setaddr); err != nil {
return 0, nil, err
}
// The signal stack cannot be changed if the task is currently
@@ -410,7 +409,7 @@ func RtSigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
// We must ensure that the Signo is set (Linux overrides this in the
// same way), and that the code is in the allowed set. This same logic
// appears below in RtSigtgqueueinfo and should be kept in sync.
- var info arch.SignalInfo
+ var info linux.SignalInfo
if _, err := info.CopyIn(t, infoAddr); err != nil {
return 0, nil, err
}
@@ -426,7 +425,7 @@ func RtSigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
// If the sender is not the receiver, it can't use si_codes used by the
// kernel or SI_TKILL.
- if (info.Code >= 0 || info.Code == arch.SignalInfoTkill) && target != t {
+ if (info.Code >= 0 || info.Code == linux.SI_TKILL) && target != t {
return 0, nil, syserror.EPERM
}
@@ -454,7 +453,7 @@ func RtTgsigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *ker
}
// Copy in the info. See RtSigqueueinfo above.
- var info arch.SignalInfo
+ var info linux.SignalInfo
if _, err := info.CopyIn(t, infoAddr); err != nil {
return 0, nil, err
}
@@ -469,7 +468,7 @@ func RtTgsigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *ker
// If the sender is not the receiver, it can't use si_codes used by the
// kernel or SI_TKILL.
- if (info.Code >= 0 || info.Code == arch.SignalInfoTkill) && target != t {
+ if (info.Code >= 0 || info.Code == linux.SI_TKILL) && target != t {
return 0, nil, syserror.EPERM
}
diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go
index 3185ea527..0d5056303 100644
--- a/pkg/sentry/syscalls/linux/sys_thread.go
+++ b/pkg/sentry/syscalls/linux/sys_thread.go
@@ -398,7 +398,7 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// out the fields it would set for a successful waitid in this case
// as well.
if infop != 0 {
- var si arch.SignalInfo
+ var si linux.SignalInfo
_, err = si.CopyOut(t, infop)
}
}
@@ -413,7 +413,7 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if infop == 0 {
return 0, nil, nil
}
- si := arch.SignalInfo{
+ si := linux.SignalInfo{
Signo: int32(linux.SIGCHLD),
}
si.SetPID(int32(wr.TID))
@@ -423,24 +423,24 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
s := unix.WaitStatus(wr.Status)
switch {
case s.Exited():
- si.Code = arch.CLD_EXITED
+ si.Code = linux.CLD_EXITED
si.SetStatus(int32(s.ExitStatus()))
case s.Signaled():
- si.Code = arch.CLD_KILLED
+ si.Code = linux.CLD_KILLED
si.SetStatus(int32(s.Signal()))
case s.CoreDump():
- si.Code = arch.CLD_DUMPED
+ si.Code = linux.CLD_DUMPED
si.SetStatus(int32(s.Signal()))
case s.Stopped():
if wr.Event == kernel.EventTraceeStop {
- si.Code = arch.CLD_TRAPPED
+ si.Code = linux.CLD_TRAPPED
si.SetStatus(int32(s.TrapCause()))
} else {
- si.Code = arch.CLD_STOPPED
+ si.Code = linux.CLD_STOPPED
si.SetStatus(int32(s.StopSignal()))
}
case s.Continued():
- si.Code = arch.CLD_CONTINUED
+ si.Code = linux.CLD_CONTINUED
si.SetStatus(int32(linux.SIGCONT))
default:
t.Warningf("waitid got incomprehensible wait status %d", s)
diff --git a/pkg/sentry/syscalls/linux/sys_time.go b/pkg/sentry/syscalls/linux/sys_time.go
index 83b777bbd..5c3b3dee2 100644
--- a/pkg/sentry/syscalls/linux/sys_time.go
+++ b/pkg/sentry/syscalls/linux/sys_time.go
@@ -180,21 +180,21 @@ func Time(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
//
// +stateify savable
type clockNanosleepRestartBlock struct {
- c ktime.Clock
- duration time.Duration
- rem hostarch.Addr
+ c ktime.Clock
+ end ktime.Time
+ rem hostarch.Addr
}
// Restart implements kernel.SyscallRestartBlock.Restart.
func (n *clockNanosleepRestartBlock) Restart(t *kernel.Task) (uintptr, error) {
- return 0, clockNanosleepFor(t, n.c, n.duration, n.rem)
+ return 0, clockNanosleepUntil(t, n.c, n.end, n.rem, true)
}
// clockNanosleepUntil blocks until a specified time.
//
// If blocking is interrupted, the syscall is restarted with the original
// arguments.
-func clockNanosleepUntil(t *kernel.Task, c ktime.Clock, ts linux.Timespec) error {
+func clockNanosleepUntil(t *kernel.Task, c ktime.Clock, end ktime.Time, rem hostarch.Addr, needRestartBlock bool) error {
notifier, tchan := ktime.NewChannelNotifier()
timer := ktime.NewTimer(c, notifier)
@@ -202,43 +202,22 @@ func clockNanosleepUntil(t *kernel.Task, c ktime.Clock, ts linux.Timespec) error
timer.Swap(ktime.Setting{
Period: 0,
Enabled: true,
- Next: ktime.FromTimespec(ts),
+ Next: end,
})
err := t.BlockWithTimer(nil, tchan)
timer.Destroy()
- // Did we just block until the timeout happened?
- if err == syserror.ETIMEDOUT {
- return nil
- }
-
- return syserror.ConvertIntr(err, syserror.ERESTARTNOHAND)
-}
-
-// clockNanosleepFor blocks for a specified duration.
-//
-// If blocking is interrupted, the syscall is restarted with the remaining
-// duration timeout.
-func clockNanosleepFor(t *kernel.Task, c ktime.Clock, dur time.Duration, rem hostarch.Addr) error {
- timer, start, tchan := ktime.After(c, dur)
-
- err := t.BlockWithTimer(nil, tchan)
-
- after := c.Now()
-
- timer.Destroy()
-
switch err {
case syserror.ETIMEDOUT:
// Slept for entire timeout.
return nil
case syserror.ErrInterrupted:
// Interrupted.
- remaining := dur - after.Sub(start)
- if remaining < 0 {
- remaining = time.Duration(0)
+ remaining := end.Sub(c.Now())
+ if remaining <= 0 {
+ return nil
}
// Copy out remaining time.
@@ -248,14 +227,16 @@ func clockNanosleepFor(t *kernel.Task, c ktime.Clock, dur time.Duration, rem hos
return err
}
}
-
- // Arrange for a restart with the remaining duration.
- t.SetSyscallRestartBlock(&clockNanosleepRestartBlock{
- c: c,
- duration: remaining,
- rem: rem,
- })
- return syserror.ERESTART_RESTARTBLOCK
+ if needRestartBlock {
+ // Arrange for a restart with the remaining duration.
+ t.SetSyscallRestartBlock(&clockNanosleepRestartBlock{
+ c: c,
+ end: end,
+ rem: rem,
+ })
+ return syserror.ERESTART_RESTARTBLOCK
+ }
+ return syserror.ERESTARTNOHAND
default:
panic(fmt.Sprintf("Impossible BlockWithTimer error %v", err))
}
@@ -278,7 +259,8 @@ func Nanosleep(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
// Just like linux, we cap the timeout with the max number that int64 can
// represent which is roughly 292 years.
dur := time.Duration(ts.ToNsecCapped()) * time.Nanosecond
- return 0, nil, clockNanosleepFor(t, t.Kernel().MonotonicClock(), dur, rem)
+ c := t.Kernel().MonotonicClock()
+ return 0, nil, clockNanosleepUntil(t, c, c.Now().Add(dur), rem, true)
}
// ClockNanosleep implements linux syscall clock_nanosleep(2).
@@ -312,11 +294,11 @@ func ClockNanosleep(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
}
if flags&linux.TIMER_ABSTIME != 0 {
- return 0, nil, clockNanosleepUntil(t, c, req)
+ return 0, nil, clockNanosleepUntil(t, c, ktime.FromTimespec(req), 0, false)
}
dur := time.Duration(req.ToNsecCapped()) * time.Nanosecond
- return 0, nil, clockNanosleepFor(t, c, dur, rem)
+ return 0, nil, clockNanosleepUntil(t, c, c.Now().Add(dur), rem, true)
}
// Gettimeofday implements linux syscall gettimeofday(2).
diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go
index c6330c21a..647e089d0 100644
--- a/pkg/sentry/syscalls/linux/vfs2/setstat.go
+++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go
@@ -242,9 +242,9 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
limit := limits.FromContext(t).Get(limits.FileSize).Cur
if uint64(size) >= limit {
- t.SendSignal(&arch.SignalInfo{
+ t.SendSignal(&linux.SignalInfo{
Signo: int32(linux.SIGXFSZ),
- Code: arch.SignalInfoUser,
+ Code: linux.SI_USER,
})
return 0, nil, syserror.EFBIG
}
diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD
index bc75a47fc..202486a1e 100644
--- a/pkg/sentry/time/BUILD
+++ b/pkg/sentry/time/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "seqatomic_parameters_unsafe.go",
package = "time",
suffix = "Parameters",
- template = "//pkg/sync:generic_seqatomic",
+ template = "//pkg/sync/seqatomic:generic_seqatomic",
types = {
"Value": "Parameters",
},
diff --git a/pkg/sentry/time/sampler_amd64.go b/pkg/sentry/time/sampler_amd64.go
index 72cb74be3..9f1b4b2fb 100644
--- a/pkg/sentry/time/sampler_amd64.go
+++ b/pkg/sentry/time/sampler_amd64.go
@@ -16,7 +16,7 @@
package time
-const(
+const (
// defaultOverheadTSC is the default estimated syscall overhead in TSC cycles.
// It is further refined as syscalls are made.
defaultOverheadCycles = 1 * 1000
diff --git a/pkg/sentry/time/sampler_arm64.go b/pkg/sentry/time/sampler_arm64.go
index b9d0273b7..4c8d33ae4 100644
--- a/pkg/sentry/time/sampler_arm64.go
+++ b/pkg/sentry/time/sampler_arm64.go
@@ -30,7 +30,7 @@ func getDefaultArchOverheadCycles() TSCValue {
// x86 devided by frqRatio
cntfrq := getCNTFRQ()
frqRatio := 1000000000 / cntfrq
- overheadCycles := ( 1 * 1000 ) / frqRatio
+ overheadCycles := (1 * 1000) / frqRatio
return overheadCycles
}
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 176bcc242..ef8d8a813 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -454,6 +454,9 @@ type FileDescriptionImpl interface {
// RemoveXattr removes the given extended attribute from the file.
RemoveXattr(ctx context.Context, name string) error
+ // SupportsLocks indicates whether file locks are supported.
+ SupportsLocks() bool
+
// LockBSD tries to acquire a BSD-style advisory file lock.
LockBSD(ctx context.Context, uid lock.UniqueID, ownerPID int32, t lock.LockType, block lock.Blocker) error
@@ -818,6 +821,11 @@ func (fd *FileDescription) Msync(ctx context.Context, mr memmap.MappableRange) e
return fd.Sync(ctx)
}
+// SupportsLocks indicates whether file locks are supported.
+func (fd *FileDescription) SupportsLocks() bool {
+ return fd.impl.SupportsLocks()
+}
+
// LockBSD tries to acquire a BSD-style advisory file lock.
func (fd *FileDescription) LockBSD(ctx context.Context, ownerPID int32, lockType lock.LockType, blocker lock.Blocker) error {
atomic.StoreUint32(&fd.usedLockBSD, 1)
diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go
index b87d9690a..2b6f47b4b 100644
--- a/pkg/sentry/vfs/file_description_impl_util.go
+++ b/pkg/sentry/vfs/file_description_impl_util.go
@@ -413,6 +413,11 @@ type LockFD struct {
locks *FileLocks
}
+// SupportsLocks implements FileDescriptionImpl.SupportsLocks.
+func (LockFD) SupportsLocks() bool {
+ return true
+}
+
// Init initializes fd with FileLocks to use.
func (fd *LockFD) Init(locks *FileLocks) {
fd.locks = locks
@@ -423,28 +428,28 @@ func (fd *LockFD) Locks() *FileLocks {
return fd.locks
}
-// LockBSD implements vfs.FileDescriptionImpl.LockBSD.
+// LockBSD implements FileDescriptionImpl.LockBSD.
func (fd *LockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error {
return fd.locks.LockBSD(ctx, uid, ownerPID, t, block)
}
-// UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD.
+// UnlockBSD implements FileDescriptionImpl.UnlockBSD.
func (fd *LockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error {
fd.locks.UnlockBSD(uid)
return nil
}
-// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+// LockPOSIX implements FileDescriptionImpl.LockPOSIX.
func (fd *LockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, r fslock.LockRange, block fslock.Blocker) error {
return fd.locks.LockPOSIX(ctx, uid, ownerPID, t, r, block)
}
-// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+// UnlockPOSIX implements FileDescriptionImpl.UnlockPOSIX.
func (fd *LockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, r fslock.LockRange) error {
return fd.locks.UnlockPOSIX(ctx, uid, r)
}
-// TestPOSIX implements vfs.FileDescriptionImpl.TestPOSIX.
+// TestPOSIX implements FileDescriptionImpl.TestPOSIX.
func (fd *LockFD) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, r fslock.LockRange) (linux.Flock, error) {
return fd.locks.TestPOSIX(ctx, uid, t, r)
}
@@ -455,27 +460,68 @@ func (fd *LockFD) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.L
// +stateify savable
type NoLockFD struct{}
-// LockBSD implements vfs.FileDescriptionImpl.LockBSD.
+// SupportsLocks implements FileDescriptionImpl.SupportsLocks.
+func (NoLockFD) SupportsLocks() bool {
+ return false
+}
+
+// LockBSD implements FileDescriptionImpl.LockBSD.
func (NoLockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error {
return syserror.ENOLCK
}
-// UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD.
+// UnlockBSD implements FileDescriptionImpl.UnlockBSD.
func (NoLockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error {
return syserror.ENOLCK
}
-// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+// LockPOSIX implements FileDescriptionImpl.LockPOSIX.
func (NoLockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, r fslock.LockRange, block fslock.Blocker) error {
return syserror.ENOLCK
}
-// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+// UnlockPOSIX implements FileDescriptionImpl.UnlockPOSIX.
func (NoLockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, r fslock.LockRange) error {
return syserror.ENOLCK
}
-// TestPOSIX implements vfs.FileDescriptionImpl.TestPOSIX.
+// TestPOSIX implements FileDescriptionImpl.TestPOSIX.
func (NoLockFD) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, r fslock.LockRange) (linux.Flock, error) {
return linux.Flock{}, syserror.ENOLCK
}
+
+// BadLockFD implements Lock*/Unlock* portion of FileDescriptionImpl interface
+// returning EBADF.
+//
+// +stateify savable
+type BadLockFD struct{}
+
+// SupportsLocks implements FileDescriptionImpl.SupportsLocks.
+func (BadLockFD) SupportsLocks() bool {
+ return false
+}
+
+// LockBSD implements FileDescriptionImpl.LockBSD.
+func (BadLockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error {
+ return syserror.EBADF
+}
+
+// UnlockBSD implements FileDescriptionImpl.UnlockBSD.
+func (BadLockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error {
+ return syserror.EBADF
+}
+
+// LockPOSIX implements FileDescriptionImpl.LockPOSIX.
+func (BadLockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, r fslock.LockRange, block fslock.Blocker) error {
+ return syserror.EBADF
+}
+
+// UnlockPOSIX implements FileDescriptionImpl.UnlockPOSIX.
+func (BadLockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, r fslock.LockRange) error {
+ return syserror.EBADF
+}
+
+// TestPOSIX implements FileDescriptionImpl.TestPOSIX.
+func (BadLockFD) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, r fslock.LockRange) (linux.Flock, error) {
+ return linux.Flock{}, syserror.EBADF
+}
diff --git a/pkg/sentry/vfs/memxattr/BUILD b/pkg/sentry/vfs/memxattr/BUILD
index d8c4d27b9..ea82f4987 100644
--- a/pkg/sentry/vfs/memxattr/BUILD
+++ b/pkg/sentry/vfs/memxattr/BUILD
@@ -8,6 +8,7 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
+ "//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
"//pkg/sync",
"//pkg/syserror",
diff --git a/pkg/sentry/vfs/memxattr/xattr.go b/pkg/sentry/vfs/memxattr/xattr.go
index 638b5d830..9b7953fa3 100644
--- a/pkg/sentry/vfs/memxattr/xattr.go
+++ b/pkg/sentry/vfs/memxattr/xattr.go
@@ -17,7 +17,10 @@
package memxattr
import (
+ "strings"
+
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -26,6 +29,9 @@ import (
// SimpleExtendedAttributes implements extended attributes using a map of
// names to values.
//
+// SimpleExtendedAttributes calls vfs.CheckXattrPermissions, so callers are not
+// required to do so.
+//
// +stateify savable
type SimpleExtendedAttributes struct {
// mu protects the below fields.
@@ -34,7 +40,11 @@ type SimpleExtendedAttributes struct {
}
// GetXattr returns the value at 'name'.
-func (x *SimpleExtendedAttributes) GetXattr(opts *vfs.GetXattrOptions) (string, error) {
+func (x *SimpleExtendedAttributes) GetXattr(creds *auth.Credentials, mode linux.FileMode, kuid auth.KUID, opts *vfs.GetXattrOptions) (string, error) {
+ if err := vfs.CheckXattrPermissions(creds, vfs.MayRead, mode, kuid, opts.Name); err != nil {
+ return "", err
+ }
+
x.mu.RLock()
value, ok := x.xattrs[opts.Name]
x.mu.RUnlock()
@@ -50,7 +60,11 @@ func (x *SimpleExtendedAttributes) GetXattr(opts *vfs.GetXattrOptions) (string,
}
// SetXattr sets 'value' at 'name'.
-func (x *SimpleExtendedAttributes) SetXattr(opts *vfs.SetXattrOptions) error {
+func (x *SimpleExtendedAttributes) SetXattr(creds *auth.Credentials, mode linux.FileMode, kuid auth.KUID, opts *vfs.SetXattrOptions) error {
+ if err := vfs.CheckXattrPermissions(creds, vfs.MayWrite, mode, kuid, opts.Name); err != nil {
+ return err
+ }
+
x.mu.Lock()
defer x.mu.Unlock()
if x.xattrs == nil {
@@ -73,12 +87,19 @@ func (x *SimpleExtendedAttributes) SetXattr(opts *vfs.SetXattrOptions) error {
}
// ListXattr returns all names in xattrs.
-func (x *SimpleExtendedAttributes) ListXattr(size uint64) ([]string, error) {
+func (x *SimpleExtendedAttributes) ListXattr(creds *auth.Credentials, size uint64) ([]string, error) {
// Keep track of the size of the buffer needed in listxattr(2) for the list.
listSize := 0
x.mu.RLock()
names := make([]string, 0, len(x.xattrs))
+ haveCap := creds.HasCapability(linux.CAP_SYS_ADMIN)
for n := range x.xattrs {
+ // Hide extended attributes in the "trusted" namespace from
+ // non-privileged users. This is consistent with Linux's
+ // fs/xattr.c:simple_xattr_list().
+ if !haveCap && strings.HasPrefix(n, linux.XATTR_TRUSTED_PREFIX) {
+ continue
+ }
names = append(names, n)
// Add one byte per null terminator.
listSize += len(n) + 1
@@ -91,7 +112,11 @@ func (x *SimpleExtendedAttributes) ListXattr(size uint64) ([]string, error) {
}
// RemoveXattr removes the xattr at 'name'.
-func (x *SimpleExtendedAttributes) RemoveXattr(name string) error {
+func (x *SimpleExtendedAttributes) RemoveXattr(creds *auth.Credentials, mode linux.FileMode, kuid auth.KUID, name string) error {
+ if err := vfs.CheckXattrPermissions(creds, vfs.MayWrite, mode, kuid, name); err != nil {
+ return err
+ }
+
x.mu.Lock()
defer x.mu.Unlock()
if _, ok := x.xattrs[name]; !ok {
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index 82fd382c2..f93da3af1 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -220,7 +220,7 @@ func (vfs *VirtualFilesystem) ConnectMountAt(ctx context.Context, creds *auth.Cr
vdDentry := vd.dentry
vdDentry.mu.Lock()
for {
- if vdDentry.dead {
+ if vd.mount.umounted || vdDentry.dead {
vdDentry.mu.Unlock()
vfs.mountMu.Unlock()
vd.DecRef(ctx)
diff --git a/pkg/sentry/vfs/opath.go b/pkg/sentry/vfs/opath.go
index 47848c76b..e9651b631 100644
--- a/pkg/sentry/vfs/opath.go
+++ b/pkg/sentry/vfs/opath.go
@@ -24,96 +24,96 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-// opathFD implements vfs.FileDescriptionImpl for a file description opened with O_PATH.
+// opathFD implements FileDescriptionImpl for a file description opened with O_PATH.
//
// +stateify savable
type opathFD struct {
vfsfd FileDescription
FileDescriptionDefaultImpl
- NoLockFD
+ BadLockFD
}
-// Release implements vfs.FileDescriptionImpl.Release.
+// Release implements FileDescriptionImpl.Release.
func (fd *opathFD) Release(context.Context) {
// noop
}
-// Allocate implements vfs.FileDescriptionImpl.Allocate.
+// Allocate implements FileDescriptionImpl.Allocate.
func (fd *opathFD) Allocate(ctx context.Context, mode, offset, length uint64) error {
return syserror.EBADF
}
-// PRead implements vfs.FileDescriptionImpl.PRead.
+// PRead implements FileDescriptionImpl.PRead.
func (fd *opathFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) {
return 0, syserror.EBADF
}
-// Read implements vfs.FileDescriptionImpl.Read.
+// Read implements FileDescriptionImpl.Read.
func (fd *opathFD) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) {
return 0, syserror.EBADF
}
-// PWrite implements vfs.FileDescriptionImpl.PWrite.
+// PWrite implements FileDescriptionImpl.PWrite.
func (fd *opathFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) {
return 0, syserror.EBADF
}
-// Write implements vfs.FileDescriptionImpl.Write.
+// Write implements FileDescriptionImpl.Write.
func (fd *opathFD) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) {
return 0, syserror.EBADF
}
-// Ioctl implements vfs.FileDescriptionImpl.Ioctl.
+// Ioctl implements FileDescriptionImpl.Ioctl.
func (fd *opathFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
return 0, syserror.EBADF
}
-// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
+// IterDirents implements FileDescriptionImpl.IterDirents.
func (fd *opathFD) IterDirents(ctx context.Context, cb IterDirentsCallback) error {
return syserror.EBADF
}
-// Seek implements vfs.FileDescriptionImpl.Seek.
+// Seek implements FileDescriptionImpl.Seek.
func (fd *opathFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
return 0, syserror.EBADF
}
-// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+// ConfigureMMap implements FileDescriptionImpl.ConfigureMMap.
func (fd *opathFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
return syserror.EBADF
}
-// ListXattr implements vfs.FileDescriptionImpl.ListXattr.
+// ListXattr implements FileDescriptionImpl.ListXattr.
func (fd *opathFD) ListXattr(ctx context.Context, size uint64) ([]string, error) {
return nil, syserror.EBADF
}
-// GetXattr implements vfs.FileDescriptionImpl.GetXattr.
+// GetXattr implements FileDescriptionImpl.GetXattr.
func (fd *opathFD) GetXattr(ctx context.Context, opts GetXattrOptions) (string, error) {
return "", syserror.EBADF
}
-// SetXattr implements vfs.FileDescriptionImpl.SetXattr.
+// SetXattr implements FileDescriptionImpl.SetXattr.
func (fd *opathFD) SetXattr(ctx context.Context, opts SetXattrOptions) error {
return syserror.EBADF
}
-// RemoveXattr implements vfs.FileDescriptionImpl.RemoveXattr.
+// RemoveXattr implements FileDescriptionImpl.RemoveXattr.
func (fd *opathFD) RemoveXattr(ctx context.Context, name string) error {
return syserror.EBADF
}
-// Sync implements vfs.FileDescriptionImpl.Sync.
+// Sync implements FileDescriptionImpl.Sync.
func (fd *opathFD) Sync(ctx context.Context) error {
return syserror.EBADF
}
-// SetStat implements vfs.FileDescriptionImpl.SetStat.
+// SetStat implements FileDescriptionImpl.SetStat.
func (fd *opathFD) SetStat(ctx context.Context, opts SetStatOptions) error {
return syserror.EBADF
}
-// Stat implements vfs.FileDescriptionImpl.Stat.
+// Stat implements FileDescriptionImpl.Stat.
func (fd *opathFD) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) {
vfsObj := fd.vfsfd.vd.mount.vfs
rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go
index 8e3146d8d..8d563d53a 100644
--- a/pkg/sentry/watchdog/watchdog.go
+++ b/pkg/sentry/watchdog/watchdog.go
@@ -115,14 +115,14 @@ func (a *Action) Get() interface{} {
}
// String returns Action's string representation.
-func (a *Action) String() string {
- switch *a {
+func (a Action) String() string {
+ switch a {
case LogWarning:
return "logWarning"
case Panic:
return "panic"
default:
- panic(fmt.Sprintf("Invalid watchdog action: %d", *a))
+ panic(fmt.Sprintf("Invalid watchdog action: %d", a))
}
}
@@ -243,6 +243,7 @@ func (w *Watchdog) waitForStart() {
}
stuckStartup.Increment()
+ metric.WeirdnessMetric.Increment("watchdog_stuck_startup")
var buf bytes.Buffer
buf.WriteString(fmt.Sprintf("Watchdog.Start() not called within %s", w.StartupTimeout))
@@ -312,10 +313,11 @@ func (w *Watchdog) runTurn() {
// New stuck task detected.
//
// Note that tasks blocked doing IO may be considered stuck in kernel,
- // unless they are surrounded b
+ // unless they are surrounded by
// Task.UninterruptibleSleepStart/Finish.
tc = &offender{lastUpdateTime: lastUpdateTime}
stuckTasks.Increment()
+ metric.WeirdnessMetric.Increment("watchdog_stuck_tasks")
newTaskFound = true
}
newOffenders[t] = tc
diff --git a/pkg/shim/BUILD b/pkg/shim/BUILD
index 4f7c02f5d..fd6127b97 100644
--- a/pkg/shim/BUILD
+++ b/pkg/shim/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -41,7 +41,19 @@ go_library(
"@com_github_containerd_fifo//:go_default_library",
"@com_github_containerd_typeurl//:go_default_library",
"@com_github_gogo_protobuf//types:go_default_library",
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
"@com_github_sirupsen_logrus//:go_default_library",
"@org_golang_x_sys//unix:go_default_library",
],
)
+
+go_test(
+ name = "shim_test",
+ size = "small",
+ srcs = ["service_test.go"],
+ library = ":shim",
+ deps = [
+ "//pkg/shim/utils",
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
+ ],
+)
diff --git a/pkg/shim/service.go b/pkg/shim/service.go
index 9d9fa8ef6..1f9adcb65 100644
--- a/pkg/shim/service.go
+++ b/pkg/shim/service.go
@@ -22,6 +22,7 @@ import (
"os"
"os/exec"
"path/filepath"
+ "strings"
"sync"
"time"
@@ -44,6 +45,7 @@ import (
"github.com/containerd/containerd/sys/reaper"
"github.com/containerd/typeurl"
"github.com/gogo/protobuf/types"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/cleanup"
@@ -944,9 +946,19 @@ func newInit(path, workDir, namespace string, platform stdio.Platform, r *proc.C
if err != nil {
return nil, fmt.Errorf("read oci spec: %w", err)
}
- if err := utils.UpdateVolumeAnnotations(r.Bundle, spec); err != nil {
+
+ updated, err := utils.UpdateVolumeAnnotations(spec)
+ if err != nil {
return nil, fmt.Errorf("update volume annotations: %w", err)
}
+ updated = updateCgroup(spec) || updated
+
+ if updated {
+ if err := utils.WriteSpec(r.Bundle, spec); err != nil {
+ return nil, err
+ }
+ }
+
runsc.FormatRunscLogPath(r.ID, options.RunscConfig)
runtime := proc.NewRunsc(options.Root, path, namespace, options.BinaryName, options.RunscConfig)
p := proc.New(r.ID, runtime, stdio.Stdio{
@@ -966,3 +978,39 @@ func newInit(path, workDir, namespace string, platform stdio.Platform, r *proc.C
p.Monitor = reaper.Default
return p, nil
}
+
+// updateCgroup updates cgroup path for the sandbox to make the sandbox join the
+// pod cgroup and not the pause container cgroup. Returns true if the spec was
+// modified. Ex.:
+// /kubepods/burstable/pod123/abc => kubepods/burstable/pod123
+//
+func updateCgroup(spec *specs.Spec) bool {
+ if !utils.IsSandbox(spec) {
+ return false
+ }
+ if spec.Linux == nil || len(spec.Linux.CgroupsPath) == 0 {
+ return false
+ }
+
+ // Search backwards for the pod cgroup path to make the sandbox use it,
+ // instead of the pause container's cgroup.
+ parts := strings.Split(spec.Linux.CgroupsPath, string(filepath.Separator))
+ for i := len(parts) - 1; i >= 0; i-- {
+ if strings.HasPrefix(parts[i], "pod") {
+ var path string
+ for j := 0; j <= i; j++ {
+ path = filepath.Join(path, parts[j])
+ }
+ // Add back the initial '/' that may have been lost above.
+ if filepath.IsAbs(spec.Linux.CgroupsPath) {
+ path = string(filepath.Separator) + path
+ }
+ if spec.Linux.CgroupsPath == path {
+ return false
+ }
+ spec.Linux.CgroupsPath = path
+ return true
+ }
+ }
+ return false
+}
diff --git a/pkg/shim/service_test.go b/pkg/shim/service_test.go
new file mode 100644
index 000000000..2d9f07e02
--- /dev/null
+++ b/pkg/shim/service_test.go
@@ -0,0 +1,121 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package shim
+
+import (
+ "testing"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/shim/utils"
+)
+
+func TestCgroupPath(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ path string
+ want string
+ }{
+ {
+ name: "simple",
+ path: "foo/pod123/container",
+ want: "foo/pod123",
+ },
+ {
+ name: "absolute",
+ path: "/foo/pod123/container",
+ want: "/foo/pod123",
+ },
+ {
+ name: "no-container",
+ path: "foo/pod123",
+ want: "foo/pod123",
+ },
+ {
+ name: "no-container-absolute",
+ path: "/foo/pod123",
+ want: "/foo/pod123",
+ },
+ {
+ name: "double-pod",
+ path: "/foo/podium/pod123/container",
+ want: "/foo/podium/pod123",
+ },
+ {
+ name: "start-pod",
+ path: "pod123/container",
+ want: "pod123",
+ },
+ {
+ name: "start-pod-absolute",
+ path: "/pod123/container",
+ want: "/pod123",
+ },
+ {
+ name: "slashes",
+ path: "///foo/////pod123//////container",
+ want: "/foo/pod123",
+ },
+ {
+ name: "no-pod",
+ path: "/foo/nopod123/container",
+ want: "/foo/nopod123/container",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ spec := specs.Spec{
+ Linux: &specs.Linux{
+ CgroupsPath: tc.path,
+ },
+ }
+ updated := updateCgroup(&spec)
+ if spec.Linux.CgroupsPath != tc.want {
+ t.Errorf("updateCgroup(%q), want: %q, got: %q", tc.path, tc.want, spec.Linux.CgroupsPath)
+ }
+ if shouldUpdate := tc.path != tc.want; shouldUpdate != updated {
+ t.Errorf("updateCgroup(%q)=%v, want: %v", tc.path, updated, shouldUpdate)
+ }
+ })
+ }
+}
+
+// Test cases that cgroup path should not be updated.
+func TestCgroupNoUpdate(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ spec *specs.Spec
+ }{
+ {
+ name: "empty",
+ spec: &specs.Spec{},
+ },
+ {
+ name: "subcontainer",
+ spec: &specs.Spec{
+ Linux: &specs.Linux{
+ CgroupsPath: "foo/pod123/container",
+ },
+ Annotations: map[string]string{
+ utils.ContainerTypeAnnotation: utils.ContainerTypeContainer,
+ },
+ },
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ if updated := updateCgroup(tc.spec); updated {
+ t.Errorf("updateCgroup(%+v), got: %v, want: false", tc.spec.Linux, updated)
+ }
+ })
+ }
+}
diff --git a/pkg/shim/utils/annotations.go b/pkg/shim/utils/annotations.go
index 1e9d3f365..c744800bb 100644
--- a/pkg/shim/utils/annotations.go
+++ b/pkg/shim/utils/annotations.go
@@ -19,7 +19,9 @@ package utils
// These are vendor due to import conflicts.
const (
sandboxLogDirAnnotation = "io.kubernetes.cri.sandbox-log-directory"
- containerTypeAnnotation = "io.kubernetes.cri.container-type"
+ // ContainerTypeAnnotation is they key that defines sandbox or container.
+ ContainerTypeAnnotation = "io.kubernetes.cri.container-type"
containerTypeSandbox = "sandbox"
- containerTypeContainer = "container"
+ // ContainerTypeContainer is the value for container.
+ ContainerTypeContainer = "container"
)
diff --git a/pkg/shim/utils/utils.go b/pkg/shim/utils/utils.go
index 7b1cd983e..f183b1bbc 100644
--- a/pkg/shim/utils/utils.go
+++ b/pkg/shim/utils/utils.go
@@ -18,19 +18,16 @@ package utils
import (
"encoding/json"
"io/ioutil"
- "os"
"path/filepath"
specs "github.com/opencontainers/runtime-spec/specs-go"
)
+const configFilename = "config.json"
+
// ReadSpec reads OCI spec from the bundle directory.
func ReadSpec(bundle string) (*specs.Spec, error) {
- f, err := os.Open(filepath.Join(bundle, "config.json"))
- if err != nil {
- return nil, err
- }
- b, err := ioutil.ReadAll(f)
+ b, err := ioutil.ReadFile(filepath.Join(bundle, configFilename))
if err != nil {
return nil, err
}
@@ -41,9 +38,18 @@ func ReadSpec(bundle string) (*specs.Spec, error) {
return &spec, nil
}
+// WriteSpec writes OCI spec to the bundle directory.
+func WriteSpec(bundle string, spec *specs.Spec) error {
+ b, err := json.Marshal(spec)
+ if err != nil {
+ return err
+ }
+ return ioutil.WriteFile(filepath.Join(bundle, configFilename), b, 0666)
+}
+
// IsSandbox checks whether a container is a sandbox container.
func IsSandbox(spec *specs.Spec) bool {
- t, ok := spec.Annotations[containerTypeAnnotation]
+ t, ok := spec.Annotations[ContainerTypeAnnotation]
return !ok || t == containerTypeSandbox
}
diff --git a/pkg/shim/utils/volumes.go b/pkg/shim/utils/volumes.go
index cdcb88229..6bc75139d 100644
--- a/pkg/shim/utils/volumes.go
+++ b/pkg/shim/utils/volumes.go
@@ -15,9 +15,7 @@
package utils
import (
- "encoding/json"
"fmt"
- "io/ioutil"
"path/filepath"
"strings"
@@ -89,8 +87,8 @@ func isVolumePath(volume, path string) (bool, error) {
}
// UpdateVolumeAnnotations add necessary OCI annotations for gvisor
-// volume optimization.
-func UpdateVolumeAnnotations(bundle string, s *specs.Spec) error {
+// volume optimization. Returns true if the spec was modified.
+func UpdateVolumeAnnotations(s *specs.Spec) (bool, error) {
var uid string
if IsSandbox(s) {
var err error
@@ -98,7 +96,7 @@ func UpdateVolumeAnnotations(bundle string, s *specs.Spec) error {
if err != nil {
// Skip if we can't get pod UID, because this doesn't work
// for containerd 1.1.
- return nil
+ return false, nil
}
}
var updated bool
@@ -114,7 +112,7 @@ func UpdateVolumeAnnotations(bundle string, s *specs.Spec) error {
// This is a sandbox.
path, err := volumePath(volume, uid)
if err != nil {
- return fmt.Errorf("get volume path for %q: %w", volume, err)
+ return false, fmt.Errorf("get volume path for %q: %w", volume, err)
}
s.Annotations[volumeSourceKey(volume)] = path
updated = true
@@ -138,15 +136,7 @@ func UpdateVolumeAnnotations(bundle string, s *specs.Spec) error {
}
}
}
- if !updated {
- return nil
- }
- // Update bundle.
- b, err := json.Marshal(s)
- if err != nil {
- return err
- }
- return ioutil.WriteFile(filepath.Join(bundle, "config.json"), b, 0666)
+ return updated, nil
}
func changeMountType(m *specs.Mount, newType string) {
diff --git a/pkg/shim/utils/volumes_test.go b/pkg/shim/utils/volumes_test.go
index b25c53c73..5db43cdf1 100644
--- a/pkg/shim/utils/volumes_test.go
+++ b/pkg/shim/utils/volumes_test.go
@@ -15,11 +15,9 @@
package utils
import (
- "encoding/json"
"fmt"
"io/ioutil"
"os"
- "path/filepath"
"reflect"
"testing"
@@ -58,7 +56,7 @@ func TestUpdateVolumeAnnotations(t *testing.T) {
spec: &specs.Spec{
Annotations: map[string]string{
sandboxLogDirAnnotation: testLogDirPath,
- containerTypeAnnotation: containerTypeSandbox,
+ ContainerTypeAnnotation: containerTypeSandbox,
volumeKeyPrefix + testVolumeName + ".share": "pod",
volumeKeyPrefix + testVolumeName + ".type": "tmpfs",
volumeKeyPrefix + testVolumeName + ".options": "ro",
@@ -67,7 +65,7 @@ func TestUpdateVolumeAnnotations(t *testing.T) {
expected: &specs.Spec{
Annotations: map[string]string{
sandboxLogDirAnnotation: testLogDirPath,
- containerTypeAnnotation: containerTypeSandbox,
+ ContainerTypeAnnotation: containerTypeSandbox,
volumeKeyPrefix + testVolumeName + ".share": "pod",
volumeKeyPrefix + testVolumeName + ".type": "tmpfs",
volumeKeyPrefix + testVolumeName + ".options": "ro",
@@ -81,7 +79,7 @@ func TestUpdateVolumeAnnotations(t *testing.T) {
spec: &specs.Spec{
Annotations: map[string]string{
sandboxLogDirAnnotation: testLegacyLogDirPath,
- containerTypeAnnotation: containerTypeSandbox,
+ ContainerTypeAnnotation: containerTypeSandbox,
volumeKeyPrefix + testVolumeName + ".share": "pod",
volumeKeyPrefix + testVolumeName + ".type": "tmpfs",
volumeKeyPrefix + testVolumeName + ".options": "ro",
@@ -90,7 +88,7 @@ func TestUpdateVolumeAnnotations(t *testing.T) {
expected: &specs.Spec{
Annotations: map[string]string{
sandboxLogDirAnnotation: testLegacyLogDirPath,
- containerTypeAnnotation: containerTypeSandbox,
+ ContainerTypeAnnotation: containerTypeSandbox,
volumeKeyPrefix + testVolumeName + ".share": "pod",
volumeKeyPrefix + testVolumeName + ".type": "tmpfs",
volumeKeyPrefix + testVolumeName + ".options": "ro",
@@ -117,7 +115,7 @@ func TestUpdateVolumeAnnotations(t *testing.T) {
},
},
Annotations: map[string]string{
- containerTypeAnnotation: containerTypeContainer,
+ ContainerTypeAnnotation: ContainerTypeContainer,
volumeKeyPrefix + testVolumeName + ".share": "pod",
volumeKeyPrefix + testVolumeName + ".type": "tmpfs",
volumeKeyPrefix + testVolumeName + ".options": "ro",
@@ -139,7 +137,7 @@ func TestUpdateVolumeAnnotations(t *testing.T) {
},
},
Annotations: map[string]string{
- containerTypeAnnotation: containerTypeContainer,
+ ContainerTypeAnnotation: ContainerTypeContainer,
volumeKeyPrefix + testVolumeName + ".share": "pod",
volumeKeyPrefix + testVolumeName + ".type": "tmpfs",
volumeKeyPrefix + testVolumeName + ".options": "ro",
@@ -159,7 +157,7 @@ func TestUpdateVolumeAnnotations(t *testing.T) {
},
},
Annotations: map[string]string{
- containerTypeAnnotation: containerTypeContainer,
+ ContainerTypeAnnotation: ContainerTypeContainer,
volumeKeyPrefix + testVolumeName + ".share": "container",
volumeKeyPrefix + testVolumeName + ".type": "bind",
volumeKeyPrefix + testVolumeName + ".options": "ro",
@@ -175,7 +173,7 @@ func TestUpdateVolumeAnnotations(t *testing.T) {
},
},
Annotations: map[string]string{
- containerTypeAnnotation: containerTypeContainer,
+ ContainerTypeAnnotation: ContainerTypeContainer,
volumeKeyPrefix + testVolumeName + ".share": "container",
volumeKeyPrefix + testVolumeName + ".type": "bind",
volumeKeyPrefix + testVolumeName + ".options": "ro",
@@ -187,7 +185,7 @@ func TestUpdateVolumeAnnotations(t *testing.T) {
name: "should not return error without pod log directory",
spec: &specs.Spec{
Annotations: map[string]string{
- containerTypeAnnotation: containerTypeSandbox,
+ ContainerTypeAnnotation: containerTypeSandbox,
volumeKeyPrefix + testVolumeName + ".share": "pod",
volumeKeyPrefix + testVolumeName + ".type": "tmpfs",
volumeKeyPrefix + testVolumeName + ".options": "ro",
@@ -195,7 +193,7 @@ func TestUpdateVolumeAnnotations(t *testing.T) {
},
expected: &specs.Spec{
Annotations: map[string]string{
- containerTypeAnnotation: containerTypeSandbox,
+ ContainerTypeAnnotation: containerTypeSandbox,
volumeKeyPrefix + testVolumeName + ".share": "pod",
volumeKeyPrefix + testVolumeName + ".type": "tmpfs",
volumeKeyPrefix + testVolumeName + ".options": "ro",
@@ -207,7 +205,7 @@ func TestUpdateVolumeAnnotations(t *testing.T) {
spec: &specs.Spec{
Annotations: map[string]string{
sandboxLogDirAnnotation: testLogDirPath,
- containerTypeAnnotation: containerTypeSandbox,
+ ContainerTypeAnnotation: containerTypeSandbox,
volumeKeyPrefix + "notexist.share": "pod",
volumeKeyPrefix + "notexist.type": "tmpfs",
volumeKeyPrefix + "notexist.options": "ro",
@@ -220,13 +218,13 @@ func TestUpdateVolumeAnnotations(t *testing.T) {
spec: &specs.Spec{
Annotations: map[string]string{
sandboxLogDirAnnotation: testLogDirPath,
- containerTypeAnnotation: containerTypeSandbox,
+ ContainerTypeAnnotation: containerTypeSandbox,
},
},
expected: &specs.Spec{
Annotations: map[string]string{
sandboxLogDirAnnotation: testLogDirPath,
- containerTypeAnnotation: containerTypeSandbox,
+ ContainerTypeAnnotation: containerTypeSandbox,
},
},
},
@@ -248,7 +246,7 @@ func TestUpdateVolumeAnnotations(t *testing.T) {
},
},
Annotations: map[string]string{
- containerTypeAnnotation: containerTypeContainer,
+ ContainerTypeAnnotation: ContainerTypeContainer,
},
},
expected: &specs.Spec{
@@ -267,7 +265,7 @@ func TestUpdateVolumeAnnotations(t *testing.T) {
},
},
Annotations: map[string]string{
- containerTypeAnnotation: containerTypeContainer,
+ ContainerTypeAnnotation: ContainerTypeContainer,
},
},
},
@@ -275,7 +273,7 @@ func TestUpdateVolumeAnnotations(t *testing.T) {
name: "bind options removed",
spec: &specs.Spec{
Annotations: map[string]string{
- containerTypeAnnotation: containerTypeContainer,
+ ContainerTypeAnnotation: ContainerTypeContainer,
volumeKeyPrefix + testVolumeName + ".share": "pod",
volumeKeyPrefix + testVolumeName + ".type": "tmpfs",
volumeKeyPrefix + testVolumeName + ".options": "ro",
@@ -292,7 +290,7 @@ func TestUpdateVolumeAnnotations(t *testing.T) {
},
expected: &specs.Spec{
Annotations: map[string]string{
- containerTypeAnnotation: containerTypeContainer,
+ ContainerTypeAnnotation: ContainerTypeContainer,
volumeKeyPrefix + testVolumeName + ".share": "pod",
volumeKeyPrefix + testVolumeName + ".type": "tmpfs",
volumeKeyPrefix + testVolumeName + ".options": "ro",
@@ -311,11 +309,7 @@ func TestUpdateVolumeAnnotations(t *testing.T) {
},
} {
t.Run(test.name, func(t *testing.T) {
- bundle, err := ioutil.TempDir(dir, "test-bundle")
- if err != nil {
- t.Fatalf("Create test bundle: %v", err)
- }
- err = UpdateVolumeAnnotations(bundle, test.spec)
+ updated, err := UpdateVolumeAnnotations(test.spec)
if test.expectErr {
if err == nil {
t.Fatal("Expected error, but got nil")
@@ -328,18 +322,8 @@ func TestUpdateVolumeAnnotations(t *testing.T) {
if !reflect.DeepEqual(test.expected, test.spec) {
t.Fatalf("Expected %+v, got %+v", test.expected, test.spec)
}
- if test.expectUpdate {
- b, err := ioutil.ReadFile(filepath.Join(bundle, "config.json"))
- if err != nil {
- t.Fatalf("Read spec from bundle: %v", err)
- }
- var spec specs.Spec
- if err := json.Unmarshal(b, &spec); err != nil {
- t.Fatalf("Unmarshal spec: %v", err)
- }
- if !reflect.DeepEqual(test.expected, &spec) {
- t.Fatalf("Expected %+v, got %+v", test.expected, &spec)
- }
+ if test.expectUpdate != updated {
+ t.Errorf("Expected %v, got %v", test.expected, updated)
}
})
}
diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD
index 8b3a11c64..73791b456 100644
--- a/pkg/sync/BUILD
+++ b/pkg/sync/BUILD
@@ -1,5 +1,4 @@
load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/go_generics:defs.bzl", "go_template")
package(
default_visibility = ["//:sandbox"],
@@ -8,45 +7,6 @@ package(
exports_files(["LICENSE"])
-go_template(
- name = "generic_atomicptr",
- srcs = ["generic_atomicptr_unsafe.go"],
- types = [
- "Value",
- ],
-)
-
-go_template(
- name = "generic_atomicptrmap",
- srcs = ["generic_atomicptrmap_unsafe.go"],
- opt_consts = [
- "ShardOrder",
- ],
- opt_types = [
- "Hasher",
- ],
- types = [
- "Key",
- "Value",
- ],
- deps = [
- ":sync",
- "//pkg/gohacks",
- ],
-)
-
-go_template(
- name = "generic_seqatomic",
- srcs = ["generic_seqatomic_unsafe.go"],
- types = [
- "Value",
- ],
- deps = [
- ":sync",
- "//pkg/gohacks",
- ],
-)
-
go_library(
name = "sync",
srcs = [
diff --git a/pkg/sync/README.md b/pkg/sync/README.md
index 2183c4e20..be1a01f08 100644
--- a/pkg/sync/README.md
+++ b/pkg/sync/README.md
@@ -1,4 +1,4 @@
-# Syncutil
+# sync
This package provides additional synchronization primitives not provided by the
Go stdlib 'sync' package. It is partially derived from the upstream 'sync'
diff --git a/pkg/sync/atomicptrtest/BUILD b/pkg/sync/atomicptr/BUILD
index e97553254..a6a7f01ac 100644
--- a/pkg/sync/atomicptrtest/BUILD
+++ b/pkg/sync/atomicptr/BUILD
@@ -1,14 +1,23 @@
load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
+load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
package(licenses = ["notice"])
+go_template(
+ name = "generic_atomicptr",
+ srcs = ["generic_atomicptr_unsafe.go"],
+ types = [
+ "Value",
+ ],
+ visibility = ["//:sandbox"],
+)
+
go_template_instance(
name = "atomicptr_int",
out = "atomicptr_int_unsafe.go",
package = "atomicptr",
suffix = "Int",
- template = "//pkg/sync:generic_atomicptr",
+ template = ":generic_atomicptr",
types = {
"Value": "int",
},
diff --git a/pkg/sync/atomicptrtest/atomicptr_test.go b/pkg/sync/atomicptr/atomicptr_test.go
index 8fdc5112e..8fdc5112e 100644
--- a/pkg/sync/atomicptrtest/atomicptr_test.go
+++ b/pkg/sync/atomicptr/atomicptr_test.go
diff --git a/pkg/sync/generic_atomicptr_unsafe.go b/pkg/sync/atomicptr/generic_atomicptr_unsafe.go
index 82b6df18c..82b6df18c 100644
--- a/pkg/sync/generic_atomicptr_unsafe.go
+++ b/pkg/sync/atomicptr/generic_atomicptr_unsafe.go
diff --git a/pkg/sync/atomicptrmaptest/BUILD b/pkg/sync/atomicptrmap/BUILD
index 3f71ae97d..b0e218c79 100644
--- a/pkg/sync/atomicptrmaptest/BUILD
+++ b/pkg/sync/atomicptrmap/BUILD
@@ -1,17 +1,36 @@
load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
+load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
package(
default_visibility = ["//visibility:private"],
licenses = ["notice"],
)
+go_template(
+ name = "generic_atomicptrmap",
+ srcs = ["generic_atomicptrmap_unsafe.go"],
+ opt_consts = [
+ "ShardOrder",
+ ],
+ opt_types = [
+ "Hasher",
+ ],
+ types = [
+ "Key",
+ "Value",
+ ],
+ deps = [
+ "//pkg/gohacks",
+ "//pkg/sync",
+ ],
+)
+
go_template_instance(
name = "test_atomicptrmap",
out = "test_atomicptrmap_unsafe.go",
package = "atomicptrmap",
prefix = "test",
- template = "//pkg/sync:generic_atomicptrmap",
+ template = ":generic_atomicptrmap",
types = {
"Key": "int64",
"Value": "testValue",
@@ -27,7 +46,7 @@ go_template_instance(
package = "atomicptrmap",
prefix = "test",
suffix = "Sharded",
- template = "//pkg/sync:generic_atomicptrmap",
+ template = ":generic_atomicptrmap",
types = {
"Key": "int64",
"Value": "testValue",
diff --git a/pkg/sync/atomicptrmaptest/atomicptrmap.go b/pkg/sync/atomicptrmap/atomicptrmap.go
index 867821ce9..867821ce9 100644
--- a/pkg/sync/atomicptrmaptest/atomicptrmap.go
+++ b/pkg/sync/atomicptrmap/atomicptrmap.go
diff --git a/pkg/sync/atomicptrmaptest/atomicptrmap_test.go b/pkg/sync/atomicptrmap/atomicptrmap_test.go
index 75a9997ef..75a9997ef 100644
--- a/pkg/sync/atomicptrmaptest/atomicptrmap_test.go
+++ b/pkg/sync/atomicptrmap/atomicptrmap_test.go
diff --git a/pkg/sync/generic_atomicptrmap_unsafe.go b/pkg/sync/atomicptrmap/generic_atomicptrmap_unsafe.go
index 3e98cb309..3e98cb309 100644
--- a/pkg/sync/generic_atomicptrmap_unsafe.go
+++ b/pkg/sync/atomicptrmap/generic_atomicptrmap_unsafe.go
diff --git a/pkg/sync/seqatomictest/BUILD b/pkg/sync/seqatomic/BUILD
index 5f9164117..60f79ab54 100644
--- a/pkg/sync/seqatomictest/BUILD
+++ b/pkg/sync/seqatomic/BUILD
@@ -1,14 +1,27 @@
load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
+load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
package(licenses = ["notice"])
+go_template(
+ name = "generic_seqatomic",
+ srcs = ["generic_seqatomic_unsafe.go"],
+ types = [
+ "Value",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ ":sync",
+ "//pkg/gohacks",
+ ],
+)
+
go_template_instance(
name = "seqatomic_int",
out = "seqatomic_int_unsafe.go",
package = "seqatomic",
suffix = "Int",
- template = "//pkg/sync:generic_seqatomic",
+ template = ":generic_seqatomic",
types = {
"Value": "int",
},
diff --git a/pkg/sync/generic_seqatomic_unsafe.go b/pkg/sync/seqatomic/generic_seqatomic_unsafe.go
index 9578c9c52..9578c9c52 100644
--- a/pkg/sync/generic_seqatomic_unsafe.go
+++ b/pkg/sync/seqatomic/generic_seqatomic_unsafe.go
diff --git a/pkg/sync/seqatomictest/seqatomic_test.go b/pkg/sync/seqatomic/seqatomic_test.go
index 2c4568b07..2c4568b07 100644
--- a/pkg/sync/seqatomictest/seqatomic_test.go
+++ b/pkg/sync/seqatomic/seqatomic_test.go
diff --git a/pkg/syserror/BUILD b/pkg/syserror/BUILD
index 7d2f5adf6..76bee5a64 100644
--- a/pkg/syserror/BUILD
+++ b/pkg/syserror/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -8,12 +8,3 @@ go_library(
visibility = ["//visibility:public"],
deps = ["@org_golang_x_sys//unix:go_default_library"],
)
-
-go_test(
- name = "syserror_test",
- srcs = ["syserror_test.go"],
- deps = [
- ":syserror",
- "@org_golang_x_sys//unix:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index ea46c30da..ed4d7e958 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -39,12 +39,30 @@ go_library(
deps_test(
name = "netstack_deps_test",
allowed = [
+ # gVisor deps.
+ "//pkg/atomicbitops",
+ "//pkg/buffer",
+ "//pkg/context",
+ "//pkg/gohacks",
+ "//pkg/goid",
+ "//pkg/ilist",
+ "//pkg/iovec",
+ "//pkg/linewriter",
+ "//pkg/log",
+ "//pkg/rand",
+ "//pkg/sleep",
+ "//pkg/state",
+ "//pkg/state/wire",
+ "//pkg/sync",
+ "//pkg/waiter",
+
+ # Other deps.
"@com_github_google_btree//:go_default_library",
"@org_golang_x_sys//unix:go_default_library",
"@org_golang_x_time//rate:go_default_library",
],
allowed_prefixes = [
- "//",
+ "//pkg/tcpip",
"@org_golang_x_sys//internal/unsafeheader",
],
targets = [
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 18e6cc3cd..e0dfe5813 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -50,7 +50,7 @@ func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) {
ipv4 := header.IPv4(b)
if !ipv4.IsValid(len(b)) {
- t.Error("Not a valid IPv4 packet")
+ t.Fatalf("Not a valid IPv4 packet: %x", ipv4)
}
if !ipv4.IsChecksumValid() {
@@ -72,7 +72,7 @@ func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) {
ipv6 := header.IPv6(b)
if !ipv6.IsValid(len(b)) {
- t.Error("Not a valid IPv6 packet")
+ t.Fatalf("Not a valid IPv6 packet: %x", ipv6)
}
for _, f := range checkers {
@@ -701,7 +701,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp
if !ok {
return
}
- opts := []byte(tcp.Options())
+ opts := tcp.Options()
limit := len(opts)
foundTS := false
tsVal := uint32(0)
@@ -748,12 +748,6 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp
}
}
-// TCPNoSACKBlockChecker creates a checker that verifies that the segment does
-// not contain any SACK blocks in the TCP options.
-func TCPNoSACKBlockChecker() TransportChecker {
- return TCPSACKBlockChecker(nil)
-}
-
// TCPSACKBlockChecker creates a checker that verifies that the segment does
// contain the specified SACK blocks in the TCP options.
func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
@@ -765,7 +759,7 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
}
var gotSACKBlocks []header.SACKBlock
- opts := []byte(tcp.Options())
+ opts := tcp.Options()
limit := len(opts)
for i := 0; i < limit; {
switch opts[i] {
diff --git a/pkg/tcpip/faketime/faketime.go b/pkg/tcpip/faketime/faketime.go
index fb819d7a8..9f8f51647 100644
--- a/pkg/tcpip/faketime/faketime.go
+++ b/pkg/tcpip/faketime/faketime.go
@@ -29,14 +29,14 @@ type NullClock struct{}
var _ tcpip.Clock = (*NullClock)(nil)
-// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
-func (*NullClock) NowNanoseconds() int64 {
- return 0
+// Now implements tcpip.Clock.Now.
+func (*NullClock) Now() time.Time {
+ return time.Time{}
}
// NowMonotonic implements tcpip.Clock.NowMonotonic.
-func (*NullClock) NowMonotonic() int64 {
- return 0
+func (*NullClock) NowMonotonic() tcpip.MonotonicTime {
+ return tcpip.MonotonicTime{}
}
// AfterFunc implements tcpip.Clock.AfterFunc.
@@ -118,16 +118,17 @@ func NewManualClock() *ManualClock {
var _ tcpip.Clock = (*ManualClock)(nil)
-// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
-func (mc *ManualClock) NowNanoseconds() int64 {
+// Now implements tcpip.Clock.Now.
+func (mc *ManualClock) Now() time.Time {
mc.mu.RLock()
defer mc.mu.RUnlock()
- return mc.mu.now.UnixNano()
+ return mc.mu.now
}
// NowMonotonic implements tcpip.Clock.NowMonotonic.
-func (mc *ManualClock) NowMonotonic() int64 {
- return mc.NowNanoseconds()
+func (mc *ManualClock) NowMonotonic() tcpip.MonotonicTime {
+ var mt tcpip.MonotonicTime
+ return mt.Add(mc.Now().Sub(time.Unix(0, 0)))
}
// AfterFunc implements tcpip.Clock.AfterFunc.
@@ -218,6 +219,12 @@ func (mc *ManualClock) stopTimerLocked(mt *manualTimer) {
}
}
+// RunImmediatelyScheduledJobs runs all jobs scheduled to run at the current
+// time.
+func (mc *ManualClock) RunImmediatelyScheduledJobs() {
+ mc.Advance(0)
+}
+
// Advance executes all work that have been scheduled to execute within d from
// the current time. Blocks until all work has completed execution.
func (mc *ManualClock) Advance(d time.Duration) {
diff --git a/pkg/tcpip/faketime/faketime_test.go b/pkg/tcpip/faketime/faketime_test.go
index c2704df2c..fd2bb470a 100644
--- a/pkg/tcpip/faketime/faketime_test.go
+++ b/pkg/tcpip/faketime/faketime_test.go
@@ -26,7 +26,7 @@ func TestManualClockAdvance(t *testing.T) {
clock := faketime.NewManualClock()
start := clock.NowMonotonic()
clock.Advance(timeout)
- if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, timeout; got != want {
+ if got, want := clock.NowMonotonic().Sub(start), timeout; got != want {
t.Errorf("got = %d, want = %d", got, want)
}
}
@@ -87,7 +87,7 @@ func TestManualClockAfterFunc(t *testing.T) {
if got, want := counter2, test.wantCounter2; got != want {
t.Errorf("got counter2 = %d, want = %d", got, want)
}
- if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, test.advance; got != want {
+ if got, want := clock.NowMonotonic().Sub(start), test.advance; got != want {
t.Errorf("got elapsed = %d, want = %d", got, want)
}
})
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index 2be21ec75..e9abbb709 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -17,6 +17,7 @@ package header
import (
"encoding/binary"
"fmt"
+ "time"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -181,7 +182,7 @@ const (
// ipv4LinkLocalUnicastSubnet is the IPv4 link local unicast subnet as defined
// by RFC 3927 section 1.
var ipv4LinkLocalUnicastSubnet = func() tcpip.Subnet {
- subnet, err := tcpip.NewSubnet("\xa9\xfe\x00\x00", tcpip.AddressMask("\xff\xff\x00\x00"))
+ subnet, err := tcpip.NewSubnet("\xa9\xfe\x00\x00", "\xff\xff\x00\x00")
if err != nil {
panic(err)
}
@@ -191,7 +192,7 @@ var ipv4LinkLocalUnicastSubnet = func() tcpip.Subnet {
// ipv4LinkLocalMulticastSubnet is the IPv4 link local multicast subnet as
// defined by RFC 5771 section 4.
var ipv4LinkLocalMulticastSubnet = func() tcpip.Subnet {
- subnet, err := tcpip.NewSubnet("\xe0\x00\x00\x00", tcpip.AddressMask("\xff\xff\xff\x00"))
+ subnet, err := tcpip.NewSubnet("\xe0\x00\x00\x00", "\xff\xff\xff\x00")
if err != nil {
panic(err)
}
@@ -572,7 +573,7 @@ func (o *IPv4OptionGeneric) Type() IPv4OptionType {
func (o *IPv4OptionGeneric) Size() uint8 { return uint8(len(*o)) }
// Contents implements IPv4Option.
-func (o *IPv4OptionGeneric) Contents() []byte { return []byte(*o) }
+func (o *IPv4OptionGeneric) Contents() []byte { return *o }
// IPv4OptionIterator is an iterator pointing to a specific IP option
// at any point of time. It also holds information as to a new options buffer
@@ -610,7 +611,7 @@ func (i *IPv4OptionIterator) InitReplacement(option IPv4Option) IPv4Options {
// RemainingBuffer returns the remaining (unused) part of the new option buffer,
// into which a new option may be written.
func (i *IPv4OptionIterator) RemainingBuffer() IPv4Options {
- return IPv4Options(i.newOptions[i.writePoint:])
+ return i.newOptions[i.writePoint:]
}
// ConsumeBuffer marks a portion of the new buffer as used.
@@ -813,9 +814,12 @@ const (
// ipv4TimestampTime provides the current time as specified in RFC 791.
func ipv4TimestampTime(clock tcpip.Clock) uint32 {
- const millisecondsPerDay = 24 * 3600 * 1000
- const nanoPerMilli = 1000000
- return uint32((clock.NowNanoseconds() / nanoPerMilli) % millisecondsPerDay)
+ // Per RFC 791 page 21:
+ // The Timestamp is a right-justified, 32-bit timestamp in
+ // milliseconds since midnight UT.
+ now := clock.Now().UTC()
+ midnight := now.Truncate(24 * time.Hour)
+ return uint32(now.Sub(midnight).Milliseconds())
}
// IP Timestamp option fields.
@@ -843,7 +847,7 @@ func (ts *IPv4OptionTimestamp) Type() IPv4OptionType { return IPv4OptionTimestam
func (ts *IPv4OptionTimestamp) Size() uint8 { return uint8(len(*ts)) }
// Contents implements IPv4Option.
-func (ts *IPv4OptionTimestamp) Contents() []byte { return []byte(*ts) }
+func (ts *IPv4OptionTimestamp) Contents() []byte { return *ts }
// Pointer returns the pointer field in the IP Timestamp option.
func (ts *IPv4OptionTimestamp) Pointer() uint8 {
@@ -947,7 +951,7 @@ func (rr *IPv4OptionRecordRoute) Type() IPv4OptionType { return IPv4OptionRecord
func (rr *IPv4OptionRecordRoute) Size() uint8 { return uint8(len(*rr)) }
// Contents implements IPv4Option.
-func (rr *IPv4OptionRecordRoute) Contents() []byte { return []byte(*rr) }
+func (rr *IPv4OptionRecordRoute) Contents() []byte { return *rr }
// Router Alert option specific related constants.
//
@@ -992,7 +996,7 @@ func (*IPv4OptionRouterAlert) Type() IPv4OptionType { return IPv4OptionRouterAle
func (ra *IPv4OptionRouterAlert) Size() uint8 { return uint8(len(*ra)) }
// Contents implements IPv4Option.
-func (ra *IPv4OptionRouterAlert) Contents() []byte { return []byte(*ra) }
+func (ra *IPv4OptionRouterAlert) Contents() []byte { return *ra }
// Value returns the value of the IPv4OptionRouterAlert.
func (ra *IPv4OptionRouterAlert) Value() uint16 {
diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go
index 3d1bccd15..d6cad3a94 100644
--- a/pkg/tcpip/header/ndp_options.go
+++ b/pkg/tcpip/header/ndp_options.go
@@ -77,12 +77,12 @@ const (
// ndpPrefixInformationOnLinkFlagMask is the mask of the On-Link Flag
// field in the flags byte within an NDPPrefixInformation.
- ndpPrefixInformationOnLinkFlagMask = (1 << 7)
+ ndpPrefixInformationOnLinkFlagMask = 1 << 7
// ndpPrefixInformationAutoAddrConfFlagMask is the mask of the
// Autonomous Address-Configuration flag field in the flags byte within
// an NDPPrefixInformation.
- ndpPrefixInformationAutoAddrConfFlagMask = (1 << 6)
+ ndpPrefixInformationAutoAddrConfFlagMask = 1 << 6
// ndpPrefixInformationReserved1FlagsMask is the mask of the Reserved1
// field in the flags byte within an NDPPrefixInformation.
@@ -451,7 +451,7 @@ func (o NDPNonceOption) String() string {
// Nonce returns the nonce value this option holds.
func (o NDPNonceOption) Nonce() []byte {
- return []byte(o)
+ return o
}
// NDPSourceLinkLayerAddressOption is the NDP Source Link Layer Option
diff --git a/pkg/tcpip/header/ndp_router_advert.go b/pkg/tcpip/header/ndp_router_advert.go
index bf7610863..7e2f0c797 100644
--- a/pkg/tcpip/header/ndp_router_advert.go
+++ b/pkg/tcpip/header/ndp_router_advert.go
@@ -19,12 +19,72 @@ import (
"time"
)
+// NDPRoutePreference is the preference values for default routers or
+// more-specific routes.
+//
+// As per RFC 4191 section 2.1,
+//
+// Default router preferences and preferences for more-specific routes
+// are encoded the same way.
+//
+// Preference values are encoded as a two-bit signed integer, as
+// follows:
+//
+// 01 High
+// 00 Medium (default)
+// 11 Low
+// 10 Reserved - MUST NOT be sent
+//
+// Note that implementations can treat the value as a two-bit signed
+// integer.
+//
+// Having just three values reinforces that they are not metrics and
+// more values do not appear to be necessary for reasonable scenarios.
+type NDPRoutePreference uint8
+
+const (
+ // HighRoutePreference indicates a high preference, as per
+ // RFC 4191 section 2.1.
+ HighRoutePreference NDPRoutePreference = 0b01
+
+ // MediumRoutePreference indicates a medium preference, as per
+ // RFC 4191 section 2.1.
+ //
+ // This is the default preference value.
+ MediumRoutePreference = 0b00
+
+ // LowRoutePreference indicates a low preference, as per
+ // RFC 4191 section 2.1.
+ LowRoutePreference = 0b11
+
+ // ReservedRoutePreference is a reserved preference value, as per
+ // RFC 4191 section 2.1.
+ //
+ // It MUST NOT be sent.
+ ReservedRoutePreference = 0b10
+)
+
// NDPRouterAdvert is an NDP Router Advertisement message. It will only contain
// the body of an ICMPv6 packet.
//
-// See RFC 4861 section 4.2 for more details.
+// See RFC 4861 section 4.2 and RFC 4191 section 2.2 for more details.
type NDPRouterAdvert []byte
+// As per RFC 4191 section 2.2,
+//
+// 0 1 2 3
+// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Type | Code | Checksum |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Cur Hop Limit |M|O|H|Prf|Resvd| Router Lifetime |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Reachable Time |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Retrans Timer |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Options ...
+// +-+-+-+-+-+-+-+-+-+-+-+-
const (
// NDPRAMinimumSize is the minimum size of a valid NDP Router
// Advertisement message (body of an ICMPv6 packet).
@@ -47,6 +107,14 @@ const (
// within the bit-field/flags byte of an NDPRouterAdvert.
ndpRAOtherConfFlagMask = (1 << 6)
+ // ndpDefaultRouterPreferenceShift is the shift of the Prf (Default Router
+ // Preference) field within the flags byte of an NDPRouterAdvert.
+ ndpDefaultRouterPreferenceShift = 3
+
+ // ndpDefaultRouterPreferenceMask is the mask of the Prf (Default Router
+ // Preference) field within the flags byte of an NDPRouterAdvert.
+ ndpDefaultRouterPreferenceMask = (0b11 << ndpDefaultRouterPreferenceShift)
+
// ndpRARouterLifetimeOffset is the start of the 2-byte Router Lifetime
// field within an NDPRouterAdvert.
ndpRARouterLifetimeOffset = 2
@@ -80,6 +148,11 @@ func (b NDPRouterAdvert) OtherConfFlag() bool {
return b[ndpRAFlagsOffset]&ndpRAOtherConfFlagMask != 0
}
+// DefaultRouterPreference returns the Default Router Preference field.
+func (b NDPRouterAdvert) DefaultRouterPreference() NDPRoutePreference {
+ return NDPRoutePreference((b[ndpRAFlagsOffset] & ndpDefaultRouterPreferenceMask) >> ndpDefaultRouterPreferenceShift)
+}
+
// RouterLifetime returns the lifetime associated with the default router. A
// value of 0 means the source of the Router Advertisement is not a default
// router and SHOULD NOT appear on the default router list. Note, a value of 0
diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go
index 1b5093e58..8fd1f7d13 100644
--- a/pkg/tcpip/header/ndp_test.go
+++ b/pkg/tcpip/header/ndp_test.go
@@ -126,36 +126,83 @@ func TestNDPNeighborAdvert(t *testing.T) {
}
func TestNDPRouterAdvert(t *testing.T) {
- b := []byte{
- 64, 128, 1, 2,
- 3, 4, 5, 6,
- 7, 8, 9, 10,
+ tests := []struct {
+ hopLimit uint8
+ managedFlag, otherConfFlag bool
+ prf NDPRoutePreference
+ routerLifetimeS uint16
+ reachableTimeMS, retransTimerMS uint32
+ }{
+ {
+ hopLimit: 1,
+ managedFlag: false,
+ otherConfFlag: true,
+ prf: HighRoutePreference,
+ routerLifetimeS: 2,
+ reachableTimeMS: 3,
+ retransTimerMS: 4,
+ },
+ {
+ hopLimit: 64,
+ managedFlag: true,
+ otherConfFlag: false,
+ prf: LowRoutePreference,
+ routerLifetimeS: 258,
+ reachableTimeMS: 78492,
+ retransTimerMS: 13213,
+ },
}
- ra := NDPRouterAdvert(b)
+ for i, test := range tests {
+ t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
+ flags := uint8(0)
+ if test.managedFlag {
+ flags |= 1 << 7
+ }
+ if test.otherConfFlag {
+ flags |= 1 << 6
+ }
+ flags |= uint8(test.prf) << 3
- if got := ra.CurrHopLimit(); got != 64 {
- t.Errorf("got ra.CurrHopLimit = %d, want = 64", got)
- }
+ b := []byte{
+ test.hopLimit, flags, 1, 2,
+ 3, 4, 5, 6,
+ 7, 8, 9, 10,
+ }
+ binary.BigEndian.PutUint16(b[2:], test.routerLifetimeS)
+ binary.BigEndian.PutUint32(b[4:], test.reachableTimeMS)
+ binary.BigEndian.PutUint32(b[8:], test.retransTimerMS)
- if got := ra.ManagedAddrConfFlag(); !got {
- t.Errorf("got ManagedAddrConfFlag = false, want = true")
- }
+ ra := NDPRouterAdvert(b)
- if got := ra.OtherConfFlag(); got {
- t.Errorf("got OtherConfFlag = true, want = false")
- }
+ if got := ra.CurrHopLimit(); got != test.hopLimit {
+ t.Errorf("got ra.CurrHopLimit() = %d, want = %d", got, test.hopLimit)
+ }
- if got, want := ra.RouterLifetime(), time.Second*258; got != want {
- t.Errorf("got ra.RouterLifetime = %d, want = %d", got, want)
- }
+ if got := ra.ManagedAddrConfFlag(); got != test.managedFlag {
+ t.Errorf("got ManagedAddrConfFlag() = %t, want = %t", got, test.managedFlag)
+ }
- if got, want := ra.ReachableTime(), time.Millisecond*50595078; got != want {
- t.Errorf("got ra.ReachableTime = %d, want = %d", got, want)
- }
+ if got := ra.OtherConfFlag(); got != test.otherConfFlag {
+ t.Errorf("got OtherConfFlag() = %t, want = %t", got, test.otherConfFlag)
+ }
+
+ if got := ra.DefaultRouterPreference(); got != test.prf {
+ t.Errorf("got DefaultRouterPreference() = %d, want = %d", got, test.prf)
+ }
- if got, want := ra.RetransTimer(), time.Millisecond*117967114; got != want {
- t.Errorf("got ra.RetransTimer = %d, want = %d", got, want)
+ if got, want := ra.RouterLifetime(), time.Second*time.Duration(test.routerLifetimeS); got != want {
+ t.Errorf("got ra.RouterLifetime() = %d, want = %d", got, want)
+ }
+
+ if got, want := ra.ReachableTime(), time.Millisecond*time.Duration(test.reachableTimeMS); got != want {
+ t.Errorf("got ra.ReachableTime() = %d, want = %d", got, want)
+ }
+
+ if got, want := ra.RetransTimer(), time.Millisecond*time.Duration(test.retransTimerMS); got != want {
+ t.Errorf("got ra.RetransTimer() = %d, want = %d", got, want)
+ }
+ })
}
}
diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go
index 0df517000..8dabe3354 100644
--- a/pkg/tcpip/header/tcp.go
+++ b/pkg/tcpip/header/tcp.go
@@ -48,6 +48,16 @@ const (
// TCPFlags is the dedicated type for TCP flags.
type TCPFlags uint8
+// Intersects returns true iff there are flags common to both f and o.
+func (f TCPFlags) Intersects(o TCPFlags) bool {
+ return f&o != 0
+}
+
+// Contains returns true iff all the flags in o are contained within f.
+func (f TCPFlags) Contains(o TCPFlags) bool {
+ return f&o == o
+}
+
// String implements Stringer.String.
func (f TCPFlags) String() string {
flagsStr := []byte("FSRPAU")
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index ef9126deb..f26c857eb 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -288,5 +288,5 @@ func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
}
// AddHeader implements stack.LinkEndpoint.AddHeader.
-func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+func (*Endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
}
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index bddb1d0a2..735c28da1 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -41,7 +41,6 @@ package fdbased
import (
"fmt"
- "math"
"sync/atomic"
"golang.org/x/sys/unix"
@@ -196,8 +195,12 @@ type Options struct {
// option for an FD with a fanoutID already in use by another FD for a different
// NIC will return an EINVAL.
//
+// Since fanoutID must be unique within the network namespace, we start with
+// the PID to avoid collisions. The only way to be sure of avoiding collisions
+// is to run in a new network namespace.
+//
// Must be accessed using atomic operations.
-var fanoutID int32 = 0
+var fanoutID int32 = int32(unix.Getpid())
// New creates a new fd-based endpoint.
//
@@ -292,11 +295,6 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool, fID int32) (lin
}
switch sa.(type) {
case *unix.SockaddrLinklayer:
- // See: PACKET_FANOUT_MAX in net/packet/internal.h
- const packetFanoutMax = 1 << 16
- if fID > packetFanoutMax {
- return nil, fmt.Errorf("host fanoutID limit exceeded, fanoutID must be <= %d", math.MaxUint16)
- }
// Enable PACKET_FANOUT mode if the underlying socket is of type
// AF_PACKET. We do not enable PACKET_FANOUT_FLAG_DEFRAG as that will
// prevent gvisor from receiving fragmented packets and the host does the
@@ -317,7 +315,7 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool, fID int32) (lin
//
// See: https://github.com/torvalds/linux/blob/7acac4b3196caee5e21fb5ea53f8bc124e6a16fc/net/packet/af_packet.c#L3881
const fanoutType = unix.PACKET_FANOUT_HASH
- fanoutArg := int(fID) | fanoutType<<16
+ fanoutArg := (int(fID) & 0xffff) | fanoutType<<16
if err := unix.SetsockoptInt(fd, unix.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil {
return nil, fmt.Errorf("failed to enable PACKET_FANOUT option: %v", err)
}
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index a72eb1aad..6fa1aee18 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -28,13 +28,13 @@ go_test(
":arp",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/stack",
"//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/icmp",
"@com_github_google_go_cmp//cmp:go_default_library",
"@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 0efa3a926..6515c31e5 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -278,7 +278,7 @@ func (*protocol) ParseAddresses(buffer.View) (src, dst tcpip.Address) {
return "", ""
}
-func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.TransportDispatcher) stack.NetworkEndpoint {
e := &endpoint{
protocol: p,
nic: nic,
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 94209b026..5fcbfeaa2 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -15,15 +15,14 @@
package arp_test
import (
- "context"
"fmt"
"testing"
- "time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
@@ -31,7 +30,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
)
const (
@@ -39,15 +37,6 @@ const (
stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
remoteLinkAddr = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06")
-
- defaultChannelSize = 1
- defaultMTU = 65536
-
- // eventChanSize defines the size of event channels used by the neighbor
- // cache's event dispatcher. The size chosen here needs to be sufficient to
- // queue all the events received during tests before consumption.
- // If eventChanSize is too small, the tests may deadlock.
- eventChanSize = 32
)
var (
@@ -123,24 +112,6 @@ func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.Neighbo
d.C <- e
}
-func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error {
- select {
- case got := <-d.C:
- if diff := cmp.Diff(want, got, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" {
- return fmt.Errorf("got invalid event (-want +got):\n%s", diff)
- }
- case <-ctx.Done():
- return fmt.Errorf("%s for %s", ctx.Err(), want)
- }
- return nil
-}
-
-func (d *arpDispatcher) waitForEventWithTimeout(want eventInfo, timeout time.Duration) error {
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- return d.waitForEvent(ctx, want)
-}
-
func (d *arpDispatcher) nextEvent() (eventInfo, bool) {
select {
case event := <-d.C:
@@ -153,55 +124,45 @@ func (d *arpDispatcher) nextEvent() (eventInfo, bool) {
type testContext struct {
s *stack.Stack
linkEP *channel.Endpoint
- nudDisp *arpDispatcher
+ nudDisp arpDispatcher
}
-func newTestContext(t *testing.T) *testContext {
- c := stack.DefaultNUDConfigurations()
- // Transition from Reachable to Stale almost immediately to test if receiving
- // probes refreshes positive reachability.
- c.BaseReachableTime = time.Microsecond
-
- d := arpDispatcher{
- // Create an event channel large enough so the neighbor cache doesn't block
- // while dispatching events. Blocking could interfere with the timing of
- // NUD transitions.
- C: make(chan eventInfo, eventChanSize),
+func makeTestContext(t *testing.T, eventDepth int, packetDepth int) testContext {
+ t.Helper()
+
+ tc := testContext{
+ nudDisp: arpDispatcher{
+ C: make(chan eventInfo, eventDepth),
+ },
}
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
- NUDConfigs: c,
- NUDDisp: &d,
+ tc.s = stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol},
+ NUDDisp: &tc.nudDisp,
+ Clock: &faketime.NullClock{},
})
- ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr)
- ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired
-
- wep := stack.LinkEndpoint(ep)
+ tc.linkEP = channel.New(packetDepth, header.IPv4MinimumMTU, stackLinkAddr)
+ tc.linkEP.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ wep := stack.LinkEndpoint(tc.linkEP)
if testing.Verbose() {
- wep = sniffer.New(ep)
+ wep = sniffer.New(wep)
}
- if err := s.CreateNIC(nicID, wep); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ if err := tc.s.CreateNIC(nicID, wep); err != nil {
+ t.Fatalf("CreateNIC failed: %s", err)
}
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress for ipv4 failed: %v", err)
+ if err := tc.s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
+ t.Fatalf("AddAddress for ipv4 failed: %s", err)
}
- s.SetRouteTable([]tcpip.Route{{
+ tc.s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
NIC: nicID,
}})
- return &testContext{
- s: s,
- linkEP: ep,
- nudDisp: &d,
- }
+ return tc
}
func (c *testContext) cleanup() {
@@ -209,7 +170,7 @@ func (c *testContext) cleanup() {
}
func TestMalformedPacket(t *testing.T) {
- c := newTestContext(t)
+ c := makeTestContext(t, 0, 0)
defer c.cleanup()
v := make(buffer.View, header.ARPSize)
@@ -228,7 +189,7 @@ func TestMalformedPacket(t *testing.T) {
}
func TestDisabledEndpoint(t *testing.T) {
- c := newTestContext(t)
+ c := makeTestContext(t, 0, 0)
defer c.cleanup()
ep, err := c.s.GetNetworkEndpoint(nicID, header.ARPProtocolNumber)
@@ -253,7 +214,7 @@ func TestDisabledEndpoint(t *testing.T) {
}
func TestDirectReply(t *testing.T) {
- c := newTestContext(t)
+ c := makeTestContext(t, 0, 0)
defer c.cleanup()
const senderMAC = "\x01\x02\x03\x04\x05\x06"
@@ -284,7 +245,7 @@ func TestDirectReply(t *testing.T) {
}
func TestDirectRequest(t *testing.T) {
- c := newTestContext(t)
+ c := makeTestContext(t, 1, 1)
defer c.cleanup()
tests := []struct {
@@ -391,17 +352,21 @@ func TestDirectRequest(t *testing.T) {
}
// Verify the sender was saved in the neighbor cache.
- wantEvent := eventInfo{
- eventType: entryAdded,
- nicID: nicID,
- entry: stack.NeighborEntry{
- Addr: test.senderAddr,
- LinkAddr: tcpip.LinkAddress(test.senderLinkAddr),
- State: stack.Stale,
- },
- }
- if err := c.nudDisp.waitForEventWithTimeout(wantEvent, time.Second); err != nil {
- t.Fatal(err)
+ if got, ok := c.nudDisp.nextEvent(); ok {
+ want := eventInfo{
+ eventType: entryAdded,
+ nicID: nicID,
+ entry: stack.NeighborEntry{
+ Addr: test.senderAddr,
+ LinkAddr: test.senderLinkAddr,
+ State: stack.Stale,
+ },
+ }
+ if diff := cmp.Diff(want, got, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAt")); diff != "" {
+ t.Errorf("got invalid event (-want +got):\n%s", diff)
+ }
+ } else {
+ t.Fatal("event didn't arrive")
}
neighbors, err := c.s.Neighbors(nicID, ipv4.ProtocolNumber)
@@ -589,7 +554,7 @@ func TestLinkAddressRequest(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
})
- linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr)
+ linkEP := channel.New(1, header.IPv4MinimumMTU, stackLinkAddr)
if err := s.CreateNIC(nicID, &testLinkEndpoint{LinkEndpoint: linkEP, writeErr: test.linkErr}); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
@@ -663,15 +628,16 @@ func TestLinkAddressRequest(t *testing.T) {
}
func TestDADARPRequestPacket(t *testing.T) {
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocolWithOptions(arp.Options{
DADConfigs: stack.DADConfigurations{
DupAddrDetectTransmits: 1,
- RetransmitTimer: time.Second,
},
}), ipv4.NewProtocol},
+ Clock: clock,
})
- e := channel.New(1, defaultMTU, stackLinkAddr)
+ e := channel.New(1, header.IPv4MinimumMTU, stackLinkAddr)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
@@ -682,7 +648,8 @@ func TestDADARPRequestPacket(t *testing.T) {
t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, header.IPv4ProtocolNumber, remoteAddr, res, stack.DADStarting)
}
- pkt, ok := e.ReadContext(context.Background())
+ clock.RunImmediatelyScheduledJobs()
+ pkt, ok := e.Read()
if !ok {
t.Fatal("expected to send an ARP request")
}
diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation.go b/pkg/tcpip/network/internal/fragmentation/fragmentation.go
index 5168f5361..1ba4d0d36 100644
--- a/pkg/tcpip/network/internal/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/internal/fragmentation/fragmentation.go
@@ -251,7 +251,7 @@ func (f *Fragmentation) releaseReassemblersLocked() {
// The list is empty.
break
}
- elapsed := time.Duration(now-r.creationTime) * time.Nanosecond
+ elapsed := now.Sub(r.createdAt)
if f.timeout > elapsed {
// If the oldest reassembler has not expired, schedule the release
// job so that this function is called back when it has expired.
diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go
index 7daf64b4a..dadfc28cc 100644
--- a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go
+++ b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go
@@ -275,15 +275,23 @@ func TestMemoryLimits(t *testing.T) {
highLimit := 3 * lowLimit // Allow at most 3 such packets.
f := NewFragmentation(minBlockSize, highLimit, lowLimit, reassembleTimeout, &faketime.NullClock{}, nil)
// Send first fragment with id = 0.
- f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0"))
+ if _, _, _, err := f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil {
+ t.Fatal(err)
+ }
// Send first fragment with id = 1.
- f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1"))
+ if _, _, _, err := f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")); err != nil {
+ t.Fatal(err)
+ }
// Send first fragment with id = 2.
- f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2"))
+ if _, _, _, err := f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")); err != nil {
+ t.Fatal(err)
+ }
// Send first fragment with id = 3. This should caused id = 0 and id = 1 to be
// evicted.
- f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3"))
+ if _, _, _, err := f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")); err != nil {
+ t.Fatal(err)
+ }
if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok {
t.Errorf("Memory limits are not respected: id=0 has not been evicted.")
@@ -300,9 +308,13 @@ func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
memSize := pkt(1, "0").MemSize()
f := NewFragmentation(minBlockSize, memSize, 0, reassembleTimeout, &faketime.NullClock{}, nil)
// Send first fragment with id = 0.
- f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0"))
+ if _, _, _, err := f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil {
+ t.Fatal(err)
+ }
// Send the same packet again.
- f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0"))
+ if _, _, _, err := f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil {
+ t.Fatal(err)
+ }
if got, want := f.memSize, memSize; got != want {
t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want)
diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler.go b/pkg/tcpip/network/internal/fragmentation/reassembler.go
index 56b76a284..5b7e4b361 100644
--- a/pkg/tcpip/network/internal/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/internal/fragmentation/reassembler.go
@@ -35,21 +35,21 @@ type hole struct {
type reassembler struct {
reassemblerEntry
- id FragmentID
- memSize int
- proto uint8
- mu sync.Mutex
- holes []hole
- filled int
- done bool
- creationTime int64
- pkt *stack.PacketBuffer
+ id FragmentID
+ memSize int
+ proto uint8
+ mu sync.Mutex
+ holes []hole
+ filled int
+ done bool
+ createdAt tcpip.MonotonicTime
+ pkt *stack.PacketBuffer
}
func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler {
r := &reassembler{
- id: id,
- creationTime: clock.NowMonotonic(),
+ id: id,
+ createdAt: clock.NowMonotonic(),
}
r.holes = append(r.holes, hole{
first: 0,
diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go
index eed49f5d2..5123b7d6a 100644
--- a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go
+++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go
@@ -83,6 +83,8 @@ func (d *DAD) Init(protocolMU sync.Locker, configs stack.DADConfigurations, opts
panic(fmt.Sprintf("given a non-zero value for NonceSize (%d) but zero for ExtendDADTransmits", opts.NonceSize))
}
+ configs.Validate()
+
*d = DAD{
opts: opts,
configs: configs,
diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go
index a22b712c6..24687cf06 100644
--- a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go
+++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go
@@ -133,7 +133,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) {
t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADDisabled)
}
// Wait for any initially fired timers to complete.
- clock.Advance(0)
+ clock.RunImmediatelyScheduledJobs()
if diff := dad.check(nil); diff != "" {
t.Errorf("dad check mismatch (-want +got):\n%s", diff)
}
@@ -147,7 +147,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) {
if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting {
t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting)
}
- clock.Advance(0)
+ clock.RunImmediatelyScheduledJobs()
if diff := dad.check([]tcpip.Address{addr1}); diff != "" {
t.Errorf("dad check mismatch (-want +got):\n%s", diff)
}
@@ -156,7 +156,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) {
if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADAlreadyRunning {
t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADAlreadyRunning)
}
- clock.Advance(0)
+ clock.RunImmediatelyScheduledJobs()
if diff := dad.check(nil); diff != "" {
t.Errorf("dad check mismatch (-want +got):\n%s", diff)
}
@@ -170,7 +170,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) {
if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting {
t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting)
}
- clock.Advance(0)
+ clock.RunImmediatelyScheduledJobs()
if diff := dad.check([]tcpip.Address{addr2}); diff != "" {
t.Errorf("dad check mismatch (-want +got):\n%s", diff)
}
@@ -208,7 +208,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) {
if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting {
t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting)
}
- clock.Advance(0)
+ clock.RunImmediatelyScheduledJobs()
if diff := dad.check([]tcpip.Address{addr2, addr2}); diff != "" {
t.Errorf("dad check mismatch (-want +got):\n%s", diff)
}
@@ -247,7 +247,7 @@ func TestDADStop(t *testing.T) {
if res := dad.checkDuplicateAddress(addr3, handler(ch, addr3)); res != stack.DADStarting {
t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting)
}
- clock.Advance(0)
+ clock.RunImmediatelyScheduledJobs()
if diff := dad.check([]tcpip.Address{addr1, addr2, addr3}); diff != "" {
t.Errorf("dad check mismatch (-want +got):\n%s", diff)
}
@@ -272,7 +272,7 @@ func TestDADStop(t *testing.T) {
if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting {
t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting)
}
- clock.Advance(0)
+ clock.RunImmediatelyScheduledJobs()
if diff := dad.check([]tcpip.Address{addr1}); diff != "" {
t.Errorf("dad check mismatch (-want +got):\n%s", diff)
}
@@ -347,7 +347,7 @@ func TestNonce(t *testing.T) {
t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting)
}
- clock.Advance(0)
+ clock.RunImmediatelyScheduledJobs()
for i, want := range test.expectedResults {
if got := dad.extendIfNonceEqual(addr1, test.mockedReceivedNonce); got != want {
t.Errorf("(i=%d) got dad.extendIfNonceEqual(%s, _) = %d, want = %d", i, addr1, got, want)
diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
index d22974b12..671dfbf32 100644
--- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
+++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
@@ -611,7 +611,7 @@ func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddr
// If a timer for any address is already running, it is reset to the new
// random value only if the requested Maximum Response Delay is less than
// the remaining value of the running timer.
- now := time.Unix(0 /* seconds */, g.opts.Clock.NowNanoseconds())
+ now := g.opts.Clock.Now()
if info.state == delayingMember {
if info.delayedReportJobFiresAt.IsZero() {
panic(fmt.Sprintf("delayed report unscheduled while in the delaying member state; group = %s", groupAddress))
diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go
index 0c2b62127..40ab21cb6 100644
--- a/pkg/tcpip/network/internal/ip/stats.go
+++ b/pkg/tcpip/network/internal/ip/stats.go
@@ -42,6 +42,10 @@ type MultiCounterIPForwardingStats struct {
// were too big for the outgoing MTU.
PacketTooBig tcpip.MultiCounterStat
+ // HostUnreachable is the number of IP packets received which could not be
+ // successfully forwarded due to an unresolvable next hop.
+ HostUnreachable tcpip.MultiCounterStat
+
// ExtensionHeaderProblem is the number of IP packets which were dropped
// because of a problem encountered when processing an IPv6 extension
// header.
@@ -61,6 +65,7 @@ func (m *MultiCounterIPForwardingStats) Init(a, b *tcpip.IPForwardingStats) {
m.ExtensionHeaderProblem.Init(a.ExtensionHeaderProblem, b.ExtensionHeaderProblem)
m.PacketTooBig.Init(a.PacketTooBig, b.PacketTooBig)
m.ExhaustedTTL.Init(a.ExhaustedTTL, b.ExhaustedTTL)
+ m.HostUnreachable.Init(a.HostUnreachable, b.HostUnreachable)
}
// LINT.ThenChange(:MultiCounterIPForwardingStats, ../../../tcpip.go:IPForwardingStats)
diff --git a/pkg/tcpip/network/internal/testutil/BUILD b/pkg/tcpip/network/internal/testutil/BUILD
index cec3e62c4..a180e5c75 100644
--- a/pkg/tcpip/network/internal/testutil/BUILD
+++ b/pkg/tcpip/network/internal/testutil/BUILD
@@ -10,6 +10,7 @@ go_library(
"//pkg/tcpip/network/internal/fragmentation:__pkg__",
"//pkg/tcpip/network/ipv4:__pkg__",
"//pkg/tcpip/network/ipv6:__pkg__",
+ "//pkg/tcpip/tests/integration:__pkg__",
],
deps = [
"//pkg/tcpip",
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index d1a82b584..2aa38eb98 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -173,9 +173,8 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
received := e.stats.icmp.packetsReceived
- // TODO(gvisor.dev/issue/170): ICMP packets don't have their
- // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
- // full explanation.
+ // ICMP packets don't have their TransportHeader fields set. See
+ // icmp/protocol.go:protocol.Parse for a full explanation.
v, ok := pkt.Data().PullUp(header.ICMPv4MinimumSize)
if !ok {
received.invalid.Increment()
@@ -222,7 +221,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
_ = e.protocol.returnError(&icmpReasonParamProblem{
pointer: optProblem.Pointer,
}, pkt)
- e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
e.stats.ip.MalformedPacketsReceived.Increment()
}
return
@@ -481,6 +479,22 @@ func (*icmpReasonFragmentationNeeded) isForwarding() bool {
return true
}
+// icmpReasonHostUnreachable is an error in which the host specified in the
+// internet destination field of the datagram is unreachable.
+type icmpReasonHostUnreachable struct{}
+
+func (*icmpReasonHostUnreachable) isICMPReason() {}
+func (*icmpReasonHostUnreachable) isForwarding() bool {
+ // If we hit a Host Unreachable error, then we know we are operating as a
+ // router. As per RFC 792 page 5, Destination Unreachable Message,
+ //
+ // In addition, in some networks, the gateway may be able to determine
+ // if the internet destination host is unreachable. Gateways in these
+ // networks may send destination unreachable messages to the source host
+ // when the destination host is unreachable.
+ return true
+}
+
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv4 and sends it back to the remote device that sent
// the problematic packet. It incorporates as much of that packet as
@@ -537,7 +551,12 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
defer route.Release()
p.mu.Lock()
- netEP, ok := p.mu.eps[pkt.NICID]
+ // We retrieve an endpoint using the newly constructed route's NICID rather
+ // than the packet's NICID. The packet's NICID corresponds to the NIC on
+ // which it arrived, which isn't necessarily the same as the NIC on which it
+ // will be transmitted. On the other hand, the route's NIC *is* guaranteed
+ // to be the NIC on which the packet will be transmitted.
+ netEP, ok := p.mu.eps[route.NICID()]
p.mu.Unlock()
if !ok {
return &tcpip.ErrNotConnected{}
@@ -653,6 +672,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(header.ICMPv4NetUnreachable)
counter = sent.dstUnreachable
+ case *icmpReasonHostUnreachable:
+ icmpHdr.SetType(header.ICMPv4DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv4HostUnreachable)
+ counter = sent.dstUnreachable
case *icmpReasonFragmentationNeeded:
icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(header.ICMPv4FragmentationNeeded)
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 049811cbb..f08b008ac 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -63,9 +63,15 @@ const (
fragmentblockSize = 8
)
+const (
+ forwardingDisabled = 0
+ forwardingEnabled = 1
+)
+
var ipv4BroadcastAddr = header.IPv4Broadcast.WithPrefix()
var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil)
+var _ stack.ForwardingNetworkEndpoint = (*endpoint)(nil)
var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
var _ stack.AddressableEndpoint = (*endpoint)(nil)
var _ stack.NetworkEndpoint = (*endpoint)(nil)
@@ -76,12 +82,18 @@ type endpoint struct {
protocol *protocol
stats sharedStats
- // enabled is set to 1 when the enpoint is enabled and 0 when it is
+ // enabled is set to 1 when the endpoint is enabled and 0 when it is
// disabled.
//
// Must be accessed using atomic operations.
enabled uint32
+ // forwarding is set to forwardingEnabled when the endpoint has forwarding
+ // enabled and forwardingDisabled when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ forwarding uint32
+
mu struct {
sync.RWMutex
@@ -92,6 +104,16 @@ type endpoint struct {
// HandleLinkResolutionFailure implements stack.LinkResolvableNetworkEndpoint.
func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) {
+ // If we are operating as a router, return an ICMP error to the original
+ // packet's sender.
+ if pkt.NetworkPacketInfo.IsForwardedPacket {
+ // TODO(gvisor.dev/issue/6005): Propagate asynchronously generated ICMP
+ // errors to local endpoints.
+ e.protocol.returnError(&icmpReasonHostUnreachable{}, pkt)
+ e.stats.ip.Forwarding.Errors.Increment()
+ e.stats.ip.Forwarding.HostUnreachable.Increment()
+ return
+ }
// handleControl expects the entire offending packet to be in the packet
// buffer's data field.
pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -151,14 +173,32 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
delete(p.mu.eps, nicID)
}
-// transitionForwarding transitions the endpoint's forwarding status to
-// forwarding.
+// Forwarding implements stack.ForwardingNetworkEndpoint.
+func (e *endpoint) Forwarding() bool {
+ return atomic.LoadUint32(&e.forwarding) == forwardingEnabled
+}
+
+// setForwarding sets the forwarding status for the endpoint.
//
-// Must only be called when the forwarding status changes.
-func (e *endpoint) transitionForwarding(forwarding bool) {
+// Returns true if the forwarding status was updated.
+func (e *endpoint) setForwarding(v bool) bool {
+ forwarding := uint32(forwardingDisabled)
+ if v {
+ forwarding = forwardingEnabled
+ }
+
+ return atomic.SwapUint32(&e.forwarding, forwarding) != forwarding
+}
+
+// SetForwarding implements stack.ForwardingNetworkEndpoint.
+func (e *endpoint) SetForwarding(forwarding bool) {
e.mu.Lock()
defer e.mu.Unlock()
+ if !e.setForwarding(forwarding) {
+ return
+ }
+
if forwarding {
// There does not seem to be an RFC requirement for a node to join the all
// routers multicast address but
@@ -292,8 +332,8 @@ func (e *endpoint) DefaultTTL() uint8 {
return e.protocol.DefaultTTL()
}
-// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
-// the network layer max header length.
+// MTU implements stack.NetworkEndpoint. It returns the link-layer MTU minus the
+// network layer max header length.
func (e *endpoint) MTU() uint32 {
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv4MinimumSize)
if err != nil {
@@ -308,7 +348,7 @@ func (e *endpoint) MaxHeaderLength() uint16 {
return e.nic.MaxHeaderLength() + header.IPv4MaximumHeaderSize
}
-// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
+// NetworkProtocolNumber implements stack.NetworkEndpoint.
func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
@@ -323,7 +363,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet
if hdrLen > header.IPv4MaximumHeaderSize {
return &tcpip.ErrMessageTooLong{}
}
- ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen))
+ ipH := header.IPv4(pkt.NetworkHeader().Push(hdrLen))
length := pkt.Size()
if length > math.MaxUint16 {
return &tcpip.ErrMessageTooLong{}
@@ -332,7 +372,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet
// datagrams. Since the DF bit is never being set here, all datagrams
// are non-atomic and need an ID.
id := atomic.AddUint32(&e.protocol.ids[hashRoute(srcAddr, dstAddr, params.Protocol, e.protocol.hashIV)%buckets], 1)
- ip.Encode(&header.IPv4Fields{
+ ipH.Encode(&header.IPv4Fields{
TotalLength: uint16(length),
ID: uint16(id),
TTL: params.TTL,
@@ -342,7 +382,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet
DstAddr: dstAddr,
Options: options,
})
- ip.SetChecksum(^ip.CalculateChecksum())
+ ipH.SetChecksum(^ipH.CalculateChecksum())
pkt.NetworkProtocolNumber = ProtocolNumber
return nil
}
@@ -351,7 +391,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet
// fragment. It returns the number of fragments handled and the number of
// fragments left to be processed. The IP header must already be present in the
// original packet.
-func (e *endpoint) handleFragments(r *stack.Route, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) {
+func (e *endpoint) handleFragments(_ *stack.Route, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) {
// Round the MTU down to align to 8 bytes.
fragmentPayloadSize := networkMTU &^ 7
networkHeader := header.IPv4(pkt.NetworkHeader().View())
@@ -389,9 +429,9 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams,
// based on destination address and do not send the packet to link
// layer.
//
- // TODO(gvisor.dev/issue/170): We should do this for every
- // packet, rather than only NATted packets, but removing this check
- // short circuits broadcasts before they are sent out to other hosts.
+ // We should do this for every packet, rather than only NATted packets, but
+ // removing this check short circuits broadcasts before they are sent out to
+ // other hosts.
if pkt.NatDone {
netHeader := header.IPv4(pkt.NetworkHeader().View())
if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil {
@@ -460,7 +500,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIn
return nil
}
-// WritePackets implements stack.NetworkEndpoint.WritePackets.
+// WritePackets implements stack.NetworkEndpoint.
func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) {
if r.Loop()&stack.PacketLoop != 0 {
panic("multiple packets in local loop")
@@ -563,34 +603,34 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
if !ok {
return &tcpip.ErrMalformedHeader{}
}
- ip := header.IPv4(h)
+ ipH := header.IPv4(h)
// Always set the total length.
pktSize := pkt.Data().Size()
- ip.SetTotalLength(uint16(pktSize))
+ ipH.SetTotalLength(uint16(pktSize))
// Set the source address when zero.
- if ip.SourceAddress() == header.IPv4Any {
- ip.SetSourceAddress(r.LocalAddress())
+ if ipH.SourceAddress() == header.IPv4Any {
+ ipH.SetSourceAddress(r.LocalAddress())
}
// Set the destination. If the packet already included a destination, it will
// be part of the route anyways.
- ip.SetDestinationAddress(r.RemoteAddress())
+ ipH.SetDestinationAddress(r.RemoteAddress())
// Set the packet ID when zero.
- if ip.ID() == 0 {
+ if ipH.ID() == 0 {
// RFC 6864 section 4.3 mandates uniqueness of ID values for
// non-atomic datagrams, so assign an ID to all such datagrams
// according to the definition given in RFC 6864 section 4.
- if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 {
- ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r.LocalAddress(), r.RemoteAddress(), 0 /* protocol */, e.protocol.hashIV)%buckets], 1)))
+ if ipH.Flags()&header.IPv4FlagDontFragment == 0 || ipH.Flags()&header.IPv4FlagMoreFragments != 0 || ipH.FragmentOffset() > 0 {
+ ipH.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r.LocalAddress(), r.RemoteAddress(), 0 /* protocol */, e.protocol.hashIV)%buckets], 1)))
}
}
// Always set the checksum.
- ip.SetChecksum(0)
- ip.SetChecksum(^ip.CalculateChecksum())
+ ipH.SetChecksum(0)
+ ipH.SetChecksum(^ipH.CalculateChecksum())
// Populate the packet buffer's network header and don't allow an invalid
// packet to be sent.
@@ -680,7 +720,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
return nil
}
- ep.handleValidatedPacket(h, pkt)
+ // The packet originally arrived on e so provide its NIC as the input NIC.
+ ep.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
return nil
}
@@ -796,7 +837,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
}
}
- e.handleValidatedPacket(h, pkt)
+ e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
}
// handleLocalPacket is like HandlePacket except it does not perform the
@@ -815,10 +856,10 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum
return
}
- e.handleValidatedPacket(h, pkt)
+ e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
}
-func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) {
+func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, inNICName string) {
pkt.NICID = e.nic.ID()
stats := e.stats
stats.ip.ValidPacketsReceived.Increment()
@@ -852,7 +893,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer)
addressEndpoint.DecRef()
pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast
} else if !e.IsInGroup(dstAddr) {
- if !e.protocol.Forwarding() {
+ if !e.Forwarding() {
stats.ip.InvalidDestinationAddressesReceived.Increment()
return
}
@@ -868,7 +909,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer)
case *ip.ErrNoRoute:
stats.ip.Forwarding.Unrouteable.Increment()
case *ip.ErrParameterProblem:
- e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
stats.ip.MalformedPacketsReceived.Increment()
case *ip.ErrMessageTooLong:
stats.ip.Forwarding.PacketTooBig.Increment()
@@ -881,8 +921,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer)
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
- inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
stats.ip.IPTablesInputDropped.Increment()
return
@@ -905,7 +944,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer)
_ = e.protocol.returnError(&icmpReasonParamProblem{
pointer: optProblem.Pointer,
}, pkt)
- e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
e.stats.ip.MalformedPacketsReceived.Increment()
}
return
@@ -955,7 +993,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer)
// The reassembler doesn't take care of fixing up the header, so we need
// to do it here.
- h.SetTotalLength(uint16(pkt.Data().Size() + len((h))))
+ h.SetTotalLength(uint16(pkt.Data().Size() + len(h)))
h.SetFlagsFragmentOffset(0, 0)
}
stats.ip.PacketsDelivered.Increment()
@@ -978,7 +1016,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer)
_ = e.protocol.returnError(&icmpReasonParamProblem{
pointer: optProblem.Pointer,
}, pkt)
- e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
stats.ip.MalformedPacketsReceived.Increment()
}
return
@@ -1144,7 +1181,6 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats {
return &e.stats.localStats
}
-var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
var _ fragmentation.TimeoutHandler = (*protocol)(nil)
@@ -1165,12 +1201,6 @@ type protocol struct {
// Must be accessed using atomic operations.
defaultTTL uint32
- // forwarding is set to 1 when the protocol has forwarding enabled and 0
- // when it is disabled.
- //
- // Must be accessed using atomic operations.
- forwarding uint32
-
ids []uint32
hashIV uint32
@@ -1194,13 +1224,13 @@ func (p *protocol) DefaultPrefixLen() int {
return header.IPv4AddressSize * 8
}
-// ParseAddresses implements NetworkProtocol.ParseAddresses.
+// ParseAddresses implements stack.NetworkProtocol.
func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.IPv4(v)
return h.SourceAddress(), h.DestinationAddress()
}
-// SetOption implements NetworkProtocol.SetOption.
+// SetOption implements stack.NetworkProtocol.
func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
@@ -1211,7 +1241,7 @@ func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.E
}
}
-// Option implements NetworkProtocol.Option.
+// Option implements stack.NetworkProtocol.
func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
@@ -1232,10 +1262,10 @@ func (p *protocol) DefaultTTL() uint8 {
return uint8(atomic.LoadUint32(&p.defaultTTL))
}
-// Close implements stack.TransportProtocol.Close.
+// Close implements stack.TransportProtocol.
func (*protocol) Close() {}
-// Wait implements stack.TransportProtocol.Wait.
+// Wait implements stack.TransportProtocol.
func (*protocol) Wait() {}
// parseAndValidate parses the packet (including its transport layer header) and
@@ -1273,7 +1303,7 @@ func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv4, bool)
return h, true
}
-// Parse implements stack.NetworkProtocol.Parse.
+// Parse implements stack.NetworkProtocol.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
if ok := parse.IPv4(pkt); !ok {
return 0, false, false
@@ -1283,35 +1313,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true
}
-// Forwarding implements stack.ForwardingNetworkProtocol.
-func (p *protocol) Forwarding() bool {
- return uint8(atomic.LoadUint32(&p.forwarding)) == 1
-}
-
-// setForwarding sets the forwarding status for the protocol.
-//
-// Returns true if the forwarding status was updated.
-func (p *protocol) setForwarding(v bool) bool {
- if v {
- return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */)
- }
- return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */)
-}
-
-// SetForwarding implements stack.ForwardingNetworkProtocol.
-func (p *protocol) SetForwarding(v bool) {
- p.mu.Lock()
- defer p.mu.Unlock()
-
- if !p.setForwarding(v) {
- return
- }
-
- for _, ep := range p.mu.eps {
- ep.transitionForwarding(v)
- }
-}
-
// calculateNetworkMTU calculates the network-layer payload MTU based on the
// link-layer payload mtu.
func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) {
@@ -1332,7 +1333,7 @@ func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error
networkMTU = MaxTotalSize
}
- return networkMTU - uint32(networkHeaderSize), nil
+ return networkMTU - networkHeaderSize, nil
}
func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32) bool {
@@ -1741,9 +1742,8 @@ type optionTracker struct {
//
// If there were no errors during parsing, the new set of options is returned as
// a new buffer.
-func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (header.IPv4Options, optionTracker, *header.IPv4OptParameterProblem) {
+func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, opts header.IPv4Options, usage optionsUsage) (header.IPv4Options, optionTracker, *header.IPv4OptParameterProblem) {
stats := e.stats.ip
- opts := header.IPv4Options(orig)
optIter := opts.MakeIterator()
// Except NOP, each option must only appear at most once (RFC 791 section 3.1,
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 5f45b9ee6..4a4448cf9 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -16,7 +16,6 @@ package ipv4_test
import (
"bytes"
- "context"
"encoding/hex"
"fmt"
"io/ioutil"
@@ -36,10 +35,10 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
- "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
+ iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/stack"
- tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/raw"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
@@ -112,33 +111,29 @@ func TestExcludeBroadcast(t *testing.T) {
})
}
-type forwardedPacket struct {
- fragments []fragmentInfo
-}
-
func TestForwarding(t *testing.T) {
const (
- nicID1 = 1
- nicID2 = 2
+ incomingNICID = 1
+ outgoingNICID = 2
randomSequence = 123
randomIdent = 42
randomTimeOffset = 0x10203040
)
- ipv4Addr1 := tcpip.AddressWithPrefix{
+ incomingIPv4Addr := tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()),
PrefixLen: 8,
}
- ipv4Addr2 := tcpip.AddressWithPrefix{
+ outgoingIPv4Addr := tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()),
PrefixLen: 8,
}
- linkAddr2 := tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
- remoteIPv4Addr1 := tcpip.Address(net.ParseIP("10.0.0.2").To4())
- remoteIPv4Addr2 := tcpip.Address(net.ParseIP("11.0.0.2").To4())
- unreachableIPv4Addr := tcpip.Address(net.ParseIP("12.0.0.2").To4())
- multicastIPv4Addr := tcpip.Address(net.ParseIP("225.0.0.0").To4())
- linkLocalIPv4Addr := tcpip.Address(net.ParseIP("169.254.0.0").To4())
+ outgoingLinkAddr := tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
+ remoteIPv4Addr1 := testutil.MustParse4("10.0.0.2")
+ remoteIPv4Addr2 := testutil.MustParse4("11.0.0.2")
+ unreachableIPv4Addr := testutil.MustParse4("12.0.0.2")
+ multicastIPv4Addr := testutil.MustParse4("225.0.0.0")
+ linkLocalIPv4Addr := testutil.MustParse4("169.254.0.0")
tests := []struct {
name string
@@ -345,6 +340,7 @@ func TestForwarding(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
clock := faketime.NewManualClock()
+
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
@@ -356,36 +352,36 @@ func TestForwarding(t *testing.T) {
clock.Advance(time.Millisecond * randomTimeOffset)
// We expect at most a single packet in response to our ICMP Echo Request.
- e1 := channel.New(1, test.mtu, "")
- if err := s.CreateNIC(nicID1, e1); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ incomingEndpoint := channel.New(1, test.mtu, "")
+ if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err)
}
- ipv4ProtoAddr1 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr1}
- if err := s.AddProtocolAddress(nicID1, ipv4ProtoAddr1); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv4ProtoAddr1, err)
+ incomingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: incomingIPv4Addr}
+ if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv4ProtoAddr, err)
}
expectedEmittedPacketCount := 1
if len(test.expectedFragmentsForwarded) > expectedEmittedPacketCount {
expectedEmittedPacketCount = len(test.expectedFragmentsForwarded)
}
- e2 := channel.New(expectedEmittedPacketCount, test.mtu, linkAddr2)
- if err := s.CreateNIC(nicID2, e2); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ outgoingEndpoint := channel.New(expectedEmittedPacketCount, test.mtu, outgoingLinkAddr)
+ if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err)
}
- ipv4ProtoAddr2 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr2}
- if err := s.AddProtocolAddress(nicID2, ipv4ProtoAddr2); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv4ProtoAddr2, err)
+ outgoingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: outgoingIPv4Addr}
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv4ProtoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
{
- Destination: ipv4Addr1.Subnet(),
- NIC: nicID1,
+ Destination: incomingIPv4Addr.Subnet(),
+ NIC: incomingNICID,
},
{
- Destination: ipv4Addr2.Subnet(),
- NIC: nicID2,
+ Destination: outgoingIPv4Addr.Subnet(),
+ NIC: outgoingNICID,
},
})
@@ -401,13 +397,13 @@ func TestForwarding(t *testing.T) {
totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength
hdr := buffer.NewPrependable(totalLength)
hdr.Prepend(test.payloadLength)
- icmp := header.ICMPv4(hdr.Prepend(icmpHeaderLength))
- icmp.SetIdent(randomIdent)
- icmp.SetSequence(randomSequence)
- icmp.SetType(header.ICMPv4Echo)
- icmp.SetCode(header.ICMPv4UnusedCode)
- icmp.SetChecksum(0)
- icmp.SetChecksum(^header.Checksum(icmp, 0))
+ icmpH := header.ICMPv4(hdr.Prepend(icmpHeaderLength))
+ icmpH.SetIdent(randomIdent)
+ icmpH.SetSequence(randomSequence)
+ icmpH.SetType(header.ICMPv4Echo)
+ icmpH.SetCode(header.ICMPv4UnusedCode)
+ icmpH.SetChecksum(0)
+ icmpH.SetChecksum(^header.Checksum(icmpH, 0))
ip := header.IPv4(hdr.Prepend(ipHeaderLength))
ip.Encode(&header.IPv4Fields{
TotalLength: uint16(totalLength),
@@ -431,9 +427,10 @@ func TestForwarding(t *testing.T) {
Data: hdr.View().ToVectorisedView(),
})
requestPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
- e1.InjectInbound(header.IPv4ProtocolNumber, requestPkt)
+ incomingEndpoint.InjectInbound(header.IPv4ProtocolNumber, requestPkt)
+
+ reply, ok := incomingEndpoint.Read()
- reply, ok := e1.Read()
if test.expectErrorICMP {
if !ok {
t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType)
@@ -451,15 +448,15 @@ func TestForwarding(t *testing.T) {
return len(hdr.View())
}
- checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())),
- checker.SrcAddr(ipv4Addr1.Address),
+ checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
+ checker.SrcAddr(incomingIPv4Addr.Address),
checker.DstAddr(test.sourceAddr),
checker.TTL(ipv4.DefaultTTL),
checker.ICMPv4(
checker.ICMPv4Checksum(),
checker.ICMPv4Type(test.icmpType),
checker.ICMPv4Code(test.icmpCode),
- checker.ICMPv4Payload([]byte(hdr.View()[0:expectedICMPPayloadLength()])),
+ checker.ICMPv4Payload(hdr.View()[:expectedICMPPayloadLength()]),
),
)
} else if ok {
@@ -468,9 +465,9 @@ func TestForwarding(t *testing.T) {
if test.expectPacketForwarded {
if len(test.expectedFragmentsForwarded) != 0 {
- fragmentedPackets := []*stack.PacketBuffer{}
+ var fragmentedPackets []*stack.PacketBuffer
for i := 0; i < len(test.expectedFragmentsForwarded); i++ {
- reply, ok = e2.Read()
+ reply, ok = outgoingEndpoint.Read()
if !ok {
t.Fatal("expected ICMP Echo fragment through outgoing NIC")
}
@@ -485,16 +482,16 @@ func TestForwarding(t *testing.T) {
// maximum IP header size and the maximum size allocated for link layer
// headers. In this case, no size is allocated for link layer headers.
expectedAvailableHeaderBytes := header.IPv4MaximumHeaderSize
- if err := compareFragments(fragmentedPackets, requestPkt, uint32(test.mtu), test.expectedFragmentsForwarded, header.ICMPv4ProtocolNumber, true /* withIPHeader */, expectedAvailableHeaderBytes); err != nil {
+ if err := compareFragments(fragmentedPackets, requestPkt, test.mtu, test.expectedFragmentsForwarded, header.ICMPv4ProtocolNumber, true /* withIPHeader */, expectedAvailableHeaderBytes); err != nil {
t.Error(err)
}
} else {
- reply, ok = e2.Read()
+ reply, ok = outgoingEndpoint.Read()
if !ok {
t.Fatal("expected ICMP Echo packet through outgoing NIC")
}
- checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
checker.SrcAddr(test.sourceAddr),
checker.DstAddr(test.destAddr),
checker.TTL(test.TTL-1),
@@ -508,11 +505,10 @@ func TestForwarding(t *testing.T) {
)
}
} else {
- if reply, ok = e2.Read(); ok {
+ if reply, ok = outgoingEndpoint.Read(); ok {
t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply)
}
}
-
boolToInt := func(val bool) uint64 {
if val {
return 1
@@ -1211,15 +1207,15 @@ func TestIPv4Sanity(t *testing.T) {
}
totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize)
hdr := buffer.NewPrependable(int(totalLen))
- icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
// Specify ident/seq to make sure we get the same in the response.
- icmp.SetIdent(randomIdent)
- icmp.SetSequence(randomSequence)
- icmp.SetType(header.ICMPv4Echo)
- icmp.SetCode(header.ICMPv4UnusedCode)
- icmp.SetChecksum(0)
- icmp.SetChecksum(^header.Checksum(icmp, 0))
+ icmpH.SetIdent(randomIdent)
+ icmpH.SetSequence(randomSequence)
+ icmpH.SetType(header.ICMPv4Echo)
+ icmpH.SetCode(header.ICMPv4UnusedCode)
+ icmpH.SetChecksum(0)
+ icmpH.SetChecksum(^header.Checksum(icmpH, 0))
ip := header.IPv4(hdr.Prepend(ipHeaderLength))
if test.maxTotalLength < totalLen {
totalLen = test.maxTotalLength
@@ -1314,7 +1310,7 @@ func TestIPv4Sanity(t *testing.T) {
checker.ICMPv4Type(test.ICMPType),
checker.ICMPv4Code(test.ICMPCode),
checker.ICMPv4Pointer(test.paramProblemPointer),
- checker.ICMPv4Payload([]byte(hdr.View())),
+ checker.ICMPv4Payload(hdr.View()),
),
)
return
@@ -1333,7 +1329,7 @@ func TestIPv4Sanity(t *testing.T) {
checker.ICMPv4(
checker.ICMPv4Type(test.ICMPType),
checker.ICMPv4Code(test.ICMPCode),
- checker.ICMPv4Payload([]byte(hdr.View())),
+ checker.ICMPv4Payload(hdr.View()),
),
)
return
@@ -1545,9 +1541,9 @@ func TestFragmentationWritePacket(t *testing.T) {
for _, ft := range fragmentationTests {
t.Run(ft.description, func(t *testing.T) {
- ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
r := buildRoute(t, ep)
- pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
+ pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
source := pkt.Clone()
err := r.WritePacket(stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
@@ -1601,7 +1597,7 @@ func TestFragmentationWritePackets(t *testing.T) {
insertAfter: 1,
},
}
- tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber)
+ tinyPacket := iptestutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber)
for _, test := range writePacketsTests {
t.Run(test.description, func(t *testing.T) {
@@ -1611,13 +1607,13 @@ func TestFragmentationWritePackets(t *testing.T) {
for i := 0; i < test.insertBefore; i++ {
pkts.PushBack(tinyPacket.Clone())
}
- pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
+ pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
pkts.PushBack(pkt.Clone())
for i := 0; i < test.insertAfter; i++ {
pkts.PushBack(tinyPacket.Clone())
}
- ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
r := buildRoute(t, ep)
wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter
@@ -1725,8 +1721,8 @@ func TestFragmentationErrors(t *testing.T) {
for _, ft := range tests {
t.Run(ft.description, func(t *testing.T) {
- pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
- ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
+ pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
+ ep := iptestutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
r := buildRoute(t, ep)
err := r.WritePacket(stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
@@ -2300,7 +2296,7 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
checker.ICMPv4Type(header.ICMPv4TimeExceeded),
checker.ICMPv4Code(header.ICMPv4ReassemblyTimeout),
checker.ICMPv4Checksum(),
- checker.ICMPv4Payload([]byte(firstFragmentSent)),
+ checker.ICMPv4Payload(firstFragmentSent),
),
)
})
@@ -2704,7 +2700,7 @@ func TestReceiveFragments(t *testing.T) {
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
RawFactory: raw.EndpointFactory{},
})
- e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00"))
+ e := channel.New(0, 1280, "\xf0\x00")
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -2947,7 +2943,7 @@ func TestWriteStats(t *testing.T) {
t.Run(writer.name, func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets)
+ ep := iptestutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets)
rt := buildRoute(t, ep)
var pkts stack.PacketBufferList
@@ -3034,7 +3030,7 @@ func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string)
return false, false
}
-func TestPacketQueing(t *testing.T) {
+func TestPacketQueuing(t *testing.T) {
const nicID = 1
var (
@@ -3073,7 +3069,7 @@ func TestPacketQueing(t *testing.T) {
Length: header.UDPMinimumSize,
})
sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv4Addr.AddressWithPrefix.Address, host1IPv4Addr.AddressWithPrefix.Address, header.UDPMinimumSize)
- sum = header.Checksum(header.UDP([]byte{}), sum)
+ sum = header.Checksum(nil, sum)
u.SetChecksum(^u.CalculateChecksum(sum))
ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
ip.Encode(&header.IPv4Fields{
@@ -3089,7 +3085,7 @@ func TestPacketQueing(t *testing.T) {
}))
},
checkResp: func(t *testing.T, e *channel.Endpoint) {
- p, ok := e.ReadContext(context.Background())
+ p, ok := e.Read()
if !ok {
t.Fatalf("timed out waiting for packet")
}
@@ -3132,7 +3128,7 @@ func TestPacketQueing(t *testing.T) {
}))
},
checkResp: func(t *testing.T, e *channel.Endpoint) {
- p, ok := e.ReadContext(context.Background())
+ p, ok := e.Read()
if !ok {
t.Fatalf("timed out waiting for packet")
}
@@ -3156,9 +3152,11 @@ func TestPacketQueing(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
e := channel.New(1, defaultMTU, host1NICLinkAddr)
e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: clock,
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -3181,7 +3179,8 @@ func TestPacketQueing(t *testing.T) {
// Wait for a ARP request since link address resolution should be
// performed.
{
- p, ok := e.ReadContext(context.Background())
+ clock.RunImmediatelyScheduledJobs()
+ p, ok := e.Read()
if !ok {
t.Fatalf("timed out waiting for packet")
}
@@ -3222,6 +3221,7 @@ func TestPacketQueing(t *testing.T) {
}
// Expect the response now that the link address has resolved.
+ clock.RunImmediatelyScheduledJobs()
test.checkResp(t, e)
// Since link resolution was already performed, it shouldn't be performed
@@ -3243,8 +3243,8 @@ func TestCloseLocking(t *testing.T) {
)
var (
- src = tcptestutil.MustParse4("16.0.0.1")
- dst = tcptestutil.MustParse4("16.0.0.2")
+ src = testutil.MustParse4("16.0.0.1")
+ dst = testutil.MustParse4("16.0.0.2")
)
s := stack.New(stack.Options{
@@ -3252,7 +3252,7 @@ func TestCloseLocking(t *testing.T) {
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
- // Perform NAT so that the endoint tries to search for a sibling endpoint
+ // Perform NAT so that the endpoint tries to search for a sibling endpoint
// which ends up taking the protocol and endpoint lock (in that order).
table := stack.Table{
Rules: []stack.Rule{
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 4051fda07..94caaae6c 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -285,8 +285,8 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, routerAlert *header.IPv6RouterAlertOption) {
sent := e.stats.icmp.packetsSent
received := e.stats.icmp.packetsReceived
- // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
- // fields set. See icmp/protocol.go:protocol.Parse for a full explanation.
+ // ICMP packets don't have their TransportHeader fields set. See
+ // icmp/protocol.go:protocol.Parse for a full explanation.
v, ok := pkt.Data().PullUp(header.ICMPv6HeaderSize)
if !ok {
received.invalid.Increment()
@@ -745,11 +745,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
return
}
- stack := e.protocol.stack
-
- // Is the networking stack operating as a router?
- if !stack.Forwarding(ProtocolNumber) {
- // ... No, silently drop the packet.
+ if !e.Forwarding() {
received.routerOnlyPacketsDroppedByHost.Increment()
return
}
@@ -1033,6 +1029,26 @@ func (*icmpReasonNetUnreachable) respondsToMulticast() bool {
return false
}
+// icmpReasonHostUnreachable is an error in which the host specified in the
+// internet destination field of the datagram is unreachable.
+type icmpReasonHostUnreachable struct{}
+
+func (*icmpReasonHostUnreachable) isICMPReason() {}
+func (*icmpReasonHostUnreachable) isForwarding() bool {
+ // If we hit a Host Unreachable error, then we know we are operating as a
+ // router. As per RFC 4443 page 8, Destination Unreachable Message,
+ //
+ // If the reason for the failure to deliver cannot be mapped to any of
+ // other codes, the Code field is set to 3. Example of such cases are
+ // an inability to resolve the IPv6 destination address into a
+ // corresponding link address, or a link-specific problem of some sort.
+ return true
+}
+
+func (*icmpReasonHostUnreachable) respondsToMulticast() bool {
+ return false
+}
+
// icmpReasonFragmentationNeeded is an error where a packet is to big to be sent
// out through the outgoing MTU, as per RFC 4443 page 9, Packet Too Big Message.
type icmpReasonPacketTooBig struct{}
@@ -1147,7 +1163,12 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
defer route.Release()
p.mu.Lock()
- netEP, ok := p.mu.eps[pkt.NICID]
+ // We retrieve an endpoint using the newly constructed route's NICID rather
+ // than the packet's NICID. The packet's NICID corresponds to the NIC on
+ // which it arrived, which isn't necessarily the same as the NIC on which it
+ // will be transmitted. On the other hand, the route's NIC *is* guaranteed
+ // to be the NIC on which the packet will be transmitted.
+ netEP, ok := p.mu.eps[route.NICID()]
p.mu.Unlock()
if !ok {
return &tcpip.ErrNotConnected{}
@@ -1226,6 +1247,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
icmpHdr.SetType(header.ICMPv6DstUnreachable)
icmpHdr.SetCode(header.ICMPv6NetworkUnreachable)
counter = sent.dstUnreachable
+ case *icmpReasonHostUnreachable:
+ icmpHdr.SetType(header.ICMPv6DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv6AddressUnreachable)
+ counter = sent.dstUnreachable
case *icmpReasonPacketTooBig:
icmpHdr.SetType(header.ICMPv6PacketTooBig)
icmpHdr.SetCode(header.ICMPv6UnusedCode)
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 040cd4bc8..c2e9544c1 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -16,17 +16,16 @@ package ipv6
import (
"bytes"
- "context"
"net"
"reflect"
"strings"
"testing"
- "time"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
@@ -46,16 +45,12 @@ const (
defaultChannelSize = 1
defaultMTU = 65536
- // Extra time to use when waiting for an async event to occur.
- defaultAsyncPositiveEventTimeout = 30 * time.Second
-
arbitraryHopLimit = 42
)
var (
lladdr0 = header.LinkLocalAddr(linkAddr0)
lladdr1 = header.LinkLocalAddr(linkAddr1)
- lladdr2 = header.LinkLocalAddr(linkAddr2)
)
type stubLinkEndpoint struct {
@@ -371,6 +366,8 @@ type testContext struct {
linkEP0 *channel.Endpoint
linkEP1 *channel.Endpoint
+
+ clock *faketime.ManualClock
}
type endpointWithResolutionCapability struct {
@@ -382,15 +379,19 @@ func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapab
}
func newTestContext(t *testing.T) *testContext {
+ clock := faketime.NewManualClock()
c := &testContext{
s0: stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
+ Clock: clock,
}),
s1: stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
+ Clock: clock,
}),
+ clock: clock,
}
c.linkEP0 = channel.New(defaultChannelSize, defaultMTU, linkAddr0)
@@ -457,10 +458,14 @@ type routeArgs struct {
remoteLinkAddr tcpip.LinkAddress
}
-func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) {
+func routeICMPv6Packet(t *testing.T, clock *faketime.ManualClock, args routeArgs, fn func(*testing.T, header.ICMPv6)) {
t.Helper()
- pi, _ := args.src.ReadContext(context.Background())
+ clock.RunImmediatelyScheduledJobs()
+ pi, ok := args.src.Read()
+ if !ok {
+ t.Fatal("packet didn't arrive")
+ }
{
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -533,7 +538,7 @@ func TestLinkResolution(t *testing.T) {
{src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))},
{src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert},
} {
- routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) {
+ routeICMPv6Packet(t, c.clock, args, func(t *testing.T, icmpv6 header.ICMPv6) {
if got, want := tcpip.Address(icmpv6[8:][:16]), lladdr1; got != want {
t.Errorf("%d: got target = %s, want = %s", icmpv6.Type(), got, want)
}
@@ -544,7 +549,7 @@ func TestLinkResolution(t *testing.T) {
{src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6EchoRequest},
{src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6EchoReply},
} {
- routeICMPv6Packet(t, args, nil)
+ routeICMPv6Packet(t, c.clock, args, nil)
}
}
@@ -1309,7 +1314,7 @@ func TestPacketQueing(t *testing.T) {
Length: header.UDPMinimumSize,
})
sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize)
- sum = header.Checksum(header.UDP([]byte{}), sum)
+ sum = header.Checksum(nil, sum)
u.SetChecksum(^u.CalculateChecksum(sum))
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
@@ -1325,7 +1330,7 @@ func TestPacketQueing(t *testing.T) {
}))
},
checkResp: func(t *testing.T, e *channel.Endpoint) {
- p, ok := e.ReadContext(context.Background())
+ p, ok := e.Read()
if !ok {
t.Fatalf("timed out waiting for packet")
}
@@ -1371,7 +1376,7 @@ func TestPacketQueing(t *testing.T) {
}))
},
checkResp: func(t *testing.T, e *channel.Endpoint) {
- p, ok := e.ReadContext(context.Background())
+ p, ok := e.Read()
if !ok {
t.Fatalf("timed out waiting for packet")
}
@@ -1396,9 +1401,11 @@ func TestPacketQueing(t *testing.T) {
e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr)
e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: clock,
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -1421,7 +1428,8 @@ func TestPacketQueing(t *testing.T) {
// Wait for a neighbor solicitation since link address resolution should
// be performed.
{
- p, ok := e.ReadContext(context.Background())
+ clock.RunImmediatelyScheduledJobs()
+ p, ok := e.Read()
if !ok {
t.Fatalf("timed out waiting for packet")
}
@@ -1475,6 +1483,7 @@ func TestPacketQueing(t *testing.T) {
}
// Expect the response now that the link address has resolved.
+ clock.RunImmediatelyScheduledJobs()
test.checkResp(t, e)
// Since link resolution was already performed, it shouldn't be performed
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index f0e06f86b..8c8fafcda 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -63,6 +63,11 @@ const (
buckets = 2048
)
+const (
+ forwardingDisabled = 0
+ forwardingEnabled = 1
+)
+
// policyTable is the default policy table defined in RFC 6724 section 2.1.
//
// A more human-readable version:
@@ -168,6 +173,7 @@ func getLabel(addr tcpip.Address) uint8 {
var _ stack.DuplicateAddressDetector = (*endpoint)(nil)
var _ stack.LinkAddressResolver = (*endpoint)(nil)
var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil)
+var _ stack.ForwardingNetworkEndpoint = (*endpoint)(nil)
var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
var _ stack.AddressableEndpoint = (*endpoint)(nil)
var _ stack.NetworkEndpoint = (*endpoint)(nil)
@@ -178,7 +184,6 @@ type endpoint struct {
nic stack.NetworkInterface
dispatcher stack.TransportDispatcher
protocol *protocol
- stack *stack.Stack
stats sharedStats
// enabled is set to 1 when the endpoint is enabled and 0 when it is
@@ -187,6 +192,12 @@ type endpoint struct {
// Must be accessed using atomic operations.
enabled uint32
+ // forwarding is set to forwardingEnabled when the endpoint has forwarding
+ // enabled and forwardingDisabled when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ forwarding uint32
+
mu struct {
sync.RWMutex
@@ -219,11 +230,11 @@ type endpoint struct {
// If the NIC was created with a name, it is passed to NICNameFromID.
//
// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are
-// generated for the same prefix on differnt NICs.
+// generated for the same prefix on different NICs.
type NICNameFromID func(tcpip.NICID, string) string
// OpaqueInterfaceIdentifierOptions holds the options related to the generation
-// of opaque interface indentifiers (IIDs) as defined by RFC 7217.
+// of opaque interface identifiers (IIDs) as defined by RFC 7217.
type OpaqueInterfaceIdentifierOptions struct {
// NICNameFromID is a function that returns a stable name for a specified NIC,
// even if the NIC ID changes over time.
@@ -270,6 +281,16 @@ func (*endpoint) DuplicateAddressProtocol() tcpip.NetworkProtocolNumber {
// HandleLinkResolutionFailure implements stack.LinkResolvableNetworkEndpoint.
func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) {
+ // If we are operating as a router, we should return an ICMP error to the
+ // original packet's sender.
+ if pkt.NetworkPacketInfo.IsForwardedPacket {
+ // TODO(gvisor.dev/issue/6005): Propagate asynchronously generated ICMP
+ // errors to local endpoints.
+ e.protocol.returnError(&icmpReasonHostUnreachable{}, pkt)
+ e.stats.ip.Forwarding.Errors.Increment()
+ e.stats.ip.Forwarding.HostUnreachable.Increment()
+ return
+ }
// handleControl expects the entire offending packet to be in the packet
// buffer's data field.
pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -405,20 +426,38 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr t
}
}
-// transitionForwarding transitions the endpoint's forwarding status to
-// forwarding.
+// Forwarding implements stack.ForwardingNetworkEndpoint.
+func (e *endpoint) Forwarding() bool {
+ return atomic.LoadUint32(&e.forwarding) == forwardingEnabled
+}
+
+// setForwarding sets the forwarding status for the endpoint.
//
-// Must only be called when the forwarding status changes.
-func (e *endpoint) transitionForwarding(forwarding bool) {
+// Returns true if the forwarding status was updated.
+func (e *endpoint) setForwarding(v bool) bool {
+ forwarding := uint32(forwardingDisabled)
+ if v {
+ forwarding = forwardingEnabled
+ }
+
+ return atomic.SwapUint32(&e.forwarding, forwarding) != forwarding
+}
+
+// SetForwarding implements stack.ForwardingNetworkEndpoint.
+func (e *endpoint) SetForwarding(forwarding bool) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if !e.setForwarding(forwarding) {
+ return
+ }
+
allRoutersGroups := [...]tcpip.Address{
header.IPv6AllRoutersInterfaceLocalMulticastAddress,
header.IPv6AllRoutersLinkLocalMulticastAddress,
header.IPv6AllRoutersSiteLocalMulticastAddress,
}
- e.mu.Lock()
- defer e.mu.Unlock()
-
if forwarding {
// As per RFC 4291 section 2.8:
//
@@ -506,7 +545,7 @@ func (e *endpoint) Enable() tcpip.Error {
// Perform DAD on the all the unicast IPv6 endpoints that are in the permanent
// state.
//
- // Addresses may have aleady completed DAD but in the time since the endpoint
+ // Addresses may have already completed DAD but in the time since the endpoint
// was last enabled, other devices may have acquired the same addresses.
var err tcpip.Error
e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
@@ -611,8 +650,8 @@ func (e *endpoint) DefaultTTL() uint8 {
return e.protocol.DefaultTTL()
}
-// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
-// the network layer max header length.
+// MTU implements stack.NetworkEndpoint. It returns the link-layer MTU minus the
+// network layer max header length.
func (e *endpoint) MTU() uint32 {
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv6MinimumSize)
if err != nil {
@@ -636,8 +675,7 @@ func addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params
if length > math.MaxUint16 {
return &tcpip.ErrMessageTooLong{}
}
- ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen))
- ip.Encode(&header.IPv6Fields{
+ header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)).Encode(&header.IPv6Fields{
PayloadLength: uint16(length),
TransportProtocol: params.Protocol,
HopLimit: params.TTL,
@@ -717,9 +755,9 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams,
// based on destination address and do not send the packet to link
// layer.
//
- // TODO(gvisor.dev/issue/170): We should do this for every
- // packet, rather than only NATted packets, but removing this check
- // short circuits broadcasts before they are sent out to other hosts.
+ // We should do this for every packet, rather than only NATted packets, but
+ // removing this check short circuits broadcasts before they are sent out to
+ // other hosts.
if pkt.NatDone {
netHeader := header.IPv6(pkt.NetworkHeader().View())
if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil {
@@ -788,7 +826,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol
return nil
}
-// WritePackets implements stack.NetworkEndpoint.WritePackets.
+// WritePackets implements stack.NetworkEndpoint.
func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) {
if r.Loop()&stack.PacketLoop != 0 {
panic("not implemented")
@@ -879,20 +917,20 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
if !ok {
return &tcpip.ErrMalformedHeader{}
}
- ip := header.IPv6(h)
+ ipH := header.IPv6(h)
// Always set the payload length.
pktSize := pkt.Data().Size()
- ip.SetPayloadLength(uint16(pktSize - header.IPv6MinimumSize))
+ ipH.SetPayloadLength(uint16(pktSize - header.IPv6MinimumSize))
// Set the source address when zero.
- if ip.SourceAddress() == header.IPv6Any {
- ip.SetSourceAddress(r.LocalAddress())
+ if ipH.SourceAddress() == header.IPv6Any {
+ ipH.SetSourceAddress(r.LocalAddress())
}
// Set the destination. If the packet already included a destination, it will
// be part of the route anyways.
- ip.SetDestinationAddress(r.RemoteAddress())
+ ipH.SetDestinationAddress(r.RemoteAddress())
// Populate the packet buffer's network header and don't allow an invalid
// packet to be sent.
@@ -953,7 +991,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
return nil
}
- ep.handleValidatedPacket(h, pkt)
+ // The packet originally arrived on e so provide its NIC as the input NIC.
+ ep.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
return nil
}
@@ -1066,7 +1105,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
}
}
- e.handleValidatedPacket(h, pkt)
+ e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
}
// handleLocalPacket is like HandlePacket except it does not perform the
@@ -1085,10 +1124,10 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum
return
}
- e.handleValidatedPacket(h, pkt)
+ e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
}
-func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) {
+func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer, inNICName string) {
pkt.NICID = e.nic.ID()
stats := e.stats.ip
stats.ValidPacketsReceived.Increment()
@@ -1109,7 +1148,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil {
addressEndpoint.DecRef()
} else if !e.IsInGroup(dstAddr) {
- if !e.protocol.Forwarding() {
+ if !e.Forwarding() {
stats.InvalidDestinationAddressesReceived.Increment()
return
}
@@ -1137,8 +1176,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
// iptables filtering. All packets that reach here are intended for
// this machine and need not be forwarded.
- inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
stats.IPTablesInputDropped.Increment()
return
@@ -1580,7 +1618,7 @@ func (e *endpoint) Close() {
e.protocol.forgetEndpoint(e.nic.ID())
}
-// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
+// NetworkProtocolNumber implements stack.NetworkEndpoint.
func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
@@ -1589,8 +1627,8 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
// TODO(b/169350103): add checks here after making sure we no longer receive
// an empty address.
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.mu.Lock()
+ defer e.mu.Unlock()
return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated)
}
@@ -1631,8 +1669,8 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre
// RemovePermanentAddress implements stack.AddressableEndpoint.
func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.mu.Lock()
+ defer e.mu.Unlock()
addressEndpoint := e.getAddressRLocked(addr)
if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() {
@@ -1932,7 +1970,6 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats {
return &e.stats.localStats
}
-var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
var _ fragmentation.TimeoutHandler = (*protocol)(nil)
@@ -1957,12 +1994,6 @@ type protocol struct {
// Must be accessed using atomic operations.
defaultTTL uint32
- // forwarding is set to 1 when the protocol has forwarding enabled and 0
- // when it is disabled.
- //
- // Must be accessed using atomic operations.
- forwarding uint32
-
fragmentation *fragmentation.Fragmentation
}
@@ -1981,7 +2012,7 @@ func (p *protocol) DefaultPrefixLen() int {
return header.IPv6AddressSize * 8
}
-// ParseAddresses implements NetworkProtocol.ParseAddresses.
+// ParseAddresses implements stack.NetworkProtocol.
func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.IPv6(v)
return h.SourceAddress(), h.DestinationAddress()
@@ -2058,7 +2089,7 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
delete(p.mu.eps, nicID)
}
-// SetOption implements NetworkProtocol.SetOption.
+// SetOption implements stack.NetworkProtocol.
func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
@@ -2069,7 +2100,7 @@ func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.E
}
}
-// Option implements NetworkProtocol.Option.
+// Option implements stack.NetworkProtocol.
func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
@@ -2090,10 +2121,10 @@ func (p *protocol) DefaultTTL() uint8 {
return uint8(atomic.LoadUint32(&p.defaultTTL))
}
-// Close implements stack.TransportProtocol.Close.
+// Close implements stack.TransportProtocol.
func (*protocol) Close() {}
-// Wait implements stack.TransportProtocol.Wait.
+// Wait implements stack.TransportProtocol.
func (*protocol) Wait() {}
// parseAndValidate parses the packet (including its transport layer header) and
@@ -2127,7 +2158,7 @@ func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv6, bool)
return h, true
}
-// Parse implements stack.NetworkProtocol.Parse.
+// Parse implements stack.NetworkProtocol.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
proto, _, fragOffset, fragMore, ok := parse.IPv6(pkt)
if !ok {
@@ -2137,35 +2168,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return proto, !fragMore && fragOffset == 0, true
}
-// Forwarding implements stack.ForwardingNetworkProtocol.
-func (p *protocol) Forwarding() bool {
- return uint8(atomic.LoadUint32(&p.forwarding)) == 1
-}
-
-// setForwarding sets the forwarding status for the protocol.
-//
-// Returns true if the forwarding status was updated.
-func (p *protocol) setForwarding(v bool) bool {
- if v {
- return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */)
- }
- return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */)
-}
-
-// SetForwarding implements stack.ForwardingNetworkProtocol.
-func (p *protocol) SetForwarding(v bool) {
- p.mu.Lock()
- defer p.mu.Unlock()
-
- if !p.setForwarding(v) {
- return
- }
-
- for _, ep := range p.mu.eps {
- ep.transitionForwarding(v)
- }
-}
-
// calculateNetworkMTU calculates the network-layer payload MTU based on the
// link-layer payload MTU and the length of every IPv6 header.
// Note that this is different than the Payload Length field of the IPv6 header,
@@ -2186,7 +2188,7 @@ func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, tcpip.Error
return 0, &tcpip.ErrMalformedHeader{}
}
- networkMTU := linkMTU - uint32(networkHeadersLen)
+ networkMTU := linkMTU - networkHeadersLen
if networkMTU > maxPayloadSize {
networkMTU = maxPayloadSize
}
@@ -2205,7 +2207,7 @@ type Options struct {
// Note, setting this to true does not mean that a link-local address is
// assigned right away, or at all. If Duplicate Address Detection is enabled,
// an address is only assigned if it successfully resolves. If it fails, no
- // further attempts are made to auto-generate a link-local adddress.
+ // further attempts are made to auto-generate a link-local address.
//
// The generated link-local address follows RFC 4291 Appendix A guidelines.
AutoGenLinkLocal bool
@@ -2221,7 +2223,7 @@ type Options struct {
// TempIIDSeed is used to seed the initial temporary interface identifier
// history value used to generate IIDs for temporary SLAAC addresses.
//
- // Temporary SLAAC adresses are short-lived addresses which are unpredictable
+ // Temporary SLAAC addresses are short-lived addresses which are unpredictable
// and random from the perspective of other nodes on the network. It is
// recommended that the seed be a random byte buffer of at least
// header.IIDSize bytes to make sure that temporary SLAAC addresses are
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 30325160a..d2a23fd4f 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -129,7 +129,7 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize)
// UDP checksum
- sum = header.Checksum(header.UDP([]byte{}), sum)
+ sum = header.Checksum(nil, sum)
u.SetChecksum(^u.CalculateChecksum(sum))
payloadLength := hdr.UsedLength()
@@ -402,7 +402,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}{
{
name: "None",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr },
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return nil, nextHdr },
shouldAccept: true,
expectICMP: false,
},
@@ -612,8 +612,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
{
name: "No next header",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{},
- noNextHdrID
+ return nil, noNextHdrID
},
shouldAccept: false,
expectICMP: false,
@@ -1160,7 +1159,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2),
[]buffer.View{
// Fragment extension header.
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0},
ipv6Payload1Addr1ToAddr2,
},
@@ -1180,7 +1179,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2),
[]buffer.View{
// Fragment extension header.
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0},
ipv6Payload3Addr1ToAddr2,
},
@@ -1202,7 +1201,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[:64],
},
@@ -1218,7 +1217,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[64:],
},
@@ -1240,7 +1239,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[64:],
},
@@ -1256,7 +1255,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[:64],
},
@@ -1278,7 +1277,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[:64],
},
@@ -1296,7 +1295,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 8, More = false, ID = 1
// NextHeader value is different than the one in the first fragment, so
// this NextHeader should be ignored.
- buffer.View([]byte{uint8(header.IPv6NoNextHeaderIdentifier), 0, 0, 64, 0, 0, 0, 1}),
+ []byte{uint8(header.IPv6NoNextHeaderIdentifier), 0, 0, 64, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[64:],
},
@@ -1318,7 +1317,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload3Addr1ToAddr2[:64],
},
@@ -1334,7 +1333,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1},
ipv6Payload3Addr1ToAddr2[64:],
},
@@ -1356,7 +1355,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload3Addr1ToAddr2[:63],
},
@@ -1372,7 +1371,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1},
ipv6Payload3Addr1ToAddr2[63:],
},
@@ -1394,7 +1393,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[:64],
},
@@ -1410,7 +1409,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 2
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2},
ipv6Payload1Addr1ToAddr2[64:],
},
@@ -1432,7 +1431,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15],
},
@@ -1448,10 +1447,10 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = udpMaximumSizeMinus15/8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0,
+ []byte{uint8(header.UDPProtocolNumber), 0,
udpMaximumSizeMinus15 >> 8,
udpMaximumSizeMinus15 & 0xff,
- 0, 0, 0, 1}),
+ 0, 0, 0, 1},
ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:],
},
@@ -1473,7 +1472,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15],
},
@@ -1489,10 +1488,10 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = udpMaximumSizeMinus15/8, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0,
+ []byte{uint8(header.UDPProtocolNumber), 0,
udpMaximumSizeMinus15 >> 8,
(udpMaximumSizeMinus15 & 0xff) + 1,
- 0, 0, 0, 1}),
+ 0, 0, 0, 1},
ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:],
},
@@ -1514,12 +1513,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Routing extension header.
//
// Segments left = 0.
- buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}),
+ []byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5},
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[:64],
},
@@ -1535,12 +1534,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Routing extension header.
//
// Segments left = 0.
- buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}),
+ []byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5},
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[64:],
},
@@ -1562,12 +1561,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Routing extension header.
//
// Segments left = 1.
- buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}),
+ []byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5},
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[:64],
},
@@ -1583,12 +1582,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Routing extension header.
//
// Segments left = 1.
- buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}),
+ []byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5},
// Fragment extension header.
//
// Fragment offset = 9, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[64:],
},
@@ -1610,12 +1609,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+ []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1},
// Routing extension header.
//
// Segments left = 0.
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5},
ipv6Payload1Addr1ToAddr2[:64],
},
@@ -1631,7 +1630,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 9, More = false, ID = 1
- buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}),
+ []byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[64:],
},
@@ -1653,12 +1652,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+ []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1},
// Routing extension header.
//
// Segments left = 1.
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5},
ipv6Payload1Addr1ToAddr2[:64],
},
@@ -1674,7 +1673,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 9, More = false, ID = 1
- buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}),
+ []byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[64:],
},
@@ -1699,12 +1698,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+ []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1},
// Routing extension header (part 1)
//
// Segments left = 0.
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 0, 2, 3, 4, 5}),
+ []byte{uint8(header.UDPProtocolNumber), 1, 1, 0, 2, 3, 4, 5},
},
),
},
@@ -1721,10 +1720,10 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 1, More = false, ID = 1
- buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}),
+ []byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1},
// Routing extension header (part 2)
- buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}),
+ []byte{6, 7, 8, 9, 10, 11, 12, 13},
ipv6Payload1Addr1ToAddr2,
},
@@ -1749,12 +1748,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+ []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1},
// Routing extension header (part 1)
//
// Segments left = 1.
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 1, 2, 3, 4, 5}),
+ []byte{uint8(header.UDPProtocolNumber), 1, 1, 1, 2, 3, 4, 5},
},
),
},
@@ -1771,10 +1770,10 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 1, More = false, ID = 1
- buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}),
+ []byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1},
// Routing extension header (part 2)
- buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}),
+ []byte{6, 7, 8, 9, 10, 11, 12, 13},
ipv6Payload1Addr1ToAddr2,
},
@@ -1798,7 +1797,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[:64],
},
@@ -1816,7 +1815,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1},
ipv6Payload2Addr1ToAddr2,
},
@@ -1832,7 +1831,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[64:],
},
@@ -1854,7 +1853,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[:64],
},
@@ -1870,7 +1869,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 2
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2},
ipv6Payload2Addr1ToAddr2[:32],
},
@@ -1886,7 +1885,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[64:],
},
@@ -1902,7 +1901,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 4, More = false, ID = 2
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2},
ipv6Payload2Addr1ToAddr2[32:],
},
@@ -1924,7 +1923,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[:64],
},
@@ -1940,7 +1939,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload1Addr3ToAddr2[:32],
},
@@ -1956,7 +1955,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[64:],
},
@@ -1972,7 +1971,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 4, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 1},
ipv6Payload1Addr3ToAddr2[32:],
},
@@ -2208,7 +2207,7 @@ func TestInvalidIPv6Fragments(t *testing.T) {
checker.ICMPv6Type(test.expectICMPType),
checker.ICMPv6Code(test.expectICMPCode),
checker.ICMPv6TypeSpecific(test.expectICMPTypeSpecific),
- checker.ICMPv6Payload([]byte(expectICMPPayload)),
+ checker.ICMPv6Payload(expectICMPPayload),
),
)
})
@@ -2459,7 +2458,7 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
checker.ICMPv6(
checker.ICMPv6Type(header.ICMPv6TimeExceeded),
checker.ICMPv6Code(header.ICMPv6ReassemblyTimeout),
- checker.ICMPv6Payload([]byte(firstFragmentSent)),
+ checker.ICMPv6Payload(firstFragmentSent),
),
)
})
@@ -2795,11 +2794,7 @@ var fragmentationTests = []struct {
}
func TestFragmentationWritePacket(t *testing.T) {
- const (
- ttl = 42
- tos = stack.DefaultTOS
- transportProto = tcp.ProtocolNumber
- )
+ const ttl = 42
for _, ft := range fragmentationTests {
t.Run(ft.description, func(t *testing.T) {
@@ -3004,17 +2999,17 @@ func TestFragmentationErrors(t *testing.T) {
func TestForwarding(t *testing.T) {
const (
- nicID1 = 1
- nicID2 = 2
+ incomingNICID = 1
+ outgoingNICID = 2
randomSequence = 123
randomIdent = 42
)
- ipv6Addr1 := tcpip.AddressWithPrefix{
+ incomingIPv6Addr := tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("10::1").To16()),
PrefixLen: 64,
}
- ipv6Addr2 := tcpip.AddressWithPrefix{
+ outgoingIPv6Addr := tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("11::1").To16()),
PrefixLen: 64,
}
@@ -3022,6 +3017,7 @@ func TestForwarding(t *testing.T) {
Address: tcpip.Address(net.ParseIP("ff00::").To16()),
PrefixLen: 64,
}
+
remoteIPv6Addr1 := tcpip.Address(net.ParseIP("10::2").To16())
remoteIPv6Addr2 := tcpip.Address(net.ParseIP("11::2").To16())
unreachableIPv6Addr := tcpip.Address(net.ParseIP("12::2").To16())
@@ -3296,36 +3292,36 @@ func TestForwarding(t *testing.T) {
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
})
// We expect at most a single packet in response to our ICMP Echo Request.
- e1 := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID1, e1); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ incomingEndpoint := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err)
}
- ipv6ProtoAddr1 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr1}
- if err := s.AddProtocolAddress(nicID1, ipv6ProtoAddr1); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv6ProtoAddr1, err)
+ incomingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: incomingIPv6Addr}
+ if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv6ProtoAddr, err)
}
- e2 := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID2, e2); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ outgoingEndpoint := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err)
}
- ipv6ProtoAddr2 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr2}
- if err := s.AddProtocolAddress(nicID2, ipv6ProtoAddr2); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv6ProtoAddr2, err)
+ outgoingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: outgoingIPv6Addr}
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv6ProtoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
{
- Destination: ipv6Addr1.Subnet(),
- NIC: nicID1,
+ Destination: incomingIPv6Addr.Subnet(),
+ NIC: incomingNICID,
},
{
- Destination: ipv6Addr2.Subnet(),
- NIC: nicID2,
+ Destination: outgoingIPv6Addr.Subnet(),
+ NIC: outgoingNICID,
},
{
Destination: multicastIPv6Addr.Subnet(),
- NIC: nicID2,
+ NIC: outgoingNICID,
},
})
@@ -3334,7 +3330,7 @@ func TestForwarding(t *testing.T) {
}
transportProtocol := header.ICMPv6ProtocolNumber
- extHdrBytes := []byte{}
+ var extHdrBytes []byte
extHdrChecker := checker.IPv6ExtHdr()
if test.extHdr != nil {
nextHdrID := hopByHopExtHdrID
@@ -3348,15 +3344,15 @@ func TestForwarding(t *testing.T) {
totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength + extHdrLen
hdr := buffer.NewPrependable(totalLength)
hdr.Prepend(test.payloadLength)
- icmp := header.ICMPv6(hdr.Prepend(icmpHeaderLength))
-
- icmp.SetIdent(randomIdent)
- icmp.SetSequence(randomSequence)
- icmp.SetType(header.ICMPv6EchoRequest)
- icmp.SetCode(header.ICMPv6UnusedCode)
- icmp.SetChecksum(0)
- icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmp,
+ icmpH := header.ICMPv6(hdr.Prepend(icmpHeaderLength))
+
+ icmpH.SetIdent(randomIdent)
+ icmpH.SetSequence(randomSequence)
+ icmpH.SetType(header.ICMPv6EchoRequest)
+ icmpH.SetCode(header.ICMPv6UnusedCode)
+ icmpH.SetChecksum(0)
+ icmpH.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: icmpH,
Src: test.sourceAddr,
Dst: test.destAddr,
}))
@@ -3372,9 +3368,10 @@ func TestForwarding(t *testing.T) {
requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
})
- e1.InjectInbound(ProtocolNumber, requestPkt)
+ incomingEndpoint.InjectInbound(ProtocolNumber, requestPkt)
+
+ reply, ok := incomingEndpoint.Read()
- reply, ok := e1.Read()
if test.expectErrorICMP {
if !ok {
t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType)
@@ -3393,31 +3390,31 @@ func TestForwarding(t *testing.T) {
return len(hdr.View())
}
- checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())),
- checker.SrcAddr(ipv6Addr1.Address),
+ checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
+ checker.SrcAddr(incomingIPv6Addr.Address),
checker.DstAddr(test.sourceAddr),
checker.TTL(DefaultTTL),
checker.ICMPv6(
checker.ICMPv6Type(test.icmpType),
checker.ICMPv6Code(test.icmpCode),
- checker.ICMPv6Payload([]byte(hdr.View()[0:expectedICMPPayloadLength()])),
+ checker.ICMPv6Payload(hdr.View()[:expectedICMPPayloadLength()]),
),
)
- if n := e2.Drain(); n != 0 {
+ if n := outgoingEndpoint.Drain(); n != 0 {
t.Fatalf("got e2.Drain() = %d, want = 0", n)
}
} else if ok {
t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", reply)
}
- reply, ok = e2.Read()
+ reply, ok = outgoingEndpoint.Read()
if test.expectPacketForwarded {
if !ok {
t.Fatal("expected ICMP Echo Request packet through outgoing NIC")
}
- checker.IPv6WithExtHdr(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.IPv6WithExtHdr(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
checker.SrcAddr(test.sourceAddr),
checker.DstAddr(test.destAddr),
checker.TTL(test.TTL-1),
@@ -3429,7 +3426,7 @@ func TestForwarding(t *testing.T) {
),
)
- if n := e1.Drain(); n != 0 {
+ if n := incomingEndpoint.Drain(); n != 0 {
t.Fatalf("got e1.Drain() = %d, want = 0", n)
}
} else if ok {
diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go
index 71d1c3e28..bc9cf6999 100644
--- a/pkg/tcpip/network/ipv6/mld_test.go
+++ b/pkg/tcpip/network/ipv6/mld_test.go
@@ -16,6 +16,7 @@ package ipv6_test
import (
"bytes"
+ "math/rand"
"testing"
"time"
@@ -138,7 +139,8 @@ func TestSendQueuedMLDReports(t *testing.T) {
var secureRNG bytes.Reader
secureRNG.Reset(secureRNGBytes[:])
s := stack.New(stack.Options{
- SecureRNG: &secureRNG,
+ SecureRNG: &secureRNG,
+ RandSource: rand.NewSource(time.Now().UnixNano()),
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
DADConfigs: stack.DADConfigurations{
DupAddrDetectTransmits: test.dadTransmits,
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index b29fed347..ee36ed254 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -16,7 +16,6 @@ package ipv6
import (
"fmt"
- "math/rand"
"time"
"gvisor.dev/gvisor/pkg/sync"
@@ -215,28 +214,23 @@ type NDPDispatcher interface {
// is also not permitted to call into the stack.
OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult)
- // OnDefaultRouterDiscovered is called when a new default router is
- // discovered. Implementations must return true if the newly discovered
- // router should be remembered.
+ // OnOffLinkRouteUpdated is called when an off-link route is updated.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool
+ OnOffLinkRouteUpdated(tcpip.NICID, tcpip.Subnet, tcpip.Address)
- // OnDefaultRouterInvalidated is called when a discovered default router that
- // was remembered is invalidated.
+ // OnOffLinkRouteInvalidated is called when an off-link route is invalidated.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address)
+ OnOffLinkRouteInvalidated(tcpip.NICID, tcpip.Subnet, tcpip.Address)
// OnOnLinkPrefixDiscovered is called when a new on-link prefix is discovered.
- // Implementations must return true if the newly discovered on-link prefix
- // should be remembered.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool
+ OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet)
// OnOnLinkPrefixInvalidated is called when a discovered on-link prefix that
// was remembered is invalidated.
@@ -246,13 +240,11 @@ type NDPDispatcher interface {
OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet)
// OnAutoGenAddress is called when a new prefix with its autonomous address-
- // configuration flag set is received and SLAAC was performed. Implementations
- // may prevent the stack from assigning the address to the NIC by returning
- // false.
+ // configuration flag set is received and SLAAC was performed.
//
// This function is not permitted to block indefinitely. It must not
// call functions on the stack itself.
- OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool
+ OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix)
// OnAutoGenAddressDeprecated is called when an auto-generated address (SLAAC)
// is deprecated, but is still considered valid. Note, if an address is
@@ -549,7 +541,7 @@ type tempSLAACAddrState struct {
// Must not be nil.
regenJob *tcpip.Job
- createdAt time.Time
+ createdAt tcpip.MonotonicTime
// The address's endpoint.
//
@@ -573,10 +565,10 @@ type slaacPrefixState struct {
invalidationJob *tcpip.Job
// Nonzero only when the address is not valid forever.
- validUntil time.Time
+ validUntil tcpip.MonotonicTime
// Nonzero only when the address is not preferred forever.
- preferredUntil time.Time
+ preferredUntil tcpip.MonotonicTime
// State associated with the stable address generated for the prefix.
stableAddr struct {
@@ -705,7 +697,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
// per-interface basis; it is a protocol-wide configuration, so we check the
// protocol's forwarding flag to determine if the IPv6 endpoint is forwarding
// packets.
- if !ndp.configs.HandleRAs.enabled(ndp.ep.protocol.Forwarding()) {
+ if !ndp.configs.HandleRAs.enabled(ndp.ep.Forwarding()) {
ndp.ep.stats.localStats.UnhandledRouterAdvertisements.Increment()
return
}
@@ -832,7 +824,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
// Let the integrator know a discovered default router is invalidated.
if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
- ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip)
+ ndpDisp.OnOffLinkRouteInvalidated(ndp.ep.nic.ID(), header.IPv6EmptySubnet, ip)
}
}
@@ -849,11 +841,7 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
}
// Inform the integrator when we discovered a default router.
- if !ndpDisp.OnDefaultRouterDiscovered(ndp.ep.nic.ID(), ip) {
- // Informed by the integrator to not remember the router, do
- // nothing further.
- return
- }
+ ndpDisp.OnOffLinkRouteUpdated(ndp.ep.nic.ID(), header.IPv6EmptySubnet, ip)
state := defaultRouterState{
invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
@@ -879,11 +867,7 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration)
}
// Inform the integrator when we discovered an on-link prefix.
- if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix) {
- // Informed by the integrator to not remember the prefix, do
- // nothing further.
- return
- }
+ ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix)
state := onLinkPrefixState{
invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
@@ -1051,7 +1035,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
maxGenerationAttempts: ndp.configs.AutoGenAddressConflictRetries + 1,
}
- now := time.Now()
+ now := ndp.ep.protocol.stack.Clock().NowMonotonic()
// The time an address is preferred until is needed to properly generate the
// address.
@@ -1097,16 +1081,13 @@ func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, config
return nil
}
- if !ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr) {
- // Informed by the integrator not to add the address.
- return nil
- }
-
addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.FirstPrimaryEndpoint, configType, deprecated)
if err != nil {
panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", addr, err))
}
+ ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr)
+
return addressEndpoint
}
@@ -1182,7 +1163,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
state.stableAddr.localGenerationFailures++
}
- if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); addressEndpoint != nil {
+ if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, ndp.ep.protocol.stack.Clock().NowMonotonic().Sub(state.preferredUntil) >= 0 /* deprecated */); addressEndpoint != nil {
state.stableAddr.addressEndpoint = addressEndpoint
state.generationAttempts++
return true
@@ -1237,13 +1218,13 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
}
stableAddr := prefixState.stableAddr.addressEndpoint.AddressWithPrefix().Address
- now := time.Now()
+ now := ndp.ep.protocol.stack.Clock().NowMonotonic()
// As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary
// address is the lower of the valid lifetime of the stable address or the
// maximum temporary address valid lifetime.
vl := ndp.configs.MaxTempAddrValidLifetime
- if prefixState.validUntil != (time.Time{}) {
+ if prefixState.validUntil != (tcpip.MonotonicTime{}) {
if prefixVL := prefixState.validUntil.Sub(now); vl > prefixVL {
vl = prefixVL
}
@@ -1259,7 +1240,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
// maximum temporary address preferred lifetime - the temporary address desync
// factor.
pl := ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor
- if prefixState.preferredUntil != (time.Time{}) {
+ if prefixState.preferredUntil != (tcpip.MonotonicTime{}) {
if prefixPL := prefixState.preferredUntil.Sub(now); pl > prefixPL {
// Respect the preferred lifetime of the prefix, as per RFC 4941 section
// 3.3 step 4.
@@ -1394,7 +1375,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// deprecation job so it can be reset.
prefixState.deprecationJob.Cancel()
- now := time.Now()
+ now := ndp.ep.protocol.stack.Clock().NowMonotonic()
// Schedule the deprecation job if prefix has a finite preferred lifetime.
if pl < header.NDPInfiniteLifetime {
@@ -1403,7 +1384,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
prefixState.preferredUntil = now.Add(pl)
} else {
- prefixState.preferredUntil = time.Time{}
+ prefixState.preferredUntil = tcpip.MonotonicTime{}
}
// As per RFC 4862 section 5.5.3.e, update the valid lifetime for prefix:
@@ -1421,17 +1402,17 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// Handle the infinite valid lifetime separately as we do not schedule a
// job in this case.
prefixState.invalidationJob.Cancel()
- prefixState.validUntil = time.Time{}
+ prefixState.validUntil = tcpip.MonotonicTime{}
} else {
var effectiveVl time.Duration
var rl time.Duration
// If the prefix was originally set to be valid forever, assume the
// remaining time to be the maximum possible value.
- if prefixState.validUntil == (time.Time{}) {
+ if prefixState.validUntil == (tcpip.MonotonicTime{}) {
rl = header.NDPInfiniteLifetime
} else {
- rl = time.Until(prefixState.validUntil)
+ rl = prefixState.validUntil.Sub(now)
}
if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl {
@@ -1463,7 +1444,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// maximum temporary address valid lifetime. Note, the valid lifetime of a
// temporary address is relative to the address's creation time.
validUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrValidLifetime)
- if prefixState.validUntil != (time.Time{}) && validUntil.Sub(prefixState.validUntil) > 0 {
+ if prefixState.validUntil != (tcpip.MonotonicTime{}) && validUntil.Sub(prefixState.validUntil) > 0 {
validUntil = prefixState.validUntil
}
@@ -1483,7 +1464,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// desync factor. Note, the preferred lifetime of a temporary address is
// relative to the address's creation time.
preferredUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor)
- if prefixState.preferredUntil != (time.Time{}) && preferredUntil.Sub(prefixState.preferredUntil) > 0 {
+ if prefixState.preferredUntil != (tcpip.MonotonicTime{}) && preferredUntil.Sub(prefixState.preferredUntil) > 0 {
preferredUntil = prefixState.preferredUntil
}
@@ -1710,7 +1691,7 @@ func (ndp *ndpState) startSolicitingRouters() {
return
}
- if !ndp.configs.HandleRAs.enabled(ndp.ep.protocol.Forwarding()) {
+ if !ndp.configs.HandleRAs.enabled(ndp.ep.Forwarding()) {
return
}
@@ -1718,7 +1699,7 @@ func (ndp *ndpState) startSolicitingRouters() {
// 4861 section 6.3.7.
var delay time.Duration
if ndp.configs.MaxRtrSolicitationDelay > 0 {
- delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay)))
+ delay = time.Duration(ndp.ep.protocol.stack.Rand().Int63n(int64(ndp.configs.MaxRtrSolicitationDelay)))
}
// Protected by ndp.ep.mu.
@@ -1754,7 +1735,7 @@ func (ndp *ndpState) startSolicitingRouters() {
header.NDPSourceLinkLayerAddressOption(linkAddress),
}
}
- payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length())
+ payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + optsSerializer.Length()
icmpData := header.ICMPv6(buffer.NewView(payloadSize))
icmpData.SetType(header.ICMPv6RouterSolicit)
rs := header.NDPRouterSolicit(icmpData.MessageBody())
@@ -1861,7 +1842,7 @@ func (ndp *ndpState) init(ep *endpoint, dadOptions ip.DADOptions) {
header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID())
if MaxDesyncFactor != 0 {
- ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor)))
+ ndp.temporaryAddressDesyncFactor = time.Duration(ep.protocol.stack.Rand().Int63n(int64(MaxDesyncFactor)))
}
}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 570c6c00c..95d23f200 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -16,7 +16,7 @@ package ipv6
import (
"bytes"
- "context"
+ "math/rand"
"strings"
"testing"
"time"
@@ -32,58 +32,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
)
-// setupStackAndEndpoint creates a stack with a single NIC with a link-local
-// address llladdr and an IPv6 endpoint to a remote with link-local address
-// rlladdr
-func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) {
- t.Helper()
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
- })
-
- if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
- {
- subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr))))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet,
- NIC: 1,
- }},
- )
- }
-
- netProto := s.NetworkProtocolInstance(ProtocolNumber)
- if netProto == nil {
- t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
- }
-
- ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{})
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
- t.Cleanup(ep.Close)
-
- addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
- if !ok {
- t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
- }
- addr := llladdr.WithPrefix()
- if addressEP, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
- } else {
- addressEP.DecRef()
- }
-
- return s, ep
-}
-
var _ NDPDispatcher = (*testNDPDispatcher)(nil)
// testNDPDispatcher is an NDPDispatcher only allows default router discovery.
@@ -94,24 +42,21 @@ type testNDPDispatcher struct {
func (*testNDPDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) {
}
-func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool {
+func (t *testNDPDispatcher) OnOffLinkRouteUpdated(_ tcpip.NICID, _ tcpip.Subnet, addr tcpip.Address) {
t.addr = addr
- return true
}
-func (t *testNDPDispatcher) OnDefaultRouterInvalidated(_ tcpip.NICID, addr tcpip.Address) {
+func (t *testNDPDispatcher) OnOffLinkRouteInvalidated(_ tcpip.NICID, _ tcpip.Subnet, addr tcpip.Address) {
t.addr = addr
}
-func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool {
- return false
+func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) {
}
func (*testNDPDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {
}
-func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool {
- return false
+func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) {
}
func (*testNDPDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {
@@ -163,11 +108,6 @@ func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) {
}
}
-type linkResolutionResult struct {
- linkAddr tcpip.LinkAddress
- ok bool
-}
-
// TestNeighborSolicitationWithSourceLinkLayerOption tests that receiving a
// valid NDP NS message with the Source Link Layer Address option results in a
// new entry in the link address cache for the sender of the message.
@@ -456,8 +396,10 @@ func TestNeighborSolicitationResponse(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ Clock: clock,
})
e := channel.New(1, 1280, nicLinkAddr)
e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
@@ -527,7 +469,8 @@ func TestNeighborSolicitationResponse(t *testing.T) {
}
if test.performsLinkResolution {
- p, got := e.ReadContext(context.Background())
+ clock.RunImmediatelyScheduledJobs()
+ p, got := e.Read()
if !got {
t.Fatal("expected an NDP NS response")
}
@@ -582,7 +525,8 @@ func TestNeighborSolicitationResponse(t *testing.T) {
}))
}
- p, got := e.ReadContext(context.Background())
+ clock.RunImmediatelyScheduledJobs()
+ p, got := e.Read()
if !got {
t.Fatal("expected an NDP NA response")
}
@@ -732,15 +676,7 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) {
}
func TestNDPValidation(t *testing.T) {
- setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) {
- t.Helper()
-
- // Create a stack with the assigned link-local address lladdr0
- // and an endpoint to lladdr1.
- s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1)
-
- return s, ep
- }
+ const nicID = 1
handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) {
var extHdrs header.IPv6ExtHdrSerializer
@@ -865,6 +801,11 @@ func TestNDPValidation(t *testing.T) {
},
}
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0))))
+ if err != nil {
+ t.Fatal(err)
+ }
+
for _, typ := range types {
for _, isRouter := range []bool{false, true} {
name := typ.name
@@ -875,7 +816,10 @@ func TestNDPValidation(t *testing.T) {
t.Run(name, func(t *testing.T) {
for _, test := range subTests {
t.Run(test.name, func(t *testing.T) {
- s, ep := setup(t)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
+ })
if isRouter {
if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil {
@@ -883,17 +827,35 @@ func TestNDPValidation(t *testing.T) {
}
}
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber)
+ if err != nil {
+ t.Fatal("cannot find network endpoint instance for IPv6")
+ }
+
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: subnet,
+ NIC: nicID,
+ }})
+
stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
routerOnly := stats.RouterOnlyPacketsDroppedByHost
typStat := typ.statCounter(stats)
- icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
- copy(icmp[typ.size:], typ.extraData)
- icmp.SetType(typ.typ)
- icmp.SetCode(test.code)
- icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmp[:typ.size],
+ icmpH := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmpH[typ.size:], typ.extraData)
+ icmpH.SetType(typ.typ)
+ icmpH.SetCode(test.code)
+ icmpH.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: icmpH[:typ.size],
Src: lladdr0,
Dst: lladdr1,
PayloadCsum: header.Checksum(typ.extraData /* initial */, 0),
@@ -907,19 +869,19 @@ func TestNDPValidation(t *testing.T) {
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
- t.Errorf("got invalid = %d, want = 0", got)
+ t.Errorf("got invalid.Value() = %d, want = 0", got)
}
- // RouterOnlyPacketsReceivedByHost count should initially be 0.
+ // Should initially not have dropped any packets.
if got := routerOnly.Value(); got != 0 {
- t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
+ t.Errorf("got routerOnly.Value() = %d, want = 0", got)
}
if t.Failed() {
t.FailNow()
}
- handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep)
+ handleIPv6Payload(buffer.View(icmpH), test.hopLimit, test.atomicFragment, ep)
// Rx count of the NDP packet should have increased.
if got := typStat.Value(); got != 1 {
@@ -932,18 +894,18 @@ func TestNDPValidation(t *testing.T) {
want = 1
}
if got := invalid.Value(); got != want {
- t.Errorf("got invalid = %d, want = %d", got, want)
+ t.Errorf("got invalid.Value() = %d, want = %d", got, want)
}
want = 0
if test.valid && !isRouter && typ.routerOnly {
- // RouterOnlyPacketsReceivedByHost count should have increased.
+ // Router only packets are expected to be dropped when operating
+ // as a host.
want = 1
}
if got := routerOnly.Value(); got != want {
- t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want)
+ t.Errorf("got routerOnly.Value() = %d, want = %d", got, want)
}
-
})
}
})
@@ -1279,8 +1241,9 @@ func TestCheckDuplicateAddress(t *testing.T) {
var secureRNG bytes.Reader
secureRNG.Reset(secureRNGBytes[:])
s := stack.New(stack.Options{
- SecureRNG: &secureRNG,
- Clock: clock,
+ Clock: clock,
+ RandSource: rand.NewSource(time.Now().UnixNano()),
+ SecureRNG: &secureRNG,
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{
DADConfigs: dadConfigs,
})},
@@ -1297,7 +1260,8 @@ func TestCheckDuplicateAddress(t *testing.T) {
snmc := header.SolicitedNodeAddr(lladdr0)
remoteLinkAddr := header.EthernetAddressFromMulticastIPv6Address(snmc)
checkDADMsg := func() {
- p, ok := e.ReadContext(context.Background())
+ clock.RunImmediatelyScheduledJobs()
+ p, ok := e.Read()
if !ok {
t.Fatalf("expected %d-th DAD message", dadPacketsSent)
}
@@ -1345,7 +1309,7 @@ func TestCheckDuplicateAddress(t *testing.T) {
t.Fatalf("RemoveAddress(%d, %s): %s", nicID, lladdr0, err)
}
// Should not restart DAD since we already requested DAD above - the handler
- // should be called when the original request compeletes so we should not send
+ // should be called when the original request completes so we should not send
// an extra DAD message here.
dadRequestsMade++
if res, err := s.CheckDuplicateAddress(nicID, ProtocolNumber, lladdr0, func(r stack.DADResult) {
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index b5b013b64..854d6a6ba 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -101,7 +101,7 @@ func (dc destToCounter) intersectionFlags(res Reservation) (BitFlags, int) {
// Wildcard destinations affect all destinations for TupleOnly.
if dest.addr == anyIPAddress || res.Dest.Addr == anyIPAddress {
// Only bitwise and the TupleOnlyFlag.
- intersection &= ((^TupleOnlyFlag) | counter.SharedFlags())
+ intersection &= (^TupleOnlyFlag) | counter.SharedFlags()
count++
}
}
@@ -238,13 +238,13 @@ type PortTester func(port uint16) (good bool, err tcpip.Error)
// possible ephemeral ports, allowing the caller to decide whether a given port
// is suitable for its needs, and stopping when a port is found or an error
// occurs.
-func (pm *PortManager) PickEphemeralPort(testPort PortTester) (port uint16, err tcpip.Error) {
+func (pm *PortManager) PickEphemeralPort(rng *rand.Rand, testPort PortTester) (port uint16, err tcpip.Error) {
pm.ephemeralMu.RLock()
firstEphemeral := pm.firstEphemeral
numEphemeral := pm.numEphemeral
pm.ephemeralMu.RUnlock()
- offset := uint32(rand.Int31n(int32(numEphemeral)))
+ offset := uint32(rng.Int31n(int32(numEphemeral)))
return pickEphemeralPort(offset, firstEphemeral, numEphemeral, testPort)
}
@@ -303,7 +303,7 @@ func pickEphemeralPort(offset uint32, first, count uint16, testPort PortTester)
// An optional PortTester can be passed in which if provided will be used to
// test if the picked port can be used. The function should return true if the
// port is safe to use, false otherwise.
-func (pm *PortManager) ReservePort(res Reservation, testPort PortTester) (reservedPort uint16, err tcpip.Error) {
+func (pm *PortManager) ReservePort(rng *rand.Rand, res Reservation, testPort PortTester) (reservedPort uint16, err tcpip.Error) {
pm.mu.Lock()
defer pm.mu.Unlock()
@@ -328,7 +328,7 @@ func (pm *PortManager) ReservePort(res Reservation, testPort PortTester) (reserv
}
// A port wasn't specified, so try to find one.
- return pm.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) {
+ return pm.PickEphemeralPort(rng, func(p uint16) (bool, tcpip.Error) {
res.Port = p
if !pm.reserveSpecificPortLocked(res) {
return false, nil
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
index 6c4fb8c68..a91b130df 100644
--- a/pkg/tcpip/ports/ports_test.go
+++ b/pkg/tcpip/ports/ports_test.go
@@ -18,6 +18,7 @@ import (
"math"
"math/rand"
"testing"
+ "time"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -331,6 +332,7 @@ func TestPortReservation(t *testing.T) {
t.Run(test.tname, func(t *testing.T) {
pm := NewPortManager()
net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber}
+ rng := rand.New(rand.NewSource(time.Now().UnixNano()))
for _, test := range test.actions {
first, _ := pm.PortRange()
@@ -356,7 +358,7 @@ func TestPortReservation(t *testing.T) {
BindToDevice: test.device,
Dest: test.dest,
}
- gotPort, err := pm.ReservePort(portRes, nil /* testPort */)
+ gotPort, err := pm.ReservePort(rng, portRes, nil /* testPort */)
if diff := cmp.Diff(test.want, err); diff != "" {
t.Fatalf("unexpected error from ReservePort(%+v, _), (-want, +got):\n%s", portRes, diff)
}
@@ -417,10 +419,11 @@ func TestPickEphemeralPort(t *testing.T) {
} {
t.Run(test.name, func(t *testing.T) {
pm := NewPortManager()
+ rng := rand.New(rand.NewSource(time.Now().UnixNano()))
if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil {
t.Fatalf("failed to set ephemeral port range: %s", err)
}
- port, err := pm.PickEphemeralPort(test.f)
+ port, err := pm.PickEphemeralPort(rng, test.f)
if diff := cmp.Diff(test.wantErr, err); diff != "" {
t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff)
}
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index b26936b7f..0ea85f9ed 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -222,7 +222,7 @@ type SocketOptions struct {
getReceiveBufferLimits GetReceiveBufferLimits `state:"manual"`
// receiveBufferSize determines the receive buffer size for this socket.
- receiveBufferSize int64
+ receiveBufferSize atomicbitops.AlignedAtomicInt64
// mu protects the access to the below fields.
mu sync.Mutex `state:"nosave"`
@@ -601,9 +601,10 @@ func (so *SocketOptions) GetBindToDevice() int32 {
return atomic.LoadInt32(&so.bindToDevice)
}
-// SetBindToDevice sets value for SO_BINDTODEVICE option.
+// SetBindToDevice sets value for SO_BINDTODEVICE option. If bindToDevice is
+// zero, the socket device binding is removed.
func (so *SocketOptions) SetBindToDevice(bindToDevice int32) Error {
- if !so.handler.HasNIC(bindToDevice) {
+ if bindToDevice != 0 && !so.handler.HasNIC(bindToDevice) {
return &ErrUnknownDevice{}
}
@@ -653,13 +654,13 @@ func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) {
// GetReceiveBufferSize gets value for SO_RCVBUF option.
func (so *SocketOptions) GetReceiveBufferSize() int64 {
- return atomic.LoadInt64(&so.receiveBufferSize)
+ return so.receiveBufferSize.Load()
}
// SetReceiveBufferSize sets value for SO_RCVBUF option.
func (so *SocketOptions) SetReceiveBufferSize(receiveBufferSize int64, notify bool) {
if !notify {
- atomic.StoreInt64(&so.receiveBufferSize, receiveBufferSize)
+ so.receiveBufferSize.Store(receiveBufferSize)
return
}
@@ -684,8 +685,8 @@ func (so *SocketOptions) SetReceiveBufferSize(receiveBufferSize int64, notify bo
v = math.MaxInt32
}
- oldSz := atomic.LoadInt64(&so.receiveBufferSize)
+ oldSz := so.receiveBufferSize.Load()
// Notify endpoint about change in buffer size.
newSz := so.handler.OnSetReceiveBufferSize(v, oldSz)
- atomic.StoreInt64(&so.receiveBufferSize, newSz)
+ so.receiveBufferSize.Store(newSz)
}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 84aa6a9e4..395ff9a07 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -56,6 +56,7 @@ go_library(
"neighbor_entry_list.go",
"neighborstate_string.go",
"nic.go",
+ "nic_stats.go",
"nud.go",
"packet_buffer.go",
"packet_buffer_list.go",
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 5720e7543..18e0d4374 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -35,7 +35,6 @@ import (
// Currently, only TCP tracking is supported.
// Our hash table has 16K buckets.
-// TODO(gvisor.dev/issue/170): These should be tunable.
const numBuckets = 1 << 14
// Direction of the tuple.
@@ -125,6 +124,8 @@ type conn struct {
tcb tcpconntrack.TCB
// lastUsed is the last time the connection saw a relevant packet, and
// is updated by each packet on the connection. It is protected by mu.
+ //
+ // TODO(gvisor.dev/issue/5939): do not use the ambient clock.
lastUsed time.Time `state:".(unixTime)"`
}
@@ -163,8 +164,6 @@ func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
// Update the state of tcb. tcb assumes it's always initialized on the
// client. However, we only need to know whether the connection is
// established or not, so the client/server distinction isn't important.
- // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle
- // other tcp states.
if cn.tcb.IsEmpty() {
cn.tcb.Init(tcpHeader)
} else if hook == cn.tcbHook {
@@ -244,8 +243,7 @@ func (ct *ConnTrack) init() {
// connFor gets the conn for pkt if it exists, or returns nil
// if it does not. It returns an error when pkt does not contain a valid TCP
// header.
-// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support
-// other transport protocols.
+// TODO(gvisor.dev/issue/6168): Support UDP.
func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) {
tid, err := packetToTupleID(pkt)
if err != nil {
@@ -383,7 +381,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
return false
}
- // TODO(gvisor.dev/issue/170): Support other transport protocols.
+ // TODO(gvisor.dev/issue/6168): Support UDP.
if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
return false
}
@@ -464,8 +462,6 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
}
// Update the state of tcb.
- // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle
- // other tcp states.
conn.mu.Lock()
defer conn.mu.Unlock()
@@ -542,8 +538,6 @@ func (ct *ConnTrack) bucket(id tupleID) int {
// reapUnused returns the next bucket that should be checked and the time after
// which it should be called again.
func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) {
- // TODO(gvisor.dev/issue/170): This can be more finely controlled, as
- // it is in Linux via sysctl.
const fractionPerReaping = 128
const maxExpiredPct = 50
const maxFullTraversal = 60 * time.Second
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index ff555722e..72f66441f 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -54,6 +55,11 @@ type fwdTestNetworkEndpoint struct {
nic NetworkInterface
proto *fwdTestNetworkProtocol
dispatcher TransportDispatcher
+
+ mu struct {
+ sync.RWMutex
+ forwarding bool
+ }
}
func (*fwdTestNetworkEndpoint) Enable() tcpip.Error {
@@ -109,10 +115,6 @@ func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen
}
-func (*fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
- return 0
-}
-
func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return f.proto.Number()
}
@@ -129,7 +131,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, params NetworkHeaderParam
}
// WritePackets implements LinkEndpoint.WritePackets.
-func (*fwdTestNetworkEndpoint) WritePackets(r *Route, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) {
+func (*fwdTestNetworkEndpoint) WritePackets(*Route, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) {
panic("not implemented")
}
@@ -169,11 +171,6 @@ type fwdTestNetworkProtocol struct {
addrResolveDelay time.Duration
onLinkAddressResolved func(*neighborCache, tcpip.Address, tcpip.LinkAddress)
onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool)
-
- mu struct {
- sync.RWMutex
- forwarding bool
- }
}
func (*fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
@@ -224,7 +221,7 @@ func (*fwdTestNetworkProtocol) Wait() {}
func (f *fwdTestNetworkEndpoint) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error {
if fn := f.proto.onLinkAddressResolved; fn != nil {
- time.AfterFunc(f.proto.addrResolveDelay, func() {
+ f.proto.stack.clock.AfterFunc(f.proto.addrResolveDelay, func() {
fn(f.proto.neigh, addr, remoteLinkAddr)
})
}
@@ -242,16 +239,16 @@ func (*fwdTestNetworkEndpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber
return fwdTestNetNumber
}
-// Forwarding implements stack.ForwardingNetworkProtocol.
-func (f *fwdTestNetworkProtocol) Forwarding() bool {
+// Forwarding implements stack.ForwardingNetworkEndpoint.
+func (f *fwdTestNetworkEndpoint) Forwarding() bool {
f.mu.RLock()
defer f.mu.RUnlock()
return f.mu.forwarding
}
-// SetForwarding implements stack.ForwardingNetworkProtocol.
-func (f *fwdTestNetworkProtocol) SetForwarding(v bool) {
+// SetForwarding implements stack.ForwardingNetworkEndpoint.
+func (f *fwdTestNetworkEndpoint) SetForwarding(v bool) {
f.mu.Lock()
defer f.mu.Unlock()
f.mu.forwarding = v
@@ -319,7 +316,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress {
return e.linkAddr
}
-func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
+func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
p := fwdTestPacketInfo{
RemoteLinkAddress: r.RemoteLinkAddress,
LocalLinkAddress: r.LocalLinkAddress,
@@ -354,17 +351,19 @@ func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
}
// AddHeader implements stack.LinkEndpoint.AddHeader.
-func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+func (e *fwdTestLinkEndpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *PacketBuffer) {
panic("not implemented")
}
-func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) {
+func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.ManualClock, *fwdTestLinkEndpoint, *fwdTestLinkEndpoint) {
+ clock := faketime.NewManualClock()
// Create a stack with the network protocol and two NICs.
s := New(Options{
NetworkProtocols: []NetworkProtocolFactory{func(s *Stack) NetworkProtocol {
proto.stack = s
return proto
}},
+ Clock: clock,
})
protoNum := proto.Number()
@@ -373,7 +372,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f
}
// NIC 1 has the link address "a", and added the network address 1.
- ep1 = &fwdTestLinkEndpoint{
+ ep1 := &fwdTestLinkEndpoint{
C: make(chan fwdTestPacketInfo, 300),
mtu: fwdTestNetDefaultMTU,
linkAddr: "a",
@@ -386,7 +385,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f
}
// NIC 2 has the link address "b", and added the network address 2.
- ep2 = &fwdTestLinkEndpoint{
+ ep2 := &fwdTestLinkEndpoint{
C: make(chan fwdTestPacketInfo, 300),
mtu: fwdTestNetDefaultMTU,
linkAddr: "b",
@@ -416,7 +415,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f
s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: 2}})
}
- return ep1, ep2
+ return clock, ep1, ep2
}
func TestForwardingWithStaticResolver(t *testing.T) {
@@ -432,7 +431,7 @@ func TestForwardingWithStaticResolver(t *testing.T) {
},
}
- ep1, ep2 := fwdTestNetFactory(t, proto)
+ clock, ep1, ep2 := fwdTestNetFactory(t, proto)
// Inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
@@ -444,6 +443,7 @@ func TestForwardingWithStaticResolver(t *testing.T) {
var p fwdTestPacketInfo
+ clock.Advance(proto.addrResolveDelay)
select {
case p = <-ep2.C:
default:
@@ -475,7 +475,7 @@ func TestForwardingWithFakeResolver(t *testing.T) {
})
},
}
- ep1, ep2 := fwdTestNetFactory(t, &proto)
+ clock, ep1, ep2 := fwdTestNetFactory(t, &proto)
// Inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
@@ -487,9 +487,10 @@ func TestForwardingWithFakeResolver(t *testing.T) {
var p fwdTestPacketInfo
+ clock.Advance(proto.addrResolveDelay)
select {
case p = <-ep2.C:
- case <-time.After(time.Second):
+ default:
t.Fatal("packet not forwarded")
}
@@ -508,7 +509,7 @@ func TestForwardingWithNoResolver(t *testing.T) {
// Whether or not we use the neighbor cache here does not matter since
// neither linkAddrCache nor neighborCache will be used.
- ep1, ep2 := fwdTestNetFactory(t, proto)
+ clock, ep1, ep2 := fwdTestNetFactory(t, proto)
// inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
@@ -518,10 +519,11 @@ func TestForwardingWithNoResolver(t *testing.T) {
Data: buf.ToVectorisedView(),
}))
+ clock.Advance(proto.addrResolveDelay)
select {
case <-ep2.C:
t.Fatal("Packet should not be forwarded")
- case <-time.After(time.Second):
+ default:
}
}
@@ -533,7 +535,7 @@ func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) {
},
}
- ep1, ep2 := fwdTestNetFactory(t, proto)
+ clock, ep1, ep2 := fwdTestNetFactory(t, proto)
const numPackets int = 5
// These packets will all be enqueued in the packet queue to wait for link
@@ -547,12 +549,12 @@ func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) {
}
// All packets should fail resolution.
- // TODO(gvisor.dev/issue/5141): Use a fake clock.
for i := 0; i < numPackets; i++ {
+ clock.Advance(proto.addrResolveDelay)
select {
case got := <-ep2.C:
t.Fatalf("got %#v; packets should have failed resolution and not been forwarded", got)
- case <-time.After(100 * time.Millisecond):
+ default:
}
}
}
@@ -576,7 +578,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
}
},
}
- ep1, ep2 := fwdTestNetFactory(t, &proto)
+ clock, ep1, ep2 := fwdTestNetFactory(t, &proto)
// Inject an inbound packet to address 4 on NIC 1. This packet should
// not be forwarded.
@@ -596,9 +598,10 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
var p fwdTestPacketInfo
+ clock.Advance(proto.addrResolveDelay)
select {
case p = <-ep2.C:
- case <-time.After(time.Second):
+ default:
t.Fatal("packet not forwarded")
}
@@ -631,7 +634,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
})
},
}
- ep1, ep2 := fwdTestNetFactory(t, &proto)
+ clock, ep1, ep2 := fwdTestNetFactory(t, &proto)
// Inject two inbound packets to address 3 on NIC 1.
for i := 0; i < 2; i++ {
@@ -645,9 +648,10 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
for i := 0; i < 2; i++ {
var p fwdTestPacketInfo
+ clock.Advance(proto.addrResolveDelay)
select {
case p = <-ep2.C:
- case <-time.After(time.Second):
+ default:
t.Fatal("packet not forwarded")
}
@@ -681,7 +685,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
})
},
}
- ep1, ep2 := fwdTestNetFactory(t, &proto)
+ clock, ep1, ep2 := fwdTestNetFactory(t, &proto)
for i := 0; i < maxPendingPacketsPerResolution+5; i++ {
// Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1.
@@ -697,9 +701,10 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
for i := 0; i < maxPendingPacketsPerResolution; i++ {
var p fwdTestPacketInfo
+ clock.Advance(proto.addrResolveDelay)
select {
case p = <-ep2.C:
- case <-time.After(time.Second):
+ default:
t.Fatal("packet not forwarded")
}
@@ -745,7 +750,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
})
},
}
- ep1, ep2 := fwdTestNetFactory(t, &proto)
+ clock, ep1, ep2 := fwdTestNetFactory(t, &proto)
for i := 0; i < maxPendingResolutions+5; i++ {
// Inject inbound 'maxPendingResolutions + 5' packets on NIC 1.
@@ -761,9 +766,10 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
for i := 0; i < maxPendingResolutions; i++ {
var p fwdTestPacketInfo
+ clock.Advance(proto.addrResolveDelay)
select {
case p = <-ep2.C:
- case <-time.After(time.Second):
+ default:
t.Fatal("packet not forwarded")
}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 3670d5995..f152c0d83 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -42,7 +42,7 @@ const reaperDelay = 5 * time.Second
// DefaultTables returns a default set of tables. Each chain is set to accept
// all packets.
-func DefaultTables() *IPTables {
+func DefaultTables(seed uint32) *IPTables {
return &IPTables{
v4Tables: [NumTables]Table{
NATID: {
@@ -182,7 +182,7 @@ func DefaultTables() *IPTables {
Postrouting: {MangleID, NATID},
},
connections: ConnTrack{
- seed: generateRandUint32(),
+ seed: seed,
},
reaperDone: make(chan struct{}, 1),
}
@@ -268,10 +268,6 @@ const (
// should continue traversing the network stack and false when it should be
// dropped.
//
-// TODO(gvisor.dev/issue/170): PacketBuffer should hold the route, from
-// which address can be gathered. Currently, address is only needed for
-// prerouting.
-//
// Precondition: pkt.NetworkHeader is set.
func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool {
if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber {
@@ -371,6 +367,7 @@ func (it *IPTables) startReaper(interval time.Duration) {
select {
case <-it.reaperDone:
return
+ // TODO(gvisor.dev/issue/5939): do not use the ambient clock.
case <-time.After(interval):
bucket, interval = it.connections.reapUnused(bucket, interval)
}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 2812c89aa..91e266de8 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -87,9 +87,6 @@ func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Addre
// destination port/IP. Outgoing packets are redirected to the loopback device,
// and incoming packets are redirected to the incoming interface (rather than
// forwarded).
-//
-// TODO(gvisor.dev/issue/170): Other flags need to be added after we support
-// them.
type RedirectTarget struct {
// Port indicates port used to redirect. It is immutable.
Port uint16
@@ -100,9 +97,6 @@ type RedirectTarget struct {
}
// Action implements Target.Action.
-// TODO(gvisor.dev/issue/170): Parse headers without copying. The current
-// implementation only works for Prerouting and calls pkt.Clone(), neither
-// of which should be the case.
func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) {
// Sanity check.
if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
@@ -136,8 +130,6 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r
panic("redirect target is supported only on output and prerouting hooks")
}
- // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if
- // we need to change dest address (for OUTPUT chain) or ports.
switch protocol := pkt.TransportProtocolNumber; protocol {
case header.UDPProtocolNumber:
udpHeader := header.UDP(pkt.TransportHeader().View())
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 93592e7f5..66e5f22ac 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -242,7 +242,6 @@ type IPHeaderFilter struct {
func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicName string) bool {
// Extract header fields.
var (
- // TODO(gvisor.dev/issue/170): Support other filter fields.
transProto tcpip.TransportProtocolNumber
dstAddr tcpip.Address
srcAddr tcpip.Address
@@ -291,7 +290,6 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicNa
return true
case Postrouting:
- // TODO(gvisor.dev/issue/170): Add the check for POSTROUTING.
return true
default:
panic(fmt.Sprintf("unknown hook: %d", hook))
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index d4ac9e1f8..b5c6626d6 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -16,14 +16,14 @@ package stack_test
import (
"bytes"
- "context"
"encoding/binary"
"fmt"
+ "math/rand"
"testing"
"time"
"github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/rand"
+ cryptorand "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -112,11 +112,12 @@ type ndpDADEvent struct {
res stack.DADResult
}
-type ndpRouterEvent struct {
- nicID tcpip.NICID
- addr tcpip.Address
- // true if router was discovered, false if invalidated.
- discovered bool
+type ndpOffLinkRouteEvent struct {
+ nicID tcpip.NICID
+ subnet tcpip.Subnet
+ router tcpip.Address
+ // true if route was updated, false if invalidated.
+ updated bool
}
type ndpPrefixEvent struct {
@@ -167,10 +168,8 @@ var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil)
// related events happen for test purposes.
type ndpDispatcher struct {
dadC chan ndpDADEvent
- routerC chan ndpRouterEvent
- rememberRouter bool
+ offLinkRouteC chan ndpOffLinkRouteEvent
prefixC chan ndpPrefixEvent
- rememberPrefix bool
autoGenAddrC chan ndpAutoGenAddrEvent
rdnssC chan ndpRDNSSEvent
dnsslC chan ndpDNSSLEvent
@@ -189,32 +188,32 @@ func (n *ndpDispatcher) OnDuplicateAddressDetectionResult(nicID tcpip.NICID, add
}
}
-// Implements ipv6.NDPDispatcher.OnDefaultRouterDiscovered.
-func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool {
- if c := n.routerC; c != nil {
- c <- ndpRouterEvent{
+// Implements ipv6.NDPDispatcher.OnOffLinkRouteUpdated.
+func (n *ndpDispatcher) OnOffLinkRouteUpdated(nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address) {
+ if c := n.offLinkRouteC; c != nil {
+ c <- ndpOffLinkRouteEvent{
nicID,
- addr,
+ subnet,
+ router,
true,
}
}
-
- return n.rememberRouter
}
-// Implements ipv6.NDPDispatcher.OnDefaultRouterInvalidated.
-func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) {
- if c := n.routerC; c != nil {
- c <- ndpRouterEvent{
+// Implements ipv6.NDPDispatcher.OnOffLinkRouteInvalidated.
+func (n *ndpDispatcher) OnOffLinkRouteInvalidated(nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address) {
+ if c := n.offLinkRouteC; c != nil {
+ c <- ndpOffLinkRouteEvent{
nicID,
- addr,
+ subnet,
+ router,
false,
}
}
}
// Implements ipv6.NDPDispatcher.OnOnLinkPrefixDiscovered.
-func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool {
+func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) {
if c := n.prefixC; c != nil {
c <- ndpPrefixEvent{
nicID,
@@ -222,8 +221,6 @@ func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip
true,
}
}
-
- return n.rememberPrefix
}
// Implements ipv6.NDPDispatcher.OnOnLinkPrefixInvalidated.
@@ -237,7 +234,7 @@ func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpi
}
}
-func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool {
+func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) {
if c := n.autoGenAddrC; c != nil {
c <- ndpAutoGenAddrEvent{
nicID,
@@ -245,7 +242,6 @@ func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWi
newAddr,
}
}
- return true
}
func (n *ndpDispatcher) OnAutoGenAddressDeprecated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) {
@@ -481,13 +477,9 @@ func TestDADResolve(t *testing.T) {
}
for _, test := range tests {
- test := test
-
t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent),
+ dadC: make(chan ndpDADEvent, 1),
}
e := channelLinkWithHeaderLength{
@@ -499,8 +491,11 @@ func TestDADResolve(t *testing.T) {
var secureRNG bytes.Reader
secureRNG.Reset(secureRNGBytes)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
- SecureRNG: &secureRNG,
+ Clock: clock,
+ RandSource: rand.NewSource(time.Now().UnixNano()),
+ SecureRNG: &secureRNG,
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPDisp: &ndpDisp,
DADConfigs: stack.DADConfigurations{
@@ -529,14 +524,10 @@ func TestDADResolve(t *testing.T) {
t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
}
- // Address should not be considered bound to the NIC yet (DAD ongoing).
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
- t.Fatal(err)
- }
-
// Make sure the address does not resolve before the resolution time has
// passed.
- time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncNegativeEventTimeout)
+ const delta = time.Nanosecond
+ clock.Advance(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta)
if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
t.Error(err)
}
@@ -566,13 +557,14 @@ func TestDADResolve(t *testing.T) {
}
// Wait for DAD to resolve.
+ clock.Advance(delta)
select {
- case <-time.After(defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr1, &stack.DADSucceeded{}); diff != "" {
t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
+ default:
+ t.Fatalf("expected DAD event for %s on NIC(%d)", addr1, nicID)
}
if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
t.Error(err)
@@ -610,7 +602,10 @@ func TestDADResolve(t *testing.T) {
// Validate the sent Neighbor Solicitation messages.
for i := uint8(0); i < test.dupAddrDetectTransmits; i++ {
- p, _ := e.ReadContext(context.Background())
+ p, ok := e.Read()
+ if !ok {
+ t.Fatal("packet didn't arrive")
+ }
// Make sure its an IPv6 packet.
if p.Proto != header.IPv6ProtocolNumber {
@@ -736,11 +731,13 @@ func TestDADFail(t *testing.T) {
dadConfigs.RetransmitTimer = time.Second * 2
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPDisp: &ndpDisp,
DADConfigs: dadConfigs,
})},
+ Clock: clock,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -766,16 +763,17 @@ func TestDADFail(t *testing.T) {
// Wait for DAD to fail and make sure the address did
// not get resolved.
+ clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer)
select {
- case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second):
- // If we don't get a failure event after the
- // expected resolution time + extra 1s buffer,
- // something is wrong.
- t.Fatal("timed out waiting for DAD failure")
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr1, &stack.DADDupAddrDetected{HolderLinkAddress: test.expectedHolderLinkAddress}); diff != "" {
t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
+ default:
+ // If we don't get a failure event after the
+ // expected resolution time + extra 1s buffer,
+ // something is wrong.
+ t.Fatal("timed out waiting for DAD failure")
}
if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
t.Fatal(err)
@@ -844,11 +842,13 @@ func TestDADStop(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPDisp: &ndpDisp,
DADConfigs: dadConfigs,
})},
+ Clock: clock,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
@@ -866,15 +866,16 @@ func TestDADStop(t *testing.T) {
test.stopFn(t, s)
// Wait for DAD to fail (since the address was removed during DAD).
+ clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer)
select {
- case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second):
- // If we don't get a failure event after the expected resolution
- // time + extra 1s buffer, something is wrong.
- t.Fatal("timed out waiting for DAD failure")
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr1, &stack.DADAborted{}); diff != "" {
t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
+ default:
+ // If we don't get a failure event after the expected resolution
+ // time + extra 1s buffer, something is wrong.
+ t.Fatal("timed out waiting for DAD failure")
}
if !test.skipFinalAddrCheck {
@@ -925,10 +926,12 @@ func TestSetNDPConfigurations(t *testing.T) {
dadC: make(chan ndpDADEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
})
expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) {
@@ -1007,28 +1010,23 @@ func TestSetNDPConfigurations(t *testing.T) {
t.Fatal(err)
}
- // Sleep until right (500ms before) before resolution to
- // make sure the address didn't resolve on NIC(1) yet.
- const delta = 500 * time.Millisecond
- time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta)
+ // Sleep until right before resolution to make sure the address didn't
+ // resolve on NIC(1) yet.
+ const delta = 1
+ clock.Advance(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta)
if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
t.Fatal(err)
}
// Wait for DAD to resolve.
+ clock.Advance(delta)
select {
- case <-time.After(2 * delta):
- // We should get a resolution event after 500ms
- // (delta) since we wait for 500ms less than the
- // expected resolution time above to make sure
- // that the address did not yet resolve. Waiting
- // for 1s (2x delta) without a resolution event
- // means something is wrong.
- t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID1, addr1, &stack.DADSucceeded{}); diff != "" {
t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
+ default:
+ t.Fatal("timed out waiting for DAD resolution")
}
if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil {
t.Fatal(err)
@@ -1040,7 +1038,7 @@ func TestSetNDPConfigurations(t *testing.T) {
// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options
// and DHCPv6 configurations specified.
func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) *stack.PacketBuffer {
- icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length())
+ icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + optSer.Length()
hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
pkt.SetType(header.ICMPv6RouterAdvert)
@@ -1053,13 +1051,13 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
if managedAddress {
// The Managed Addresses flag field is the 7th bit of byte #1 (0-indexing)
// of the RA payload.
- raPayload[1] |= (1 << 7)
+ raPayload[1] |= 1 << 7
}
// Populate the Other Configurations flag field.
if otherConfigurations {
// The Other Configurations flag field is the 6th bit of byte #1
// (0-indexing) of the RA payload.
- raPayload[1] |= (1 << 6)
+ raPayload[1] |= 1 << 6
}
opts := ra.Options()
opts.Serialize(optSer)
@@ -1203,9 +1201,9 @@ func TestDynamicConfigurationsDisabled(t *testing.T) {
t.Run(fmt.Sprintf("HandleRAs(%s), Forwarding(%t), Enabled(%t)", handle, forwarding, enable), func(t *testing.T) {
ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- prefixC: make(chan ndpPrefixEvent, 1),
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1),
+ prefixC: make(chan ndpPrefixEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
ndpConfigs := test.config(enable)
ndpConfigs.HandleRAs = handle
@@ -1275,8 +1273,8 @@ func TestDynamicConfigurationsDisabled(t *testing.T) {
t.Errorf("got v6Stats.UnhandledRouterAdvertisements.Value() = %d, want = %d", got, want)
}
select {
- case e := <-ndpDisp.routerC:
- t.Errorf("unexpectedly discovered a router when configured not to: %#v", e)
+ case e := <-ndpDisp.offLinkRouteC:
+ t.Errorf("unexpectedly updated an off-link route when configured not to: %#v", e)
default:
}
select {
@@ -1303,9 +1301,9 @@ func boolToUint64(v bool) uint64 {
}
// Check e to make sure that the event is for addr on nic with ID 1, and the
-// discovered flag set to discovered.
-func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string {
- return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e))
+// update flag set to updated.
+func checkOffLinkRouteEvent(e ndpOffLinkRouteEvent, router tcpip.Address, updated bool) string {
+ return cmp.Diff(ndpOffLinkRouteEvent{nicID: 1, subnet: header.IPv6EmptySubnet, router: router, updated: updated}, e, cmp.AllowUnexported(e))
}
func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, bool)) {
@@ -1338,56 +1336,13 @@ func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, b
}
}
-// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not
-// remember a discovered router when the dispatcher asks it not to.
-func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
- ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- DiscoverDefaultRouters: true,
- },
- NDPDisp: &ndpDisp,
- })},
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Receive an RA for a router we should not remember.
- const lifetimeSeconds = 1
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, lifetimeSeconds))
- select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, llAddr2, true); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected router discovery event")
- }
-
- // Wait for the invalidation time plus some buffer to make sure we do
- // not actually receive any invalidation events as we should not have
- // remembered the router in the first place.
- select {
- case <-ndpDisp.routerC:
- t.Fatal("should not have received any router events")
- case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
- }
-}
-
func TestRouterDiscovery(t *testing.T) {
testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) {
ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- rememberRouter: true,
+ offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
@@ -1396,30 +1351,32 @@ func TestRouterDiscovery(t *testing.T) {
},
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
})
- expectRouterEvent := func(addr tcpip.Address, discovered bool) {
+ expectOffLinkRouteEvent := func(addr tcpip.Address, updated bool) {
t.Helper()
select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, addr, discovered); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ case e := <-ndpDisp.offLinkRouteC:
+ if diff := checkOffLinkRouteEvent(e, addr, updated); diff != "" {
+ t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
}
default:
t.Fatal("expected router discovery event")
}
}
- expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) {
+ expectAsyncOffLinkRouteInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) {
t.Helper()
+ clock.Advance(timeout)
select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, addr, false); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ case e := <-ndpDisp.offLinkRouteC:
+ if diff := checkOffLinkRouteEvent(e, addr, false); diff != "" {
+ t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(timeout):
+ default:
t.Fatal("timed out waiting for router discovery event")
}
}
@@ -1436,26 +1393,26 @@ func TestRouterDiscovery(t *testing.T) {
// remembered.
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
select {
- case <-ndpDisp.routerC:
- t.Fatal("unexpectedly discovered a router with 0 lifetime")
+ case <-ndpDisp.offLinkRouteC:
+ t.Fatal("unexpectedly updated an off-link route with 0 lifetime")
default:
}
// Rx an RA from lladdr2 with a huge lifetime.
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
- expectRouterEvent(llAddr2, true)
+ expectOffLinkRouteEvent(llAddr2, true)
// Rx an RA from another router (lladdr3) with non-zero lifetime.
const l3LifetimeSeconds = 6
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds))
- expectRouterEvent(llAddr3, true)
+ expectOffLinkRouteEvent(llAddr3, true)
// Rx an RA from lladdr2 with lesser lifetime.
const l2LifetimeSeconds = 2
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds))
select {
- case <-ndpDisp.routerC:
- t.Fatal("Should not receive a router event when updating lifetimes for known routers")
+ case <-ndpDisp.offLinkRouteC:
+ t.Fatal("should not receive a off-link route event when updating lifetimes for known routers")
default:
}
@@ -1466,15 +1423,15 @@ func TestRouterDiscovery(t *testing.T) {
// Wait for the normal lifetime plus an extra bit for the
// router to get invalidated. If we don't get an invalidation
// event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
+ expectAsyncOffLinkRouteInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second)
// Rx an RA from lladdr2 with huge lifetime.
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
- expectRouterEvent(llAddr2, true)
+ expectOffLinkRouteEvent(llAddr2, true)
// Rx an RA from lladdr2 with zero lifetime. It should be invalidated.
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
- expectRouterEvent(llAddr2, false)
+ expectOffLinkRouteEvent(llAddr2, false)
// Wait for lladdr3's router invalidation job to execute. The lifetime
// of the router should have been updated to the most recent (smaller)
@@ -1483,7 +1440,7 @@ func TestRouterDiscovery(t *testing.T) {
// Wait for the normal lifetime plus an extra bit for the
// router to get invalidated. If we don't get an invalidation
// event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
+ expectAsyncOffLinkRouteInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second)
})
}
@@ -1491,8 +1448,7 @@ func TestRouterDiscovery(t *testing.T) {
// ipv6.MaxDiscoveredDefaultRouters discovered routers are remembered.
func TestRouterDiscoveryMaxRouters(t *testing.T) {
ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- rememberRouter: true,
+ offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
@@ -1519,9 +1475,9 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
if i <= ipv6.MaxDiscoveredDefaultRouters {
select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, llAddr, true); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ case e := <-ndpDisp.offLinkRouteC:
+ if diff := checkOffLinkRouteEvent(e, llAddr, true); diff != "" {
+ t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
}
default:
t.Fatal("expected router discovery event")
@@ -1529,7 +1485,7 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
} else {
select {
- case <-ndpDisp.routerC:
+ case <-ndpDisp.offLinkRouteC:
t.Fatal("should not have discovered a new router after we already discovered the max number of routers")
default:
}
@@ -1543,51 +1499,6 @@ func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) st
return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e))
}
-// TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not
-// remember a discovered on-link prefix when the dispatcher asks it not to.
-func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
- prefix, subnet, _ := prefixSubnetAddr(0, "")
-
- ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- DiscoverOnLinkPrefixes: true,
- },
- NDPDisp: &ndpDisp,
- })},
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Receive an RA with prefix that we should not remember.
- const lifetimeSeconds = 1
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetimeSeconds, 0))
- select {
- case e := <-ndpDisp.prefixC:
- if diff := checkPrefixEvent(e, subnet, true); diff != "" {
- t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected prefix discovery event")
- }
-
- // Wait for the invalidation time plus some buffer to make sure we do
- // not actually receive any invalidation events as we should not have
- // remembered the prefix in the first place.
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("should not have received any prefix events")
- case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
- }
-}
-
func TestPrefixDiscovery(t *testing.T) {
prefix1, subnet1, _ := prefixSubnetAddr(0, "")
prefix2, subnet2, _ := prefixSubnetAddr(1, "")
@@ -1595,10 +1506,10 @@ func TestPrefixDiscovery(t *testing.T) {
testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) {
ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 1),
- rememberPrefix: true,
+ prefixC: make(chan ndpPrefixEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
@@ -1607,6 +1518,7 @@ func TestPrefixDiscovery(t *testing.T) {
},
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
})
if err := s.CreateNIC(1, e); err != nil {
@@ -1667,12 +1579,13 @@ func TestPrefixDiscovery(t *testing.T) {
// Wait for prefix2's most recent invalidation job plus some buffer to
// expire.
+ clock.Advance(time.Duration(lifetime) * time.Second)
select {
case e := <-ndpDisp.prefixC:
if diff := checkPrefixEvent(e, subnet2, false); diff != "" {
t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for prefix discovery event")
}
@@ -1701,10 +1614,10 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
subnet := prefix.Subnet()
ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 1),
- rememberPrefix: true,
+ prefixC: make(chan ndpPrefixEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
@@ -1713,6 +1626,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
},
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
})
if err := s.CreateNIC(1, e); err != nil {
@@ -1736,21 +1650,23 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
// with infinite valid lifetime which should not get invalidated.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0))
expectPrefixEvent(subnet, true)
+ clock.Advance(testInfiniteLifetime)
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
- case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout):
+ default:
}
// Receive an RA with finite lifetime.
// The prefix should get invalidated after 1s.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0))
+ clock.Advance(testInfiniteLifetime)
select {
case e := <-ndpDisp.prefixC:
if diff := checkPrefixEvent(e, subnet, false); diff != "" {
t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(testInfiniteLifetime):
+ default:
t.Fatal("timed out waiting for prefix discovery event")
}
@@ -1761,19 +1677,21 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
// Receive an RA with prefix with an infinite lifetime.
// The prefix should not be invalidated.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0))
+ clock.Advance(testInfiniteLifetime)
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
- case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout):
+ default:
}
// Receive an RA with a prefix with a lifetime value greater than the
// set infinite lifetime value.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0))
+ clock.Advance((testInfiniteLifetimeSeconds + 1) * time.Second)
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
- case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultAsyncNegativeEventTimeout):
+ default:
}
// Receive an RA with 0 lifetime.
@@ -1786,8 +1704,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
// ipv6.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered.
func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3),
- rememberPrefix: true,
+ prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3),
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
@@ -3618,6 +3535,7 @@ func TestAutoGenAddrRemoval(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
@@ -3626,6 +3544,7 @@ func TestAutoGenAddrRemoval(t *testing.T) {
},
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
})
if err := s.CreateNIC(1, e); err != nil {
@@ -3659,10 +3578,11 @@ func TestAutoGenAddrRemoval(t *testing.T) {
// Wait for the original valid lifetime to make sure the original job got
// cancelled/cleaned up.
+ clock.Advance(lifetimeSeconds * time.Second)
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly received an auto gen addr event")
- case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
+ default:
}
}
@@ -3784,6 +3704,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
@@ -3792,6 +3713,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
},
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
})
if err := s.CreateNIC(1, e); err != nil {
@@ -3821,30 +3743,36 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
// Should not get an invalidation event after the PI's invalidation
// time.
+ clock.Advance(lifetimeSeconds * time.Second)
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly received an auto gen addr event")
- case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
+ default:
}
if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
t.Fatalf("Should have %s in the list of addresses", addr1)
}
}
+func makeSecretKey(t *testing.T) []byte {
+ secretKey := make([]byte, header.OpaqueIIDSecretKeyMinBytes)
+ n, err := cryptorand.Read(secretKey)
+ if err != nil {
+ t.Fatalf("cryptorand.Read(_): %s", err)
+ }
+ if l := len(secretKey); n != l {
+ t.Fatalf("got cryptorand.Read(_) = (%d, nil), want = (%d, nil)", n, l)
+ }
+ return secretKey
+}
+
// TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use
// opaque interface identifiers when configured to do so.
func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
const nicID = 1
const nicName = "nic1"
- var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
- secretKey := secretKeyBuf[:]
- n, err := rand.Read(secretKey)
- if err != nil {
- t.Fatalf("rand.Read(_): %s", err)
- }
- if n != header.OpaqueIIDSecretKeyMinBytes {
- t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes)
- }
+
+ secretKey := makeSecretKey(t)
prefix1, subnet1, _ := prefixSubnetAddr(0, linkAddr1)
prefix2, subnet2, _ := prefixSubnetAddr(1, linkAddr1)
@@ -3866,6 +3794,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
@@ -3880,6 +3809,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
SecretKey: secretKey,
},
})},
+ Clock: clock,
})
opts := stack.NICOptions{Name: nicName}
if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
@@ -3918,12 +3848,13 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
}
// Wait for addr of prefix1 to be invalidated.
+ clock.Advance(validLifetimeSecondPrefix1 * time.Second)
select {
case e := <-ndpDisp.autoGenAddrC:
if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for addr auto gen event")
}
if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
@@ -3949,15 +3880,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
}()
ipv6.MaxDesyncFactor = time.Nanosecond
- var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
- secretKey := secretKeyBuf[:]
- n, err := rand.Read(secretKey)
- if err != nil {
- t.Fatalf("rand.Read(_): %s", err)
- }
- if n != header.OpaqueIIDSecretKeyMinBytes {
- t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes)
- }
+ secretKey := makeSecretKey(t)
prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1)
@@ -4236,13 +4159,12 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
addrType := addrType
t.Run(addrType.name, func(t *testing.T) {
- t.Parallel()
-
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent, 1),
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
}
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
AutoGenLinkLocal: addrType.autoGenLinkLocal,
@@ -4253,6 +4175,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
RetransmitTimer: retransmitTimer,
},
})},
+ Clock: clock,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -4297,7 +4220,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpectedly got an auto-generated address event = %+v", e)
- case <-time.After(defaultAsyncNegativeEventTimeout):
+ default:
}
})
}
@@ -4314,15 +4237,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
const maxRetries = 1
const lifetimeSeconds = 5
- var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
- secretKey := secretKeyBuf[:]
- n, err := rand.Read(secretKey)
- if err != nil {
- t.Fatalf("rand.Read(_): %s", err)
- }
- if n != header.OpaqueIIDSecretKeyMinBytes {
- t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes)
- }
+ secretKey := makeSecretKey(t)
prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1)
@@ -4331,6 +4246,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
}
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
DADConfigs: stack.DADConfigurations{
@@ -4350,6 +4266,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
SecretKey: secretKey,
},
})},
+ Clock: clock,
})
opts := stack.NICOptions{Name: nicName}
if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
@@ -4380,7 +4297,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
expectAutoGenAddrEvent(addr, newAddr)
// Simulate a DAD conflict after some time has passed.
- time.Sleep(failureTimer)
+ clock.Advance(failureTimer)
rxNDPSolicit(e, addr.Address)
expectAutoGenAddrEvent(addr, invalidatedAddr)
select {
@@ -4395,12 +4312,13 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
// Let the next address resolve.
addr.Address = tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 1, secretKey))
expectAutoGenAddrEvent(addr, newAddr)
+ clock.Advance(dadTransmits * retransmitTimer)
select {
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" {
t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for DAD event")
}
@@ -4414,6 +4332,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
//
// We expect either just the invalidation event or the deprecation event
// followed by the invalidation event.
+ clock.Advance(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer)
select {
case e := <-ndpDisp.autoGenAddrC:
if e.eventType == deprecatedAddr {
@@ -4426,7 +4345,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for invalidated auto gen addr event after deprecation")
}
} else {
@@ -4434,7 +4353,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
}
- case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for auto gen addr event")
}
}
@@ -4696,11 +4615,9 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) {
)
ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- rememberRouter: true,
- prefixC: make(chan ndpPrefixEvent, 1),
- rememberPrefix: true,
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1),
+ prefixC: make(chan ndpPrefixEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
@@ -4743,17 +4660,17 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) {
),
)
select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, llAddr3, true /* discovered */); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ case e := <-ndpDisp.offLinkRouteC:
+ if diff := checkOffLinkRouteEvent(e, llAddr3, true /* discovered */); diff != "" {
+ t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
}
default:
- t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID)
+ t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID)
}
select {
case e := <-ndpDisp.prefixC:
if diff := checkPrefixEvent(e, subnet, true /* discovered */); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
}
default:
t.Errorf("expected prefix event for %s on NIC(%d)", prefix, nicID)
@@ -4775,8 +4692,8 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) {
t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
}
select {
- case e := <-ndpDisp.routerC:
- t.Errorf("unexpected router event = %#v", e)
+ case e := <-ndpDisp.offLinkRouteC:
+ t.Errorf("unexpected off-link route event = %#v", e)
default:
}
select {
@@ -4862,12 +4779,11 @@ func TestCleanupNDPState(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, maxRouterAndPrefixEvents),
- rememberRouter: true,
- prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents),
- rememberPrefix: true,
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents),
+ offLinkRouteC: make(chan ndpOffLinkRouteEvent, maxRouterAndPrefixEvents),
+ prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents),
}
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
AutoGenLinkLocal: true,
@@ -4879,16 +4795,17 @@ func TestCleanupNDPState(t *testing.T) {
},
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
})
- expectRouterEvent := func() (bool, ndpRouterEvent) {
+ expectOffLinkRouteEvent := func() (bool, ndpOffLinkRouteEvent) {
select {
- case e := <-ndpDisp.routerC:
+ case e := <-ndpDisp.offLinkRouteC:
return true, e
default:
}
- return false, ndpRouterEvent{}
+ return false, ndpOffLinkRouteEvent{}
}
expectPrefixEvent := func() (bool, ndpPrefixEvent) {
@@ -4933,8 +4850,8 @@ func TestCleanupNDPState(t *testing.T) {
// multiple addresses.
e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1)
+ if ok, _ := expectOffLinkRouteEvent(); !ok {
+ t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID1)
}
if ok, _ := expectPrefixEvent(); !ok {
t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1)
@@ -4944,8 +4861,8 @@ func TestCleanupNDPState(t *testing.T) {
}
e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1)
+ if ok, _ := expectOffLinkRouteEvent(); !ok {
+ t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr4, nicID1)
}
if ok, _ := expectPrefixEvent(); !ok {
t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1)
@@ -4955,8 +4872,8 @@ func TestCleanupNDPState(t *testing.T) {
}
e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2)
+ if ok, _ := expectOffLinkRouteEvent(); !ok {
+ t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID2)
}
if ok, _ := expectPrefixEvent(); !ok {
t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2)
@@ -4966,8 +4883,8 @@ func TestCleanupNDPState(t *testing.T) {
}
e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2)
+ if ok, _ := expectOffLinkRouteEvent(); !ok {
+ t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr4, nicID2)
}
if ok, _ := expectPrefixEvent(); !ok {
t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2)
@@ -5008,14 +4925,14 @@ func TestCleanupNDPState(t *testing.T) {
test.cleanupFn(t, s)
// Collect invalidation events after having NDP state cleaned up.
- gotRouterEvents := make(map[ndpRouterEvent]int)
+ gotOffLinkRouteEvents := make(map[ndpOffLinkRouteEvent]int)
for i := 0; i < maxRouterAndPrefixEvents; i++ {
- ok, e := expectRouterEvent()
+ ok, e := expectOffLinkRouteEvent()
if !ok {
- t.Errorf("expected %d router events after becoming a router; got = %d", maxRouterAndPrefixEvents, i)
+ t.Errorf("expected %d off-link route events after becoming a router; got = %d", maxRouterAndPrefixEvents, i)
break
}
- gotRouterEvents[e]++
+ gotOffLinkRouteEvents[e]++
}
gotPrefixEvents := make(map[ndpPrefixEvent]int)
for i := 0; i < maxRouterAndPrefixEvents; i++ {
@@ -5042,14 +4959,14 @@ func TestCleanupNDPState(t *testing.T) {
t.FailNow()
}
- expectedRouterEvents := map[ndpRouterEvent]int{
- {nicID: nicID1, addr: llAddr3, discovered: false}: 1,
- {nicID: nicID1, addr: llAddr4, discovered: false}: 1,
- {nicID: nicID2, addr: llAddr3, discovered: false}: 1,
- {nicID: nicID2, addr: llAddr4, discovered: false}: 1,
+ expectedOffLinkRouteEvents := map[ndpOffLinkRouteEvent]int{
+ {nicID: nicID1, subnet: header.IPv6EmptySubnet, router: llAddr3, updated: false}: 1,
+ {nicID: nicID1, subnet: header.IPv6EmptySubnet, router: llAddr4, updated: false}: 1,
+ {nicID: nicID2, subnet: header.IPv6EmptySubnet, router: llAddr3, updated: false}: 1,
+ {nicID: nicID2, subnet: header.IPv6EmptySubnet, router: llAddr4, updated: false}: 1,
}
- if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" {
- t.Errorf("router events mismatch (-want +got):\n%s", diff)
+ if diff := cmp.Diff(expectedOffLinkRouteEvents, gotOffLinkRouteEvents); diff != "" {
+ t.Errorf("off-link route events mismatch (-want +got):\n%s", diff)
}
expectedPrefixEvents := map[ndpPrefixEvent]int{
{nicID: nicID1, prefix: subnet1, discovered: false}: 1,
@@ -5111,10 +5028,10 @@ func TestCleanupNDPState(t *testing.T) {
// Should not get any more events (invalidation timers should have been
// cancelled when the NDP state was cleaned up).
- time.Sleep(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout)
+ clock.Advance(lifetimeSeconds * time.Second)
select {
- case <-ndpDisp.routerC:
- t.Error("unexpected router event")
+ case <-ndpDisp.offLinkRouteC:
+ t.Error("unexpected off-link route event")
default:
}
select {
@@ -5139,7 +5056,6 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
ndpDisp := ndpDispatcher{
dhcpv6ConfigurationC: make(chan ndpDHCPv6Event, 1),
- rememberRouter: true,
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
@@ -5241,6 +5157,23 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
expectNoDHCPv6Event()
}
+var _ rand.Source = (*savingRandSource)(nil)
+
+type savingRandSource struct {
+ s rand.Source
+
+ lastInt63 int64
+}
+
+func (d *savingRandSource) Int63() int64 {
+ i := d.s.Int63()
+ d.lastInt63 = i
+ return i
+}
+func (d *savingRandSource) Seed(seed int64) {
+ d.s.Seed(seed)
+}
+
// TestRouterSolicitation tests the initial Router Solicitations that are sent
// when a NIC newly becomes enabled.
func TestRouterSolicitation(t *testing.T) {
@@ -5407,6 +5340,9 @@ func TestRouterSolicitation(t *testing.T) {
t.Fatalf("unexpectedly got a packet = %#v", p)
}
}
+ randSource := savingRandSource{
+ s: rand.NewSource(time.Now().UnixNano()),
+ }
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
@@ -5416,8 +5352,10 @@ func TestRouterSolicitation(t *testing.T) {
MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
},
})},
- Clock: clock,
+ Clock: clock,
+ RandSource: &randSource,
})
+
if err := s.CreateNIC(nicID, &e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -5430,19 +5368,27 @@ func TestRouterSolicitation(t *testing.T) {
// Make sure each RS is sent at the right time.
remaining := test.maxRtrSolicit
- if remaining > 0 {
- waitForPkt(test.effectiveMaxRtrSolicitDelay)
+ if remaining != 0 {
+ maxRtrSolicitDelay := test.maxRtrSolicitDelay
+ if maxRtrSolicitDelay < 0 {
+ maxRtrSolicitDelay = ipv6.DefaultNDPConfigurations().MaxRtrSolicitationDelay
+ }
+ var actualRtrSolicitDelay time.Duration
+ if maxRtrSolicitDelay != 0 {
+ actualRtrSolicitDelay = time.Duration(randSource.lastInt63) % maxRtrSolicitDelay
+ }
+ waitForPkt(actualRtrSolicitDelay)
remaining--
}
subTest.afterFirstRS(t, s)
- for ; remaining > 0; remaining-- {
- if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
+ for ; remaining != 0; remaining-- {
+ if test.effectiveRtrSolicitInt != 0 {
waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond)
waitForPkt(time.Nanosecond)
} else {
- waitForPkt(test.effectiveRtrSolicitInt)
+ waitForPkt(0)
}
}
@@ -5538,12 +5484,11 @@ func TestStopStartSolicitingRouters(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
e := channel.New(maxRtrSolicitations, 1280, linkAddr1)
- waitForPkt := func(timeout time.Duration) {
+ waitForPkt := func(clock *faketime.ManualClock, timeout time.Duration) {
t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- p, ok := e.ReadContext(ctx)
+ clock.Advance(timeout)
+ p, ok := e.Read()
if !ok {
t.Fatal("timed out waiting for packet")
}
@@ -5557,6 +5502,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
checker.TTL(header.NDPHopLimit),
checker.NDPRS())
}
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
@@ -5566,6 +5512,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
MaxRtrSolicitationDelay: delay,
},
})},
+ Clock: clock,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -5573,13 +5520,11 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Stop soliciting routers.
test.stopFn(t, s, true /* first */)
- ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout)
- defer cancel()
- if _, ok := e.ReadContext(ctx); ok {
+ clock.Advance(delay)
+ if _, ok := e.Read(); ok {
// A single RS may have been sent before solicitations were stopped.
- ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout)
- defer cancel()
- if _, ok = e.ReadContext(ctx); ok {
+ clock.Advance(interval)
+ if _, ok = e.Read(); ok {
t.Fatal("should not have sent more than one RS message")
}
}
@@ -5587,9 +5532,8 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Stopping router solicitations after it has already been stopped should
// do nothing.
test.stopFn(t, s, false /* first */)
- ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout)
- defer cancel()
- if _, ok := e.ReadContext(ctx); ok {
+ clock.Advance(delay)
+ if _, ok := e.Read(); ok {
t.Fatal("unexpectedly got a packet after router solicitation has been stopepd")
}
@@ -5600,21 +5544,19 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Start soliciting routers.
test.startFn(t, s)
- waitForPkt(delay + defaultAsyncPositiveEventTimeout)
- waitForPkt(interval + defaultAsyncPositiveEventTimeout)
- waitForPkt(interval + defaultAsyncPositiveEventTimeout)
- ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout)
- defer cancel()
- if _, ok := e.ReadContext(ctx); ok {
+ waitForPkt(clock, delay)
+ waitForPkt(clock, interval)
+ waitForPkt(clock, interval)
+ clock.Advance(interval)
+ if _, ok := e.Read(); ok {
t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
}
// Starting router solicitations after it has already completed should do
// nothing.
test.startFn(t, s)
- ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout)
- defer cancel()
- if _, ok := e.ReadContext(ctx); ok {
+ clock.Advance(interval)
+ if _, ok := e.Read(); ok {
t.Fatal("unexpectedly got a packet after finishing router solicitations")
}
})
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
index 509f5ce5c..08857e1a9 100644
--- a/pkg/tcpip/stack/neighbor_cache.go
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -310,7 +310,7 @@ func (n *neighborCache) handleUpperLevelConfirmation(addr tcpip.Address) {
func (n *neighborCache) init(nic *nic, r LinkAddressResolver) {
*n = neighborCache{
nic: nic,
- state: NewNUDState(nic.stack.nudConfigs, nic.stack.randomGenerator),
+ state: NewNUDState(nic.stack.nudConfigs, nic.stack.clock, nic.stack.randomGenerator),
linkRes: r,
}
n.mu.Lock()
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index 9821a18d3..7de25fe37 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -15,8 +15,6 @@
package stack
import (
- "bytes"
- "encoding/binary"
"fmt"
"math"
"math/rand"
@@ -48,9 +46,6 @@ const (
// be sent to all nodes.
testEntryBroadcastAddr = tcpip.Address("broadcast")
- // testEntryLocalAddr is the source address of neighbor probes.
- testEntryLocalAddr = tcpip.Address("local_addr")
-
// testEntryBroadcastLinkAddr is a special link address sent back to
// multicast neighbor probes.
testEntryBroadcastLinkAddr = tcpip.LinkAddress("mac_broadcast")
@@ -95,7 +90,7 @@ func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, cl
randomGenerator: rng,
},
id: 1,
- stats: makeNICStats(),
+ stats: makeNICStats(tcpip.NICStats{}.FillIn()),
}, linkRes)
return linkRes
}
@@ -106,20 +101,24 @@ type testEntryStore struct {
entriesMap map[tcpip.Address]NeighborEntry
}
-func toAddress(i int) tcpip.Address {
- buf := new(bytes.Buffer)
- binary.Write(buf, binary.BigEndian, uint8(1))
- binary.Write(buf, binary.BigEndian, uint8(0))
- binary.Write(buf, binary.BigEndian, uint16(i))
- return tcpip.Address(buf.String())
+func toAddress(i uint16) tcpip.Address {
+ return tcpip.Address([]byte{
+ 1,
+ 0,
+ byte(i >> 8),
+ byte(i),
+ })
}
-func toLinkAddress(i int) tcpip.LinkAddress {
- buf := new(bytes.Buffer)
- binary.Write(buf, binary.BigEndian, uint8(1))
- binary.Write(buf, binary.BigEndian, uint8(0))
- binary.Write(buf, binary.BigEndian, uint32(i))
- return tcpip.LinkAddress(buf.String())
+func toLinkAddress(i uint16) tcpip.LinkAddress {
+ return tcpip.LinkAddress([]byte{
+ 1,
+ 0,
+ 0,
+ 0,
+ byte(i >> 8),
+ byte(i),
+ })
}
// newTestEntryStore returns a testEntryStore pre-populated with entries.
@@ -127,7 +126,7 @@ func newTestEntryStore() *testEntryStore {
store := &testEntryStore{
entriesMap: make(map[tcpip.Address]NeighborEntry),
}
- for i := 0; i < entryStoreSize; i++ {
+ for i := uint16(0); i < entryStoreSize; i++ {
addr := toAddress(i)
linkAddr := toLinkAddress(i)
@@ -140,15 +139,15 @@ func newTestEntryStore() *testEntryStore {
}
// size returns the number of entries in the store.
-func (s *testEntryStore) size() int {
+func (s *testEntryStore) size() uint16 {
s.mu.RLock()
defer s.mu.RUnlock()
- return len(s.entriesMap)
+ return uint16(len(s.entriesMap))
}
// entry returns the entry at index i. Returns an empty entry and false if i is
// out of bounds.
-func (s *testEntryStore) entry(i int) (NeighborEntry, bool) {
+func (s *testEntryStore) entry(i uint16) (NeighborEntry, bool) {
return s.entryByAddr(toAddress(i))
}
@@ -166,7 +165,7 @@ func (s *testEntryStore) entries() []NeighborEntry {
entries := make([]NeighborEntry, 0, len(s.entriesMap))
s.mu.RLock()
defer s.mu.RUnlock()
- for i := 0; i < entryStoreSize; i++ {
+ for i := uint16(0); i < entryStoreSize; i++ {
addr := toAddress(i)
if entry, ok := s.entriesMap[addr]; ok {
entries = append(entries, entry)
@@ -176,7 +175,7 @@ func (s *testEntryStore) entries() []NeighborEntry {
}
// set modifies the link addresses of an entry.
-func (s *testEntryStore) set(i int, linkAddr tcpip.LinkAddress) {
+func (s *testEntryStore) set(i uint16, linkAddr tcpip.LinkAddress) {
addr := toAddress(i)
s.mu.Lock()
defer s.mu.Unlock()
@@ -236,13 +235,6 @@ func (*testNeighborResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return 0
}
-type entryEvent struct {
- nicID tcpip.NICID
- address tcpip.Address
- linkAddr tcpip.LinkAddress
- state NeighborState
-}
-
func TestNeighborCacheGetConfig(t *testing.T) {
nudDisp := testNUDDispatcher{}
c := DefaultNUDConfigurations()
@@ -301,10 +293,10 @@ func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.Ma
EventType: entryTestRemoved,
NICID: 1,
Entry: NeighborEntry{
- Addr: removedEntry.Addr,
- LinkAddr: removedEntry.LinkAddr,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: removedEntry.Addr,
+ LinkAddr: removedEntry.LinkAddr,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
},
})
}
@@ -313,10 +305,10 @@ func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.Ma
EventType: entryTestAdded,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: "",
- State: Incomplete,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: "",
+ State: Incomplete,
+ UpdatedAt: clock.Now(),
},
})
@@ -347,10 +339,10 @@ func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.Ma
EventType: entryTestChanged,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -419,10 +411,10 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
EventType: entryTestRemoved,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -461,7 +453,7 @@ func newTestContext(c NUDConfigurations) testContext {
}
type overflowOptions struct {
- startAtEntryIndex int
+ startAtEntryIndex uint16
wantStaticEntries []NeighborEntry
}
@@ -500,12 +492,12 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
if !ok {
return fmt.Errorf("got c.linkRes.entries.entry(%d) = _, false, want = true", i)
}
- durationReachableNanos := int64(c.linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds()
+ durationReachableNanos := time.Duration(c.linkRes.entries.size()-i-1) * typicalLatency
wantEntry := NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAtNanos: c.clock.NowNanoseconds() - durationReachableNanos,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAt: c.clock.Now().Add(-durationReachableNanos),
}
wantUnorderedEntries = append(wantUnorderedEntries, wantEntry)
}
@@ -571,10 +563,10 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
EventType: entryTestRemoved,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAtNanos: c.clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAt: c.clock.Now(),
},
},
}
@@ -616,10 +608,10 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) {
EventType: entryTestAdded,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
- UpdatedAtNanos: c.clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ UpdatedAt: c.clock.Now(),
},
},
}
@@ -663,10 +655,10 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T)
EventType: entryTestAdded,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
- UpdatedAtNanos: c.clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ UpdatedAt: c.clock.Now(),
},
},
}
@@ -689,10 +681,10 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T)
EventType: entryTestChanged,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
- UpdatedAtNanos: c.clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ UpdatedAt: c.clock.Now(),
},
},
}
@@ -733,10 +725,10 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
EventType: entryTestAdded,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
- UpdatedAtNanos: c.clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ UpdatedAt: c.clock.Now(),
},
},
}
@@ -758,10 +750,10 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
EventType: entryTestRemoved,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
- UpdatedAtNanos: c.clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ UpdatedAt: c.clock.Now(),
},
},
}
@@ -814,20 +806,20 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
EventType: entryTestRemoved,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAtNanos: c.clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAt: c.clock.Now(),
},
},
{
EventType: entryTestAdded,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
- UpdatedAtNanos: c.clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ UpdatedAt: c.clock.Now(),
},
},
}
@@ -844,10 +836,10 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
startAtEntryIndex: 1,
wantStaticEntries: []NeighborEntry{
{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
- UpdatedAtNanos: c.clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ UpdatedAt: c.clock.Now(),
},
},
}
@@ -875,10 +867,10 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
t.Errorf("unexpected error from c.linkRes.neigh.entry(%s, \"\", nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Static,
- UpdatedAtNanos: c.clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Static,
+ UpdatedAt: c.clock.Now(),
}
if diff := cmp.Diff(want, e); diff != "" {
t.Errorf("c.linkRes.neigh.entry(%s, \"\", nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
@@ -890,10 +882,10 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
EventType: entryTestAdded,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Static,
- UpdatedAtNanos: c.clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Static,
+ UpdatedAt: c.clock.Now(),
},
},
}
@@ -910,10 +902,10 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
startAtEntryIndex: 1,
wantStaticEntries: []NeighborEntry{
{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Static,
- UpdatedAtNanos: c.clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Static,
+ UpdatedAt: c.clock.Now(),
},
},
}
@@ -947,10 +939,10 @@ func TestNeighborCacheClear(t *testing.T) {
EventType: entryTestAdded,
NICID: 1,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Static,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Static,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -973,20 +965,20 @@ func TestNeighborCacheClear(t *testing.T) {
EventType: entryTestRemoved,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
},
},
{
EventType: entryTestRemoved,
NICID: 1,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Static,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Static,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1027,10 +1019,10 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) {
EventType: entryTestRemoved,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAtNanos: c.clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAt: c.clock.Now(),
},
},
}
@@ -1062,13 +1054,13 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
clock := faketime.NewManualClock()
linkRes := newTestNeighborResolver(&nudDisp, config, clock)
- startedAt := clock.NowNanoseconds()
+ startedAt := clock.Now()
// The following logic is very similar to overflowCache, but
// periodically refreshes the frequently used entry.
// Fill the neighbor cache to capacity
- for i := 0; i < neighborCacheSize; i++ {
+ for i := uint16(0); i < neighborCacheSize; i++ {
entry, ok := linkRes.entries.entry(i)
if !ok {
t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i)
@@ -1084,7 +1076,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
}
// Keep adding more entries
- for i := neighborCacheSize; i < linkRes.entries.size(); i++ {
+ for i := uint16(neighborCacheSize); i < linkRes.entries.size(); i++ {
// Periodically refresh the frequently used entry
if i%(neighborCacheSize/2) == 0 {
if _, _, err := linkRes.neigh.entry(frequentlyUsedEntry.Addr, "", nil); err != nil {
@@ -1118,7 +1110,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
State: Reachable,
// Can be inferred since the frequently used entry is the first to
// be created and transitioned to Reachable.
- UpdatedAtNanos: startedAt + typicalLatency.Nanoseconds(),
+ UpdatedAt: startedAt.Add(typicalLatency),
},
}
@@ -1127,12 +1119,12 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if !ok {
t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i)
}
- durationReachableNanos := int64(linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds()
+ durationReachableNanos := time.Duration(linkRes.entries.size()-i-1) * typicalLatency
wantUnsortedEntries = append(wantUnsortedEntries, NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds() - durationReachableNanos,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAt: clock.Now().Add(-durationReachableNanos),
})
}
@@ -1190,12 +1182,12 @@ func TestNeighborCacheConcurrent(t *testing.T) {
if !ok {
t.Errorf("got linkRes.entries.entry(%d) = _, false, want = true", i)
}
- durationReachableNanos := int64(linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds()
+ durationReachableNanos := time.Duration(linkRes.entries.size()-i-1) * typicalLatency
wantUnsortedEntries = append(wantUnsortedEntries, NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds() - durationReachableNanos,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAt: clock.Now().Add(-durationReachableNanos),
})
}
@@ -1244,10 +1236,10 @@ func TestNeighborCacheReplace(t *testing.T) {
t.Fatalf("linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: updatedLinkAddr,
- State: Delay,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: updatedLinkAddr,
+ State: Delay,
+ UpdatedAt: clock.Now(),
}
if diff := cmp.Diff(want, e); diff != "" {
t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
@@ -1263,10 +1255,10 @@ func TestNeighborCacheReplace(t *testing.T) {
t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: updatedLinkAddr,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: updatedLinkAddr,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
}
if diff := cmp.Diff(want, e); diff != "" {
t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
@@ -1301,10 +1293,10 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
@@ -1405,10 +1397,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
EventType: entryTestAdded,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: "",
- State: Incomplete,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: "",
+ State: Incomplete,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1436,10 +1428,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
EventType: entryTestChanged,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: "",
- State: Unreachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: "",
+ State: Unreachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1455,10 +1447,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
{
wantEntries := []NeighborEntry{
{
- Addr: entry.Addr,
- LinkAddr: "",
- State: Unreachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: "",
+ State: Unreachable,
+ UpdatedAt: clock.Now(),
},
}
if diff := cmp.Diff(linkRes.neigh.entries(), wantEntries, unorderedEntriesDiffOpts()...); diff != "" {
@@ -1488,10 +1480,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
EventType: entryTestChanged,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: "",
- State: Incomplete,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: "",
+ State: Incomplete,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1518,10 +1510,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
EventType: entryTestChanged,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1541,10 +1533,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
}
wantEntry := NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
}
if diff := cmp.Diff(gotEntry, wantEntry); diff != "" {
t.Fatalf("neighbor entry mismatch (-got, +want):\n%s", diff)
@@ -1561,9 +1553,9 @@ func BenchmarkCacheClear(b *testing.B) {
linkRes.delay = 0
// Clear for every possible size of the cache
- for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ {
+ for cacheSize := uint16(0); cacheSize < neighborCacheSize; cacheSize++ {
// Fill the neighbor cache to capacity.
- for i := 0; i < cacheSize; i++ {
+ for i := uint16(0); i < cacheSize; i++ {
entry, ok := linkRes.entries.entry(i)
if !ok {
b.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i)
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 6d95e1664..0a59eecdd 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -31,10 +31,10 @@ const (
// NeighborEntry describes a neighboring device in the local network.
type NeighborEntry struct {
- Addr tcpip.Address
- LinkAddr tcpip.LinkAddress
- State NeighborState
- UpdatedAtNanos int64
+ Addr tcpip.Address
+ LinkAddr tcpip.LinkAddress
+ State NeighborState
+ UpdatedAt time.Time
}
// NeighborState defines the state of a NeighborEntry within the Neighbor
@@ -138,10 +138,10 @@ func newNeighborEntry(cache *neighborCache, remoteAddr tcpip.Address, nudState *
// calling `setStateLocked`.
func newStaticNeighborEntry(cache *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry {
entry := NeighborEntry{
- Addr: addr,
- LinkAddr: linkAddr,
- State: Static,
- UpdatedAtNanos: cache.nic.stack.clock.NowNanoseconds(),
+ Addr: addr,
+ LinkAddr: linkAddr,
+ State: Static,
+ UpdatedAt: cache.nic.stack.clock.Now(),
}
n := &neighborEntry{
cache: cache,
@@ -166,14 +166,20 @@ func (e *neighborEntry) notifyCompletionLocked(err tcpip.Error) {
if ch := e.mu.done; ch != nil {
close(ch)
e.mu.done = nil
- // Dequeue the pending packets in a new goroutine to not hold up the current
+ // Dequeue the pending packets asynchronously to not hold up the current
// goroutine as writing packets may be a costly operation.
//
// At the time of writing, when writing packets, a neighbor's link address
// is resolved (which ends up obtaining the entry's lock) while holding the
- // link resolution queue's lock. Dequeuing packets in a new goroutine avoids
- // a lock ordering violation.
- go e.cache.nic.linkResQueue.dequeue(ch, e.mu.neigh.LinkAddr, err)
+ // link resolution queue's lock. Dequeuing packets asynchronously avoids a
+ // lock ordering violation.
+ //
+ // NB: this is equivalent to spawning a goroutine directly using the go
+ // keyword but allows tests that use manual clocks to deterministically
+ // wait for this work to complete.
+ e.cache.nic.stack.clock.AfterFunc(0, func() {
+ e.cache.nic.linkResQueue.dequeue(ch, e.mu.neigh.LinkAddr, err)
+ })
}
}
@@ -224,7 +230,7 @@ func (e *neighborEntry) cancelTimerLocked() {
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) removeLocked() {
- e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds()
+ e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.Now()
e.dispatchRemoveEventLocked()
e.cancelTimerLocked()
// TODO(https://gvisor.dev/issues/5583): test the case where this function is
@@ -246,7 +252,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
prev := e.mu.neigh.State
e.mu.neigh.State = next
- e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds()
+ e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.Now()
config := e.nudState.Config()
switch next {
@@ -307,7 +313,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
// a shared lock.
e.mu.timer = timer{
done: &done,
- timer: e.cache.nic.stack.Clock().AfterFunc(0, func() {
+ timer: e.cache.nic.stack.Clock().AfterFunc(immediateDuration, func() {
var err tcpip.Error = &tcpip.ErrTimeout{}
if remaining != 0 {
err = e.cache.linkRes.LinkAddressRequest(addr, "" /* localAddr */, linkAddr)
@@ -354,14 +360,14 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
case Unknown, Unreachable:
prev := e.mu.neigh.State
e.mu.neigh.State = Incomplete
- e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds()
+ e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.Now()
switch prev {
case Unknown:
e.dispatchAddEventLocked()
case Unreachable:
e.dispatchChangeEventLocked()
- e.cache.nic.stats.Neighbor.UnreachableEntryLookups.Increment()
+ e.cache.nic.stats.neighbor.unreachableEntryLookups.Increment()
}
config := e.nudState.Config()
@@ -378,7 +384,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
// a shared lock.
e.mu.timer = timer{
done: &done,
- timer: e.cache.nic.stack.Clock().AfterFunc(0, func() {
+ timer: e.cache.nic.stack.Clock().AfterFunc(immediateDuration, func() {
var err tcpip.Error = &tcpip.ErrTimeout{}
if remaining != 0 {
// As per RFC 4861 section 7.2.2:
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index 1d39ee73d..59d86d6d4 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -36,11 +36,6 @@ const (
entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01")
entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02")
-
- // entryTestNetDefaultMTU is the MTU, in bytes, used throughout the tests,
- // except where another value is explicitly used. It is chosen to match the
- // MTU of loopback interfaces on Linux systems.
- entryTestNetDefaultMTU = 65536
)
var (
@@ -196,13 +191,13 @@ func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.A
// ResolveStaticAddress attempts to resolve address without sending requests.
// It either resolves the name immediately or returns the empty LinkAddress.
-func (r *entryTestLinkResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+func (*entryTestLinkResolver) ResolveStaticAddress(tcpip.Address) (tcpip.LinkAddress, bool) {
return "", false
}
// LinkAddressProtocol returns the network protocol of the addresses this
// resolver can resolve.
-func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+func (*entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return entryTestNetNumber
}
@@ -219,7 +214,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e
nudConfigs: c,
randomGenerator: rand.New(rand.NewSource(time.Now().UnixNano())),
},
- stats: makeNICStats(),
+ stats: makeNICStats(tcpip.NICStats{}.FillIn()),
}
netEP := (&testIPv6Protocol{}).NewEndpoint(&nic, nil)
nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{
@@ -354,10 +349,10 @@ func unknownToIncomplete(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *
EventType: entryTestAdded,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -415,10 +410,10 @@ func unknownToStale(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entry
EventType: entryTestAdded,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -446,7 +441,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
// UpdatedAt should remain the same during address resolution.
e.mu.Lock()
- startedAt := e.mu.neigh.UpdatedAtNanos
+ startedAt := e.mu.neigh.UpdatedAt
e.mu.Unlock()
// Wait for the rest of the reachability probe transmissions, signifying
@@ -470,7 +465,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
}
e.mu.Lock()
- if got, want := e.mu.neigh.UpdatedAtNanos, startedAt; got != want {
+ if got, want := e.mu.neigh.UpdatedAt, startedAt; got != want {
t.Errorf("got e.mu.neigh.UpdatedAt = %q, want = %q", got, want)
}
e.mu.Unlock()
@@ -485,10 +480,10 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Unreachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Unreachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -547,10 +542,10 @@ func incompleteToReachableWithFlags(e *neighborEntry, nudDisp *testNUDDispatcher
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -644,10 +639,10 @@ func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) {
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -678,10 +673,10 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) {
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -757,10 +752,10 @@ func incompleteToUnreachable(c NUDConfigurations, e *neighborEntry, nudDisp *tes
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Unreachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Unreachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -943,10 +938,10 @@ func reachableToStale(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDDis
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -998,10 +993,10 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1050,10 +1045,10 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T)
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1102,10 +1097,10 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1191,10 +1186,10 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1243,10 +1238,10 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1284,10 +1279,10 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1332,10 +1327,10 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1391,10 +1386,10 @@ func staleToDelay(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTe
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1443,10 +1438,10 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1498,10 +1493,10 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1553,10 +1548,10 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1645,10 +1640,10 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1697,10 +1692,10 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1770,10 +1765,10 @@ func delayToProbe(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDDispatc
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1827,10 +1822,10 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1882,10 +1877,10 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -2003,10 +1998,10 @@ func probeToReachableWithFlags(e *neighborEntry, nudDisp *testNUDDispatcher, lin
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: linkAddr,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: linkAddr,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -2155,10 +2150,10 @@ func probeToUnreachable(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDD
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Unreachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Unreachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -2227,10 +2222,10 @@ func unreachableToIncomplete(e *neighborEntry, nudDisp *testNUDDispatcher, linkR
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ UpdatedAt: clock.Now(),
},
},
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 8d615500f..9cac6bbd1 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -51,7 +51,7 @@ type nic struct {
name string
context NICContext
- stats NICStats
+ stats sharedStats
// The network endpoints themselves may be modified by calling the interface's
// methods, but the map reference and entries must be constant.
@@ -78,26 +78,13 @@ type nic struct {
}
}
-// NICStats hold statistics for a NIC.
-type NICStats struct {
- Tx DirectionStats
- Rx DirectionStats
-
- DisabledRx DirectionStats
-
- Neighbor NeighborStats
-}
-
-func makeNICStats() NICStats {
- var s NICStats
- tcpip.InitStatCounters(reflect.ValueOf(&s).Elem())
- return s
-}
-
-// DirectionStats includes packet and byte counts.
-type DirectionStats struct {
- Packets *tcpip.StatCounter
- Bytes *tcpip.StatCounter
+// makeNICStats initializes the NIC statistics and associates them to the global
+// NIC statistics.
+func makeNICStats(global tcpip.NICStats) sharedStats {
+ var stats sharedStats
+ tcpip.InitStatCounters(reflect.ValueOf(&stats.local).Elem())
+ stats.init(&stats.local, &global)
+ return stats
}
type packetEndpointList struct {
@@ -150,7 +137,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
id: id,
name: name,
context: ctx,
- stats: makeNICStats(),
+ stats: makeNICStats(stack.Stats().NICs),
networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]*linkResolver),
duplicateAddressDetectors: make(map[tcpip.NetworkProtocolNumber]DuplicateAddressDetector),
@@ -382,8 +369,8 @@ func (n *nic) writePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt
return err
}
- n.stats.Tx.Packets.Increment()
- n.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
+ n.stats.tx.packets.Increment()
+ n.stats.tx.bytes.IncrementBy(uint64(numBytes))
return nil
}
@@ -399,13 +386,13 @@ func (n *nic) writePackets(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pk
}
writtenPackets, err := n.LinkEndpoint.WritePackets(r, pkts, protocol)
- n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets))
+ n.stats.tx.packets.IncrementBy(uint64(writtenPackets))
writtenBytes := 0
for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() {
writtenBytes += pb.Size()
}
- n.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes))
+ n.stats.tx.bytes.IncrementBy(uint64(writtenBytes))
return writtenPackets, err
}
@@ -718,18 +705,18 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
if !enabled {
n.mu.RUnlock()
- n.stats.DisabledRx.Packets.Increment()
- n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data().Size()))
+ n.stats.disabledRx.packets.Increment()
+ n.stats.disabledRx.bytes.IncrementBy(uint64(pkt.Data().Size()))
return
}
- n.stats.Rx.Packets.Increment()
- n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data().Size()))
+ n.stats.rx.packets.Increment()
+ n.stats.rx.bytes.IncrementBy(uint64(pkt.Data().Size()))
networkEndpoint, ok := n.networkEndpoints[protocol]
if !ok {
n.mu.RUnlock()
- n.stack.stats.UnknownProtocolRcvdPackets.Increment()
+ n.stats.unknownL3ProtocolRcvdPackets.Increment()
return
}
@@ -786,7 +773,7 @@ func (n *nic) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc
func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
- n.stack.stats.UnknownProtocolRcvdPackets.Increment()
+ n.stats.unknownL4ProtocolRcvdPackets.Increment()
return TransportPacketProtocolUnreachable
}
@@ -800,27 +787,26 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt
// TransportHeader is empty only when pkt is an ICMP packet or was reassembled
// from fragments.
if pkt.TransportHeader().View().IsEmpty() {
- // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
- // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a
- // full explanation.
+ // ICMP packets don't have their TransportHeader fields set yet, parse it
+ // here. See icmp/protocol.go:protocol.Parse for a full explanation.
if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber {
// ICMP packets may be longer, but until icmp.Parse is implemented, here
// we parse it using the minimum size.
if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok {
- n.stack.stats.MalformedRcvdPackets.Increment()
+ n.stats.malformedL4RcvdPackets.Increment()
// We consider a malformed transport packet handled because there is
// nothing the caller can do.
return TransportPacketHandled
}
} else if !transProto.Parse(pkt) {
- n.stack.stats.MalformedRcvdPackets.Increment()
+ n.stats.malformedL4RcvdPackets.Increment()
return TransportPacketHandled
}
}
srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View())
if err != nil {
- n.stack.stats.MalformedRcvdPackets.Increment()
+ n.stats.malformedL4RcvdPackets.Increment()
return TransportPacketHandled
}
@@ -852,7 +838,7 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt
// If it doesn't handle it then we should do so.
switch res := transProto.HandleUnknownDestinationPacket(id, pkt); res {
case UnknownDestinationPacketMalformed:
- n.stack.stats.MalformedRcvdPackets.Increment()
+ n.stats.malformedL4RcvdPackets.Increment()
return TransportPacketHandled
case UnknownDestinationPacketUnhandled:
return TransportPacketDestinationPortUnreachable
@@ -1000,3 +986,32 @@ func (n *nic) checkDuplicateAddress(protocol tcpip.NetworkProtocolNumber, addr t
return d.CheckDuplicateAddress(addr, h), nil
}
+
+func (n *nic) setForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error {
+ ep := n.getNetworkEndpoint(protocol)
+ if ep == nil {
+ return &tcpip.ErrUnknownProtocol{}
+ }
+
+ forwardingEP, ok := ep.(ForwardingNetworkEndpoint)
+ if !ok {
+ return &tcpip.ErrNotSupported{}
+ }
+
+ forwardingEP.SetForwarding(enable)
+ return nil
+}
+
+func (n *nic) forwarding(protocol tcpip.NetworkProtocolNumber) (bool, tcpip.Error) {
+ ep := n.getNetworkEndpoint(protocol)
+ if ep == nil {
+ return false, &tcpip.ErrUnknownProtocol{}
+ }
+
+ forwardingEP, ok := ep.(ForwardingNetworkEndpoint)
+ if !ok {
+ return false, &tcpip.ErrNotSupported{}
+ }
+
+ return forwardingEP.Forwarding(), nil
+}
diff --git a/pkg/tcpip/stack/nic_stats.go b/pkg/tcpip/stack/nic_stats.go
new file mode 100644
index 000000000..1773d5e8d
--- /dev/null
+++ b/pkg/tcpip/stack/nic_stats.go
@@ -0,0 +1,74 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+type sharedStats struct {
+ local tcpip.NICStats
+ multiCounterNICStats
+}
+
+// LINT.IfChange(multiCounterNICPacketStats)
+
+type multiCounterNICPacketStats struct {
+ packets tcpip.MultiCounterStat
+ bytes tcpip.MultiCounterStat
+}
+
+func (m *multiCounterNICPacketStats) init(a, b *tcpip.NICPacketStats) {
+ m.packets.Init(a.Packets, b.Packets)
+ m.bytes.Init(a.Bytes, b.Bytes)
+}
+
+// LINT.ThenChange(../../tcpip.go:NICPacketStats)
+
+// LINT.IfChange(multiCounterNICNeighborStats)
+
+type multiCounterNICNeighborStats struct {
+ unreachableEntryLookups tcpip.MultiCounterStat
+}
+
+func (m *multiCounterNICNeighborStats) init(a, b *tcpip.NICNeighborStats) {
+ m.unreachableEntryLookups.Init(a.UnreachableEntryLookups, b.UnreachableEntryLookups)
+}
+
+// LINT.ThenChange(../../tcpip.go:NICNeighborStats)
+
+// LINT.IfChange(multiCounterNICStats)
+
+type multiCounterNICStats struct {
+ unknownL3ProtocolRcvdPackets tcpip.MultiCounterStat
+ unknownL4ProtocolRcvdPackets tcpip.MultiCounterStat
+ malformedL4RcvdPackets tcpip.MultiCounterStat
+ tx multiCounterNICPacketStats
+ rx multiCounterNICPacketStats
+ disabledRx multiCounterNICPacketStats
+ neighbor multiCounterNICNeighborStats
+}
+
+func (m *multiCounterNICStats) init(a, b *tcpip.NICStats) {
+ m.unknownL3ProtocolRcvdPackets.Init(a.UnknownL3ProtocolRcvdPackets, b.UnknownL3ProtocolRcvdPackets)
+ m.unknownL4ProtocolRcvdPackets.Init(a.UnknownL4ProtocolRcvdPackets, b.UnknownL4ProtocolRcvdPackets)
+ m.malformedL4RcvdPackets.Init(a.MalformedL4RcvdPackets, b.MalformedL4RcvdPackets)
+ m.tx.init(&a.Tx, &b.Tx)
+ m.rx.init(&a.Rx, &b.Rx)
+ m.disabledRx.init(&a.DisabledRx, &b.DisabledRx)
+ m.neighbor.init(&a.Neighbor, &b.Neighbor)
+}
+
+// LINT.ThenChange(../../tcpip.go:NICStats)
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 8a3005295..5cb342f78 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -15,11 +15,13 @@
package stack
import (
+ "reflect"
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
var _ AddressableEndpoint = (*testIPv6Endpoint)(nil)
@@ -171,19 +173,19 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
// When the NIC is disabled, the only field that matters is the stats field.
// This test is limited to stats counter checks.
nic := nic{
- stats: makeNICStats(),
+ stats: makeNICStats(tcpip.NICStats{}.FillIn()),
}
- if got := nic.stats.DisabledRx.Packets.Value(); got != 0 {
+ if got := nic.stats.local.DisabledRx.Packets.Value(); got != 0 {
t.Errorf("got DisabledRx.Packets = %d, want = 0", got)
}
- if got := nic.stats.DisabledRx.Bytes.Value(); got != 0 {
+ if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 0 {
t.Errorf("got DisabledRx.Bytes = %d, want = 0", got)
}
- if got := nic.stats.Rx.Packets.Value(); got != 0 {
+ if got := nic.stats.local.Rx.Packets.Value(); got != 0 {
t.Errorf("got Rx.Packets = %d, want = 0", got)
}
- if got := nic.stats.Rx.Bytes.Value(); got != 0 {
+ if got := nic.stats.local.Rx.Bytes.Value(); got != 0 {
t.Errorf("got Rx.Bytes = %d, want = 0", got)
}
@@ -195,16 +197,28 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView(),
}))
- if got := nic.stats.DisabledRx.Packets.Value(); got != 1 {
+ if got := nic.stats.local.DisabledRx.Packets.Value(); got != 1 {
t.Errorf("got DisabledRx.Packets = %d, want = 1", got)
}
- if got := nic.stats.DisabledRx.Bytes.Value(); got != 4 {
+ if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 4 {
t.Errorf("got DisabledRx.Bytes = %d, want = 4", got)
}
- if got := nic.stats.Rx.Packets.Value(); got != 0 {
+ if got := nic.stats.local.Rx.Packets.Value(); got != 0 {
t.Errorf("got Rx.Packets = %d, want = 0", got)
}
- if got := nic.stats.Rx.Bytes.Value(); got != 0 {
+ if got := nic.stats.local.Rx.Bytes.Value(); got != 0 {
t.Errorf("got Rx.Bytes = %d, want = 0", got)
}
}
+
+func TestMultiCounterStatsInitialization(t *testing.T) {
+ global := tcpip.NICStats{}.FillIn()
+ nic := nic{
+ stats: makeNICStats(global),
+ }
+ multi := nic.stats.multiCounterNICStats
+ local := nic.stats.local
+ if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&multi).Elem(), []reflect.Value{reflect.ValueOf(&local).Elem(), reflect.ValueOf(&global).Elem()}); err != nil {
+ t.Error(err)
+ }
+}
diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go
index 5a94e9ac6..ca9822bca 100644
--- a/pkg/tcpip/stack/nud.go
+++ b/pkg/tcpip/stack/nud.go
@@ -16,6 +16,7 @@ package stack
import (
"math"
+ "math/rand"
"sync"
"time"
@@ -313,45 +314,36 @@ func calcMaxRandomFactor(minRandomFactor float32) float32 {
return defaultMaxRandomFactor
}
-// A Rand is a source of random numbers.
-type Rand interface {
- // Float32 returns, as a float32, a pseudo-random number in [0.0,1.0).
- Float32() float32
-}
-
// NUDState stores states needed for calculating reachable time.
type NUDState struct {
- rng Rand
+ clock tcpip.Clock
+ rng *rand.Rand
- // mu protects the fields below.
- //
- // It is necessary for NUDState to handle its own locking since neighbor
- // entries may access the NUD state from within the goroutine spawned by
- // time.AfterFunc(). This goroutine may run concurrently with the main
- // process for controlling the neighbor cache and would otherwise introduce
- // race conditions if NUDState was not locked properly.
- mu sync.RWMutex
+ mu struct {
+ sync.RWMutex
- config NUDConfigurations
+ config NUDConfigurations
- // reachableTime is the duration to wait for a REACHABLE entry to
- // transition into STALE after inactivity. This value is calculated with
- // the algorithm defined in RFC 4861 section 6.3.2.
- reachableTime time.Duration
+ // reachableTime is the duration to wait for a REACHABLE entry to
+ // transition into STALE after inactivity. This value is calculated with
+ // the algorithm defined in RFC 4861 section 6.3.2.
+ reachableTime time.Duration
- expiration time.Time
- prevBaseReachableTime time.Duration
- prevMinRandomFactor float32
- prevMaxRandomFactor float32
+ expiration time.Time
+ prevBaseReachableTime time.Duration
+ prevMinRandomFactor float32
+ prevMaxRandomFactor float32
+ }
}
// NewNUDState returns new NUDState using c as configuration and the specified
// random number generator for use in recomputing ReachableTime.
-func NewNUDState(c NUDConfigurations, rng Rand) *NUDState {
+func NewNUDState(c NUDConfigurations, clock tcpip.Clock, rng *rand.Rand) *NUDState {
s := &NUDState{
- rng: rng,
+ clock: clock,
+ rng: rng,
}
- s.config = c
+ s.mu.config = c
return s
}
@@ -359,14 +351,14 @@ func NewNUDState(c NUDConfigurations, rng Rand) *NUDState {
func (s *NUDState) Config() NUDConfigurations {
s.mu.RLock()
defer s.mu.RUnlock()
- return s.config
+ return s.mu.config
}
// SetConfig replaces the existing NUD configurations with c.
func (s *NUDState) SetConfig(c NUDConfigurations) {
s.mu.Lock()
defer s.mu.Unlock()
- s.config = c
+ s.mu.config = c
}
// ReachableTime returns the duration to wait for a REACHABLE entry to
@@ -377,13 +369,13 @@ func (s *NUDState) ReachableTime() time.Duration {
s.mu.Lock()
defer s.mu.Unlock()
- if time.Now().After(s.expiration) ||
- s.config.BaseReachableTime != s.prevBaseReachableTime ||
- s.config.MinRandomFactor != s.prevMinRandomFactor ||
- s.config.MaxRandomFactor != s.prevMaxRandomFactor {
+ if s.clock.Now().After(s.mu.expiration) ||
+ s.mu.config.BaseReachableTime != s.mu.prevBaseReachableTime ||
+ s.mu.config.MinRandomFactor != s.mu.prevMinRandomFactor ||
+ s.mu.config.MaxRandomFactor != s.mu.prevMaxRandomFactor {
s.recomputeReachableTimeLocked()
}
- return s.reachableTime
+ return s.mu.reachableTime
}
// recomputeReachableTimeLocked forces a recalculation of ReachableTime using
@@ -408,23 +400,23 @@ func (s *NUDState) ReachableTime() time.Duration {
//
// s.mu MUST be locked for writing.
func (s *NUDState) recomputeReachableTimeLocked() {
- s.prevBaseReachableTime = s.config.BaseReachableTime
- s.prevMinRandomFactor = s.config.MinRandomFactor
- s.prevMaxRandomFactor = s.config.MaxRandomFactor
+ s.mu.prevBaseReachableTime = s.mu.config.BaseReachableTime
+ s.mu.prevMinRandomFactor = s.mu.config.MinRandomFactor
+ s.mu.prevMaxRandomFactor = s.mu.config.MaxRandomFactor
- randomFactor := s.config.MinRandomFactor + s.rng.Float32()*(s.config.MaxRandomFactor-s.config.MinRandomFactor)
+ randomFactor := s.mu.config.MinRandomFactor + s.rng.Float32()*(s.mu.config.MaxRandomFactor-s.mu.config.MinRandomFactor)
// Check for overflow, given that minRandomFactor and maxRandomFactor are
// guaranteed to be positive numbers.
- if float32(math.MaxInt64)/randomFactor < float32(s.config.BaseReachableTime) {
- s.reachableTime = time.Duration(math.MaxInt64)
+ if math.MaxInt64/randomFactor < float32(s.mu.config.BaseReachableTime) {
+ s.mu.reachableTime = time.Duration(math.MaxInt64)
} else if randomFactor == 1 {
// Avoid loss of precision when a large base reachable time is used.
- s.reachableTime = s.config.BaseReachableTime
+ s.mu.reachableTime = s.mu.config.BaseReachableTime
} else {
- reachableTime := int64(float32(s.config.BaseReachableTime) * randomFactor)
- s.reachableTime = time.Duration(reachableTime)
+ reachableTime := int64(float32(s.mu.config.BaseReachableTime) * randomFactor)
+ s.mu.reachableTime = time.Duration(reachableTime)
}
- s.expiration = time.Now().Add(2 * time.Hour)
+ s.mu.expiration = s.clock.Now().Add(2 * time.Hour)
}
diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go
index e1253f310..1aeb2f8a5 100644
--- a/pkg/tcpip/stack/nud_test.go
+++ b/pkg/tcpip/stack/nud_test.go
@@ -16,6 +16,7 @@ package stack_test
import (
"math"
+ "math/rand"
"testing"
"time"
@@ -28,17 +29,15 @@ import (
)
const (
- defaultBaseReachableTime = 30 * time.Second
- minimumBaseReachableTime = time.Millisecond
- defaultMinRandomFactor = 0.5
- defaultMaxRandomFactor = 1.5
- defaultRetransmitTimer = time.Second
- minimumRetransmitTimer = time.Millisecond
- defaultDelayFirstProbeTime = 5 * time.Second
- defaultMaxMulticastProbes = 3
- defaultMaxUnicastProbes = 3
- defaultMaxAnycastDelayTime = time.Second
- defaultMaxReachbilityConfirmations = 3
+ defaultBaseReachableTime = 30 * time.Second
+ minimumBaseReachableTime = time.Millisecond
+ defaultMinRandomFactor = 0.5
+ defaultMaxRandomFactor = 1.5
+ defaultRetransmitTimer = time.Second
+ minimumRetransmitTimer = time.Millisecond
+ defaultDelayFirstProbeTime = 5 * time.Second
+ defaultMaxMulticastProbes = 3
+ defaultMaxUnicastProbes = 3
defaultFakeRandomNum = 0.5
)
@@ -48,12 +47,14 @@ type fakeRand struct {
num float32
}
-var _ stack.Rand = (*fakeRand)(nil)
+var _ rand.Source = (*fakeRand)(nil)
-func (f *fakeRand) Float32() float32 {
- return f.num
+func (f *fakeRand) Int63() int64 {
+ return int64(f.num * float32(1<<63))
}
+func (*fakeRand) Seed(int64) {}
+
func TestNUDFunctions(t *testing.T) {
const nicID = 1
@@ -169,7 +170,7 @@ func TestNUDFunctions(t *testing.T) {
t.Errorf("s.Neigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff)
} else if test.expectedErr == nil {
if diff := cmp.Diff(
- []stack.NeighborEntry{{Addr: llAddr2, LinkAddr: linkAddr1, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}},
+ []stack.NeighborEntry{{Addr: llAddr2, LinkAddr: linkAddr1, State: stack.Static, UpdatedAt: clock.Now()}},
neighbors,
); diff != "" {
t.Errorf("neighbors mismatch (-want +got):\n%s", diff)
@@ -710,7 +711,8 @@ func TestNUDStateReachableTime(t *testing.T) {
rng := fakeRand{
num: defaultFakeRandomNum,
}
- s := stack.NewNUDState(c, &rng)
+ var clock faketime.NullClock
+ s := stack.NewNUDState(c, &clock, rand.New(&rng))
if got, want := s.ReachableTime(), test.want; got != want {
t.Errorf("got ReachableTime = %q, want = %q", got, want)
}
@@ -782,7 +784,8 @@ func TestNUDStateRecomputeReachableTime(t *testing.T) {
rng := fakeRand{
num: defaultFakeRandomNum,
}
- s := stack.NewNUDState(c, &rng)
+ var clock faketime.NullClock
+ s := stack.NewNUDState(c, &clock, rand.New(&rng))
old := s.ReachableTime()
if got, want := s.ReachableTime(), old; got != want {
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index e2e073091..9192d8433 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -134,7 +134,7 @@ type PacketBuffer struct {
// https://www.man7.org/linux/man-pages/man7/packet.7.html.
PktType tcpip.PacketType
- // NICID is the ID of the interface the network packet was received at.
+ // NICID is the ID of the last interface the network packet was handled at.
NICID tcpip.NICID
// RXTransportChecksumValidated indicates that transport checksum verification
@@ -245,10 +245,10 @@ func (pk *PacketBuffer) dataOffset() int {
func (pk *PacketBuffer) push(typ headerType, size int) tcpipbuffer.View {
h := &pk.headers[typ]
if h.length > 0 {
- panic(fmt.Sprintf("push must not be called twice: type %s", typ))
+ panic(fmt.Sprintf("push(%s, %d) called after previous push", typ, size))
}
if pk.pushed+size > pk.reserved {
- panic("not enough headroom reserved")
+ panic(fmt.Sprintf("push(%s, %d) overflows; pushed=%d reserved=%d", typ, size, pk.pushed, pk.reserved))
}
pk.pushed += size
h.offset = -pk.pushed
@@ -261,7 +261,7 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v tcpipbuffer.View, c
if h.length > 0 {
panic(fmt.Sprintf("consume must not be called twice: type %s", typ))
}
- if pk.headerOffset()+pk.consumed+size > int(pk.buf.Size()) {
+ if pk.reserved+pk.consumed+size > int(pk.buf.Size()) {
return nil, false
}
h.offset = pk.consumed
diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go
index 1c1aeb950..a8da34992 100644
--- a/pkg/tcpip/stack/packet_buffer_test.go
+++ b/pkg/tcpip/stack/packet_buffer_test.go
@@ -259,6 +259,37 @@ func TestPacketHeaderPushConsumeMixed(t *testing.T) {
})
}
+func TestPacketHeaderPushConsumeMixedTooLong(t *testing.T) {
+ link := makeView(10)
+ network := makeView(20)
+ data := makeView(30)
+
+ initData := concatViews(network, data)
+ pk := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: len(link),
+ Data: buffer.NewViewFromBytes(initData).ToVectorisedView(),
+ })
+
+ // 1. Push link header
+ copy(pk.LinkHeader().Push(len(link)), link)
+
+ checkPacketContents(t, "" /* prefix */, pk, packetContents{
+ link: link,
+ data: initData,
+ })
+
+ // 2. Consume network header, with a number of bytes too large.
+ gotNetwork, ok := pk.NetworkHeader().Consume(len(initData) + 1)
+ if ok {
+ t.Fatalf("pk.NetworkHeader().Consume(%d) = %q, true; want _, false", len(initData)+1, gotNetwork)
+ }
+
+ checkPacketContents(t, "" /* prefix */, pk, packetContents{
+ link: link,
+ data: initData,
+ })
+}
+
func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) {
const headerSize = 10
diff --git a/pkg/tcpip/stack/rand.go b/pkg/tcpip/stack/rand.go
index 421fb5c15..c8294eb6e 100644
--- a/pkg/tcpip/stack/rand.go
+++ b/pkg/tcpip/stack/rand.go
@@ -15,7 +15,7 @@
package stack
import (
- mathrand "math/rand"
+ "math/rand"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -23,7 +23,7 @@ import (
// lockedRandomSource provides a threadsafe rand.Source.
type lockedRandomSource struct {
mu sync.Mutex
- src mathrand.Source
+ src rand.Source
}
func (r *lockedRandomSource) Int63() (n int64) {
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index a82c807b4..a038389e0 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -420,7 +420,7 @@ const (
PermanentExpired
// Temporary is an endpoint, created on a one-off basis to temporarily
- // consider the NIC bound an an address that it is not explictiy bound to
+ // consider the NIC bound an an address that it is not explicitly bound to
// (such as a permanent address). Its reference count must not be biased by 1
// so that the address is removed immediately when references to it are no
// longer held.
@@ -630,7 +630,7 @@ type NetworkEndpoint interface {
// HandlePacket takes ownership of pkt.
HandlePacket(pkt *PacketBuffer)
- // Close is called when the endpoint is reomved from a stack.
+ // Close is called when the endpoint is removed from a stack.
Close()
// NetworkProtocolNumber returns the tcpip.NetworkProtocolNumber for
@@ -658,9 +658,9 @@ type IPNetworkEndpointStats interface {
IPStats() *tcpip.IPStats
}
-// ForwardingNetworkProtocol is a NetworkProtocol that may forward packets.
-type ForwardingNetworkProtocol interface {
- NetworkProtocol
+// ForwardingNetworkEndpoint is a network endpoint that may forward packets.
+type ForwardingNetworkEndpoint interface {
+ NetworkEndpoint
// Forwarding returns the forwarding configuration.
Forwarding() bool
@@ -968,7 +968,7 @@ type DuplicateAddressDetector interface {
// called with the result of the original DAD request.
CheckDuplicateAddress(tcpip.Address, DADCompletionHandler) DADCheckAddressDisposition
- // SetDADConfiguations sets the configurations for DAD.
+ // SetDADConfigurations sets the configurations for DAD.
SetDADConfigurations(c DADConfigurations)
// DuplicateAddressProtocol returns the network protocol the receiver can
@@ -979,7 +979,7 @@ type DuplicateAddressDetector interface {
// LinkAddressResolver handles link address resolution for a network protocol.
type LinkAddressResolver interface {
// LinkAddressRequest sends a request for the link address of the target
- // address. The request is broadcasted on the local network if a remote link
+ // address. The request is broadcast on the local network if a remote link
// address is not provided.
LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error
@@ -1072,4 +1072,4 @@ type GSOEndpoint interface {
// SoftwareGSOMaxSize is a maximum allowed size of a software GSO segment.
// This isn't a hard limit, because it is never set into packet headers.
-const SoftwareGSOMaxSize = (1 << 16)
+const SoftwareGSOMaxSize = 1 << 16
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 8a044c073..f17c04277 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -446,7 +446,7 @@ func (r *Route) isValidForOutgoingRLocked() bool {
// If the source NIC and outgoing NIC are different, make sure the stack has
// forwarding enabled, or the packet will be handled locally.
- if r.outgoingNIC != r.localAddressNIC && !r.outgoingNIC.stack.Forwarding(r.NetProto()) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto(), r.RemoteAddress())) {
+ if r.outgoingNIC != r.localAddressNIC && !isNICForwarding(r.localAddressNIC, r.NetProto()) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto(), r.RemoteAddress())) {
return false
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 483a960c8..81fabe29a 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -20,17 +20,16 @@
package stack
import (
- "bytes"
"encoding/binary"
"fmt"
"io"
- mathrand "math/rand"
+ "math/rand"
"sync/atomic"
"time"
"golang.org/x/time/rate"
"gvisor.dev/gvisor/pkg/atomicbitops"
- "gvisor.dev/gvisor/pkg/rand"
+ cryptorand "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -40,13 +39,6 @@ import (
)
const (
- // ageLimit is set to the same cache stale time used in Linux.
- ageLimit = 1 * time.Minute
- // resolutionTimeout is set to the same ARP timeout used in Linux.
- resolutionTimeout = 1 * time.Second
- // resolutionAttempts is set to the same ARP retries used in Linux.
- resolutionAttempts = 3
-
// DefaultTOS is the default type of service value for network endpoints.
DefaultTOS = 0
)
@@ -95,8 +87,9 @@ type Stack struct {
}
}
- mu sync.RWMutex
- nics map[tcpip.NICID]*nic
+ mu sync.RWMutex
+ nics map[tcpip.NICID]*nic
+ defaultForwardingEnabled map[tcpip.NetworkProtocolNumber]struct{}
// cleanupEndpointsMu protects cleanupEndpoints.
cleanupEndpointsMu sync.Mutex
@@ -115,7 +108,7 @@ type Stack struct {
handleLocal bool
// tables are the iptables packet filtering and manipulation rules.
- // TODO(gvisor.dev/issue/170): S/R this field.
+ // TODO(gvisor.dev/issue/4595): S/R this field.
tables *IPTables
// resumableEndpoints is a list of endpoints that need to be resumed if the
@@ -144,7 +137,7 @@ type Stack struct {
// randomGenerator is an injectable pseudo random generator that can be
// used when a random number is required.
- randomGenerator *mathrand.Rand
+ randomGenerator *rand.Rand
// secureRNG is a cryptographically secure random number generator.
secureRNG io.Reader
@@ -195,9 +188,9 @@ type Options struct {
// TransportProtocols lists the transport protocols to enable.
TransportProtocols []TransportProtocolFactory
- // Clock is an optional clock source used for timestampping packets.
+ // Clock is an optional clock used for timekeeping.
//
- // If no Clock is specified, the clock source will be time.Now.
+ // If Clock is nil, tcpip.NewStdClock() will be used.
Clock tcpip.Clock
// Stats are optional statistic counters.
@@ -224,15 +217,21 @@ type Options struct {
// RandSource is an optional source to use to generate random
// numbers. If omitted it defaults to a Source seeded by the data
- // returned by rand.Read().
+ // returned by the stack secure RNG.
//
// RandSource must be thread-safe.
- RandSource mathrand.Source
+ RandSource rand.Source
- // IPTables are the initial iptables rules. If nil, iptables will allow
+ // IPTables are the initial iptables rules. If nil, DefaultIPTables will be
+ // used to construct the initial iptables rules.
// all traffic.
IPTables *IPTables
+ // DefaultIPTables is an optional iptables rules constructor that is called
+ // if IPTables is nil. If both fields are nil, iptables will allow all
+ // traffic.
+ DefaultIPTables func(uint32) *IPTables
+
// SecureRNG is a cryptographically secure random number generator.
SecureRNG io.Reader
}
@@ -330,40 +329,50 @@ func New(opts Options) *Stack {
opts.UniqueID = new(uniqueIDGenerator)
}
+ if opts.SecureRNG == nil {
+ opts.SecureRNG = cryptorand.Reader
+ }
+
randSrc := opts.RandSource
if randSrc == nil {
- // Source provided by mathrand.NewSource is not thread-safe so
+ var v int64
+ if err := binary.Read(opts.SecureRNG, binary.LittleEndian, &v); err != nil {
+ panic(err)
+ }
+ // Source provided by rand.NewSource is not thread-safe so
// we wrap it in a simple thread-safe version.
- randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())}
+ randSrc = &lockedRandomSource{src: rand.NewSource(v)}
}
+ randomGenerator := rand.New(randSrc)
+ seed := randomGenerator.Uint32()
if opts.IPTables == nil {
- opts.IPTables = DefaultTables()
+ if opts.DefaultIPTables == nil {
+ opts.DefaultIPTables = DefaultTables
+ }
+ opts.IPTables = opts.DefaultIPTables(seed)
}
opts.NUDConfigs.resetInvalidFields()
- if opts.SecureRNG == nil {
- opts.SecureRNG = rand.Reader
- }
-
s := &Stack{
- transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
- networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
- nics: make(map[tcpip.NICID]*nic),
- cleanupEndpoints: make(map[TransportEndpoint]struct{}),
- PortManager: ports.NewPortManager(),
- clock: clock,
- stats: opts.Stats.FillIn(),
- handleLocal: opts.HandleLocal,
- tables: opts.IPTables,
- icmpRateLimiter: NewICMPRateLimiter(),
- seed: generateRandUint32(),
- nudConfigs: opts.NUDConfigs,
- uniqueIDGenerator: opts.UniqueID,
- nudDisp: opts.NUDDisp,
- randomGenerator: mathrand.New(randSrc),
- secureRNG: opts.SecureRNG,
+ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
+ networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
+ nics: make(map[tcpip.NICID]*nic),
+ defaultForwardingEnabled: make(map[tcpip.NetworkProtocolNumber]struct{}),
+ cleanupEndpoints: make(map[TransportEndpoint]struct{}),
+ PortManager: ports.NewPortManager(),
+ clock: clock,
+ stats: opts.Stats.FillIn(),
+ handleLocal: opts.HandleLocal,
+ tables: opts.IPTables,
+ icmpRateLimiter: NewICMPRateLimiter(),
+ seed: seed,
+ nudConfigs: opts.NUDConfigs,
+ uniqueIDGenerator: opts.UniqueID,
+ nudDisp: opts.NUDDisp,
+ randomGenerator: randomGenerator,
+ secureRNG: opts.SecureRNG,
sendBufferSize: tcpip.SendBufferSizeOption{
Min: MinBufferSize,
Default: DefaultBufferSize,
@@ -492,37 +501,61 @@ func (s *Stack) Stats() tcpip.Stats {
return s.stats
}
-// SetForwardingDefaultAndAllNICs sets packet forwarding for all NICs for the
-// passed protocol and sets the default setting for newly created NICs.
-func (s *Stack) SetForwardingDefaultAndAllNICs(protocolNum tcpip.NetworkProtocolNumber, enable bool) tcpip.Error {
- protocol, ok := s.networkProtocols[protocolNum]
+// SetNICForwarding enables or disables packet forwarding on the specified NIC
+// for the passed protocol.
+func (s *Stack) SetNICForwarding(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[id]
if !ok {
- return &tcpip.ErrUnknownProtocol{}
+ return &tcpip.ErrUnknownNICID{}
}
- forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol)
+ return nic.setForwarding(protocol, enable)
+}
+
+// NICForwarding returns the forwarding configuration for the specified NIC.
+func (s *Stack) NICForwarding(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (bool, tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[id]
if !ok {
- return &tcpip.ErrNotSupported{}
+ return false, &tcpip.ErrUnknownNICID{}
}
- forwardingProtocol.SetForwarding(enable)
- return nil
+ return nic.forwarding(protocol)
}
-// Forwarding returns true if packet forwarding between NICs is enabled for the
-// passed protocol.
-func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool {
- protocol, ok := s.networkProtocols[protocolNum]
- if !ok {
- return false
+// SetForwardingDefaultAndAllNICs sets packet forwarding for all NICs for the
+// passed protocol and sets the default setting for newly created NICs.
+func (s *Stack) SetForwardingDefaultAndAllNICs(protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ doneOnce := false
+ for id, nic := range s.nics {
+ if err := nic.setForwarding(protocol, enable); err != nil {
+ // Expect forwarding to be settable on all interfaces if it was set on
+ // one.
+ if doneOnce {
+ panic(fmt.Sprintf("nic(id=%d).setForwarding(%d, %t): %s", id, protocol, enable, err))
+ }
+
+ return err
+ }
+
+ doneOnce = true
}
- forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol)
- if !ok {
- return false
+ if enable {
+ s.defaultForwardingEnabled[protocol] = struct{}{}
+ } else {
+ delete(s.defaultForwardingEnabled, protocol)
}
- return forwardingProtocol.Forwarding()
+ return nil
}
// PortRange returns the UDP and TCP inclusive range of ephemeral ports used in
@@ -659,6 +692,11 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp
}
n := newNIC(s, id, opts.Name, ep, opts.Context)
+ for proto := range s.defaultForwardingEnabled {
+ if err := n.setForwarding(proto, true); err != nil {
+ panic(fmt.Sprintf("newNIC(%d, ...).setForwarding(%d, true): %s", id, proto, err))
+ }
+ }
s.nics[id] = n
if !opts.Disabled {
return n.enable()
@@ -773,7 +811,7 @@ type NICInfo struct {
// MTU is the maximum transmission unit.
MTU uint32
- Stats NICStats
+ Stats tcpip.NICStats
// NetworkStats holds the stats of each NetworkEndpoint bound to the NIC.
NetworkStats map[tcpip.NetworkProtocolNumber]NetworkEndpointStats
@@ -786,6 +824,10 @@ type NICInfo struct {
// value sent in haType field of an ARP Request sent by this NIC and the
// value expected in the haType field of an ARP response.
ARPHardwareType header.ARPHardwareType
+
+ // Forwarding holds the forwarding status for each network endpoint that
+ // supports forwarding.
+ Forwarding map[tcpip.NetworkProtocolNumber]bool
}
// HasNIC returns true if the NICID is defined in the stack.
@@ -815,17 +857,33 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
netStats[proto] = netEP.Stats()
}
- nics[id] = NICInfo{
+ info := NICInfo{
Name: nic.name,
LinkAddress: nic.LinkEndpoint.LinkAddress(),
ProtocolAddresses: nic.primaryAddresses(),
Flags: flags,
MTU: nic.LinkEndpoint.MTU(),
- Stats: nic.stats,
+ Stats: nic.stats.local,
NetworkStats: netStats,
Context: nic.context,
ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(),
+ Forwarding: make(map[tcpip.NetworkProtocolNumber]bool),
}
+
+ for proto := range s.networkProtocols {
+ switch forwarding, err := nic.forwarding(proto); err.(type) {
+ case nil:
+ info.Forwarding[proto] = forwarding
+ case *tcpip.ErrUnknownProtocol:
+ panic(fmt.Sprintf("expected network protocol %d to be available on NIC %d", proto, nic.ID()))
+ case *tcpip.ErrNotSupported:
+ // Not all network protocols support forwarding.
+ default:
+ panic(fmt.Sprintf("nic(id=%d).forwarding(%d): %s", nic.ID(), proto, err))
+ }
+ }
+
+ nics[id] = info
}
return nics
}
@@ -1029,6 +1087,20 @@ func (s *Stack) HandleLocal() bool {
return s.handleLocal
}
+func isNICForwarding(nic *nic, proto tcpip.NetworkProtocolNumber) bool {
+ switch forwarding, err := nic.forwarding(proto); err.(type) {
+ case nil:
+ return forwarding
+ case *tcpip.ErrUnknownProtocol:
+ panic(fmt.Sprintf("expected network protocol %d to be available on NIC %d", proto, nic.ID()))
+ case *tcpip.ErrNotSupported:
+ // Not all network protocols support forwarding.
+ return false
+ default:
+ panic(fmt.Sprintf("nic(id=%d).forwarding(%d): %s", nic.ID(), proto, err))
+ }
+}
+
// FindRoute creates a route to the given destination address, leaving through
// the given NIC and local address (if provided).
//
@@ -1081,7 +1153,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
return nil, &tcpip.ErrNetworkUnreachable{}
}
- canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal
+ onlyGlobalAddresses := !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal
// Find a route to the remote with the route table.
var chosenRoute tcpip.Route
@@ -1120,7 +1192,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
// requirement to do this from any RFC but simply a choice made to better
// follow a strong host model which the netstack follows at the time of
// writing.
- if canForward && chosenRoute == (tcpip.Route{}) {
+ if onlyGlobalAddresses && chosenRoute == (tcpip.Route{}) && isNICForwarding(nic, netProto) {
chosenRoute = route
}
}
@@ -1754,7 +1826,7 @@ func (s *Stack) Seed() uint32 {
// Rand returns a reference to a pseudo random generator that can be used
// to generate random numbers as required.
-func (s *Stack) Rand() *mathrand.Rand {
+func (s *Stack) Rand() *rand.Rand {
return s.randomGenerator
}
@@ -1764,27 +1836,6 @@ func (s *Stack) SecureRNG() io.Reader {
return s.secureRNG
}
-func generateRandUint32() uint32 {
- b := make([]byte, 4)
- if _, err := rand.Read(b); err != nil {
- panic(err)
- }
- return binary.LittleEndian.Uint32(b)
-}
-
-func generateRandInt64() int64 {
- b := make([]byte, 8)
- if _, err := rand.Read(b); err != nil {
- panic(err)
- }
- buf := bytes.NewReader(b)
- var v int64
- if err := binary.Read(buf, binary.LittleEndian, &v); err != nil {
- panic(err)
- }
- return v
-}
-
// FindNICNameFromID returns the name of the NIC for the given NICID.
func (s *Stack) FindNICNameFromID(id tcpip.NICID) string {
s.mu.RLock()
@@ -1821,9 +1872,8 @@ const (
// ParsePacketBufferTransport parses the provided packet buffer's transport
// header.
func (s *Stack) ParsePacketBufferTransport(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) ParseResult {
- // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
- // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a
- // full explanation.
+ // ICMP packets don't have their TransportHeader fields set yet, parse it
+ // here. See icmp/protocol.go:protocol.Parse for a full explanation.
if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber {
return ParsedOK
}
diff --git a/pkg/tcpip/stack/stack_global_state.go b/pkg/tcpip/stack/stack_global_state.go
index 33824afd0..dfec4258a 100644
--- a/pkg/tcpip/stack/stack_global_state.go
+++ b/pkg/tcpip/stack/stack_global_state.go
@@ -14,78 +14,6 @@
package stack
-import "time"
-
// StackFromEnv is the global stack created in restore run.
// FIXME(b/36201077)
var StackFromEnv *Stack
-
-// saveT is invoked by stateify.
-func (t *TCPCubicState) saveT() unixTime {
- return unixTime{t.T.Unix(), t.T.UnixNano()}
-}
-
-// loadT is invoked by stateify.
-func (t *TCPCubicState) loadT(unix unixTime) {
- t.T = time.Unix(unix.second, unix.nano)
-}
-
-// saveXmitTime is invoked by stateify.
-func (t *TCPRACKState) saveXmitTime() unixTime {
- return unixTime{t.XmitTime.Unix(), t.XmitTime.UnixNano()}
-}
-
-// loadXmitTime is invoked by stateify.
-func (t *TCPRACKState) loadXmitTime(unix unixTime) {
- t.XmitTime = time.Unix(unix.second, unix.nano)
-}
-
-// saveLastSendTime is invoked by stateify.
-func (t *TCPSenderState) saveLastSendTime() unixTime {
- return unixTime{t.LastSendTime.Unix(), t.LastSendTime.UnixNano()}
-}
-
-// loadLastSendTime is invoked by stateify.
-func (t *TCPSenderState) loadLastSendTime(unix unixTime) {
- t.LastSendTime = time.Unix(unix.second, unix.nano)
-}
-
-// saveRTTMeasureTime is invoked by stateify.
-func (t *TCPSenderState) saveRTTMeasureTime() unixTime {
- return unixTime{t.RTTMeasureTime.Unix(), t.RTTMeasureTime.UnixNano()}
-}
-
-// loadRTTMeasureTime is invoked by stateify.
-func (t *TCPSenderState) loadRTTMeasureTime(unix unixTime) {
- t.RTTMeasureTime = time.Unix(unix.second, unix.nano)
-}
-
-// saveMeasureTime is invoked by stateify.
-func (r *RcvBufAutoTuneParams) saveMeasureTime() unixTime {
- return unixTime{r.MeasureTime.Unix(), r.MeasureTime.UnixNano()}
-}
-
-// loadMeasureTime is invoked by stateify.
-func (r *RcvBufAutoTuneParams) loadMeasureTime(unix unixTime) {
- r.MeasureTime = time.Unix(unix.second, unix.nano)
-}
-
-// saveRTTMeasureTime is invoked by stateify.
-func (r *RcvBufAutoTuneParams) saveRTTMeasureTime() unixTime {
- return unixTime{r.RTTMeasureTime.Unix(), r.RTTMeasureTime.UnixNano()}
-}
-
-// loadRTTMeasureTime is invoked by stateify.
-func (r *RcvBufAutoTuneParams) loadRTTMeasureTime(unix unixTime) {
- r.RTTMeasureTime = time.Unix(unix.second, unix.nano)
-}
-
-// saveSegTime is invoked by stateify.
-func (t *TCPEndpointState) saveSegTime() unixTime {
- return unixTime{t.SegTime.Unix(), t.SegTime.UnixNano()}
-}
-
-// loadSegTime is invoked by stateify.
-func (t *TCPEndpointState) loadSegTime(unix unixTime) {
- t.SegTime = time.Unix(unix.second, unix.nano)
-}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index ff88b1bd3..21951d05a 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -84,7 +84,8 @@ type fakeNetworkEndpoint struct {
mu struct {
sync.RWMutex
- enabled bool
+ enabled bool
+ forwarding bool
}
nic stack.NetworkInterface
@@ -165,10 +166,6 @@ func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
return f.nic.MaxHeaderLength() + fakeNetHeaderLen
}
-func (*fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
- return 0
-}
-
func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return f.proto.Number()
}
@@ -196,11 +193,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, params stack.NetworkHe
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) {
+func (*fakeNetworkEndpoint) WritePackets(*stack.Route, stack.PacketBufferList, stack.NetworkHeaderParams) (int, tcpip.Error) {
panic("not implemented")
}
-func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error {
+func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) tcpip.Error {
return &tcpip.ErrNotSupported{}
}
@@ -227,11 +224,6 @@ type fakeNetworkProtocol struct {
packetCount [10]int
sendPacketCount [10]int
defaultTTL uint8
-
- mu struct {
- sync.RWMutex
- forwarding bool
- }
}
func (*fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
@@ -300,15 +292,15 @@ func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProto
return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true
}
-// Forwarding implements stack.ForwardingNetworkProtocol.
-func (f *fakeNetworkProtocol) Forwarding() bool {
+// Forwarding implements stack.ForwardingNetworkEndpoint.
+func (f *fakeNetworkEndpoint) Forwarding() bool {
f.mu.RLock()
defer f.mu.RUnlock()
return f.mu.forwarding
}
-// SetForwarding implements stack.ForwardingNetworkProtocol.
-func (f *fakeNetworkProtocol) SetForwarding(v bool) {
+// SetForwarding implements stack.ForwardingNetworkEndpoint.
+func (f *fakeNetworkEndpoint) SetForwarding(v bool) {
f.mu.Lock()
defer f.mu.Unlock()
f.mu.forwarding = v
@@ -467,14 +459,14 @@ func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer
}
}
-func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) {
+func testFailingSend(t *testing.T, r *stack.Route, payload buffer.View, wantErr tcpip.Error) {
t.Helper()
if gotErr := send(r, payload); gotErr != wantErr {
t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr)
}
}
-func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) {
+func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View, wantErr tcpip.Error) {
t.Helper()
if gotErr := sendTo(s, addr, payload); gotErr != wantErr {
t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr)
@@ -924,15 +916,15 @@ func TestRouteWithDownNIC(t *testing.T) {
if err := test.downFn(s, nicID1); err != nil {
t.Fatalf("test.downFn(_, %d): %s", nicID1, err)
}
- testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{})
+ testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{})
testSend(t, r2, ep2, buf)
// Writes with Routes that use NIC2 after being brought down should fail.
if err := test.downFn(s, nicID2); err != nil {
t.Fatalf("test.downFn(_, %d): %s", nicID2, err)
}
- testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{})
- testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{})
+ testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{})
+ testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{})
if upFn := test.upFn; upFn != nil {
// Writes with Routes that use NIC1 after being brought up should
@@ -945,7 +937,7 @@ func TestRouteWithDownNIC(t *testing.T) {
t.Fatalf("test.upFn(_, %d): %s", nicID1, err)
}
testSend(t, r1, ep1, buf)
- testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{})
+ testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{})
}
})
}
@@ -1070,7 +1062,7 @@ func TestAddressRemoval(t *testing.T) {
t.Fatal("RemoveAddress failed:", err)
}
testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
- testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{})
+ testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{})
// Check that removing the same address fails.
err := s.RemoveAddress(1, localAddr)
@@ -1122,8 +1114,8 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) {
t.Fatal("RemoveAddress failed:", err)
}
testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
- testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{})
- testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{})
+ testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{})
+ testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{})
// Check that removing the same address fails.
{
@@ -1144,7 +1136,7 @@ func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.A
// No address given, verify that there is no address assigned to the NIC.
for _, a := range info.ProtocolAddresses {
if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) {
- t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{}))
+ t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, tcpip.AddressWithPrefix{})
}
}
return
@@ -1224,7 +1216,7 @@ func TestEndpointExpiration(t *testing.T) {
// FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
// testSendTo(t, s, remoteAddr, ep, nil)
} else {
- testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{})
+ testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{})
}
// 2. Add Address, everything should work.
@@ -1252,7 +1244,7 @@ func TestEndpointExpiration(t *testing.T) {
// FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
// testSendTo(t, s, remoteAddr, ep, nil)
} else {
- testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{})
+ testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{})
}
// 4. Add Address back, everything should work again.
@@ -1291,8 +1283,8 @@ func TestEndpointExpiration(t *testing.T) {
testSend(t, r, ep, nil)
testSendTo(t, s, remoteAddr, ep, nil)
} else {
- testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{})
- testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{})
+ testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{})
+ testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{})
}
// 7. Add Address back, everything should work again.
@@ -1328,7 +1320,7 @@ func TestEndpointExpiration(t *testing.T) {
// FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
// testSendTo(t, s, remoteAddr, ep, nil)
} else {
- testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{})
+ testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{})
}
})
}
@@ -1578,7 +1570,7 @@ func TestSpoofingNoAddress(t *testing.T) {
t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
}
// Sending a packet fails.
- testFailingSendTo(t, s, dstAddr, ep, nil, &tcpip.ErrNoRoute{})
+ testFailingSendTo(t, s, dstAddr, nil, &tcpip.ErrNoRoute{})
// With address spoofing enabled, FindRoute permits any address to be used
// as the source.
@@ -1619,7 +1611,7 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
}
}
- protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}}
+ protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{Address: header.IPv4Any}}
if err := s.AddProtocolAddress(1, protoAddr); err != nil {
t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err)
}
@@ -1645,12 +1637,12 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
}
func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
- defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0}
+ defaultAddr := tcpip.AddressWithPrefix{Address: header.IPv4Any}
// Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1.
- nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24}
+ nic1Addr := tcpip.AddressWithPrefix{Address: "\xc0\xa8\x01\x3a", PrefixLen: 24}
nic1Gateway := testutil.MustParse4("192.168.1.1")
// Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1.
- nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24}
+ nic2Addr := tcpip.AddressWithPrefix{Address: "\x0a\x0a\x0a\x05", PrefixLen: 24}
nic2Gateway := testutil.MustParse4("10.10.10.1")
// Create a new stack with two NICs.
@@ -1664,12 +1656,12 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err := s.CreateNIC(2, ep); err != nil {
t.Fatalf("CreateNIC failed: %s", err)
}
- nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr}
+ nic1ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic1Addr}
if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil {
t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err)
}
- nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr}
+ nic2ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic2Addr}
if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil {
t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err)
}
@@ -1713,7 +1705,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
// 2. Case: Having an explicit route for broadcast will select that one.
rt = append(
[]tcpip.Route{
- {Destination: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}.Subnet(), NIC: 1},
+ {Destination: tcpip.AddressWithPrefix{Address: header.IPv4Broadcast, PrefixLen: 8 * header.IPv4AddressSize}.Subnet(), NIC: 1},
},
rt...,
)
@@ -2053,7 +2045,7 @@ func TestAddAddress(t *testing.T) {
}
expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen},
+ AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen},
})
}
@@ -2117,7 +2109,7 @@ func TestAddAddressWithOptions(t *testing.T) {
}
expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen},
+ AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen},
})
}
}
@@ -2238,7 +2230,7 @@ func TestCreateNICWithOptions(t *testing.T) {
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
s := stack.New(stack.Options{})
- ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"))
+ ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00")
for _, call := range test.calls {
if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want {
t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want)
@@ -2252,46 +2244,87 @@ func TestNICStats(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
- ep1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep1); err != nil {
- t.Fatal("CreateNIC failed: ", err)
- }
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
+
+ nics := []struct {
+ addr tcpip.Address
+ txByteCount int
+ rxByteCount int
+ }{
+ {
+ addr: "\x01",
+ txByteCount: 30,
+ rxByteCount: 10,
+ },
+ {
+ addr: "\x02",
+ txByteCount: 50,
+ rxByteCount: 20,
+ },
}
- // Route all packets for address \x01 to NIC 1.
- {
- subnet, err := tcpip.NewSubnet("\x01", "\xff")
- if err != nil {
- t.Fatal(err)
+
+ var txBytesTotal, rxBytesTotal, txPacketsTotal, rxPacketsTotal int
+ for i, nic := range nics {
+ nicid := tcpip.NICID(i)
+ ep := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicid, ep); err != nil {
+ t.Fatal("CreateNIC failed: ", err)
+ }
+ if err := s.AddAddress(nicid, fakeNetNumber, nic.addr); err != nil {
+ t.Fatal("AddAddress failed:", err)
}
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
- // Send a packet to address 1.
- buf := buffer.NewView(30)
- ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want {
- t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want)
- }
+ {
+ subnet, err := tcpip.NewSubnet(nic.addr, "\xff")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicid}})
+ }
- if got, want := s.NICInfo()[1].Stats.Rx.Bytes.Value(), uint64(len(buf)); got != want {
- t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want)
+ nicStats := s.NICInfo()[nicid].Stats
+
+ // Inbound packet.
+ rxBuffer := buffer.NewView(nic.rxByteCount)
+ ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: rxBuffer.ToVectorisedView(),
+ }))
+ if got, want := nicStats.Rx.Packets.Value(), uint64(1); got != want {
+ t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want)
+ }
+ if got, want := nicStats.Rx.Bytes.Value(), uint64(nic.rxByteCount); got != want {
+ t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want)
+ }
+ rxPacketsTotal++
+ rxBytesTotal += nic.rxByteCount
+
+ // Outbound packet.
+ txBuffer := buffer.NewView(nic.txByteCount)
+ actualTxLength := nic.txByteCount + fakeNetHeaderLen
+ if err := sendTo(s, nic.addr, txBuffer); err != nil {
+ t.Fatal("sendTo failed: ", err)
+ }
+ want := ep.Drain()
+ if got := nicStats.Tx.Packets.Value(); got != uint64(want) {
+ t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want)
+ }
+ if got, want := nicStats.Tx.Bytes.Value(), uint64(actualTxLength); got != want {
+ t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
+ }
+ txPacketsTotal += want
+ txBytesTotal += actualTxLength
}
- payload := buffer.NewView(10)
- // Write a packet out via the address for NIC 1
- if err := sendTo(s, "\x01", payload); err != nil {
- t.Fatal("sendTo failed: ", err)
+ // Now verify that each NIC stats was correctly aggregated at the stack level.
+ if got, want := s.Stats().NICs.Rx.Packets.Value(), uint64(rxPacketsTotal); got != want {
+ t.Errorf("got s.Stats().NIC.Rx.Packets.Value() = %d, want = %d", got, want)
}
- want := uint64(ep1.Drain())
- if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want {
- t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want)
+ if got, want := s.Stats().NICs.Rx.Bytes.Value(), uint64(rxBytesTotal); got != want {
+ t.Errorf("got s.Stats().Rx.Bytes.Value() = %d, want = %d", got, want)
}
-
- if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)+fakeNetHeaderLen); got != want {
+ if got, want := s.Stats().NICs.Tx.Packets.Value(), uint64(txPacketsTotal); got != want {
+ t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want)
+ }
+ if got, want := s.Stats().NICs.Tx.Bytes.Value(), uint64(txBytesTotal); got != want {
t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
}
}
@@ -2320,7 +2353,7 @@ func TestNICContextPreservation(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{})
id := tcpip.NICID(1)
- ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"))
+ ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00")
if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil {
t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err)
}
@@ -2607,15 +2640,17 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
const nicID = 1
ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent),
+ dadC: make(chan ndpDADEvent, 1),
}
dadConfigs := stack.DefaultDADConfigurations()
+ clock := faketime.NewManualClock()
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
AutoGenLinkLocal: true,
NDPDisp: &ndpDisp,
DADConfigs: dadConfigs,
})},
+ Clock: clock,
}
e := channel.New(int(dadConfigs.DupAddrDetectTransmits), 1280, linkAddr1)
@@ -2633,17 +2668,18 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
linkLocalAddr := header.LinkLocalAddr(linkAddr1)
// Wait for DAD to resolve.
+ clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer)
select {
- case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second):
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, linkLocalAddr, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ default:
// We should get a resolution event after 1s (default time to
// resolve as per default NDP configurations). Waiting for that
// resolution time + an extra 1s without a resolution event
// means something is wrong.
t.Fatal("timed out waiting for DAD resolution")
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, linkLocalAddr, &stack.DADSucceeded{}); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
- }
}
if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); err != nil {
t.Fatal(err)
@@ -3274,8 +3310,9 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
const nicID = 1
ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent),
+ dadC: make(chan ndpDADEvent, 1),
}
+ clock := faketime.NewManualClock()
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
DADConfigs: stack.DADConfigurations{
@@ -3284,6 +3321,7 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
},
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
}
e := channel.New(dadTransmits, 1280, linkAddr1)
@@ -3328,13 +3366,14 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
}
// Wait for DAD to resolve.
+ clock.Advance(dadTransmits * retransmitTimer)
select {
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, &stack.DADSucceeded{}); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
+ default:
+ t.Fatal("timed out waiting for DAD resolution")
}
if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
@@ -3841,8 +3880,6 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) {
// TestAddRoute tests Stack.AddRoute
func TestAddRoute(t *testing.T) {
- const nicID = 1
-
s := stack.New(stack.Options{})
subnet1, err := tcpip.NewSubnet("\x00", "\x00")
@@ -3879,8 +3916,6 @@ func TestAddRoute(t *testing.T) {
// TestRemoveRoutes tests Stack.RemoveRoutes
func TestRemoveRoutes(t *testing.T) {
- const nicID = 1
-
s := stack.New(stack.Options{})
addressToRemove := tcpip.Address("\x01")
@@ -4227,7 +4262,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}})
r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
- if r != nil {
+ if err == nil {
defer r.Release()
}
if diff := cmp.Diff(test.findRouteErr, err); diff != "" {
@@ -4398,7 +4433,7 @@ func TestClearNeighborCacheOnNICDisable(t *testing.T) {
if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil {
t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err)
} else if diff := cmp.Diff(
- []stack.NeighborEntry{{Addr: addr.addr, LinkAddr: linkAddr, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}},
+ []stack.NeighborEntry{{Addr: addr.addr, LinkAddr: linkAddr, State: stack.Static, UpdatedAt: clock.Now()}},
neighbors,
); diff != "" {
t.Fatalf("proto=%d neighbors mismatch (-want +got):\n%s", addr.proto, diff)
diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go
index ddff6e2d6..e90c1a770 100644
--- a/pkg/tcpip/stack/tcp.go
+++ b/pkg/tcpip/stack/tcp.go
@@ -39,7 +39,7 @@ type TCPCubicState struct {
WMax float64
// T is the time when the current congestion avoidance was entered.
- T time.Time `state:".(unixTime)"`
+ T tcpip.MonotonicTime
// TimeSinceLastCongestion denotes the time since the current
// congestion avoidance was entered.
@@ -78,7 +78,7 @@ type TCPCubicState struct {
type TCPRACKState struct {
// XmitTime is the transmission timestamp of the most recent
// acknowledged segment.
- XmitTime time.Time `state:".(unixTime)"`
+ XmitTime tcpip.MonotonicTime
// EndSequence is the ending TCP sequence number of the most recent
// acknowledged segment.
@@ -216,7 +216,7 @@ type TCPRTTState struct {
// +stateify savable
type TCPSenderState struct {
// LastSendTime is the timestamp at which we sent the last segment.
- LastSendTime time.Time `state:".(unixTime)"`
+ LastSendTime tcpip.MonotonicTime
// DupAckCount is the number of Duplicate ACKs received. It is used for
// fast retransmit.
@@ -256,7 +256,7 @@ type TCPSenderState struct {
RTTMeasureSeqNum seqnum.Value
// RTTMeasureTime is the time when the RTTMeasureSeqNum was sent.
- RTTMeasureTime time.Time `state:".(unixTime)"`
+ RTTMeasureTime tcpip.MonotonicTime
// Closed indicates that the caller has closed the endpoint for
// sending.
@@ -313,7 +313,7 @@ type TCPSACKInfo struct {
type RcvBufAutoTuneParams struct {
// MeasureTime is the time at which the current measurement was
// started.
- MeasureTime time.Time `state:".(unixTime)"`
+ MeasureTime tcpip.MonotonicTime
// CopiedBytes is the number of bytes copied to user space since this
// measure began.
@@ -341,7 +341,7 @@ type RcvBufAutoTuneParams struct {
// RTTMeasureTime is the absolute time at which the current RTT
// measurement period began.
- RTTMeasureTime time.Time `state:".(unixTime)"`
+ RTTMeasureTime tcpip.MonotonicTime
// Disabled is true if an explicit receive buffer is set for the
// endpoint.
@@ -429,7 +429,7 @@ type TCPEndpointState struct {
ID TCPEndpointID
// SegTime denotes the absolute time when this segment was received.
- SegTime time.Time `state:".(unixTime)"`
+ SegTime tcpip.MonotonicTime
// RcvBufState contains information about the state of the endpoint's
// receive socket buffer.
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 80ad1a9d4..8a8454a6a 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -16,8 +16,6 @@ package stack
import (
"fmt"
- "math/rand"
-
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
@@ -223,7 +221,7 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t
return multiPortEp.singleRegisterEndpoint(t, flags)
}
-func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
+func (epsByNIC *endpointsByNIC) checkEndpoint(flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
epsByNIC.mu.RLock()
defer epsByNIC.mu.RUnlock()
@@ -475,7 +473,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
if !ok {
epsByNIC = &endpointsByNIC{
endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
- seed: rand.Uint32(),
+ seed: d.stack.Seed(),
}
eps.endpoints[id] = epsByNIC
}
@@ -502,7 +500,7 @@ func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNum
return nil
}
- return epsByNIC.checkEndpoint(d, netProto, protocol, flags, bindToDevice)
+ return epsByNIC.checkEndpoint(flags, bindToDevice)
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 4848495c9..0972c94de 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -18,6 +18,7 @@ import (
"io/ioutil"
"math"
"math/rand"
+ "strconv"
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -84,7 +85,8 @@ func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICI
}
type headers struct {
- srcPort, dstPort uint16
+ srcPort uint16
+ dstPort uint16
}
func newPayload() []byte {
@@ -208,7 +210,7 @@ func TestBindToDeviceDistribution(t *testing.T) {
reuse bool
bindToDevice tcpip.NICID
}
- for _, test := range []struct {
+ tcs := []struct {
name string
// endpoints will received the inject packets.
endpoints []endpointSockopts
@@ -217,29 +219,29 @@ func TestBindToDeviceDistribution(t *testing.T) {
wantDistributions map[tcpip.NICID][]float64
}{
{
- "BindPortReuse",
+ name: "BindPortReuse",
// 5 endpoints that all have reuse set.
- []endpointSockopts{
+ endpoints: []endpointSockopts{
{reuse: true, bindToDevice: 0},
{reuse: true, bindToDevice: 0},
{reuse: true, bindToDevice: 0},
{reuse: true, bindToDevice: 0},
{reuse: true, bindToDevice: 0},
},
- map[tcpip.NICID][]float64{
+ wantDistributions: map[tcpip.NICID][]float64{
// Injected packets on dev0 get distributed evenly.
1: {0.2, 0.2, 0.2, 0.2, 0.2},
},
},
{
- "BindToDevice",
+ name: "BindToDevice",
// 3 endpoints with various bindings.
- []endpointSockopts{
+ endpoints: []endpointSockopts{
{reuse: false, bindToDevice: 1},
{reuse: false, bindToDevice: 2},
{reuse: false, bindToDevice: 3},
},
- map[tcpip.NICID][]float64{
+ wantDistributions: map[tcpip.NICID][]float64{
// Injected packets on dev0 go only to the endpoint bound to dev0.
1: {1, 0, 0},
// Injected packets on dev1 go only to the endpoint bound to dev1.
@@ -249,9 +251,9 @@ func TestBindToDeviceDistribution(t *testing.T) {
},
},
{
- "ReuseAndBindToDevice",
+ name: "ReuseAndBindToDevice",
// 6 endpoints with various bindings.
- []endpointSockopts{
+ endpoints: []endpointSockopts{
{reuse: true, bindToDevice: 1},
{reuse: true, bindToDevice: 1},
{reuse: true, bindToDevice: 2},
@@ -259,7 +261,7 @@ func TestBindToDeviceDistribution(t *testing.T) {
{reuse: true, bindToDevice: 2},
{reuse: true, bindToDevice: 0},
},
- map[tcpip.NICID][]float64{
+ wantDistributions: map[tcpip.NICID][]float64{
// Injected packets on dev0 get distributed among endpoints bound to
// dev0.
1: {0.5, 0.5, 0, 0, 0, 0},
@@ -270,35 +272,42 @@ func TestBindToDeviceDistribution(t *testing.T) {
1000: {0, 0, 0, 0, 0, 1},
},
},
- } {
- for protoName, netProtoNum := range map[string]tcpip.NetworkProtocolNumber{
- "IPv4": ipv4.ProtocolNumber,
- "IPv6": ipv6.ProtocolNumber,
- } {
+ }
+ protos := map[string]tcpip.NetworkProtocolNumber{
+ "IPv4": ipv4.ProtocolNumber,
+ "IPv6": ipv6.ProtocolNumber,
+ }
+
+ for _, test := range tcs {
+ for protoName, protoNum := range protos {
for device, wantDistribution := range test.wantDistributions {
- t.Run(test.name+protoName+string(device), func(t *testing.T) {
+ t.Run(test.name+protoName+"-"+strconv.Itoa(int(device)), func(t *testing.T) {
+ // Create the NICs.
var devices []tcpip.NICID
for d := range test.wantDistributions {
devices = append(devices, d)
}
c := newDualTestContextMultiNIC(t, defaultMTU, devices)
+ // Create endpoints and bind each to a NIC, sometimes reusing ports.
eps := make(map[tcpip.Endpoint]int)
-
pollChannel := make(chan tcpip.Endpoint)
for i, endpoint := range test.endpoints {
// Try to receive the data.
wq := waiter.Queue{}
we, ch := waiter.NewChannelEntry(nil)
wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
- defer close(ch)
+ t.Cleanup(func() {
+ wq.EventUnregister(&we)
+ close(ch)
+ })
var err tcpip.Error
- ep, err := c.s.NewEndpoint(udp.ProtocolNumber, netProtoNum, &wq)
+ ep, err := c.s.NewEndpoint(udp.ProtocolNumber, protoNum, &wq)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
}
+ t.Cleanup(ep.Close)
eps[ep] = i
go func(ep tcpip.Endpoint) {
@@ -307,32 +316,34 @@ func TestBindToDeviceDistribution(t *testing.T) {
}
}(ep)
- defer ep.Close()
ep.SocketOptions().SetReusePort(endpoint.reuse)
if err := ep.SocketOptions().SetBindToDevice(int32(endpoint.bindToDevice)); err != nil {
t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", endpoint.bindToDevice, endpoint.bindToDevice, i, err)
}
var dstAddr tcpip.Address
- switch netProtoNum {
+ switch protoNum {
case ipv4.ProtocolNumber:
dstAddr = testDstAddrV4
case ipv6.ProtocolNumber:
dstAddr = testDstAddrV6
default:
- t.Fatalf("unexpected protocol number: %d", netProtoNum)
+ t.Fatalf("unexpected protocol number: %d", protoNum)
}
if err := ep.Bind(tcpip.FullAddress{Addr: dstAddr, Port: testDstPort}); err != nil {
t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err)
}
}
- npackets := 100000
- nports := 10000
+ // Send packets across a range of ports, checking that packets from
+ // the same source port are always demultiplexed to the same
+ // destination endpoint.
+ npackets := 10_000
+ nports := 1_000
if got, want := len(test.endpoints), len(wantDistribution); got != want {
t.Fatalf("got len(test.endpoints) = %d, want %d", got, want)
}
- ports := make(map[uint16]tcpip.Endpoint)
+ endpoints := make(map[uint16]tcpip.Endpoint)
stats := make(map[tcpip.Endpoint]int)
for i := 0; i < npackets; i++ {
// Send a packet.
@@ -342,13 +353,13 @@ func TestBindToDeviceDistribution(t *testing.T) {
srcPort: testSrcPort + port,
dstPort: testDstPort,
}
- switch netProtoNum {
+ switch protoNum {
case ipv4.ProtocolNumber:
c.sendV4Packet(payload, hdrs, device)
case ipv6.ProtocolNumber:
c.sendV6Packet(payload, hdrs, device)
default:
- t.Fatalf("unexpected protocol number: %d", netProtoNum)
+ t.Fatalf("unexpected protocol number: %d", protoNum)
}
ep := <-pollChannel
@@ -357,11 +368,11 @@ func TestBindToDeviceDistribution(t *testing.T) {
}
stats[ep]++
if i < nports {
- ports[uint16(i)] = ep
+ endpoints[uint16(i)] = ep
} else {
// Check that all packets from one client are handled by the same
// socket.
- if want, got := ports[port], ep; want != got {
+ if want, got := endpoints[port], ep; want != got {
t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got])
}
}
diff --git a/pkg/tcpip/stdclock.go b/pkg/tcpip/stdclock.go
index 7ce43a68e..371da2f40 100644
--- a/pkg/tcpip/stdclock.go
+++ b/pkg/tcpip/stdclock.go
@@ -60,11 +60,11 @@ type stdClock struct {
// monotonicOffset is assigned maxMonotonic after restore so that the
// monotonic time will continue from where it "left off" before saving as part
// of S/R.
- monotonicOffset int64 `state:"nosave"`
+ monotonicOffset MonotonicTime `state:"nosave"`
// monotonicMU protects maxMonotonic.
monotonicMU sync.Mutex `state:"nosave"`
- maxMonotonic int64
+ maxMonotonic MonotonicTime
}
// NewStdClock returns an instance of a clock that uses the time package.
@@ -76,25 +76,25 @@ func NewStdClock() Clock {
var _ Clock = (*stdClock)(nil)
-// NowNanoseconds implements Clock.NowNanoseconds.
-func (*stdClock) NowNanoseconds() int64 {
- return time.Now().UnixNano()
+// Now implements Clock.Now.
+func (*stdClock) Now() time.Time {
+ return time.Now()
}
// NowMonotonic implements Clock.NowMonotonic.
-func (s *stdClock) NowMonotonic() int64 {
+func (s *stdClock) NowMonotonic() MonotonicTime {
sinceBase := time.Since(s.baseTime)
if sinceBase < 0 {
panic(fmt.Sprintf("got negative duration = %s since base time = %s", sinceBase, s.baseTime))
}
- monotonicValue := sinceBase.Nanoseconds() + s.monotonicOffset
+ monotonicValue := s.monotonicOffset.Add(sinceBase)
s.monotonicMU.Lock()
defer s.monotonicMU.Unlock()
// Monotonic time values must never decrease.
- if monotonicValue > s.maxMonotonic {
+ if s.maxMonotonic.Before(monotonicValue) {
s.maxMonotonic = monotonicValue
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 797778e08..8f2658f64 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -64,17 +64,47 @@ func (e *ErrSaveRejection) Error() string {
return "save rejected due to unsupported networking state: " + e.Err.Error()
}
+// MonotonicTime is a monotonic clock reading.
+//
+// +stateify savable
+type MonotonicTime struct {
+ nanoseconds int64
+}
+
+// Before reports whether the monotonic clock reading mt is before u.
+func (mt MonotonicTime) Before(u MonotonicTime) bool {
+ return mt.nanoseconds < u.nanoseconds
+}
+
+// After reports whether the monotonic clock reading mt is after u.
+func (mt MonotonicTime) After(u MonotonicTime) bool {
+ return mt.nanoseconds > u.nanoseconds
+}
+
+// Add returns the monotonic clock reading mt+d.
+func (mt MonotonicTime) Add(d time.Duration) MonotonicTime {
+ return MonotonicTime{
+ nanoseconds: time.Unix(0, mt.nanoseconds).Add(d).Sub(time.Unix(0, 0)).Nanoseconds(),
+ }
+}
+
+// Sub returns the duration mt-u. If the result exceeds the maximum (or minimum)
+// value that can be stored in a Duration, the maximum (or minimum) duration
+// will be returned. To compute t-d for a duration d, use t.Add(-d).
+func (mt MonotonicTime) Sub(u MonotonicTime) time.Duration {
+ return time.Unix(0, mt.nanoseconds).Sub(time.Unix(0, u.nanoseconds))
+}
+
// A Clock provides the current time and schedules work for execution.
//
// Times returned by a Clock should always be used for application-visible
// time. Only monotonic times should be used for netstack internal timekeeping.
type Clock interface {
- // NowNanoseconds returns the current real time as a number of
- // nanoseconds since the Unix epoch.
- NowNanoseconds() int64
+ // Now returns the current local time.
+ Now() time.Time
- // NowMonotonic returns a monotonic time value at nanosecond resolution.
- NowMonotonic() int64
+ // NowMonotonic returns the current monotonic clock reading.
+ NowMonotonic() MonotonicTime
// AfterFunc waits for the duration to elapse and then calls f in its own
// goroutine. It returns a Timer that can be used to cancel the call using
@@ -435,11 +465,11 @@ type ControlMessages struct {
// PacketOwner is used to get UID and GID of the packet.
type PacketOwner interface {
- // UID returns UID of the packet.
- UID() uint32
+ // UID returns KUID of the packet.
+ KUID() uint32
- // GID returns GID of the packet.
- GID() uint32
+ // GID returns KGID of the packet.
+ KGID() uint32
}
// ReadOptions contains options for Endpoint.Read.
@@ -861,6 +891,9 @@ type SettableSocketOption interface {
isSettableSocketOption()
}
+// EndpointState represents the state of an endpoint.
+type EndpointState uint8
+
// CongestionControlState indicates the current congestion control state for
// TCP sender.
type CongestionControlState int
@@ -897,6 +930,9 @@ type TCPInfoOption struct {
// RTO is the retransmission timeout for the endpoint.
RTO time.Duration
+ // State is the current endpoint protocol state.
+ State EndpointState
+
// CcState is the congestion control state.
CcState CongestionControlState
@@ -1552,6 +1588,10 @@ type IPForwardingStats struct {
// were too big for the outgoing MTU.
PacketTooBig *StatCounter
+ // HostUnreachable is the number of IP packets received which could not be
+ // successfully forwarded due to an unresolvable next hop.
+ HostUnreachable *StatCounter
+
// ExtensionHeaderProblem is the number of IP packets which were dropped
// because of a problem encountered when processing an IPv6 extension
// header.
@@ -1835,37 +1875,104 @@ type UDPStats struct {
ChecksumErrors *StatCounter
}
+// NICNeighborStats holds metrics for the neighbor table.
+type NICNeighborStats struct {
+ // LINT.IfChange(NICNeighborStats)
+
+ // UnreachableEntryLookups counts the number of lookups performed on an
+ // entry in Unreachable state.
+ UnreachableEntryLookups *StatCounter
+
+ // LINT.ThenChange(stack/nic_stats.go:multiCounterNICNeighborStats)
+}
+
+// NICPacketStats holds basic packet statistics.
+type NICPacketStats struct {
+ // LINT.IfChange(NICPacketStats)
+
+ // Packets is the number of packets counted.
+ Packets *StatCounter
+
+ // Bytes is the number of bytes counted.
+ Bytes *StatCounter
+
+ // LINT.ThenChange(stack/nic_stats.go:multiCounterNICPacketStats)
+}
+
+// NICStats holds NIC statistics.
+type NICStats struct {
+ // LINT.IfChange(NICStats)
+
+ // UnknownL3ProtocolRcvdPackets is the number of packets received that were
+ // for an unknown or unsupported network protocol.
+ UnknownL3ProtocolRcvdPackets *StatCounter
+
+ // UnknownL4ProtocolRcvdPackets is the number of packets received that were
+ // for an unknown or unsupported transport protocol.
+ UnknownL4ProtocolRcvdPackets *StatCounter
+
+ // MalformedL4RcvdPackets is the number of packets received by a NIC that
+ // could not be delivered to a transport endpoint because the L4 header could
+ // not be parsed.
+ MalformedL4RcvdPackets *StatCounter
+
+ // Tx contains statistics about transmitted packets.
+ Tx NICPacketStats
+
+ // Rx contains statistics about received packets.
+ Rx NICPacketStats
+
+ // DisabledRx contains statistics about received packets on disabled NICs.
+ DisabledRx NICPacketStats
+
+ // Neighbor contains statistics about neighbor entries.
+ Neighbor NICNeighborStats
+
+ // LINT.ThenChange(stack/nic_stats.go:multiCounterNICStats)
+}
+
+// FillIn returns a copy of s with nil fields initialized to new StatCounters.
+func (s NICStats) FillIn() NICStats {
+ InitStatCounters(reflect.ValueOf(&s).Elem())
+ return s
+}
+
// Stats holds statistics about the networking stack.
-//
-// All fields are optional.
type Stats struct {
- // UnknownProtocolRcvdPackets is the number of packets received by the
- // stack that were for an unknown or unsupported protocol.
- UnknownProtocolRcvdPackets *StatCounter
+ // TODO(https://gvisor.dev/issues/5986): Make the DroppedPackets stat less
+ // ambiguous.
- // MalformedRcvdPackets is the number of packets received by the stack
- // that were deemed malformed.
- MalformedRcvdPackets *StatCounter
-
- // DroppedPackets is the number of packets dropped due to full queues.
+ // DroppedPackets is the number of packets dropped at the transport layer.
DroppedPackets *StatCounter
- // ICMP breaks out ICMP-specific stats (both v4 and v6).
+ // NICs is an aggregation of every NIC's statistics. These should not be
+ // incremented using this field, but using the relevant NIC multicounters.
+ NICs NICStats
+
+ // ICMP is an aggregation of every NetworkEndpoint's ICMP statistics (both v4
+ // and v6). These should not be incremented using this field, but using the
+ // relevant NetworkEndpoint ICMP multicounters.
ICMP ICMPStats
- // IGMP breaks out IGMP-specific stats.
+ // IGMP is an aggregation of every NetworkEndpoint's IGMP statistics. These
+ // should not be incremented using this field, but using the relevant
+ // NetworkEndpoint IGMP multicounters.
IGMP IGMPStats
- // IP breaks out IP-specific stats (both v4 and v6).
+ // IP is an aggregation of every NetworkEndpoint's IP statistics. These should
+ // not be incremented using this field, but using the relevant NetworkEndpoint
+ // IP multicounters.
IP IPStats
- // ARP breaks out ARP-specific stats.
+ // ARP is an aggregation of every NetworkEndpoint's ARP statistics. These
+ // should not be incremented using this field, but using the relevant
+ // NetworkEndpoint ARP multicounters.
ARP ARPStats
- // TCP breaks out TCP-specific stats.
+ // TCP holds TCP-specific stats.
TCP TCPStats
- // UDP breaks out UDP-specific stats.
+ // UDP holds UDP-specific stats.
UDP UDPStats
}
diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go
index 269081ff8..c96ae2f02 100644
--- a/pkg/tcpip/tcpip_test.go
+++ b/pkg/tcpip/tcpip_test.go
@@ -19,7 +19,6 @@ import (
"fmt"
"io"
"net"
- "strings"
"testing"
"github.com/google/go-cmp/cmp"
@@ -210,26 +209,6 @@ func TestAddressString(t *testing.T) {
}
}
-func TestStatsString(t *testing.T) {
- got := fmt.Sprintf("%+v", Stats{}.FillIn())
-
- matchers := []string{
- // Print root-level stats correctly.
- "UnknownProtocolRcvdPackets:0",
- // Print protocol-specific stats correctly.
- "TCP:{ActiveConnectionOpenings:0",
- }
-
- for _, m := range matchers {
- if !strings.Contains(got, m) {
- t.Errorf("string.Contains(got, %q) = false", m)
- }
- }
- if t.Failed() {
- t.Logf(`got = fmt.Sprintf("%%+v", Stats{}.FillIn()) = %q`, got)
- }
-}
-
func TestAddressWithPrefixSubnet(t *testing.T) {
tests := []struct {
addr Address
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index ab2dab60c..181ef799e 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -53,12 +53,14 @@ go_test(
"//pkg/tcpip/checker",
"//pkg/tcpip/faketime",
"//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
"//pkg/tcpip/link/pipe",
"//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/tests/utils",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index 42bc53328..92fa6257d 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -16,6 +16,7 @@ package forward_test
import (
"bytes"
+ "fmt"
"testing"
"github.com/google/go-cmp/cmp"
@@ -34,6 +35,39 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+const ttl = 64
+
+var (
+ ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10")
+ ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a")
+)
+
+func rxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
+ utils.RxICMPv4EchoRequest(e, src, dst, ttl)
+}
+
+func rxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
+ utils.RxICMPv6EchoRequest(e, src, dst, ttl)
+}
+
+func forwardedICMPv4EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
+ checker.IPv4(t, b,
+ checker.SrcAddr(src),
+ checker.DstAddr(dst),
+ checker.TTL(ttl-1),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4Echo)))
+}
+
+func forwardedICMPv6EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
+ checker.IPv6(t, b,
+ checker.SrcAddr(src),
+ checker.DstAddr(dst),
+ checker.TTL(ttl-1),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoRequest)))
+}
+
func TestForwarding(t *testing.T) {
const listenPort = 8080
@@ -320,45 +354,16 @@ func TestMulticastForwarding(t *testing.T) {
const (
nicID1 = 1
nicID2 = 2
- ttl = 64
)
var (
ipv4LinkLocalUnicastAddr = testutil.MustParse4("169.254.0.10")
ipv4LinkLocalMulticastAddr = testutil.MustParse4("224.0.0.10")
- ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10")
ipv6LinkLocalUnicastAddr = testutil.MustParse6("fe80::a")
ipv6LinkLocalMulticastAddr = testutil.MustParse6("ff02::a")
- ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a")
)
- rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
- utils.RxICMPv4EchoRequest(e, src, dst, ttl)
- }
-
- rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
- utils.RxICMPv6EchoRequest(e, src, dst, ttl)
- }
-
- v4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
- checker.IPv4(t, b,
- checker.SrcAddr(src),
- checker.DstAddr(dst),
- checker.TTL(ttl-1),
- checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4Echo)))
- }
-
- v6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
- checker.IPv6(t, b,
- checker.SrcAddr(src),
- checker.DstAddr(dst),
- checker.TTL(ttl-1),
- checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6EchoRequest)))
- }
-
tests := []struct {
name string
srcAddr, dstAddr tcpip.Address
@@ -394,7 +399,7 @@ func TestMulticastForwarding(t *testing.T) {
rx: rxICMPv4EchoRequest,
expectForward: true,
checker: func(t *testing.T, b []byte) {
- v4Checker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
+ forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
},
},
{
@@ -404,7 +409,7 @@ func TestMulticastForwarding(t *testing.T) {
rx: rxICMPv4EchoRequest,
expectForward: true,
checker: func(t *testing.T, b []byte) {
- v4Checker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr)
+ forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr)
},
},
@@ -436,7 +441,7 @@ func TestMulticastForwarding(t *testing.T) {
rx: rxICMPv6EchoRequest,
expectForward: true,
checker: func(t *testing.T, b []byte) {
- v6Checker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
+ forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
},
},
{
@@ -446,7 +451,7 @@ func TestMulticastForwarding(t *testing.T) {
rx: rxICMPv6EchoRequest,
expectForward: true,
checker: func(t *testing.T, b []byte) {
- v6Checker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr)
+ forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr)
},
},
}
@@ -506,3 +511,180 @@ func TestMulticastForwarding(t *testing.T) {
})
}
}
+
+func TestPerInterfaceForwarding(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+ )
+
+ tests := []struct {
+ name string
+ srcAddr, dstAddr tcpip.Address
+ rx func(*channel.Endpoint, tcpip.Address, tcpip.Address)
+ checker func(*testing.T, []byte)
+ }{
+ {
+ name: "IPv4 unicast",
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
+ rx: rxICMPv4EchoRequest,
+ checker: func(t *testing.T, b []byte) {
+ forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
+ },
+ },
+ {
+ name: "IPv4 multicast",
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: ipv4GlobalMulticastAddr,
+ rx: rxICMPv4EchoRequest,
+ checker: func(t *testing.T, b []byte) {
+ forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr)
+ },
+ },
+
+ {
+ name: "IPv6 unicast",
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
+ rx: rxICMPv6EchoRequest,
+ checker: func(t *testing.T, b []byte) {
+ forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
+ },
+ },
+ {
+ name: "IPv6 multicast",
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: ipv6GlobalMulticastAddr,
+ rx: rxICMPv6EchoRequest,
+ checker: func(t *testing.T, b []byte) {
+ forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr)
+ },
+ },
+ }
+
+ netProtos := [...]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber, ipv6.ProtocolNumber}
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ // ARP is not used in this test but it is a network protocol that does
+ // not support forwarding. We install the protocol to make sure that
+ // forwarding information for a NIC is only reported for network
+ // protocols that support forwarding.
+ arp.NewProtocol,
+
+ ipv4.NewProtocol,
+ ipv6.NewProtocol,
+ },
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+
+ e1 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID1, e1); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err)
+ }
+
+ e2 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID2, e2); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err)
+ }
+
+ for _, add := range [...]struct {
+ nicID tcpip.NICID
+ addr tcpip.ProtocolAddress
+ }{
+ {
+ nicID: nicID1,
+ addr: utils.RouterNIC1IPv4Addr,
+ },
+ {
+ nicID: nicID1,
+ addr: utils.RouterNIC1IPv6Addr,
+ },
+ {
+ nicID: nicID2,
+ addr: utils.RouterNIC2IPv4Addr,
+ },
+ {
+ nicID: nicID2,
+ addr: utils.RouterNIC2IPv6Addr,
+ },
+ } {
+ if err := s.AddProtocolAddress(add.nicID, add.addr); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", add.nicID, add.addr, err)
+ }
+ }
+
+ // Only enable forwarding on NIC1 and make sure that only packets arriving
+ // on NIC1 are forwarded.
+ for _, netProto := range netProtos {
+ if err := s.SetNICForwarding(nicID1, netProto, true); err != nil {
+ t.Fatalf("s.SetNICForwarding(%d, %d, true): %s", nicID1, netProtos, err)
+ }
+ }
+
+ nicsInfo := s.NICInfo()
+ for _, subTest := range [...]struct {
+ nicID tcpip.NICID
+ nicEP *channel.Endpoint
+ otherNICID tcpip.NICID
+ otherNICEP *channel.Endpoint
+ expectForwarding bool
+ }{
+ {
+ nicID: nicID1,
+ nicEP: e1,
+ otherNICID: nicID2,
+ otherNICEP: e2,
+ expectForwarding: true,
+ },
+ {
+ nicID: nicID2,
+ nicEP: e2,
+ otherNICID: nicID2,
+ otherNICEP: e1,
+ expectForwarding: false,
+ },
+ } {
+ t.Run(fmt.Sprintf("Packet arriving at NIC%d", subTest.nicID), func(t *testing.T) {
+ nicInfo, ok := nicsInfo[subTest.nicID]
+ if !ok {
+ t.Errorf("expected NIC info for NIC %d; got = %#v", subTest.nicID, nicsInfo)
+ } else {
+ forwarding := make(map[tcpip.NetworkProtocolNumber]bool)
+ for _, netProto := range netProtos {
+ forwarding[netProto] = subTest.expectForwarding
+ }
+
+ if diff := cmp.Diff(forwarding, nicInfo.Forwarding); diff != "" {
+ t.Errorf("nicsInfo[%d].Forwarding mismatch (-want +got):\n%s", subTest.nicID, diff)
+ }
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: subTest.otherNICID,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: subTest.otherNICID,
+ },
+ })
+
+ test.rx(subTest.nicEP, test.srcAddr, test.dstAddr)
+ if p, ok := subTest.nicEP.Read(); ok {
+ t.Errorf("unexpectedly got a response from the interface the packet arrived on: %#v", p)
+ }
+ if p, ok := subTest.otherNICEP.Read(); ok != subTest.expectForwarding {
+ t.Errorf("got otherNICEP.Read() = (%#v, %t), want = (_, %t)", p, ok, subTest.expectForwarding)
+ } else if subTest.expectForwarding {
+ test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index 07ba2b837..f9ab7d0af 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -166,7 +166,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
// Make sure the packet is not dropped by the next rule.
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
}
},
genPacket: genPacketV6,
@@ -187,7 +187,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
}
},
genPacket: genPacketV4,
@@ -207,7 +207,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
}
},
genPacket: genPacketV6,
@@ -227,7 +227,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
}
},
genPacket: genPacketV4,
@@ -250,7 +250,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
}
},
genPacket: genPacketV6,
@@ -273,7 +273,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
}
},
genPacket: genPacketV4,
@@ -293,7 +293,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
}
},
genPacket: genPacketV6,
@@ -313,7 +313,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
}
},
genPacket: genPacketV4,
@@ -465,7 +465,7 @@ func TestIPTableWritePackets(t *testing.T) {
}
if err := s.IPTables().ReplaceTable(stack.FilterID, table, false /* ipv4 */); err != nil {
- t.Fatalf("RelaceTable(%d, _, false): %s", stack.FilterID, err)
+ t.Fatalf("ReplaceTable(%d, _, false): %s", stack.FilterID, err)
}
},
genPacket: func(r *stack.Route) stack.PacketBufferList {
@@ -556,7 +556,7 @@ func TestIPTableWritePackets(t *testing.T) {
}
if err := s.IPTables().ReplaceTable(stack.FilterID, table, true /* ipv6 */); err != nil {
- t.Fatalf("RelaceTable(%d, _, true): %s", stack.FilterID, err)
+ t.Fatalf("ReplaceTable(%d, _, true): %s", stack.FilterID, err)
}
},
genPacket: func(r *stack.Route) stack.PacketBufferList {
@@ -681,6 +681,32 @@ func forwardedICMPv6EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Addr
checker.ICMPv6Type(header.ICMPv6EchoReply)))
}
+func boolToInt(v bool) uint64 {
+ if v {
+ return 1
+ }
+ return 0
+}
+
+func setupDropFilter(hook stack.Hook, f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) {
+ return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) {
+ t.Helper()
+
+ ipv6 := netProto == ipv6.ProtocolNumber
+
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, ipv6)
+ ruleIdx := filter.BuiltinChains[hook]
+ filter.Rules[ruleIdx].Filter = f
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto}
+ // Make sure the packet is not dropped by the next rule.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err)
+ }
+ }
+}
+
func TestForwardingHook(t *testing.T) {
const (
nicID1 = 1
@@ -740,32 +766,6 @@ func TestForwardingHook(t *testing.T) {
},
}
- setupDropFilter := func(f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) {
- return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) {
- t.Helper()
-
- ipv6 := netProto == ipv6.ProtocolNumber
-
- ipt := s.IPTables()
- filter := ipt.GetTable(stack.FilterID, ipv6)
- ruleIdx := filter.BuiltinChains[stack.Forward]
- filter.Rules[ruleIdx].Filter = f
- filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto}
- // Make sure the packet is not dropped by the next rule.
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto}
- if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err)
- }
- }
- }
-
- boolToInt := func(v bool) uint64 {
- if v {
- return 1
- }
- return 0
- }
-
subTests := []struct {
name string
setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber)
@@ -779,59 +779,59 @@ func TestForwardingHook(t *testing.T) {
{
name: "Drop",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{}),
expectForward: false,
},
{
name: "Drop with input NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name}),
expectForward: false,
},
{
name: "Drop with output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name}),
expectForward: false,
},
{
name: "Drop with input and output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}),
expectForward: false,
},
{
name: "Drop with other input NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName}),
expectForward: true,
},
{
name: "Drop with other output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: otherNICName}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: otherNICName}),
expectForward: true,
},
{
name: "Drop with other input and output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}),
expectForward: true,
},
{
name: "Drop with input and other output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}),
expectForward: true,
},
{
name: "Drop with other input and other output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}),
expectForward: true,
},
{
name: "Drop with inverted input NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}),
expectForward: true,
},
{
name: "Drop with inverted output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}),
expectForward: true,
},
}
@@ -941,3 +941,194 @@ func TestForwardingHook(t *testing.T) {
})
}
}
+
+func TestInputHookWithLocalForwarding(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+
+ nic1Name = "nic1"
+ nic2Name = "nic2"
+
+ otherNICName = "otherNIC"
+ )
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ rx func(*channel.Endpoint)
+ checker func(*testing.T, []byte)
+ }{
+ {
+ name: "IPv4",
+ netProto: ipv4.ProtocolNumber,
+ rx: func(e *channel.Endpoint) {
+ utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address, ttl)
+ },
+ checker: func(t *testing.T, b []byte) {
+ checker.IPv4(t, b,
+ checker.SrcAddr(utils.Ipv4Addr2.AddressWithPrefix.Address),
+ checker.DstAddr(utils.RemoteIPv4Addr),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4EchoReply)))
+ },
+ },
+ {
+ name: "IPv6",
+ netProto: ipv6.ProtocolNumber,
+ rx: func(e *channel.Endpoint) {
+ utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address, ttl)
+ },
+ checker: func(t *testing.T, b []byte) {
+ checker.IPv6(t, b,
+ checker.SrcAddr(utils.Ipv6Addr2.AddressWithPrefix.Address),
+ checker.DstAddr(utils.RemoteIPv6Addr),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoReply)))
+ },
+ },
+ }
+
+ subTests := []struct {
+ name string
+ setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber)
+ expectDrop bool
+ }{
+ {
+ name: "Accept",
+ setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ },
+ expectDrop: false,
+ },
+
+ {
+ name: "Drop",
+ setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{}),
+ expectDrop: true,
+ },
+ {
+ name: "Drop with input NIC filtering on arrival NIC",
+ setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic1Name}),
+ expectDrop: true,
+ },
+ {
+ name: "Drop with input NIC filtering on delivered NIC",
+ setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic2Name}),
+ expectDrop: false,
+ },
+
+ {
+ name: "Drop with input NIC filtering on other NIC",
+ setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: otherNICName}),
+ expectDrop: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ })
+
+ subTest.setupFilter(t, s, test.netProto)
+
+ e1 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil {
+ t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err)
+ }
+ if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv4Addr1, err)
+ }
+ if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv6Addr1, err)
+ }
+
+ e2 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil {
+ t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
+ }
+ if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv4Addr2, err)
+ }
+ if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv6Addr2, err)
+ }
+
+ if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
+ t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err)
+ }
+ if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID1,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID1,
+ },
+ })
+
+ test.rx(e1)
+
+ ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err)
+ }
+ ep1Stats := ep1.Stats()
+ ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats)
+ if !ok {
+ t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats)
+ }
+ ip1Stats := ipEP1Stats.IPStats()
+
+ if got := ip1Stats.PacketsReceived.Value(); got != 1 {
+ t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got)
+ }
+ if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 {
+ t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got)
+ }
+ if got, want := ip1Stats.PacketsSent.Value(), boolToInt(!subTest.expectDrop); got != want {
+ t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = %d", got, want)
+ }
+
+ ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err)
+ }
+ ep2Stats := ep2.Stats()
+ ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats)
+ if !ok {
+ t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats)
+ }
+ ip2Stats := ipEP2Stats.IPStats()
+ if got := ip2Stats.PacketsReceived.Value(); got != 0 {
+ t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got)
+ }
+ if got := ip2Stats.ValidPacketsReceived.Value(); got != 1 {
+ t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = 1", got)
+ }
+ if got, want := ip2Stats.IPTablesInputDropped.Value(), boolToInt(subTest.expectDrop); got != want {
+ t.Errorf("got ip2Stats.IPTablesInputDropped.Value() = %d, want = %d", got, want)
+ }
+ if got := ip2Stats.PacketsSent.Value(); got != 0 {
+ t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = 0", got)
+ }
+
+ if p, ok := e1.Read(); ok == subTest.expectDrop {
+ t.Errorf("got e1.Read() = (%#v, %t), want = (_, %t)", p, ok, !subTest.expectDrop)
+ } else if !subTest.expectDrop {
+ test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
+ }
+ if p, ok := e2.Read(); ok {
+ t.Errorf("got e1.Read() = (%#v, true), want = (_, false)", p)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index c657714ba..27caa0c28 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -17,6 +17,8 @@ package link_resolution_test
import (
"bytes"
"fmt"
+ "net"
+ "runtime"
"testing"
"time"
@@ -27,12 +29,14 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/pipe"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
+ tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -283,9 +287,11 @@ func TestTCPLinkResolutionFailure(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
stackOpts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ Clock: clock,
}
host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID)
@@ -329,7 +335,17 @@ func TestTCPLinkResolutionFailure(t *testing.T) {
// Wait for an error due to link resolution failing, or the endpoint to be
// writable.
+ if test.expectedWriteErr != nil {
+ nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto)
+ if err != nil {
+ t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err)
+ }
+ clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer)
+ } else {
+ clock.RunImmediatelyScheduledJobs()
+ }
<-ch
+
{
var r bytes.Reader
r.Reset([]byte{0})
@@ -395,6 +411,242 @@ func TestTCPLinkResolutionFailure(t *testing.T) {
}
}
+func TestForwardingWithLinkResolutionFailure(t *testing.T) {
+ const (
+ incomingNICID = 1
+ outgoingNICID = 2
+ ttl = 2
+ expectedHostUnreachableErrorCount = 1
+ )
+ outgoingLinkAddr := tcptestutil.MustParseLink("02:03:03:04:05:06")
+
+ rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
+ utils.RxICMPv4EchoRequest(e, src, dst, ttl)
+ }
+
+ rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
+ utils.RxICMPv6EchoRequest(e, src, dst, ttl)
+ }
+
+ arpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) {
+ if request.Proto != arp.ProtocolNumber {
+ t.Errorf("got request.Proto = %d, want = %d", request.Proto, arp.ProtocolNumber)
+ }
+ if request.Route.RemoteLinkAddress != header.EthernetBroadcastAddress {
+ t.Errorf("got request.Route.RemoteLinkAddress = %s, want = %s", request.Route.RemoteLinkAddress, header.EthernetBroadcastAddress)
+ }
+ rep := header.ARP(request.Pkt.NetworkHeader().View())
+ if got := rep.Op(); got != header.ARPRequest {
+ t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest)
+ }
+ if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != outgoingLinkAddr {
+ t.Errorf("got HardwareAddressSender = %s, want = %s", got, outgoingLinkAddr)
+ }
+ if got := tcpip.Address(rep.ProtocolAddressSender()); got != src {
+ t.Errorf("got ProtocolAddressSender = %s, want = %s", got, src)
+ }
+ if got := tcpip.Address(rep.ProtocolAddressTarget()); got != dst {
+ t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, dst)
+ }
+ }
+
+ ndpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) {
+ if request.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", request.Proto, header.IPv6ProtocolNumber)
+ }
+
+ snmc := header.SolicitedNodeAddr(dst)
+ if want := header.EthernetAddressFromMulticastIPv6Address(snmc); request.Route.RemoteLinkAddress != want {
+ t.Errorf("got remote link address = %s, want = %s", request.Route.RemoteLinkAddress, want)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(request.Pkt.NetworkHeader()),
+ checker.SrcAddr(src),
+ checker.DstAddr(snmc),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(dst),
+ ))
+ }
+
+ icmpv4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
+ checker.IPv4(t, b,
+ checker.SrcAddr(src),
+ checker.DstAddr(dst),
+ checker.TTL(ipv4.DefaultTTL),
+ checker.ICMPv4(
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ checker.ICMPv4Code(header.ICMPv4HostUnreachable),
+ ),
+ )
+ }
+
+ icmpv6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
+ checker.IPv6(t, b,
+ checker.SrcAddr(src),
+ checker.DstAddr(dst),
+ checker.TTL(ipv6.DefaultTTL),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6DstUnreachable),
+ checker.ICMPv6Code(header.ICMPv6AddressUnreachable),
+ ),
+ )
+ }
+
+ tests := []struct {
+ name string
+ networkProtocolFactory []stack.NetworkProtocolFactory
+ networkProtocolNumber tcpip.NetworkProtocolNumber
+ sourceAddr tcpip.Address
+ destAddr tcpip.Address
+ incomingAddr tcpip.AddressWithPrefix
+ outgoingAddr tcpip.AddressWithPrefix
+ transportProtocol func(*stack.Stack) stack.TransportProtocol
+ rx func(*channel.Endpoint, tcpip.Address, tcpip.Address)
+ linkResolutionRequestChecker func(*testing.T, channel.PacketInfo, tcpip.Address, tcpip.Address)
+ icmpReplyChecker func(*testing.T, []byte, tcpip.Address, tcpip.Address)
+ mtu uint32
+ }{
+ {
+ name: "IPv4 Host unreachable",
+ networkProtocolFactory: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
+ networkProtocolNumber: header.IPv4ProtocolNumber,
+ sourceAddr: tcptestutil.MustParse4("10.0.0.2"),
+ destAddr: tcptestutil.MustParse4("11.0.0.2"),
+ incomingAddr: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()),
+ PrefixLen: 8,
+ },
+ outgoingAddr: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()),
+ PrefixLen: 8,
+ },
+ transportProtocol: icmp.NewProtocol4,
+ linkResolutionRequestChecker: arpChecker,
+ icmpReplyChecker: icmpv4Checker,
+ rx: rxICMPv4EchoRequest,
+ mtu: ipv4.MaxTotalSize,
+ },
+ {
+ name: "IPv6 Host unreachable",
+ networkProtocolFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
+ networkProtocolNumber: header.IPv6ProtocolNumber,
+ sourceAddr: tcptestutil.MustParse6("10::2"),
+ destAddr: tcptestutil.MustParse6("11::2"),
+ incomingAddr: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10::1").To16()),
+ PrefixLen: 64,
+ },
+ outgoingAddr: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("11::1").To16()),
+ PrefixLen: 64,
+ },
+ transportProtocol: icmp.NewProtocol6,
+ linkResolutionRequestChecker: ndpChecker,
+ icmpReplyChecker: icmpv6Checker,
+ rx: rxICMPv6EchoRequest,
+ mtu: header.IPv6MinimumMTU,
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: test.networkProtocolFactory,
+ TransportProtocols: []stack.TransportProtocolFactory{test.transportProtocol},
+ Clock: clock,
+ })
+
+ // Set up endpoint through which we will receive packets.
+ incomingEndpoint := channel.New(1, test.mtu, "")
+ if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err)
+ }
+ incomingProtoAddr := tcpip.ProtocolAddress{
+ Protocol: test.networkProtocolNumber,
+ AddressWithPrefix: test.incomingAddr,
+ }
+ if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingProtoAddr, err)
+ }
+
+ // Set up endpoint through which we will attempt to forward packets.
+ outgoingEndpoint := channel.New(1, test.mtu, outgoingLinkAddr)
+ outgoingEndpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err)
+ }
+ outgoingProtoAddr := tcpip.ProtocolAddress{
+ Protocol: test.networkProtocolNumber,
+ AddressWithPrefix: test.outgoingAddr,
+ }
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingProtoAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: test.incomingAddr.Subnet(),
+ NIC: incomingNICID,
+ },
+ {
+ Destination: test.outgoingAddr.Subnet(),
+ NIC: outgoingNICID,
+ },
+ })
+
+ if err := s.SetForwardingDefaultAndAllNICs(test.networkProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", test.networkProtocolNumber, err)
+ }
+
+ test.rx(incomingEndpoint, test.sourceAddr, test.destAddr)
+
+ nudConfigs, err := s.NUDConfigurations(outgoingNICID, test.networkProtocolNumber)
+ if err != nil {
+ t.Fatalf("s.NUDConfigurations(%d, %d): %s", outgoingNICID, test.networkProtocolNumber, err)
+ }
+ // Trigger the first packet on the endpoint.
+ clock.RunImmediatelyScheduledJobs()
+
+ for i := 0; i < int(nudConfigs.MaxMulticastProbes); i++ {
+ request, ok := outgoingEndpoint.Read()
+ if !ok {
+ t.Fatal("expected ARP packet through outgoing NIC")
+ }
+
+ test.linkResolutionRequestChecker(t, request, test.outgoingAddr.Address, test.destAddr)
+
+ // Advance the clock the span of one request timeout.
+ clock.Advance(nudConfigs.RetransmitTimer)
+ }
+
+ // Next, we make a blocking read to retrieve the error packet. This is
+ // necessary because outgoing packets are dequeued asynchronously when
+ // link resolution fails, and this dequeue is what triggers the ICMP
+ // error.
+ reply, ok := incomingEndpoint.Read()
+ if !ok {
+ t.Fatal("expected ICMP packet through incoming NIC")
+ }
+
+ test.icmpReplyChecker(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), test.incomingAddr.Address, test.sourceAddr)
+
+ // Since link resolution failed, we don't expect the packet to be
+ // forwarded.
+ forwardedPacket, ok := outgoingEndpoint.Read()
+ if ok {
+ t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", forwardedPacket)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.HostUnreachable.Value(), expectedHostUnreachableErrorCount; int(got) != want {
+ t.Errorf("got rt.Stats().IP.Forwarding.HostUnreachable.Value() = %d, want = %d", got, want)
+ }
+ })
+ }
+}
+
func TestGetLinkAddress(t *testing.T) {
const (
host1NICID = 1
@@ -449,8 +701,10 @@ func TestGetLinkAddress(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
stackOpts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ Clock: clock,
}
host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
@@ -466,8 +720,20 @@ func TestGetLinkAddress(t *testing.T) {
if test.expectedErr == nil {
wantRes.LinkAddress = utils.LinkAddr2
}
- if diff := cmp.Diff(wantRes, <-ch); diff != "" {
- t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff)
+
+ nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto)
+ if err != nil {
+ t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err)
+ }
+
+ clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer)
+ select {
+ case got := <-ch:
+ if diff := cmp.Diff(wantRes, got); diff != "" {
+ t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("event didn't arrive")
}
})
}
@@ -544,8 +810,10 @@ func TestRouteResolvedFields(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
stackOpts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ Clock: clock,
}
host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
@@ -575,8 +843,20 @@ func TestRouteResolvedFields(t *testing.T) {
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{})
}
- if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Err: test.expectedErr}, <-ch, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" {
- t.Errorf("route resolve result mismatch (-want +got):\n%s", diff)
+
+ nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto)
+ if err != nil {
+ t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err)
+ }
+ clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer)
+
+ select {
+ case got := <-ch:
+ if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Err: test.expectedErr}, got, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" {
+ t.Errorf("route resolve result mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatalf("event didn't arrive")
}
if test.expectedErr != nil {
@@ -789,11 +1069,16 @@ func (d *nudDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.Neighbo
d.c <- e
}
-func (d *nudDispatcher) waitForEvent(want eventInfo) error {
- if diff := cmp.Diff(want, <-d.c, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" {
- return fmt.Errorf("got invalid event (-want +got):\n%s", diff)
+func (d *nudDispatcher) expectEvent(want eventInfo) error {
+ select {
+ case got := <-d.c:
+ if diff := cmp.Diff(want, got, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAt")); diff != "" {
+ return fmt.Errorf("got invalid event (-want +got):\n%s", diff)
+ }
+ return nil
+ default:
+ return fmt.Errorf("event didn't arrive")
}
- return nil
}
// TestTCPConfirmNeighborReachability tests that TCP informs layers beneath it
@@ -804,7 +1089,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto tcpip.NetworkProtocolNumber
remoteAddr tcpip.Address
neighborAddr tcpip.Address
- getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{})
+ getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{})
isHost1Listener bool
}{
{
@@ -812,23 +1097,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto: ipv4.ProtocolNumber,
remoteAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address,
neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
var listenerWQ waiter.Queue
+ listenerWE, listenerCH := waiter.NewChannelEntry(nil)
+ listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
if err != nil {
t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
}
+ t.Cleanup(listenerEP.Close)
var clientWQ waiter.Queue
clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.WritableEvents)
+ clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
if err != nil {
- listenerEP.Close()
t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
}
- return listenerEP, clientEP, clientCH
+ return listenerEP, listenerCH, clientEP, clientCH
},
},
{
@@ -836,23 +1123,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto: ipv6.ProtocolNumber,
remoteAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address,
neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
var listenerWQ waiter.Queue
+ listenerWE, listenerCH := waiter.NewChannelEntry(nil)
+ listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
if err != nil {
t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
}
+ t.Cleanup(listenerEP.Close)
var clientWQ waiter.Queue
clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.WritableEvents)
+ clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
if err != nil {
- listenerEP.Close()
t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
}
- return listenerEP, clientEP, clientCH
+ return listenerEP, listenerCH, clientEP, clientCH
},
},
{
@@ -860,23 +1149,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto: ipv4.ProtocolNumber,
remoteAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
var listenerWQ waiter.Queue
+ listenerWE, listenerCH := waiter.NewChannelEntry(nil)
+ listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
if err != nil {
t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
}
+ t.Cleanup(listenerEP.Close)
var clientWQ waiter.Queue
clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.WritableEvents)
+ clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
if err != nil {
- listenerEP.Close()
t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
}
- return listenerEP, clientEP, clientCH
+ return listenerEP, listenerCH, clientEP, clientCH
},
},
{
@@ -884,23 +1175,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto: ipv6.ProtocolNumber,
remoteAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
var listenerWQ waiter.Queue
+ listenerWE, listenerCH := waiter.NewChannelEntry(nil)
+ listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
if err != nil {
t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
}
+ t.Cleanup(listenerEP.Close)
var clientWQ waiter.Queue
clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.WritableEvents)
+ clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
if err != nil {
- listenerEP.Close()
t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
}
- return listenerEP, clientEP, clientCH
+ return listenerEP, listenerCH, clientEP, clientCH
},
},
{
@@ -908,23 +1201,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto: ipv4.ProtocolNumber,
remoteAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
var listenerWQ waiter.Queue
+ listenerWE, listenerCH := waiter.NewChannelEntry(nil)
+ listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
if err != nil {
t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
}
+ t.Cleanup(listenerEP.Close)
var clientWQ waiter.Queue
clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.WritableEvents)
+ clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
if err != nil {
- listenerEP.Close()
t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
}
- return listenerEP, clientEP, clientCH
+ return listenerEP, listenerCH, clientEP, clientCH
},
isHost1Listener: true,
},
@@ -933,23 +1228,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto: ipv6.ProtocolNumber,
remoteAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
var listenerWQ waiter.Queue
+ listenerWE, listenerCH := waiter.NewChannelEntry(nil)
+ listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
if err != nil {
t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
}
+ t.Cleanup(listenerEP.Close)
var clientWQ waiter.Queue
clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.WritableEvents)
+ clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
if err != nil {
- listenerEP.Close()
t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
}
- return listenerEP, clientEP, clientCH
+ return listenerEP, listenerCH, clientEP, clientCH
},
isHost1Listener: true,
},
@@ -958,23 +1255,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto: ipv4.ProtocolNumber,
remoteAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
var listenerWQ waiter.Queue
+ listenerWE, listenerCH := waiter.NewChannelEntry(nil)
+ listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
if err != nil {
t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
}
+ t.Cleanup(listenerEP.Close)
var clientWQ waiter.Queue
clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.WritableEvents)
+ clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
if err != nil {
- listenerEP.Close()
t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
}
- return listenerEP, clientEP, clientCH
+ return listenerEP, listenerCH, clientEP, clientCH
},
isHost1Listener: true,
},
@@ -983,23 +1282,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto: ipv6.ProtocolNumber,
remoteAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
var listenerWQ waiter.Queue
+ listenerWE, listenerCH := waiter.NewChannelEntry(nil)
+ listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
if err != nil {
t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
}
+ t.Cleanup(listenerEP.Close)
var clientWQ waiter.Queue
clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.WritableEvents)
+ clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
if err != nil {
- listenerEP.Close()
t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
}
- return listenerEP, clientEP, clientCH
+ return listenerEP, listenerCH, clientEP, clientCH
},
isHost1Listener: true,
},
@@ -1037,14 +1338,14 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
t.Fatalf("link resolution mismatch (-want +got):\n%s", diff)
}
}
- if err := nudDisp.waitForEvent(eventInfo{
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryAdded,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Incomplete, Addr: test.neighborAddr},
}); err != nil {
t.Fatalf("error waiting for initial NUD event: %s", err)
}
- if err := nudDisp.waitForEvent(eventInfo{
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryChanged,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
@@ -1064,7 +1365,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
// See NUDConfigurations.BaseReachableTime for more information.
maxReachableTime := time.Duration(float32(nudConfigs.BaseReachableTime) * nudConfigs.MaxRandomFactor)
clock.Advance(maxReachableTime)
- if err := nudDisp.waitForEvent(eventInfo{
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryChanged,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
@@ -1072,8 +1373,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
t.Fatalf("error waiting for stale NUD event: %s", err)
}
- listenerEP, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack)
- defer listenerEP.Close()
+ listenerEP, listenerCH, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack)
defer clientEP.Close()
listenerAddr := tcpip.FullAddress{Addr: test.remoteAddr, Port: 1234}
if err := listenerEP.Bind(listenerAddr); err != nil {
@@ -1094,14 +1394,15 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
// with confirmation that the neighbor is reachable (indicated by a
// successful 3-way handshake).
<-clientCH
- if err := nudDisp.waitForEvent(eventInfo{
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryChanged,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
}); err != nil {
t.Fatalf("error waiting for delay NUD event: %s", err)
}
- if err := nudDisp.waitForEvent(eventInfo{
+ <-listenerCH
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryChanged,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
@@ -1109,26 +1410,55 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
t.Fatalf("error waiting for reachable NUD event: %s", err)
}
+ peerEP, peerWQ, err := listenerEP.Accept(nil)
+ if err != nil {
+ t.Fatalf("listenerEP.Accept(): %s", err)
+ }
+ defer peerEP.Close()
+ peerWE, peerCH := waiter.NewChannelEntry(nil)
+ peerWQ.EventRegister(&peerWE, waiter.ReadableEvents)
+
// Wait for the neighbor to be stale again then send data to the remote.
//
// On successful transmission, the neighbor should become reachable
// without probing the neighbor as a TCP ACK would be received which is an
// indication of the neighbor being reachable.
clock.Advance(maxReachableTime)
- if err := nudDisp.waitForEvent(eventInfo{
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryChanged,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
}); err != nil {
t.Fatalf("error waiting for stale NUD event: %s", err)
}
- var r bytes.Reader
- r.Reset([]byte{0})
- var wOpts tcpip.WriteOptions
- if _, err := clientEP.Write(&r, wOpts); err != nil {
- t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err)
+ {
+ var r bytes.Reader
+ r.Reset([]byte{0})
+ var wOpts tcpip.WriteOptions
+ if _, err := clientEP.Write(&r, wOpts); err != nil {
+ t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err)
+ }
+ }
+ // Heads up, there is a race here.
+ //
+ // Incoming TCP segments are handled in
+ // tcp.(*endpoint).handleSegmentLocked:
+ //
+ // - tcp.(*endpoint).rcv.handleRcvdSegment puts the segment on the
+ // segment queue and notifies waiting readers (such as this channel)
+ //
+ // - tcp.(*endpoint).snd.handleRcvdSegment sends an ACK for the segment
+ // and notifies the NUD machinery that the peer is reachable
+ //
+ // Thus we must permit a delay between the readable signal and the
+ // expected NUD event.
+ //
+ // At the time of writing, this race is reliably hit with gotsan.
+ <-peerCH
+ for len(nudDisp.c) == 0 {
+ runtime.Gosched()
}
- if err := nudDisp.waitForEvent(eventInfo{
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryChanged,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
@@ -1141,7 +1471,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
// TCP should not mark the route reachable and NUD should go through the
// probe state.
clock.Advance(nudConfigs.DelayFirstProbeTime)
- if err := nudDisp.waitForEvent(eventInfo{
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryChanged,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Probe, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
@@ -1149,7 +1479,16 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
t.Fatalf("error waiting for probe NUD event: %s", err)
}
}
- if err := nudDisp.waitForEvent(eventInfo{
+ {
+ var r bytes.Reader
+ r.Reset([]byte{0})
+ var wOpts tcpip.WriteOptions
+ if _, err := peerEP.Write(&r, wOpts); err != nil {
+ t.Errorf("peerEP.Write(_, %#v): %s", wOpts, err)
+ }
+ }
+ <-clientCH
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryChanged,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index 87d36e1dd..b4b2ec723 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -44,20 +44,17 @@ type ndpDispatcher struct{}
func (*ndpDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) {
}
-func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool {
- return false
+func (*ndpDispatcher) OnOffLinkRouteUpdated(tcpip.NICID, tcpip.Subnet, tcpip.Address) {
}
-func (*ndpDispatcher) OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address) {}
+func (*ndpDispatcher) OnOffLinkRouteInvalidated(tcpip.NICID, tcpip.Subnet, tcpip.Address) {}
-func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool {
- return false
+func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) {
}
func (*ndpDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {}
-func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool {
- return true
+func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) {
}
func (*ndpDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {}
diff --git a/pkg/tcpip/testutil/testutil.go b/pkg/tcpip/testutil/testutil.go
index f84d399fb..94b580a70 100644
--- a/pkg/tcpip/testutil/testutil.go
+++ b/pkg/tcpip/testutil/testutil.go
@@ -109,3 +109,15 @@ func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) er
return nil
}
+
+// MustParseLink parses a Link string into a tcpip.LinkAddress, panicking on
+// error.
+//
+// The string must be in the format aa:bb:cc:dd:ee:ff or aa-bb-cc-dd-ee-ff.
+func MustParseLink(addr string) tcpip.LinkAddress {
+ parsed, err := tcpip.ParseMACAddress(addr)
+ if err != nil {
+ panic(fmt.Sprintf("tcpip.ParseMACAddress(%s): %s", addr, err))
+ }
+ return parsed
+}
diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go
index 1633d0aeb..ed1ed8ac6 100644
--- a/pkg/tcpip/timer_test.go
+++ b/pkg/tcpip/timer_test.go
@@ -15,6 +15,7 @@
package tcpip_test
import (
+ "math"
"sync"
"testing"
"time"
@@ -22,10 +23,85 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
+func TestMonotonicTimeBefore(t *testing.T) {
+ var mt tcpip.MonotonicTime
+ if mt.Before(mt) {
+ t.Errorf("%#v.Before(%#v)", mt, mt)
+ }
+
+ one := mt.Add(1)
+ if one.Before(mt) {
+ t.Errorf("%#v.Before(%#v)", one, mt)
+ }
+ if !mt.Before(one) {
+ t.Errorf("!%#v.Before(%#v)", mt, one)
+ }
+}
+
+func TestMonotonicTimeAfter(t *testing.T) {
+ var mt tcpip.MonotonicTime
+ if mt.After(mt) {
+ t.Errorf("%#v.After(%#v)", mt, mt)
+ }
+
+ one := mt.Add(1)
+ if mt.After(one) {
+ t.Errorf("%#v.After(%#v)", mt, one)
+ }
+ if !one.After(mt) {
+ t.Errorf("!%#v.After(%#v)", one, mt)
+ }
+}
+
+func TestMonotonicTimeAddSub(t *testing.T) {
+ var mt tcpip.MonotonicTime
+ if one, two := mt.Add(2), mt.Add(1).Add(1); one != two {
+ t.Errorf("mt.Add(2) != mt.Add(1).Add(1) (%#v != %#v)", one, two)
+ }
+
+ min := mt.Add(math.MinInt64)
+ max := mt.Add(math.MaxInt64)
+
+ if overflow := mt.Add(1).Add(math.MaxInt64); overflow != max {
+ t.Errorf("mt.Add(math.MaxInt64) != mt.Add(1).Add(math.MaxInt64) (%#v != %#v)", max, overflow)
+ }
+ if underflow := mt.Add(-1).Add(math.MinInt64); underflow != min {
+ t.Errorf("mt.Add(math.MinInt64) != mt.Add(-1).Add(math.MinInt64) (%#v != %#v)", min, underflow)
+ }
+
+ if got, want := min.Sub(min), time.Duration(0); want != got {
+ t.Errorf("got min.Sub(min) = %d, want %d", got, want)
+ }
+ if got, want := max.Sub(max), time.Duration(0); want != got {
+ t.Errorf("got max.Sub(max) = %d, want %d", got, want)
+ }
+
+ if overflow, want := max.Sub(min), time.Duration(math.MaxInt64); overflow != want {
+ t.Errorf("mt.Add(math.MaxInt64).Sub(mt.Add(math.MinInt64) != %s (%#v)", want, overflow)
+ }
+ if underflow, want := min.Sub(max), time.Duration(math.MinInt64); underflow != want {
+ t.Errorf("mt.Add(math.MinInt64).Sub(mt.Add(math.MaxInt64) != %s (%#v)", want, underflow)
+ }
+}
+
+func TestMonotonicTimeSub(t *testing.T) {
+ var mt tcpip.MonotonicTime
+
+ if one, two := mt.Add(2), mt.Add(1).Add(1); one != two {
+ t.Errorf("mt.Add(2) != mt.Add(1).Add(1) (%#v != %#v)", one, two)
+ }
+
+ if max, overflow := mt.Add(math.MaxInt64), mt.Add(1).Add(math.MaxInt64); max != overflow {
+ t.Errorf("mt.Add(math.MaxInt64) != mt.Add(1).Add(math.MaxInt64) (%#v != %#v)", max, overflow)
+ }
+ if max, underflow := mt.Add(math.MinInt64), mt.Add(-1).Add(math.MinInt64); max != underflow {
+ t.Errorf("mt.Add(math.MinInt64) != mt.Add(-1).Add(math.MinInt64) (%#v != %#v)", max, underflow)
+ }
+}
+
const (
shortDuration = 1 * time.Nanosecond
middleDuration = 100 * time.Millisecond
- longDuration = 1 * time.Second
)
func TestJobReschedule(t *testing.T) {
@@ -53,10 +129,14 @@ func TestJobReschedule(t *testing.T) {
wg.Wait()
}
+func stdClockWithAfter() (tcpip.Clock, func(time.Duration) <-chan time.Time) {
+ return tcpip.NewStdClock(), time.After
+}
+
func TestJobExecution(t *testing.T) {
t.Parallel()
- clock := tcpip.NewStdClock()
+ clock, after := stdClockWithAfter()
var lock sync.Mutex
ch := make(chan struct{})
@@ -68,7 +148,7 @@ func TestJobExecution(t *testing.T) {
// Wait for timer to fire.
select {
case <-ch:
- case <-time.After(middleDuration):
+ case <-after(middleDuration):
t.Fatal("timed out waiting for timer to fire")
}
@@ -76,14 +156,14 @@ func TestJobExecution(t *testing.T) {
select {
case <-ch:
t.Fatal("no other timers should have fired")
- case <-time.After(middleDuration):
+ case <-after(middleDuration):
}
}
func TestCancellableTimerResetFromLongDuration(t *testing.T) {
t.Parallel()
- clock := tcpip.NewStdClock()
+ clock, after := stdClockWithAfter()
var lock sync.Mutex
ch := make(chan struct{})
@@ -99,7 +179,7 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) {
// Wait for timer to fire.
select {
case <-ch:
- case <-time.After(middleDuration):
+ case <-after(middleDuration):
t.Fatal("timed out waiting for timer to fire")
}
@@ -107,14 +187,14 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) {
select {
case <-ch:
t.Fatal("no other timers should have fired")
- case <-time.After(middleDuration):
+ case <-after(middleDuration):
}
}
func TestJobRescheduleFromShortDuration(t *testing.T) {
t.Parallel()
- clock := tcpip.NewStdClock()
+ clock, after := stdClockWithAfter()
var lock sync.Mutex
ch := make(chan struct{})
@@ -128,7 +208,7 @@ func TestJobRescheduleFromShortDuration(t *testing.T) {
select {
case <-ch:
t.Fatal("timer fired after being stopped")
- case <-time.After(middleDuration):
+ case <-after(middleDuration):
}
job.Schedule(shortDuration)
@@ -136,7 +216,7 @@ func TestJobRescheduleFromShortDuration(t *testing.T) {
// Wait for timer to fire.
select {
case <-ch:
- case <-time.After(middleDuration):
+ case <-after(middleDuration):
t.Fatal("timed out waiting for timer to fire")
}
@@ -144,14 +224,14 @@ func TestJobRescheduleFromShortDuration(t *testing.T) {
select {
case <-ch:
t.Fatal("no other timers should have fired")
- case <-time.After(middleDuration):
+ case <-after(middleDuration):
}
}
func TestJobImmediatelyCancel(t *testing.T) {
t.Parallel()
- clock := tcpip.NewStdClock()
+ clock, after := stdClockWithAfter()
var lock sync.Mutex
ch := make(chan struct{})
@@ -167,14 +247,19 @@ func TestJobImmediatelyCancel(t *testing.T) {
select {
case <-ch:
t.Fatal("timer fired after being stopped")
- case <-time.After(middleDuration):
+ case <-after(middleDuration):
}
}
+func stdClockWithAfterAndSleep() (tcpip.Clock, func(time.Duration) <-chan time.Time, func(time.Duration)) {
+ clock, after := stdClockWithAfter()
+ return clock, after, time.Sleep
+}
+
func TestJobCancelledRescheduleWithoutLock(t *testing.T) {
t.Parallel()
- clock := tcpip.NewStdClock()
+ clock, after, sleep := stdClockWithAfterAndSleep()
var lock sync.Mutex
ch := make(chan struct{})
@@ -189,7 +274,7 @@ func TestJobCancelledRescheduleWithoutLock(t *testing.T) {
lock.Lock()
// Sleep until the timer fires and gets blocked trying to take the lock.
- time.Sleep(middleDuration * 2)
+ sleep(middleDuration * 2)
job.Cancel()
lock.Unlock()
}
@@ -199,14 +284,14 @@ func TestJobCancelledRescheduleWithoutLock(t *testing.T) {
select {
case <-ch:
t.Fatal("timer fired after being stopped")
- case <-time.After(middleDuration * 2):
+ case <-after(middleDuration * 2):
}
}
func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
t.Parallel()
- clock := tcpip.NewStdClock()
+ clock, after, sleep := stdClockWithAfterAndSleep()
var lock sync.Mutex
ch := make(chan struct{})
@@ -215,7 +300,7 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
job.Schedule(shortDuration)
for i := 0; i < 10; i++ {
// Sleep until the timer fires and gets blocked trying to take the lock.
- time.Sleep(middleDuration)
+ sleep(middleDuration)
job.Cancel()
job.Schedule(shortDuration)
}
@@ -224,7 +309,7 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
// Wait for double the duration for the last timer to fire.
select {
case <-ch:
- case <-time.After(middleDuration):
+ case <-after(middleDuration):
t.Fatal("timed out waiting for timer to fire")
}
@@ -232,14 +317,14 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
select {
case <-ch:
t.Fatal("no other timers should have fired")
- case <-time.After(middleDuration):
+ case <-after(middleDuration):
}
}
func TestManyJobReschedulesUnderLock(t *testing.T) {
t.Parallel()
- clock := tcpip.NewStdClock()
+ clock, after := stdClockWithAfter()
var lock sync.Mutex
ch := make(chan struct{})
@@ -255,7 +340,7 @@ func TestManyJobReschedulesUnderLock(t *testing.T) {
// Wait for double the duration for the last timer to fire.
select {
case <-ch:
- case <-time.After(middleDuration):
+ case <-after(middleDuration):
t.Fatal("timed out waiting for timer to fire")
}
@@ -263,6 +348,6 @@ func TestManyJobReschedulesUnderLock(t *testing.T) {
select {
case <-ch:
t.Fatal("no other timers should have fired")
- case <-time.After(middleDuration):
+ case <-after(middleDuration):
}
}
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index 7e5c79776..bbc0e3ecc 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
@@ -38,3 +38,22 @@ go_library(
"//pkg/waiter",
],
)
+
+go_test(
+ name = "icmp_x_test",
+ size = "small",
+ srcs = ["icmp_test.go"],
+ deps = [
+ ":icmp",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 8afde7fca..cb316d27a 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -16,6 +16,7 @@ package icmp
import (
"io"
+ "time"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -26,14 +27,12 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-// TODO(https://gvisor.dev/issues/5623): Unit test this package.
-
// +stateify savable
type icmpPacket struct {
icmpPacketEntry
senderAddress tcpip.FullAddress
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
- timestamp int64
+ receivedAt time.Time `state:".(int64)"`
}
type endpointState int
@@ -133,7 +132,8 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */)
+ bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
+ e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, bindToDevice)
}
// Close the receive list and drain it.
@@ -160,7 +160,7 @@ func (e *endpoint) Close() {
}
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
-func (e *endpoint) ModerateRecvBuf(copied int) {}
+func (*endpoint) ModerateRecvBuf(int) {}
// SetOwner implements tcpip.Endpoint.SetOwner.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
@@ -193,7 +193,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
Total: p.data.Size(),
ControlMessages: tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: p.timestamp,
+ Timestamp: p.receivedAt.UnixNano(),
},
}
if opts.NeedRemoteAddr {
@@ -304,6 +304,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
nicID := to.NIC
+ if nicID == 0 {
+ nicID = tcpip.NICID(e.ops.GetBindToDevice())
+ }
if e.BindNICID != 0 {
if nicID != 0 && nicID != e.BindNICID {
return 0, &tcpip.ErrNoRoute{}
@@ -348,8 +351,15 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
return int64(len(v)), nil
}
+var _ tcpip.SocketOptionsHandler = (*endpoint)(nil)
+
+// HasNIC implements tcpip.SocketOptionsHandler.
+func (e *endpoint) HasNIC(id int32) bool {
+ return e.stack.HasNIC(tcpip.NICID(id))
+}
+
// SetSockOpt sets a socket option.
-func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
+func (*endpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error {
return nil
}
@@ -390,7 +400,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error {
return &tcpip.ErrUnknownProtocolOption{}
}
@@ -606,18 +616,19 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi
return nil, nil, &tcpip.ErrNotSupported{}
}
-func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) {
+func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) {
+ bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
- err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */)
+ err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice)
return id, err
}
// We need to find a port for the endpoint.
- _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) {
+ _, err := e.stack.PickEphemeralPort(e.stack.Rand(), func(p uint16) (bool, tcpip.Error) {
id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */)
+ err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice)
switch err.(type) {
case nil:
return true, nil
@@ -747,8 +758,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
switch e.NetProto {
case header.IPv4ProtocolNumber:
h := header.ICMPv4(pkt.TransportHeader().View())
- // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed
- // after early parsing.
if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
@@ -756,8 +765,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
}
case header.IPv6ProtocolNumber:
h := header.ICMPv6(pkt.TransportHeader().View())
- // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed
- // after early parsing.
if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
@@ -800,7 +807,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
- packet.timestamp = e.stack.Clock().NowNanoseconds()
+ packet.receivedAt = e.stack.Clock().Now()
e.rcvMu.Unlock()
e.stats.PacketsReceived.Increment()
diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go
index 28a56a2d5..b8b839e4a 100644
--- a/pkg/tcpip/transport/icmp/endpoint_state.go
+++ b/pkg/tcpip/transport/icmp/endpoint_state.go
@@ -15,11 +15,23 @@
package icmp
import (
+ "time"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+// saveReceivedAt is invoked by stateify.
+func (p *icmpPacket) saveReceivedAt() int64 {
+ return p.receivedAt.UnixNano()
+}
+
+// loadReceivedAt is invoked by stateify.
+func (p *icmpPacket) loadReceivedAt(nsec int64) {
+ p.receivedAt = time.Unix(0, nsec)
+}
+
// saveData saves icmpPacket.data field.
func (p *icmpPacket) saveData() buffer.VectorisedView {
// We cannot save p.data directly as p.data.views may alias to p.views,
diff --git a/pkg/tcpip/transport/icmp/icmp_test.go b/pkg/tcpip/transport/icmp/icmp_test.go
new file mode 100644
index 000000000..cc950cbde
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/icmp_test.go
@@ -0,0 +1,235 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package icmp_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// TODO(https://gvisor.dev/issues/5623): Finish unit testing the icmp package.
+// See the issue for remaining areas of work.
+
+var (
+ localV4Addr1 = testutil.MustParse4("10.0.0.1")
+ localV4Addr2 = testutil.MustParse4("10.0.0.2")
+ remoteV4Addr = testutil.MustParse4("10.0.0.3")
+)
+
+func addNICWithDefaultRoute(t *testing.T, s *stack.Stack, id tcpip.NICID, name string, addrV4 tcpip.Address) *channel.Endpoint {
+ t.Helper()
+
+ ep := channel.New(1 /* size */, header.IPv4MinimumMTU, "" /* linkAddr */)
+ t.Cleanup(ep.Close)
+
+ wep := stack.LinkEndpoint(ep)
+ if testing.Verbose() {
+ wep = sniffer.New(ep)
+ }
+
+ opts := stack.NICOptions{Name: name}
+ if err := s.CreateNICWithOptions(id, wep, opts); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _) = %s", id, err)
+ }
+
+ if err := s.AddAddress(id, ipv4.ProtocolNumber, addrV4); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s) = %s", id, ipv4.ProtocolNumber, addrV4, err)
+ }
+
+ s.AddRoute(tcpip.Route{
+ Destination: header.IPv4EmptySubnet,
+ NIC: id,
+ })
+
+ return ep
+}
+
+func writePayload(buf []byte) {
+ for i := range buf {
+ buf[i] = byte(i)
+ }
+}
+
+func newICMPv4EchoRequest(payloadSize uint32) buffer.View {
+ buf := buffer.NewView(header.ICMPv4MinimumSize + int(payloadSize))
+ writePayload(buf[header.ICMPv4MinimumSize:])
+
+ icmp := header.ICMPv4(buf)
+ icmp.SetType(header.ICMPv4Echo)
+ // No need to set the checksum; it is reset by the socket before the packet
+ // is sent.
+
+ return buf
+}
+
+// TestWriteUnboundWithBindToDevice exercises writing to an unbound ICMP socket
+// when SO_BINDTODEVICE is set to the non-default NIC for that subnet.
+//
+// Only IPv4 is tested. The logic to determine which NIC to use is agnostic to
+// the version of IP.
+func TestWriteUnboundWithBindToDevice(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
+ HandleLocal: true,
+ })
+
+ // Add two NICs, both with default routes on the same subnet. The first NIC
+ // added will be the default NIC for that subnet.
+ defaultEP := addNICWithDefaultRoute(t, s, 1, "default", localV4Addr1)
+ alternateEP := addNICWithDefaultRoute(t, s, 2, "alternate", localV4Addr2)
+
+ socket, err := s.NewEndpoint(icmp.ProtocolNumber4, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _) = %s", icmp.ProtocolNumber4, ipv4.ProtocolNumber, err)
+ }
+ defer socket.Close()
+
+ echoPayloadSize := defaultEP.MTU() - header.IPv4MinimumSize - header.ICMPv4MinimumSize
+
+ // Send a packet without SO_BINDTODEVICE. This verifies that the first NIC
+ // to be added is the default NIC to send packets when not explicitly bound.
+ {
+ buf := newICMPv4EchoRequest(echoPayloadSize)
+ r := buf.Reader()
+ n, err := socket.Write(&r, tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: remoteV4Addr},
+ })
+ if err != nil {
+ t.Fatalf("socket.Write(_, {To:%s}) = %s", remoteV4Addr, err)
+ }
+ if n != int64(len(buf)) {
+ t.Fatalf("got n = %d, want n = %d", n, len(buf))
+ }
+
+ // Verify the packet was sent out the default NIC.
+ p, ok := defaultEP.Read()
+ if !ok {
+ t.Fatalf("got defaultEP.Read(_) = _, false; want = _, true (packet wasn't written out)")
+ }
+
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ b := vv.ToView()
+
+ checker.IPv4(t, b, []checker.NetworkChecker{
+ checker.SrcAddr(localV4Addr1),
+ checker.DstAddr(remoteV4Addr),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4Echo),
+ checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]),
+ ),
+ }...)
+
+ // Verify the packet was not sent out the alternate NIC.
+ if p, ok := alternateEP.Read(); ok {
+ t.Fatalf("got alternateEP.Read(_) = %+v, true; want = _, false", p)
+ }
+ }
+
+ // Send a packet with SO_BINDTODEVICE. This exercises reliance on
+ // SO_BINDTODEVICE to route the packet to the alternate NIC.
+ {
+ // Use SO_BINDTODEVICE to send over the alternate NIC by default.
+ socket.SocketOptions().SetBindToDevice(2)
+
+ buf := newICMPv4EchoRequest(echoPayloadSize)
+ r := buf.Reader()
+ n, err := socket.Write(&r, tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: remoteV4Addr},
+ })
+ if err != nil {
+ t.Fatalf("socket.Write(_, {To:%s}) = %s", tcpip.Address(remoteV4Addr), err)
+ }
+ if n != int64(len(buf)) {
+ t.Fatalf("got n = %d, want n = %d", n, len(buf))
+ }
+
+ // Verify the packet was not sent out the default NIC.
+ if p, ok := defaultEP.Read(); ok {
+ t.Fatalf("got defaultEP.Read(_) = %+v, true; want = _, false", p)
+ }
+
+ // Verify the packet was sent out the alternate NIC.
+ p, ok := alternateEP.Read()
+ if !ok {
+ t.Fatalf("got alternateEP.Read(_) = _, false; want = _, true (packet wasn't written out)")
+ }
+
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ b := vv.ToView()
+
+ checker.IPv4(t, b, []checker.NetworkChecker{
+ checker.SrcAddr(localV4Addr2),
+ checker.DstAddr(remoteV4Addr),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4Echo),
+ checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]),
+ ),
+ }...)
+ }
+
+ // Send a packet with SO_BINDTODEVICE cleared. This verifies that clearing
+ // the device binding will fallback to using the default NIC to send
+ // packets.
+ {
+ socket.SocketOptions().SetBindToDevice(0)
+
+ buf := newICMPv4EchoRequest(echoPayloadSize)
+ r := buf.Reader()
+ n, err := socket.Write(&r, tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: remoteV4Addr},
+ })
+ if err != nil {
+ t.Fatalf("socket.Write(_, {To:%s}) = %s", tcpip.Address(remoteV4Addr), err)
+ }
+ if n != int64(len(buf)) {
+ t.Fatalf("got n = %d, want n = %d", n, len(buf))
+ }
+
+ // Verify the packet was sent out the default NIC.
+ p, ok := defaultEP.Read()
+ if !ok {
+ t.Fatalf("got defaultEP.Read(_) = _, false; want = _, true (packet wasn't written out)")
+ }
+
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ b := vv.ToView()
+
+ checker.IPv4(t, b, []checker.NetworkChecker{
+ checker.SrcAddr(localV4Addr1),
+ checker.DstAddr(remoteV4Addr),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4Echo),
+ checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]),
+ ),
+ }...)
+
+ // Verify the packet was not sent out the alternate NIC.
+ if p, ok := alternateEP.Read(); ok {
+ t.Fatalf("got alternateEP.Read(_) = %+v, true; want = _, false", p)
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 47f7dd1cb..fa82affc1 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -123,8 +123,6 @@ func (*protocol) Wait() {}
// Parse implements stack.TransportProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
- // TODO(gvisor.dev/issue/170): Implement parsing of ICMP.
- //
// Right now, the Parse() method is tied to enabled protocols passed into
// stack.New. This works for UDP and TCP, but we handle ICMP traffic even
// when netstack users don't pass ICMP as a supported protocol.
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 496eca581..cd8c99d41 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -27,6 +27,7 @@ package packet
import (
"fmt"
"io"
+ "time"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -41,9 +42,8 @@ type packet struct {
packetEntry
// data holds the actual packet data, including any headers and
// payload.
- data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
- // timestampNS is the unix time at which the packet was received.
- timestampNS int64
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ receivedAt time.Time `state:".(int64)"`
// senderAddr is the network address of the sender.
senderAddr tcpip.FullAddress
// packetInfo holds additional information like the protocol
@@ -159,7 +159,7 @@ func (ep *endpoint) Close() {
}
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
-func (ep *endpoint) ModerateRecvBuf(copied int) {}
+func (*endpoint) ModerateRecvBuf(int) {}
// Read implements tcpip.Endpoint.Read.
func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
@@ -189,7 +189,7 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul
Total: packet.data.Size(),
ControlMessages: tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: packet.timestampNS,
+ Timestamp: packet.receivedAt.UnixNano(),
},
}
if opts.NeedRemoteAddr {
@@ -220,19 +220,19 @@ func (*endpoint) Disconnect() tcpip.Error {
// Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be
// connected, and this function always returnes *tcpip.ErrNotSupported.
-func (*endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
+func (*endpoint) Connect(tcpip.FullAddress) tcpip.Error {
return &tcpip.ErrNotSupported{}
}
// Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used
// with Shutdown, and this function always returns *tcpip.ErrNotSupported.
-func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
+func (*endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error {
return &tcpip.ErrNotSupported{}
}
// Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with
// Listen, and this function always returns *tcpip.ErrNotSupported.
-func (*endpoint) Listen(backlog int) tcpip.Error {
+func (*endpoint) Listen(int) tcpip.Error {
return &tcpip.ErrNotSupported{}
}
@@ -318,7 +318,7 @@ func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
}
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
-func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
+func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error {
return &tcpip.ErrUnknownProtocolOption{}
}
@@ -339,7 +339,7 @@ func (ep *endpoint) UpdateLastError(err tcpip.Error) {
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error {
return &tcpip.ErrNotSupported{}
}
@@ -451,7 +451,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
packet.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views())
}
}
- packet.timestampNS = ep.stack.Clock().NowNanoseconds()
+ packet.receivedAt = ep.stack.Clock().Now()
ep.rcvList.PushBack(&packet)
ep.rcvBufSize += packet.data.Size()
@@ -484,7 +484,7 @@ func (ep *endpoint) Stats() tcpip.EndpointStats {
}
// SetOwner implements tcpip.Endpoint.SetOwner.
-func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {}
+func (*endpoint) SetOwner(tcpip.PacketOwner) {}
// SocketOptions implements tcpip.Endpoint.SocketOptions.
func (ep *endpoint) SocketOptions() *tcpip.SocketOptions {
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
index 5bd860d20..e729921db 100644
--- a/pkg/tcpip/transport/packet/endpoint_state.go
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -15,11 +15,23 @@
package packet
import (
+ "time"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+// saveReceivedAt is invoked by stateify.
+func (p *packet) saveReceivedAt() int64 {
+ return p.receivedAt.UnixNano()
+}
+
+// loadReceivedAt is invoked by stateify.
+func (p *packet) loadReceivedAt(nsec int64) {
+ p.receivedAt = time.Unix(0, nsec)
+}
+
// saveData saves packet.data field.
func (p *packet) saveData() buffer.VectorisedView {
// We cannot save p.data directly as p.data.views may alias to p.views,
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index bcec3d2e7..1bce2769a 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -27,6 +27,7 @@ package raw
import (
"io"
+ "time"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -41,9 +42,8 @@ type rawPacket struct {
rawPacketEntry
// data holds the actual packet data, including any headers and
// payload.
- data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
- // timestampNS is the unix time at which the packet was received.
- timestampNS int64
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ receivedAt time.Time `state:".(int64)"`
// senderAddr is the network address of the sender.
senderAddr tcpip.FullAddress
}
@@ -183,7 +183,7 @@ func (e *endpoint) Close() {
}
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
-func (e *endpoint) ModerateRecvBuf(copied int) {}
+func (*endpoint) ModerateRecvBuf(int) {}
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.mu.Lock()
@@ -219,7 +219,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
Total: pkt.data.Size(),
ControlMessages: tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: pkt.timestampNS,
+ Timestamp: pkt.receivedAt.UnixNano(),
},
}
if opts.NeedRemoteAddr {
@@ -402,7 +402,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
}
// Find a route to the destination.
- route, err := e.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, e.NetProto, false)
+ route, err := e.stack.FindRoute(nic, "", addr.Addr, e.NetProto, false)
if err != nil {
return err
}
@@ -428,7 +428,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
}
// Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets.
-func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
+func (e *endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -439,7 +439,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
}
// Listen implements tcpip.Endpoint.Listen.
-func (*endpoint) Listen(backlog int) tcpip.Error {
+func (*endpoint) Listen(int) tcpip.Error {
return &tcpip.ErrNotSupported{}
}
@@ -513,12 +513,12 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
}
}
-func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
+func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error {
return &tcpip.ErrUnknownProtocolOption{}
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error {
return &tcpip.ErrUnknownProtocolOption{}
}
@@ -621,7 +621,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
}
combinedVV.Append(pkt.Data().ExtractVV())
packet.data = combinedVV
- packet.timestampNS = e.stack.Clock().NowNanoseconds()
+ packet.receivedAt = e.stack.Clock().Now()
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index 5d6f2709c..39669b445 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -15,11 +15,23 @@
package raw
import (
+ "time"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+// saveReceivedAt is invoked by stateify.
+func (p *rawPacket) saveReceivedAt() int64 {
+ return p.receivedAt.UnixNano()
+}
+
+// loadReceivedAt is invoked by stateify.
+func (p *rawPacket) loadReceivedAt(nsec int64) {
+ p.receivedAt = time.Unix(0, nsec)
+}
+
// saveData saves rawPacket.data field.
func (p *rawPacket) saveData() buffer.VectorisedView {
// We cannot save p.data directly as p.data.views may alias to p.views,
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 0f20d3856..8436d2cf0 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -41,7 +41,6 @@ go_library(
"protocol.go",
"rack.go",
"rcv.go",
- "rcv_state.go",
"reno.go",
"reno_recovery.go",
"sack.go",
@@ -53,7 +52,6 @@ go_library(
"segment_state.go",
"segment_unsafe.go",
"snd.go",
- "snd_state.go",
"tcp_endpoint_list.go",
"tcp_segment_list.go",
"timer.go",
@@ -134,6 +132,7 @@ go_test(
deps = [
"//pkg/sleep",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/stack",
"@com_github_google_go_cmp//cmp:go_default_library",
],
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index d4bd4e80e..d807b13b7 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -114,8 +114,8 @@ type listenContext struct {
}
// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds.
-func timeStamp() uint32 {
- return uint32(time.Now().Unix()>>6) & tsMask
+func timeStamp(clock tcpip.Clock) uint32 {
+ return uint32(clock.NowMonotonic().Sub(tcpip.MonotonicTime{}).Seconds()) >> 6 & tsMask
}
// newListenContext creates a new listen context.
@@ -171,7 +171,7 @@ func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonc
// createCookie creates a SYN cookie for the given id and incoming sequence
// number.
func (l *listenContext) createCookie(id stack.TransportEndpointID, seq seqnum.Value, data uint32) seqnum.Value {
- ts := timeStamp()
+ ts := timeStamp(l.stack.Clock())
v := l.cookieHash(id, 0, 0) + uint32(seq) + (ts << tsOffset)
v += (l.cookieHash(id, ts, 1) + data) & hashMask
return seqnum.Value(v)
@@ -181,7 +181,7 @@ func (l *listenContext) createCookie(id stack.TransportEndpointID, seq seqnum.Va
// sequence number. If it is, it also returns the data originally encoded in the
// cookie when createCookie was called.
func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnum.Value, seq seqnum.Value) (uint32, bool) {
- ts := timeStamp()
+ ts := timeStamp(l.stack.Clock())
v := uint32(cookie) - l.cookieHash(id, 0, 0) - uint32(seq)
cookieTS := v >> tsOffset
if ((ts - cookieTS) & tsMask) > maxTSDiff {
@@ -247,7 +247,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header
func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
- isn := generateSecureISN(s.id, l.stack.Seed())
+ isn := generateSecureISN(s.id, l.stack.Clock(), l.stack.Seed())
ep, err := l.createConnectingEndpoint(s, opts, queue)
if err != nil {
return nil, err
@@ -550,7 +550,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
e.rcvQueueInfo.rcvQueueMu.Lock()
rcvClosed := e.rcvQueueInfo.RcvClosed
e.rcvQueueInfo.rcvQueueMu.Unlock()
- if rcvClosed || s.flagsAreSet(header.TCPFlagSyn|header.TCPFlagAck) {
+ if rcvClosed || s.flags.Contains(header.TCPFlagSyn|header.TCPFlagAck) {
// If the endpoint is shutdown, reply with reset.
//
// RFC 793 section 3.4 page 35 (figure 12) outlines that a RST
@@ -560,6 +560,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
}
switch {
+ case s.flags.Contains(header.TCPFlagRst):
+ e.stack.Stats().DroppedPackets.Increment()
+ return nil
+
case s.flags == header.TCPFlagSyn:
if e.acceptQueueIsFull() {
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
@@ -591,7 +595,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
synOpts := header.TCPSynOptions{
WS: -1,
TS: opts.TS,
- TSVal: tcpTimeStamp(time.Now(), timeStampOffset()),
+ TSVal: tcpTimeStamp(e.stack.Clock().NowMonotonic(), timeStampOffset(e.stack.Rand())),
TSEcr: opts.TSVal,
MSS: calculateAdvertisedMSS(e.userMSS, route),
}
@@ -611,7 +615,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
return nil
- case (s.flags & header.TCPFlagAck) != 0:
+ case s.flags.Contains(header.TCPFlagAck):
if e.acceptQueueIsFull() {
// Silently drop the ack as the application can't accept
// the connection at this point. The ack will be
@@ -736,6 +740,13 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
mss: rcvdSynOptions.MSS,
})
+ // Requeue the segment if the ACK completing the handshake has more info
+ // to be procesed by the newly established endpoint.
+ if (s.flags.Contains(header.TCPFlagFin) || s.data.Size() > 0) && n.enqueueSegment(s) {
+ s.incRef()
+ n.newSegmentWaker.Assert()
+ }
+
// Do the delivery in a separate goroutine so
// that we don't block the listen loop in case
// the application is slow to accept or stops
@@ -753,6 +764,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
return nil
default:
+ e.stack.Stats().DroppedPackets.Increment()
return nil
}
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 5e03e7715..2137ebc25 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -19,7 +19,6 @@ import (
"math"
"time"
- "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -92,7 +91,7 @@ type handshake struct {
rcvWndScale int
// startTime is the time at which the first SYN/SYN-ACK was sent.
- startTime time.Time
+ startTime tcpip.MonotonicTime
// deferAccept if non-zero will drop the final ACK for a passive
// handshake till an ACK segment with data is received or the timeout is
@@ -147,21 +146,16 @@ func FindWndScale(wnd seqnum.Size) int {
// resetState resets the state of the handshake object such that it becomes
// ready for a new 3-way handshake.
func (h *handshake) resetState() {
- b := make([]byte, 4)
- if _, err := rand.Read(b); err != nil {
- panic(err)
- }
-
h.state = handshakeSynSent
h.flags = header.TCPFlagSyn
h.ackNum = 0
h.mss = 0
- h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Seed())
+ h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Clock(), h.ep.stack.Seed())
}
// generateSecureISN generates a secure Initial Sequence number based on the
// recommendation here https://tools.ietf.org/html/rfc6528#page-3.
-func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value {
+func generateSecureISN(id stack.TransportEndpointID, clock tcpip.Clock, seed uint32) seqnum.Value {
isnHasher := jenkins.Sum32(seed)
isnHasher.Write([]byte(id.LocalAddress))
isnHasher.Write([]byte(id.RemoteAddress))
@@ -180,7 +174,7 @@ func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value {
//
// Which sort of guarantees that we won't reuse the ISN for a new
// connection for the same tuple for at least 274s.
- isn := isnHasher.Sum32() + uint32(time.Now().UnixNano()>>6)
+ isn := isnHasher.Sum32() + uint32(clock.NowMonotonic().Sub(tcpip.MonotonicTime{}).Nanoseconds()>>6)
return seqnum.Value(isn)
}
@@ -212,7 +206,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea
// a TCP 3-way handshake is valid. If it's not, a RST segment is sent back in
// response.
func (h *handshake) checkAck(s *segment) bool {
- if s.flagIsSet(header.TCPFlagAck) && s.ackNumber != h.iss+1 {
+ if s.flags.Contains(header.TCPFlagAck) && s.ackNumber != h.iss+1 {
// RFC 793, page 36, states that a reset must be generated when
// the connection is in any non-synchronized state and an
// incoming segment acknowledges something not yet sent. The
@@ -230,8 +224,8 @@ func (h *handshake) checkAck(s *segment) bool {
func (h *handshake) synSentState(s *segment) tcpip.Error {
// RFC 793, page 37, states that in the SYN-SENT state, a reset is
// acceptable if the ack field acknowledges the SYN.
- if s.flagIsSet(header.TCPFlagRst) {
- if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == h.iss+1 {
+ if s.flags.Contains(header.TCPFlagRst) {
+ if s.flags.Contains(header.TCPFlagAck) && s.ackNumber == h.iss+1 {
// RFC 793, page 67, states that "If the RST bit is set [and] If the ACK
// was acceptable then signal the user "error: connection reset", drop
// the segment, enter CLOSED state, delete TCB, and return."
@@ -249,7 +243,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error {
// We are in the SYN-SENT state. We only care about segments that have
// the SYN flag.
- if !s.flagIsSet(header.TCPFlagSyn) {
+ if !s.flags.Contains(header.TCPFlagSyn) {
return nil
}
@@ -270,7 +264,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error {
// If this is a SYN ACK response, we only need to acknowledge the SYN
// and the handshake is completed.
- if s.flagIsSet(header.TCPFlagAck) {
+ if s.flags.Contains(header.TCPFlagAck) {
h.state = handshakeCompleted
h.ep.transitionToStateEstablishedLocked(h)
@@ -316,7 +310,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error {
// synRcvdState handles a segment received when the TCP 3-way handshake is in
// the SYN-RCVD state.
func (h *handshake) synRcvdState(s *segment) tcpip.Error {
- if s.flagIsSet(header.TCPFlagRst) {
+ if s.flags.Contains(header.TCPFlagRst) {
// RFC 793, page 37, states that in the SYN-RCVD state, a reset
// is acceptable if the sequence number is in the window.
if s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) {
@@ -340,13 +334,13 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error {
return nil
}
- if s.flagIsSet(header.TCPFlagSyn) && s.sequenceNumber != h.ackNum-1 {
+ if s.flags.Contains(header.TCPFlagSyn) && s.sequenceNumber != h.ackNum-1 {
// We received two SYN segments with different sequence
// numbers, so we reset this and restart the whole
// process, except that we don't reset the timer.
ack := s.sequenceNumber.Add(s.logicalLen())
seq := seqnum.Value(0)
- if s.flagIsSet(header.TCPFlagAck) {
+ if s.flags.Contains(header.TCPFlagAck) {
seq = s.ackNumber
}
h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0)
@@ -378,10 +372,10 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error {
// We have previously received (and acknowledged) the peer's SYN. If the
// peer acknowledges our SYN, the handshake is completed.
- if s.flagIsSet(header.TCPFlagAck) {
+ if s.flags.Contains(header.TCPFlagAck) {
// If deferAccept is not zero and this is a bare ACK and the
// timeout is not hit then drop the ACK.
- if h.deferAccept != 0 && s.data.Size() == 0 && time.Since(h.startTime) < h.deferAccept {
+ if h.deferAccept != 0 && s.data.Size() == 0 && h.ep.stack.Clock().NowMonotonic().Sub(h.startTime) < h.deferAccept {
h.acked = true
h.ep.stack.Stats().DroppedPackets.Increment()
return nil
@@ -412,11 +406,11 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error {
h.ep.transitionToStateEstablishedLocked(h)
- // If the segment has data then requeue it for the receiver
- // to process it again once main loop is started.
- if s.data.Size() > 0 {
+ // Requeue the segment if the ACK completing the handshake has more info
+ // to be procesed by the newly established endpoint.
+ if (s.flags.Contains(header.TCPFlagFin) || s.data.Size() > 0) && h.ep.enqueueSegment(s) {
s.incRef()
- h.ep.enqueueSegment(s)
+ h.ep.newSegmentWaker.Assert()
}
return nil
}
@@ -426,7 +420,7 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error {
func (h *handshake) handleSegment(s *segment) tcpip.Error {
h.sndWnd = s.window
- if !s.flagIsSet(header.TCPFlagSyn) && h.sndWndScale > 0 {
+ if !s.flags.Contains(header.TCPFlagSyn) && h.sndWndScale > 0 {
h.sndWnd <<= uint8(h.sndWndScale)
}
@@ -474,7 +468,7 @@ func (h *handshake) processSegments() tcpip.Error {
// start sends the first SYN/SYN-ACK. It does not block, even if link address
// resolution is required.
func (h *handshake) start() {
- h.startTime = time.Now()
+ h.startTime = h.ep.stack.Clock().NowMonotonic()
h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
var sackEnabled tcpip.TCPSACKEnabled
if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil {
@@ -527,7 +521,7 @@ func (h *handshake) complete() tcpip.Error {
defer s.Done()
// Initialize the resend timer.
- timer, err := newBackoffTimer(time.Second, MaxRTO, resendWaker.Assert)
+ timer, err := newBackoffTimer(h.ep.stack.Clock(), time.Second, MaxRTO, resendWaker.Assert)
if err != nil {
return err
}
@@ -552,7 +546,7 @@ func (h *handshake) complete() tcpip.Error {
// The last is required to provide a way for the peer to complete
// the connection with another ACK or data (as ACKs are never
// retransmitted on their own).
- if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept {
+ if h.active || !h.acked || h.deferAccept != 0 && h.ep.stack.Clock().NowMonotonic().Sub(h.startTime) > h.deferAccept {
h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.TransportEndpointInfo.ID,
ttl: h.ep.ttl,
@@ -608,15 +602,15 @@ func (h *handshake) complete() tcpip.Error {
type backoffTimer struct {
timeout time.Duration
maxTimeout time.Duration
- t *time.Timer
+ t tcpip.Timer
}
-func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, tcpip.Error) {
+func newBackoffTimer(clock tcpip.Clock, timeout, maxTimeout time.Duration, f func()) (*backoffTimer, tcpip.Error) {
if timeout > maxTimeout {
return nil, &tcpip.ErrTimeout{}
}
bt := &backoffTimer{timeout: timeout, maxTimeout: maxTimeout}
- bt.t = time.AfterFunc(timeout, f)
+ bt.t = clock.AfterFunc(timeout, f)
return bt, nil
}
@@ -634,7 +628,7 @@ func (bt *backoffTimer) stop() {
}
func parseSynSegmentOptions(s *segment) header.TCPSynOptions {
- synOpts := header.ParseSynOptions(s.options, s.flagIsSet(header.TCPFlagAck))
+ synOpts := header.ParseSynOptions(s.options, s.flags.Contains(header.TCPFlagAck))
if synOpts.TS {
s.parsedOptions.TSVal = synOpts.TSVal
s.parsedOptions.TSEcr = synOpts.TSEcr
@@ -1188,11 +1182,11 @@ func (e *endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error)
// the TCPEndpointState after the segment is processed.
defer e.probeSegmentLocked()
- if s.flagIsSet(header.TCPFlagRst) {
+ if s.flags.Contains(header.TCPFlagRst) {
if ok, err := e.handleReset(s); !ok {
return false, err
}
- } else if s.flagIsSet(header.TCPFlagSyn) {
+ } else if s.flags.Contains(header.TCPFlagSyn) {
// See: https://tools.ietf.org/html/rfc5961#section-4.1
// 1) If the SYN bit is set, irrespective of the sequence number, TCP
// MUST send an ACK (also referred to as challenge ACK) to the remote
@@ -1216,7 +1210,7 @@ func (e *endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error)
// should then rely on SYN retransmission from the remote end to
// re-establish the connection.
e.snd.maybeSendOutOfWindowAck(s)
- } else if s.flagIsSet(header.TCPFlagAck) {
+ } else if s.flags.Contains(header.TCPFlagAck) {
// Patch the window size in the segment according to the
// send window scale.
s.window <<= e.snd.SndWndScale
@@ -1235,7 +1229,7 @@ func (e *endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error)
// Now check if the received segment has caused us to transition
// to a CLOSED state, if yes then terminate processing and do
// not invoke the sender.
- state := e.state
+ state := e.EndpointState()
if state == StateClose {
// When we get into StateClose while processing from the queue,
// return immediately and let the protocolMainloop handle it.
@@ -1267,7 +1261,7 @@ func (e *endpoint) keepaliveTimerExpired() tcpip.Error {
// If a userTimeout is set then abort the connection if it is
// exceeded.
- if userTimeout != 0 && time.Since(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 {
+ if userTimeout != 0 && e.stack.Clock().NowMonotonic().Sub(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 {
e.keepalive.Unlock()
e.stack.Stats().TCP.EstablishedTimedout.Increment()
return &tcpip.ErrTimeout{}
@@ -1322,7 +1316,7 @@ func (e *endpoint) disableKeepaliveTimer() {
// segments.
func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error {
e.mu.Lock()
- var closeTimer *time.Timer
+ var closeTimer tcpip.Timer
var closeWaker sleep.Waker
epilogue := func() {
@@ -1480,11 +1474,19 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
return &tcpip.ErrConnectionReset{}
}
- if n&notifyClose != 0 && closeTimer == nil {
- if e.EndpointState() == StateFinWait2 && e.closed {
+ if n&notifyClose != 0 && e.closed {
+ switch e.EndpointState() {
+ case StateEstablished:
+ // Perform full shutdown if the endpoint is still
+ // established. This can occur when notifyClose
+ // was asserted just before becoming established.
+ e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead)
+ case StateFinWait2:
// The socket has been closed and we are in FIN_WAIT2
// so start the FIN_WAIT2 timer.
- closeTimer = time.AfterFunc(e.tcpLingerTimeout, closeWaker.Assert)
+ if closeTimer == nil {
+ closeTimer = e.stack.Clock().AfterFunc(e.tcpLingerTimeout, closeWaker.Assert)
+ }
}
}
@@ -1721,7 +1723,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
var timeWaitWaker sleep.Waker
s.AddWaker(&timeWaitWaker, timeWaitDone)
- timeWaitTimer := time.AfterFunc(timeWaitDuration, timeWaitWaker.Assert)
+ timeWaitTimer := e.stack.Clock().AfterFunc(timeWaitDuration, timeWaitWaker.Assert)
defer timeWaitTimer.Stop()
for {
diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go
index 962f1d687..6985194bb 100644
--- a/pkg/tcpip/transport/tcp/cubic.go
+++ b/pkg/tcpip/transport/tcp/cubic.go
@@ -41,7 +41,7 @@ type cubicState struct {
func newCubicCC(s *sender) *cubicState {
return &cubicState{
TCPCubicState: stack.TCPCubicState{
- T: time.Now(),
+ T: s.ep.stack.Clock().NowMonotonic(),
Beta: 0.7,
C: 0.4,
},
@@ -60,7 +60,7 @@ func (c *cubicState) enterCongestionAvoidance() {
// https://tools.ietf.org/html/rfc8312#section-4.8
if c.numCongestionEvents == 0 {
c.K = 0
- c.T = time.Now()
+ c.T = c.s.ep.stack.Clock().NowMonotonic()
c.WLastMax = c.WMax
c.WMax = float64(c.s.SndCwnd)
}
@@ -115,14 +115,15 @@ func (c *cubicState) cubicCwnd(t float64) float64 {
// getCwnd returns the current congestion window as computed by CUBIC.
// Refer: https://tools.ietf.org/html/rfc8312#section-4
func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int {
- elapsed := time.Since(c.T).Seconds()
+ elapsed := c.s.ep.stack.Clock().NowMonotonic().Sub(c.T)
+ elapsedSeconds := elapsed.Seconds()
// Compute the window as per Cubic after 'elapsed' time
// since last congestion event.
- c.WC = c.cubicCwnd(elapsed - c.K)
+ c.WC = c.cubicCwnd(elapsedSeconds - c.K)
// Compute the TCP friendly estimate of the congestion window.
- c.WEst = c.WMax*c.Beta + (3.0*((1.0-c.Beta)/(1.0+c.Beta)))*(elapsed/srtt.Seconds())
+ c.WEst = c.WMax*c.Beta + (3.0*((1.0-c.Beta)/(1.0+c.Beta)))*(elapsedSeconds/srtt.Seconds())
// Make sure in the TCP friendly region CUBIC performs at least
// as well as Reno.
@@ -134,7 +135,7 @@ func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int
// In Concave/Convex region of CUBIC, calculate what CUBIC window
// will be after 1 RTT and use that to grow congestion window
// for every ack.
- tEst := (time.Since(c.T) + srtt).Seconds()
+ tEst := (elapsed + srtt).Seconds()
wtRtt := c.cubicCwnd(tEst - c.K)
// As per 4.3 for each received ACK cwnd must be incremented
// by (w_cubic(t+RTT) - cwnd/cwnd.
@@ -151,7 +152,7 @@ func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int
func (c *cubicState) HandleLossDetected() {
// See: https://tools.ietf.org/html/rfc8312#section-4.5
c.numCongestionEvents++
- c.T = time.Now()
+ c.T = c.s.ep.stack.Clock().NowMonotonic()
c.WLastMax = c.WMax
c.WMax = float64(c.s.SndCwnd)
@@ -162,7 +163,7 @@ func (c *cubicState) HandleLossDetected() {
// HandleRTOExpired implements congestionContrl.HandleRTOExpired.
func (c *cubicState) HandleRTOExpired() {
// See: https://tools.ietf.org/html/rfc8312#section-4.6
- c.T = time.Now()
+ c.T = c.s.ep.stack.Clock().NowMonotonic()
c.numCongestionEvents = 0
c.WLastMax = c.WMax
c.WMax = float64(c.s.SndCwnd)
@@ -193,7 +194,7 @@ func (c *cubicState) fastConvergence() {
// PostRecovery implemements congestionControl.PostRecovery.
func (c *cubicState) PostRecovery() {
- c.T = time.Now()
+ c.T = c.s.ep.stack.Clock().NowMonotonic()
}
// reduceSlowStartThreshold returns new SsThresh as described in
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index 512053a04..dff7cb89c 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -16,10 +16,11 @@ package tcp
import (
"encoding/binary"
+ "math/rand"
- "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -141,15 +142,16 @@ func (p *processor) start(wg *sync.WaitGroup) {
// in-order.
type dispatcher struct {
processors []processor
- seed uint32
- wg sync.WaitGroup
+ // seed is a random secret for a jenkins hash.
+ seed uint32
+ wg sync.WaitGroup
}
-func (d *dispatcher) init(nProcessors int) {
+func (d *dispatcher) init(rng *rand.Rand, nProcessors int) {
d.close()
d.wait()
d.processors = make([]processor, nProcessors)
- d.seed = generateRandUint32()
+ d.seed = rng.Uint32()
for i := range d.processors {
p := &d.processors[i]
p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker)
@@ -172,12 +174,11 @@ func (d *dispatcher) wait() {
d.wg.Wait()
}
-func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.TransportEndpointID, clock tcpip.Clock, pkt *stack.PacketBuffer) {
ep := stackEP.(*endpoint)
- s := newIncomingSegment(id, pkt)
+ s := newIncomingSegment(id, clock, pkt)
if !s.parse(pkt.RXTransportChecksumValidated) {
- ep.stack.Stats().MalformedRcvdPackets.Increment()
ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
s.decRef()
@@ -185,7 +186,6 @@ func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.Trans
}
if !s.csumValid {
- ep.stack.Stats().MalformedRcvdPackets.Increment()
ep.stack.Stats().TCP.ChecksumErrors.Increment()
ep.stats.ReceiveErrors.ChecksumErrors.Increment()
s.decRef()
@@ -213,14 +213,6 @@ func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.Trans
d.selectProcessor(id).queueEndpoint(ep)
}
-func generateRandUint32() uint32 {
- b := make([]byte, 4)
- if _, err := rand.Read(b); err != nil {
- panic(err)
- }
- return binary.LittleEndian.Uint32(b)
-}
-
func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor {
var payload [4]byte
binary.LittleEndian.PutUint16(payload[0:], id.LocalPort)
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index f148d505d..5342aacfd 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -421,7 +421,7 @@ func testV4Accept(t *testing.T, c *context.Context) {
r.Reset(data)
nep.Write(&r, tcpip.WriteOptions{})
b = c.GetPacket()
- tcp = header.TCP(header.IPv4(b).Payload())
+ tcp = header.IPv4(b).Payload()
if string(tcp.Payload()) != data {
t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data)
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 50d39cbad..a27e2110b 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -20,12 +20,12 @@ import (
"fmt"
"io"
"math"
+ "math/rand"
"runtime"
"strings"
"sync/atomic"
"time"
- "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -38,19 +38,15 @@ import (
)
// EndpointState represents the state of a TCP endpoint.
-type EndpointState uint32
+type EndpointState tcpip.EndpointState
// Endpoint states. Note that are represented in a netstack-specific manner and
// may not be meaningful externally. Specifically, they need to be translated to
// Linux's representation for these states if presented to userspace.
const (
- // Endpoint states internal to netstack. These map to the TCP state CLOSED.
- StateInitial EndpointState = iota
- StateBound
- StateConnecting // Connect() called, but the initial SYN hasn't been sent.
- StateError
-
- // TCP protocol states.
+ _ EndpointState = iota
+ // TCP protocol states in sync with the definitions in
+ // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/tcp_states.h#L13
StateEstablished
StateSynSent
StateSynRecv
@@ -62,6 +58,12 @@ const (
StateLastAck
StateListen
StateClosing
+
+ // Endpoint states internal to netstack.
+ StateInitial
+ StateBound
+ StateConnecting // Connect() called, but the initial SYN hasn't been sent.
+ StateError
)
const (
@@ -97,6 +99,16 @@ func (s EndpointState) connecting() bool {
}
}
+// internal returns true when the state is netstack internal.
+func (s EndpointState) internal() bool {
+ switch s {
+ case StateInitial, StateBound, StateConnecting, StateError:
+ return true
+ default:
+ return false
+ }
+}
+
// handshake returns true when s is one of the states representing an endpoint
// in the middle of a TCP handshake.
func (s EndpointState) handshake() bool {
@@ -422,12 +434,12 @@ type endpoint struct {
// state must be read/set using the EndpointState()/setEndpointState()
// methods.
- state EndpointState `state:".(EndpointState)"`
+ state uint32 `state:".(EndpointState)"`
// origEndpointState is only used during a restore phase to save the
// endpoint state at restore time as the socket is moved to it's correct
// state.
- origEndpointState EndpointState `state:"nosave"`
+ origEndpointState uint32 `state:"nosave"`
isPortReserved bool `state:"manual"`
isRegistered bool `state:"manual"`
@@ -468,7 +480,7 @@ type endpoint struct {
// recentTSTime is the unix time when we last updated
// TCPEndpointStateInner.RecentTS.
- recentTSTime time.Time `state:".(unixTime)"`
+ recentTSTime tcpip.MonotonicTime
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
@@ -626,7 +638,7 @@ type endpoint struct {
// lastOutOfWindowAckTime is the time at which the an ACK was sent in response
// to an out of window segment being received by this endpoint.
- lastOutOfWindowAckTime time.Time `state:".(unixTime)"`
+ lastOutOfWindowAckTime tcpip.MonotonicTime
}
// UniqueID implements stack.TransportEndpoint.UniqueID.
@@ -747,7 +759,7 @@ func (e *endpoint) ResumeWork() {
//
// Precondition: e.mu must be held to call this method.
func (e *endpoint) setEndpointState(state EndpointState) {
- oldstate := EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
+ oldstate := EndpointState(atomic.LoadUint32(&e.state))
switch state {
case StateEstablished:
e.stack.Stats().TCP.CurrentEstablished.Increment()
@@ -764,18 +776,18 @@ func (e *endpoint) setEndpointState(state EndpointState) {
e.stack.Stats().TCP.CurrentEstablished.Decrement()
}
}
- atomic.StoreUint32((*uint32)(&e.state), uint32(state))
+ atomic.StoreUint32(&e.state, uint32(state))
}
// EndpointState returns the current state of the endpoint.
func (e *endpoint) EndpointState() EndpointState {
- return EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
+ return EndpointState(atomic.LoadUint32(&e.state))
}
// setRecentTimestamp sets the recentTS field to the provided value.
func (e *endpoint) setRecentTimestamp(recentTS uint32) {
e.RecentTS = recentTS
- e.recentTSTime = time.Now()
+ e.recentTSTime = e.stack.Clock().NowMonotonic()
}
// recentTimestamp returns the value of the recentTS field.
@@ -806,11 +818,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
},
sndQueueInfo: sndQueueInfo{
TCPSndBufState: stack.TCPSndBufState{
- SndMTU: int(math.MaxInt32),
+ SndMTU: math.MaxInt32,
},
},
waiterQueue: waiterQueue,
- state: StateInitial,
+ state: uint32(StateInitial),
keepalive: keepalive{
// Linux defaults.
idle: 2 * time.Hour,
@@ -870,9 +882,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
}
e.segmentQueue.ep = e
- e.TSOffset = timeStampOffset()
+ e.TSOffset = timeStampOffset(e.stack.Rand())
e.acceptCond = sync.NewCond(&e.acceptMu)
- e.keepalive.timer.init(&e.keepalive.waker)
+ e.keepalive.timer.init(e.stack.Clock(), &e.keepalive.waker)
return e
}
@@ -1189,7 +1201,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
e.rcvQueueInfo.rcvQueueMu.Unlock()
return
}
- now := time.Now()
+ now := e.stack.Clock().NowMonotonic()
if rtt := e.rcvQueueInfo.RcvAutoParams.RTT; rtt == 0 || now.Sub(e.rcvQueueInfo.RcvAutoParams.MeasureTime) < rtt {
e.rcvQueueInfo.RcvAutoParams.CopiedBytes += copied
e.rcvQueueInfo.rcvQueueMu.Unlock()
@@ -1544,7 +1556,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
}
// Add data to the send queue.
- s := newOutgoingSegment(e.TransportEndpointInfo.ID, v)
+ s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), v)
e.sndQueueInfo.SndBufUsed += len(v)
e.sndQueueInfo.SndBufInQueue += seqnum.Size(len(v))
e.sndQueueInfo.sndQueue.PushBack(s)
@@ -1956,6 +1968,11 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
func (e *endpoint) getTCPInfo() tcpip.TCPInfoOption {
info := tcpip.TCPInfoOption{}
e.LockUser()
+ if state := e.EndpointState(); state.internal() {
+ info.State = tcpip.EndpointState(StateClose)
+ } else {
+ info.State = tcpip.EndpointState(state)
+ }
snd := e.snd
if snd != nil {
// We do not calculate RTT before sending the data packets. If
@@ -2198,7 +2215,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
BindToDevice: bindToDevice,
Dest: addr,
}
- if _, err := e.stack.ReservePort(portRes, nil /* testPort */); err != nil {
+ if _, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */); err != nil {
if _, ok := err.(*tcpip.ErrPortInUse); !ok || !reuse {
return false, nil
}
@@ -2224,7 +2241,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
// If the endpoint is not in TIME-WAIT or if it is in TIME-WAIT but
// less than 1 second has elapsed since its recentTS was updated then
// we cannot reuse the port.
- if tcpEP.EndpointState() != StateTimeWait || time.Since(tcpEP.recentTSTime) < 1*time.Second {
+ if tcpEP.EndpointState() != StateTimeWait || e.stack.Clock().NowMonotonic().Sub(tcpEP.recentTSTime) < 1*time.Second {
tcpEP.UnlockUser()
return false, nil
}
@@ -2245,7 +2262,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
BindToDevice: bindToDevice,
Dest: addr,
}
- if _, err := e.stack.ReservePort(portRes, nil /* testPort */); err != nil {
+ if _, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */); err != nil {
return false, nil
}
}
@@ -2370,7 +2387,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error {
}
// Queue fin segment.
- s := newOutgoingSegment(e.TransportEndpointInfo.ID, nil)
+ s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), nil)
e.sndQueueInfo.sndQueue.PushBack(s)
e.sndQueueInfo.SndBufInQueue++
// Mark endpoint as closed.
@@ -2581,7 +2598,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) {
BindToDevice: bindToDevice,
Dest: tcpip.FullAddress{},
}
- port, err := e.stack.ReservePort(portRes, func(p uint16) (bool, tcpip.Error) {
+ port, err := e.stack.ReservePort(e.stack.Rand(), portRes, func(p uint16) (bool, tcpip.Error) {
id := e.TransportEndpointInfo.ID
id.LocalPort = p
// CheckRegisterTransportEndpoint should only return an error if there is a
@@ -2731,7 +2748,7 @@ func (e *endpoint) updateSndBufferUsage(v int) {
// We only notify when there is half the sendBufferSize available after
// a full buffer event occurs. This ensures that we don't wake up
// writers to queue just 1-2 segments and go back to sleep.
- notify = notify && e.sndQueueInfo.SndBufUsed < int(sendBufferSize)>>1
+ notify = notify && e.sndQueueInfo.SndBufUsed < sendBufferSize>>1
e.sndQueueInfo.sndQueueMu.Unlock()
if notify {
@@ -2848,23 +2865,20 @@ func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
// timestamp returns the timestamp value to be used in the TSVal field of the
// timestamp option for outgoing TCP segments for a given endpoint.
func (e *endpoint) timestamp() uint32 {
- return tcpTimeStamp(time.Now(), e.TSOffset)
+ return tcpTimeStamp(e.stack.Clock().NowMonotonic(), e.TSOffset)
}
// tcpTimeStamp returns a timestamp offset by the provided offset. This is
// not inlined above as it's used when SYN cookies are in use and endpoint
// is not created at the time when the SYN cookie is sent.
-func tcpTimeStamp(curTime time.Time, offset uint32) uint32 {
- return uint32(curTime.Unix()*1000+int64(curTime.Nanosecond()/1e6)) + offset
+func tcpTimeStamp(curTime tcpip.MonotonicTime, offset uint32) uint32 {
+ d := curTime.Sub(tcpip.MonotonicTime{})
+ return uint32(d.Milliseconds()) + offset
}
// timeStampOffset returns a randomized timestamp offset to be used when sending
// timestamp values in a timestamp option for a TCP segment.
-func timeStampOffset() uint32 {
- b := make([]byte, 4)
- if _, err := rand.Read(b); err != nil {
- panic(err)
- }
+func timeStampOffset(rng *rand.Rand) uint32 {
// Initialize a random tsOffset that will be added to the recentTS
// everytime the timestamp is sent when the Timestamp option is enabled.
//
@@ -2874,7 +2888,7 @@ func timeStampOffset() uint32 {
// NOTE: This is not completely to spec as normally this should be
// initialized in a manner analogous to how sequence numbers are
// randomized per connection basis. But for now this is sufficient.
- return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
+ return rng.Uint32()
}
// maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint
@@ -2909,7 +2923,7 @@ func (e *endpoint) completeStateLocked() stack.TCPEndpointState {
s := stack.TCPEndpointState{
TCPEndpointStateInner: e.TCPEndpointStateInner,
ID: stack.TCPEndpointID(e.TransportEndpointInfo.ID),
- SegTime: time.Now(),
+ SegTime: e.stack.Clock().NowMonotonic(),
Receiver: e.rcv.TCPReceiverState,
Sender: e.snd.TCPSenderState,
}
@@ -2937,7 +2951,7 @@ func (e *endpoint) completeStateLocked() stack.TCPEndpointState {
if cubic, ok := e.snd.cc.(*cubicState); ok {
s.Sender.Cubic = cubic.TCPCubicState
- s.Sender.Cubic.TimeSinceLastCongestion = time.Since(s.Sender.Cubic.T)
+ s.Sender.Cubic.TimeSinceLastCongestion = e.stack.Clock().NowMonotonic().Sub(s.Sender.Cubic.T)
}
s.Sender.RACKState = e.snd.rc.TCPRACKState
@@ -3029,14 +3043,16 @@ func GetTCPSendBufferLimits(s tcpip.StackHandler) tcpip.SendBufferSizeOption {
// allowOutOfWindowAck returns true if an out-of-window ACK can be sent now.
func (e *endpoint) allowOutOfWindowAck() bool {
- var limit stack.TCPInvalidRateLimitOption
- if err := e.stack.Option(&limit); err != nil {
- panic(fmt.Sprintf("e.stack.Option(%+v) failed with error: %s", limit, err))
- }
+ now := e.stack.Clock().NowMonotonic()
- now := time.Now()
- if now.Sub(e.lastOutOfWindowAckTime) < time.Duration(limit) {
- return false
+ if e.lastOutOfWindowAckTime != (tcpip.MonotonicTime{}) {
+ var limit stack.TCPInvalidRateLimitOption
+ if err := e.stack.Option(&limit); err != nil {
+ panic(fmt.Sprintf("e.stack.Option(%+v) failed with error: %s", limit, err))
+ }
+ if now.Sub(e.lastOutOfWindowAckTime) < time.Duration(limit) {
+ return false
+ }
}
e.lastOutOfWindowAckTime = now
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 6e9777fe4..952ccacdd 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -154,20 +154,25 @@ func (e *endpoint) afterLoad() {
e.origEndpointState = e.state
// Restore the endpoint to InitialState as it will be moved to
// its origEndpointState during Resume.
- e.state = StateInitial
+ e.state = uint32(StateInitial)
// Condition variables and mutexs are not S/R'ed so reinitialize
// acceptCond with e.acceptMu.
e.acceptCond = sync.NewCond(&e.acceptMu)
- e.keepalive.timer.init(&e.keepalive.waker)
stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
+ e.keepalive.timer.init(s.Clock(), &e.keepalive.waker)
+ if snd := e.snd; snd != nil {
+ snd.resendTimer.init(s.Clock(), &snd.resendWaker)
+ snd.reorderTimer.init(s.Clock(), &snd.reorderWaker)
+ snd.probeTimer.init(s.Clock(), &snd.probeWaker)
+ }
e.stack = s
e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits)
e.segmentQueue.thaw()
- epState := e.origEndpointState
+ epState := EndpointState(e.origEndpointState)
switch epState {
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
var ss tcpip.TCPSendBufferSizeRangeOption
@@ -281,32 +286,12 @@ func (e *endpoint) Resume(s *stack.Stack) {
}()
case epState == StateClose:
e.isPortReserved = false
- e.state = StateClose
+ e.state = uint32(StateClose)
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
case epState == StateError:
- e.state = StateError
+ e.state = uint32(StateError)
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
}
-
-// saveRecentTSTime is invoked by stateify.
-func (e *endpoint) saveRecentTSTime() unixTime {
- return unixTime{e.recentTSTime.Unix(), e.recentTSTime.UnixNano()}
-}
-
-// loadRecentTSTime is invoked by stateify.
-func (e *endpoint) loadRecentTSTime(unix unixTime) {
- e.recentTSTime = time.Unix(unix.second, unix.nano)
-}
-
-// saveLastOutOfWindowAckTime is invoked by stateify.
-func (e *endpoint) saveLastOutOfWindowAckTime() unixTime {
- return unixTime{e.lastOutOfWindowAckTime.Unix(), e.lastOutOfWindowAckTime.UnixNano()}
-}
-
-// loadLastOutOfWindowAckTime is invoked by stateify.
-func (e *endpoint) loadLastOutOfWindowAckTime(unix unixTime) {
- e.lastOutOfWindowAckTime = time.Unix(unix.second, unix.nano)
-}
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 2f9fe7ee0..65c86823a 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -65,7 +65,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
- s := newIncomingSegment(id, pkt)
+ s := newIncomingSegment(id, f.stack.Clock(), pkt)
defer s.decRef()
// We only care about well-formed SYN packets.
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index a3d1aa1a3..2fc282e73 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -131,7 +131,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) {
// goroutine which is responsible for dequeuing and doing full TCP dispatch of
// the packet.
func (p *protocol) QueuePacket(ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
- p.dispatcher.queuePacket(ep, id, pkt)
+ p.dispatcher.queuePacket(ep, id, p.stack.Clock(), pkt)
}
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
@@ -142,14 +142,14 @@ func (p *protocol) QueuePacket(ep stack.TransportEndpoint, id stack.TransportEnd
// particular, SYNs addressed to a non-existent connection are rejected by this
// means."
func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
- s := newIncomingSegment(id, pkt)
+ s := newIncomingSegment(id, p.stack.Clock(), pkt)
defer s.decRef()
if !s.parse(pkt.RXTransportChecksumValidated) || !s.csumValid {
return stack.UnknownDestinationPacketMalformed
}
- if !s.flagIsSet(header.TCPFlagRst) {
+ if !s.flags.Contains(header.TCPFlagRst) {
replyWithReset(p.stack, s, stack.DefaultTOS, 0)
}
@@ -181,7 +181,7 @@ func replyWithReset(st *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error {
// reset has sequence number zero and the ACK field is set to the sum
// of the sequence number and segment length of the incoming segment.
// The connection remains in the CLOSED state.
- if s.flagIsSet(header.TCPFlagAck) {
+ if s.flags.Contains(header.TCPFlagAck) {
seq = s.ackNumber
} else {
flags |= header.TCPFlagAck
@@ -401,7 +401,7 @@ func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Er
case *tcpip.TCPTimeWaitReuseOption:
p.mu.RLock()
- *v = tcpip.TCPTimeWaitReuseOption(p.timeWaitReuse)
+ *v = p.timeWaitReuse
p.mu.RUnlock()
return nil
@@ -481,6 +481,6 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol {
// TODO(gvisor.dev/issue/5243): Set recovery to tcpip.TCPRACKLossDetection.
recovery: 0,
}
- p.dispatcher.init(runtime.GOMAXPROCS(0))
+ p.dispatcher.init(s.Rand(), runtime.GOMAXPROCS(0))
return &p
}
diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go
index 9e332dcf7..0da4eafaa 100644
--- a/pkg/tcpip/transport/tcp/rack.go
+++ b/pkg/tcpip/transport/tcp/rack.go
@@ -79,7 +79,7 @@ func (rc *rackControl) init(snd *sender, iss seqnum.Value) {
// update will update the RACK related fields when an ACK has been received.
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-09#section-6.2
func (rc *rackControl) update(seg *segment, ackSeg *segment) {
- rtt := time.Now().Sub(seg.xmitTime)
+ rtt := rc.snd.ep.stack.Clock().NowMonotonic().Sub(seg.xmitTime)
tsOffset := rc.snd.ep.TSOffset
// If the ACK is for a retransmitted packet, do not update if it is a
@@ -115,7 +115,7 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) {
// ending sequence number of the packet which has been acknowledged
// most recently.
endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
- if rc.XmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.XmitTime) && rc.EndSequence.LessThan(endSeq)) {
+ if rc.XmitTime.Before(seg.xmitTime) || (seg.xmitTime == rc.XmitTime && rc.EndSequence.LessThan(endSeq)) {
rc.XmitTime = seg.xmitTime
rc.EndSequence = endSeq
}
@@ -174,7 +174,7 @@ func (s *sender) schedulePTO() {
}
s.rtt.Unlock()
- now := time.Now()
+ now := s.ep.stack.Clock().NowMonotonic()
if s.resendTimer.enabled() {
if now.Add(pto).After(s.resendTimer.target) {
pto = s.resendTimer.target.Sub(now)
@@ -279,7 +279,7 @@ func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) {
// been observed RACK uses reo_wnd of zero during loss recovery, in order to
// retransmit quickly, or when the number of DUPACKs exceeds the classic
// DUPACKthreshold.
-func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) {
+func (rc *rackControl) updateRACKReorderWindow() {
dsackSeen := rc.DSACKSeen
snd := rc.snd
@@ -352,7 +352,7 @@ func (rc *rackControl) exitRecovery() {
// detectLoss marks the segment as lost if the reordering window has elapsed
// and the ACK is not received. It will also arm the reorder timer.
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 Step 5.
-func (rc *rackControl) detectLoss(rcvTime time.Time) int {
+func (rc *rackControl) detectLoss(rcvTime tcpip.MonotonicTime) int {
var timeout time.Duration
numLost := 0
for seg := rc.snd.writeList.Front(); seg != nil && seg.xmitCount != 0; seg = seg.Next() {
@@ -366,7 +366,7 @@ func (rc *rackControl) detectLoss(rcvTime time.Time) int {
}
endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
- if seg.xmitTime.Before(rc.XmitTime) || (seg.xmitTime.Equal(rc.XmitTime) && rc.EndSequence.LessThan(endSeq)) {
+ if seg.xmitTime.Before(rc.XmitTime) || (seg.xmitTime == rc.XmitTime && rc.EndSequence.LessThan(endSeq)) {
timeRemaining := seg.xmitTime.Sub(rcvTime) + rc.RTT + rc.ReoWnd
if timeRemaining <= 0 {
seg.lost = true
@@ -392,7 +392,7 @@ func (rc *rackControl) reorderTimerExpired() tcpip.Error {
return nil
}
- numLost := rc.detectLoss(time.Now())
+ numLost := rc.detectLoss(rc.snd.ep.stack.Clock().NowMonotonic())
if numLost == 0 {
return nil
}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 133371455..661ca604a 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -17,7 +17,6 @@ package tcp
import (
"container/heap"
"math"
- "time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -50,7 +49,7 @@ type receiver struct {
pendingRcvdSegments segmentHeap
// Time when the last ack was received.
- lastRcvdAckTime time.Time `state:".(unixTime)"`
+ lastRcvdAckTime tcpip.MonotonicTime
}
func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver {
@@ -63,7 +62,7 @@ func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale
},
rcvWnd: rcvWnd,
rcvWUP: irs + 1,
- lastRcvdAckTime: time.Now(),
+ lastRcvdAckTime: ep.stack.Clock().NowMonotonic(),
}
}
@@ -137,9 +136,9 @@ func (r *receiver) getSendParams() (RcvNxt seqnum.Value, rcvWnd seqnum.Size) {
// rcvWUP RcvNxt RcvAcc new RcvAcc
// <=====curWnd ===>
// <========= newWnd > curWnd ========= >
- if r.RcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.RcvNxt.Add(seqnum.Size(newWnd))) && toGrow {
+ if r.RcvNxt.Add(curWnd).LessThan(r.RcvNxt.Add(newWnd)) && toGrow {
// If the new window moves the right edge, then update RcvAcc.
- r.RcvAcc = r.RcvNxt.Add(seqnum.Size(newWnd))
+ r.RcvAcc = r.RcvNxt.Add(newWnd)
} else {
if newWnd == 0 {
// newWnd is zero but we can't advertise a zero as it would cause window
@@ -245,7 +244,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
TrimSACKBlockList(&r.ep.sack, r.RcvNxt)
// Handle FIN or FIN-ACK.
- if s.flagIsSet(header.TCPFlagFin) {
+ if s.flags.Contains(header.TCPFlagFin) {
r.RcvNxt++
// Send ACK immediately.
@@ -261,7 +260,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
case StateEstablished:
r.ep.setEndpointState(StateCloseWait)
case StateFinWait1:
- if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt {
+ if s.flags.Contains(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt {
// FIN-ACK, transition to TIME-WAIT.
r.ep.setEndpointState(StateTimeWait)
} else {
@@ -296,7 +295,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// Handle ACK (not FIN-ACK, which we handled above) during one of the
// shutdown states.
- if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt {
+ if s.flags.Contains(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt {
switch r.ep.EndpointState() {
case StateFinWait1:
r.ep.setEndpointState(StateFinWait2)
@@ -325,9 +324,9 @@ func (r *receiver) updateRTT() {
// is first acknowledged and the receipt of data that is at least one
// window beyond the sequence number that was acknowledged.
r.ep.rcvQueueInfo.rcvQueueMu.Lock()
- if r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime.IsZero() {
+ if r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime == (tcpip.MonotonicTime{}) {
// New measurement.
- r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = time.Now()
+ r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = r.ep.stack.Clock().NowMonotonic()
r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber = r.RcvNxt.Add(r.rcvWnd)
r.ep.rcvQueueInfo.rcvQueueMu.Unlock()
return
@@ -336,14 +335,14 @@ func (r *receiver) updateRTT() {
r.ep.rcvQueueInfo.rcvQueueMu.Unlock()
return
}
- rtt := time.Since(r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime)
+ rtt := r.ep.stack.Clock().NowMonotonic().Sub(r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime)
// We only store the minimum observed RTT here as this is only used in
// absence of a SRTT available from either timestamps or a sender
// measurement of RTT.
if r.ep.rcvQueueInfo.RcvAutoParams.RTT == 0 || rtt < r.ep.rcvQueueInfo.RcvAutoParams.RTT {
r.ep.rcvQueueInfo.RcvAutoParams.RTT = rtt
}
- r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = time.Now()
+ r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = r.ep.stack.Clock().NowMonotonic()
r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber = r.RcvNxt.Add(r.rcvWnd)
r.ep.rcvQueueInfo.rcvQueueMu.Unlock()
}
@@ -424,7 +423,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// while the FIN is considered to occur after
// the last actual data octet in a segment in
// which it occurs.
- if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.RcvNxt+1) {
+ if closed && (!s.flags.Contains(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.RcvNxt+1) {
return true, &tcpip.ErrConnectionAborted{}
}
}
@@ -467,11 +466,11 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) {
}
// Store the time of the last ack.
- r.lastRcvdAckTime = time.Now()
+ r.lastRcvdAckTime = r.ep.stack.Clock().NowMonotonic()
// Defer segment processing if it can't be consumed now.
if !r.consumeSegment(s, segSeq, segLen) {
- if segLen > 0 || s.flagIsSet(header.TCPFlagFin) {
+ if segLen > 0 || s.flags.Contains(header.TCPFlagFin) {
// We only store the segment if it's within our buffer size limit.
//
// Only use 75% of the receive buffer queue for out-of-order
@@ -539,7 +538,7 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn
//
// As we do not yet support PAWS, we are being conservative in ignoring
// RSTs by default.
- if s.flagIsSet(header.TCPFlagRst) {
+ if s.flags.Contains(header.TCPFlagRst) {
return false, false
}
@@ -559,13 +558,13 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn
// (2) returns to TIME-WAIT state if the SYN turns out
// to be an old duplicate".
- if s.flagIsSet(header.TCPFlagSyn) && r.RcvNxt.LessThan(segSeq) {
+ if s.flags.Contains(header.TCPFlagSyn) && r.RcvNxt.LessThan(segSeq) {
return false, true
}
// Drop the segment if it does not contain an ACK.
- if !s.flagIsSet(header.TCPFlagAck) {
+ if !s.flags.Contains(header.TCPFlagAck) {
return false, false
}
@@ -574,7 +573,7 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn
r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.MaxSentAck, segSeq)
}
- if segSeq.Add(1) == r.RcvNxt && s.flagIsSet(header.TCPFlagFin) {
+ if segSeq.Add(1) == r.RcvNxt && s.flags.Contains(header.TCPFlagFin) {
// If it's a FIN-ACK then resetTimeWait and send an ACK, as it
// indicates our final ACK could have been lost.
r.ep.snd.sendAck()
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 7e5ba6ef7..ca78c96f2 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -17,7 +17,6 @@ package tcp
import (
"fmt"
"sync/atomic"
- "time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -73,9 +72,9 @@ type segment struct {
parsedOptions header.TCPOptions
options []byte `state:".([]byte)"`
hasNewSACKInfo bool
- rcvdTime time.Time `state:".(unixTime)"`
+ rcvdTime tcpip.MonotonicTime
// xmitTime is the last transmit time of this segment.
- xmitTime time.Time `state:".(unixTime)"`
+ xmitTime tcpip.MonotonicTime
xmitCount uint32
// acked indicates if the segment has already been SACKed.
@@ -88,7 +87,7 @@ type segment struct {
lost bool
}
-func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
+func newIncomingSegment(id stack.TransportEndpointID, clock tcpip.Clock, pkt *stack.PacketBuffer) *segment {
netHdr := pkt.Network()
s := &segment{
refCnt: 1,
@@ -100,17 +99,17 @@ func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *
}
s.data = pkt.Data().ExtractVV().Clone(s.views[:])
s.hdr = header.TCP(pkt.TransportHeader().View())
- s.rcvdTime = time.Now()
+ s.rcvdTime = clock.NowMonotonic()
s.dataMemSize = s.data.Size()
return s
}
-func newOutgoingSegment(id stack.TransportEndpointID, v buffer.View) *segment {
+func newOutgoingSegment(id stack.TransportEndpointID, clock tcpip.Clock, v buffer.View) *segment {
s := &segment{
refCnt: 1,
id: id,
}
- s.rcvdTime = time.Now()
+ s.rcvdTime = clock.NowMonotonic()
if len(v) != 0 {
s.views[0] = v
s.data = buffer.NewVectorisedView(len(v), s.views[:1])
@@ -149,16 +148,6 @@ func (s *segment) merge(oth *segment) {
oth.dataMemSize = oth.data.Size()
}
-// flagIsSet checks if at least one flag in flags is set in s.flags.
-func (s *segment) flagIsSet(flags header.TCPFlags) bool {
- return s.flags&flags != 0
-}
-
-// flagsAreSet checks if all flags in flags are set in s.flags.
-func (s *segment) flagsAreSet(flags header.TCPFlags) bool {
- return s.flags&flags == flags
-}
-
// setOwner sets the owning endpoint for this segment. Its required
// to be called to ensure memory accounting for receive/send buffer
// queues is done properly.
@@ -198,10 +187,10 @@ func (s *segment) incRef() {
// as the data length plus one for each of the SYN and FIN bits set.
func (s *segment) logicalLen() seqnum.Size {
l := seqnum.Size(s.data.Size())
- if s.flagIsSet(header.TCPFlagSyn) {
+ if s.flags.Contains(header.TCPFlagSyn) {
l++
}
- if s.flagIsSet(header.TCPFlagFin) {
+ if s.flags.Contains(header.TCPFlagFin) {
l++
}
return l
@@ -243,7 +232,7 @@ func (s *segment) parse(skipChecksumValidation bool) bool {
return false
}
- s.options = []byte(s.hdr[header.TCPMinimumSize:])
+ s.options = s.hdr[header.TCPMinimumSize:]
s.parsedOptions = header.ParseTCPOptions(s.options)
if skipChecksumValidation {
s.csumValid = true
@@ -262,5 +251,5 @@ func (s *segment) parse(skipChecksumValidation bool) bool {
// sackBlock returns a header.SACKBlock that represents this segment.
func (s *segment) sackBlock() header.SACKBlock {
- return header.SACKBlock{s.sequenceNumber, s.sequenceNumber.Add(s.logicalLen())}
+ return header.SACKBlock{Start: s.sequenceNumber, End: s.sequenceNumber.Add(s.logicalLen())}
}
diff --git a/pkg/tcpip/transport/tcp/segment_state.go b/pkg/tcpip/transport/tcp/segment_state.go
index 7422d8c02..dcfa80f95 100644
--- a/pkg/tcpip/transport/tcp/segment_state.go
+++ b/pkg/tcpip/transport/tcp/segment_state.go
@@ -15,8 +15,6 @@
package tcp
import (
- "time"
-
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -55,23 +53,3 @@ func (s *segment) loadOptions(options []byte) {
// allocated so there is no cost here.
s.options = options
}
-
-// saveRcvdTime is invoked by stateify.
-func (s *segment) saveRcvdTime() unixTime {
- return unixTime{s.rcvdTime.Unix(), s.rcvdTime.UnixNano()}
-}
-
-// loadRcvdTime is invoked by stateify.
-func (s *segment) loadRcvdTime(unix unixTime) {
- s.rcvdTime = time.Unix(unix.second, unix.nano)
-}
-
-// saveXmitTime is invoked by stateify.
-func (s *segment) saveXmitTime() unixTime {
- return unixTime{s.rcvdTime.Unix(), s.rcvdTime.UnixNano()}
-}
-
-// loadXmitTime is invoked by stateify.
-func (s *segment) loadXmitTime(unix unixTime) {
- s.rcvdTime = time.Unix(unix.second, unix.nano)
-}
diff --git a/pkg/tcpip/transport/tcp/segment_test.go b/pkg/tcpip/transport/tcp/segment_test.go
index 486016fc0..2e6ea06f5 100644
--- a/pkg/tcpip/transport/tcp/segment_test.go
+++ b/pkg/tcpip/transport/tcp/segment_test.go
@@ -19,6 +19,7 @@ import (
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -39,10 +40,11 @@ func checkSegmentSize(t *testing.T, name string, seg *segment, want segmentSizeW
}
func TestSegmentMerge(t *testing.T) {
+ var clock faketime.NullClock
id := stack.TransportEndpointID{}
- seg1 := newOutgoingSegment(id, buffer.NewView(10))
+ seg1 := newOutgoingSegment(id, &clock, buffer.NewView(10))
defer seg1.decRef()
- seg2 := newOutgoingSegment(id, buffer.NewView(20))
+ seg2 := newOutgoingSegment(id, &clock, buffer.NewView(20))
defer seg2.decRef()
checkSegmentSize(t, "seg1", seg1, segmentSizeWants{
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index f43e86677..72d58dcff 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -94,7 +94,7 @@ type sender struct {
// firstRetransmittedSegXmitTime is the original transmit time of
// the first segment that was retransmitted due to RTO expiration.
- firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"`
+ firstRetransmittedSegXmitTime tcpip.MonotonicTime
// zeroWindowProbing is set if the sender is currently probing
// for zero receive window.
@@ -169,7 +169,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
SndUna: iss + 1,
SndNxt: iss + 1,
RTTMeasureSeqNum: iss + 1,
- LastSendTime: time.Now(),
+ LastSendTime: ep.stack.Clock().NowMonotonic(),
MaxPayloadSize: maxPayloadSize,
MaxSentAck: irs + 1,
FastRecovery: stack.TCPFastRecoveryState{
@@ -197,9 +197,9 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
s.SndWndScale = uint8(sndWndScale)
}
- s.resendTimer.init(&s.resendWaker)
- s.reorderTimer.init(&s.reorderWaker)
- s.probeTimer.init(&s.probeWaker)
+ s.resendTimer.init(s.ep.stack.Clock(), &s.resendWaker)
+ s.reorderTimer.init(s.ep.stack.Clock(), &s.reorderWaker)
+ s.probeTimer.init(s.ep.stack.Clock(), &s.probeWaker)
s.updateMaxPayloadSize(int(ep.route.MTU()), 0)
@@ -441,7 +441,7 @@ func (s *sender) retransmitTimerExpired() bool {
// timeout since the first retransmission.
uto := s.ep.userTimeout
- if s.firstRetransmittedSegXmitTime.IsZero() {
+ if s.firstRetransmittedSegXmitTime == (tcpip.MonotonicTime{}) {
// We store the original xmitTime of the segment that we are
// about to retransmit as the retransmission time. This is
// required as by the time the retransmitTimer has expired the
@@ -450,7 +450,7 @@ func (s *sender) retransmitTimerExpired() bool {
s.firstRetransmittedSegXmitTime = s.writeList.Front().xmitTime
}
- elapsed := time.Since(s.firstRetransmittedSegXmitTime)
+ elapsed := s.ep.stack.Clock().NowMonotonic().Sub(s.firstRetransmittedSegXmitTime)
remaining := s.maxRTO
if uto != 0 {
// Cap to the user specified timeout if one is specified.
@@ -616,7 +616,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt
// 'S2' that meets the following 3 criteria for determinig
// loss, the sequence range of one segment of up to SMSS
// octects starting with S2 MUST be returned.
- if !s.ep.scoreboard.IsSACKED(header.SACKBlock{segSeq, segSeq.Add(1)}) {
+ if !s.ep.scoreboard.IsSACKED(header.SACKBlock{Start: segSeq, End: segSeq.Add(1)}) {
// NextSeg():
//
// (1.a) S2 is greater than HighRxt
@@ -866,8 +866,8 @@ func (s *sender) enableZeroWindowProbing() {
// We piggyback the probing on the retransmit timer with the
// current retranmission interval, as we may start probing while
// segment retransmissions.
- if s.firstRetransmittedSegXmitTime.IsZero() {
- s.firstRetransmittedSegXmitTime = time.Now()
+ if s.firstRetransmittedSegXmitTime == (tcpip.MonotonicTime{}) {
+ s.firstRetransmittedSegXmitTime = s.ep.stack.Clock().NowMonotonic()
}
s.resendTimer.enable(s.RTO)
}
@@ -875,7 +875,7 @@ func (s *sender) enableZeroWindowProbing() {
func (s *sender) disableZeroWindowProbing() {
s.zeroWindowProbing = false
s.unackZeroWindowProbes = 0
- s.firstRetransmittedSegXmitTime = time.Time{}
+ s.firstRetransmittedSegXmitTime = tcpip.MonotonicTime{}
s.resendTimer.disable()
}
@@ -925,7 +925,7 @@ func (s *sender) sendData() {
// "A TCP SHOULD set cwnd to no more than RW before beginning
// transmission if the TCP has not sent data in the interval exceeding
// the retrasmission timeout."
- if !s.FastRecovery.Active && s.state != tcpip.RTORecovery && time.Now().Sub(s.LastSendTime) > s.RTO {
+ if !s.FastRecovery.Active && s.state != tcpip.RTORecovery && s.ep.stack.Clock().NowMonotonic().Sub(s.LastSendTime) > s.RTO {
if s.SndCwnd > InitialCwnd {
s.SndCwnd = InitialCwnd
}
@@ -1024,7 +1024,7 @@ func (s *sender) SetPipe() {
if segEnd.LessThan(endSeq) {
endSeq = segEnd
}
- sb := header.SACKBlock{startSeq, endSeq}
+ sb := header.SACKBlock{Start: startSeq, End: endSeq}
// SetPipe():
//
// After initializing pipe to zero, the following steps are
@@ -1132,7 +1132,7 @@ func (s *sender) isDupAck(seg *segment) bool {
// (b) The incoming acknowledgment carries no data.
seg.logicalLen() == 0 &&
// (c) The SYN and FIN bits are both off.
- !seg.flagIsSet(header.TCPFlagFin) && !seg.flagIsSet(header.TCPFlagSyn) &&
+ !seg.flags.Intersects(header.TCPFlagFin|header.TCPFlagSyn) &&
// (d) the ACK number is equal to the greatest acknowledgment received on
// the given connection (TCP.UNA from RFC793).
seg.ackNumber == s.SndUna &&
@@ -1234,7 +1234,7 @@ func checkDSACK(rcvdSeg *segment) bool {
func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// Check if we can extract an RTT measurement from this ack.
if !rcvdSeg.parsedOptions.TS && s.RTTMeasureSeqNum.LessThan(rcvdSeg.ackNumber) {
- s.updateRTO(time.Now().Sub(s.RTTMeasureTime))
+ s.updateRTO(s.ep.stack.Clock().NowMonotonic().Sub(s.RTTMeasureTime))
s.RTTMeasureSeqNum = s.SndNxt
}
@@ -1444,7 +1444,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
if s.SndUna == s.SndNxt {
s.Outstanding = 0
// Reset firstRetransmittedSegXmitTime to the zero value.
- s.firstRetransmittedSegXmitTime = time.Time{}
+ s.firstRetransmittedSegXmitTime = tcpip.MonotonicTime{}
s.resendTimer.disable()
s.probeTimer.disable()
}
@@ -1455,7 +1455,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
// * Upon receiving an ACK:
// * Step 4: Update RACK reordering window
- s.rc.updateRACKReorderWindow(rcvdSeg)
+ s.rc.updateRACKReorderWindow()
// After the reorder window is calculated, detect any loss by checking
// if the time elapsed after the segments are sent is greater than the
@@ -1502,7 +1502,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error {
s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment()
}
}
- seg.xmitTime = time.Now()
+ seg.xmitTime = s.ep.stack.Clock().NowMonotonic()
seg.xmitCount++
seg.lost = false
err := s.sendSegmentFromView(seg.data, seg.flags, seg.sequenceNumber)
@@ -1527,7 +1527,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error {
// sendSegmentFromView sends a new segment containing the given payload, flags
// and sequence number.
func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags header.TCPFlags, seq seqnum.Value) tcpip.Error {
- s.LastSendTime = time.Now()
+ s.LastSendTime = s.ep.stack.Clock().NowMonotonic()
if seq == s.RTTMeasureSeqNum {
s.RTTMeasureTime = s.LastSendTime
}
diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go
deleted file mode 100644
index 2f805d8ce..000000000
--- a/pkg/tcpip/transport/tcp/snd_state.go
+++ /dev/null
@@ -1,42 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp
-
-import (
- "time"
-)
-
-// +stateify savable
-type unixTime struct {
- second int64
- nano int64
-}
-
-// afterLoad is invoked by stateify.
-func (s *sender) afterLoad() {
- s.resendTimer.init(&s.resendWaker)
- s.reorderTimer.init(&s.reorderWaker)
- s.probeTimer.init(&s.probeWaker)
-}
-
-// saveFirstRetransmittedSegXmitTime is invoked by stateify.
-func (s *sender) saveFirstRetransmittedSegXmitTime() unixTime {
- return unixTime{s.firstRetransmittedSegXmitTime.Unix(), s.firstRetransmittedSegXmitTime.UnixNano()}
-}
-
-// loadFirstRetransmittedSegXmitTime is invoked by stateify.
-func (s *sender) loadFirstRetransmittedSegXmitTime(unix unixTime) {
- s.firstRetransmittedSegXmitTime = time.Unix(unix.second, unix.nano)
-}
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
index c58361bc1..d6cf786a1 100644
--- a/pkg/tcpip/transport/tcp/tcp_rack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -38,7 +38,7 @@ const (
func setStackRACKPermitted(t *testing.T, c *context.Context) {
t.Helper()
- opt := tcpip.TCPRecovery(tcpip.TCPRACKLossDetection)
+ opt := tcpip.TCPRACKLossDetection
if err := c.Stack().SetTransportProtocolOption(header.TCPProtocolNumber, &opt); err != nil {
t.Fatalf("c.s.SetTransportProtocolOption(%d, &%v(%v)): %s", header.TCPProtocolNumber, opt, opt, err)
}
@@ -50,7 +50,7 @@ func TestRACKUpdate(t *testing.T) {
c := context.New(t, uint32(mtu))
defer c.Cleanup()
- var xmitTime time.Time
+ var xmitTime tcpip.MonotonicTime
probeDone := make(chan struct{})
c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
// Validate that the endpoint Sender.RACKState is what we expect.
@@ -79,7 +79,7 @@ func TestRACKUpdate(t *testing.T) {
}
// Write the data.
- xmitTime = time.Now()
+ xmitTime = c.Stack().Clock().NowMonotonic()
var r bytes.Reader
r.Reset(data)
if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 9916182e3..9bbe9bc3e 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -4259,7 +4259,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(iss),
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -5335,7 +5335,7 @@ func TestKeepalive(t *testing.T) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(next-1)),
+ checker.TCPSeqNum(next-1),
checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck),
),
@@ -5360,12 +5360,7 @@ func TestKeepalive(t *testing.T) {
})
checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(next)),
- checker.TCPAckNum(uint32(0)),
- checker.TCPFlags(header.TCPFlagRst),
- ),
+ checker.TCP(checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst)),
)
if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 {
@@ -5507,7 +5502,7 @@ func TestListenBacklogFull(t *testing.T) {
// Now execute send one more SYN. The stack should not respond as the backlog
// is full at this point.
c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort + uint16(lastPortOffset),
+ SrcPort: context.TestPort + lastPortOffset,
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: seqnum.Value(context.TestInitialSequenceNumber),
@@ -5884,7 +5879,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
r.Reset(data)
newEP.Write(&r, tcpip.WriteOptions{})
pkt := c.GetPacket()
- tcp = header.TCP(header.IPv4(pkt).Payload())
+ tcp = header.IPv4(pkt).Payload()
if string(tcp.Payload()) != data {
t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data)
}
@@ -6118,7 +6113,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
}
pkt := c.GetPacket()
- tcpHdr = header.TCP(header.IPv4(pkt).Payload())
+ tcpHdr = header.IPv4(pkt).Payload()
if string(tcpHdr.Payload()) != data {
t.Fatalf("unexpected data: got %s, want %s", string(tcpHdr.Payload()), data)
}
@@ -6243,6 +6238,54 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
}
}
+func TestListenDropIncrement(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ stats := c.Stack().Stats()
+ c.Create(-1 /*epRcvBuf*/)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ if err := c.EP.Listen(1 /*backlog*/); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ initialDropped := stats.DroppedPackets.Value()
+
+ // Send RST, FIN segments, that are expected to be dropped by the listener.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagRst,
+ })
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagFin,
+ })
+
+ // To ensure that the RST, FIN sent earlier are indeed received and ignored
+ // by the listener, send a SYN and wait for the SYN to be ACKd.
+ irs := seqnum.Value(context.TestInitialSequenceNumber)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ })
+ checker.IPv4(t, c.GetPacket(), checker.TCP(checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
+ checker.TCPAckNum(uint32(irs)+1),
+ ))
+
+ if got, want := stats.DroppedPackets.Value(), initialDropped+2; got != want {
+ t.Fatalf("got stats.DroppedPackets.Value() = %d, want = %d", got, want)
+ }
+}
+
func TestEndpointBindListenAcceptState(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -6375,7 +6418,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// Allocate a large enough payload for the test.
payloadSize := receiveBufferSize * 2
- b := make([]byte, int(payloadSize))
+ b := make([]byte, payloadSize)
worker := (c.EP).(interface {
StopWork()
@@ -6429,7 +6472,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// ack, 1 for the non-zero window
p := c.GetPacket()
checker.IPv4(t, p, checker.TCP(
- checker.TCPAckNum(uint32(wantAckNum)),
+ checker.TCPAckNum(wantAckNum),
func(t *testing.T, h header.Transport) {
tcp, ok := h.(header.TCP)
if !ok {
@@ -6484,14 +6527,14 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
- tsVal := uint32(rawEP.TSVal)
+ tsVal := rawEP.TSVal
rawEP.NextSeqNum--
rawEP.SendPacketWithTS(nil, tsVal)
rawEP.NextSeqNum++
pkt := rawEP.VerifyAndReturnACKWithTS(tsVal)
curRcvWnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale
scaleRcvWnd := func(rcvWnd int) uint16 {
- return uint16(rcvWnd >> uint16(c.WindowScale))
+ return uint16(rcvWnd >> c.WindowScale)
}
// Allocate a large array to send to the endpoint.
b := make([]byte, receiveBufferSize*48)
@@ -6619,19 +6662,16 @@ func TestDelayEnabled(t *testing.T) {
defer c.Cleanup()
checkDelayOption(t, c, false, false) // Delay is disabled by default.
- for _, v := range []struct {
- delayEnabled tcpip.TCPDelayEnabled
- wantDelayOption bool
- }{
- {delayEnabled: false, wantDelayOption: false},
- {delayEnabled: true, wantDelayOption: true},
- } {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &v.delayEnabled); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, v.delayEnabled, v.delayEnabled, err)
- }
- checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption)
+ for _, delayEnabled := range []bool{false, true} {
+ t.Run(fmt.Sprintf("delayEnabled=%t", delayEnabled), func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ opt := tcpip.TCPDelayEnabled(delayEnabled)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, delayEnabled, err)
+ }
+ checkDelayOption(t, c, opt, delayEnabled)
+ })
}
}
@@ -7042,7 +7082,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
// Receive the SYN-ACK reply.
b = c.GetPacket()
- tcpHdr = header.TCP(header.IPv4(b).Payload())
+ tcpHdr = header.IPv4(b).Payload()
c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
ackHeaders = &context.Headers{
@@ -7467,7 +7507,7 @@ func TestTCPUserTimeout(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(next)),
+ checker.TCPSeqNum(next),
checker.TCPAckNum(uint32(0)),
checker.TCPFlags(header.TCPFlagRst),
),
@@ -7545,7 +7585,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
DstPort: c.Port,
Flags: header.TCPFlagAck,
SeqNum: iss,
- AckNum: seqnum.Value(c.IRS + 1),
+ AckNum: c.IRS + 1,
RcvWnd: 30000,
})
diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go
index 38a335840..5645c772e 100644
--- a/pkg/tcpip/transport/tcp/timer.go
+++ b/pkg/tcpip/transport/tcp/timer.go
@@ -15,21 +15,29 @@
package tcp
import (
+ "math"
"time"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/tcpip"
)
type timerState int
const (
+ // The timer is disabled.
timerStateDisabled timerState = iota
+ // The timer is enabled, but the clock timer may be set to an earlier
+ // expiration time due to a previous orphaned state.
timerStateEnabled
+ // The timer is disabled, but the clock timer is enabled, which means that
+ // it will cause a spurious wakeup unless the timer is enabled before the
+ // clock timer fires.
timerStateOrphaned
)
// timer is a timer implementation that reduces the interactions with the
-// runtime timer infrastructure by letting timers run (and potentially
+// clock timer infrastructure by letting timers run (and potentially
// eventually expire) even if they are stopped. It makes it cheaper to
// disable/reenable timers at the expense of spurious wakes. This is useful for
// cases when the same timer is disabled/reenabled repeatedly with relatively
@@ -39,44 +47,37 @@ const (
// (currently at least 200ms), and get disabled when acks are received, and
// reenabled when new pending segments are sent.
//
-// It is advantageous to avoid interacting with the runtime because it acquires
+// It is advantageous to avoid interacting with the clock because it acquires
// a global mutex and performs O(log n) operations, where n is the global number
// of timers, whenever a timer is enabled or disabled, and may make a syscall.
//
// This struct is thread-compatible.
type timer struct {
- // state is the current state of the timer, it can be one of the
- // following values:
- // disabled - the timer is disabled.
- // orphaned - the timer is disabled, but the runtime timer is
- // enabled, which means that it will evetually cause a
- // spurious wake (unless it gets enabled again before
- // then).
- // enabled - the timer is enabled, but the runtime timer may be set
- // to an earlier expiration time due to a previous
- // orphaned state.
state timerState
+ clock tcpip.Clock
+
// target is the expiration time of the current timer. It is only
// meaningful in the enabled state.
- target time.Time
+ target tcpip.MonotonicTime
- // runtimeTarget is the expiration time of the runtime timer. It is
+ // clockTarget is the expiration time of the clock timer. It is
// meaningful in the enabled and orphaned states.
- runtimeTarget time.Time
+ clockTarget tcpip.MonotonicTime
- // timer is the runtime timer used to wait on.
- timer *time.Timer
+ // timer is the clock timer used to wait on.
+ timer tcpip.Timer
}
// init initializes the timer. Once it expires, it the given waker will be
// asserted.
-func (t *timer) init(w *sleep.Waker) {
+func (t *timer) init(clock tcpip.Clock, w *sleep.Waker) {
t.state = timerStateDisabled
+ t.clock = clock
- // Initialize a runtime timer that will assert the waker, then
+ // Initialize a clock timer that will assert the waker, then
// immediately stop it.
- t.timer = time.AfterFunc(time.Hour, func() {
+ t.timer = t.clock.AfterFunc(math.MaxInt64, func() {
w.Assert()
})
t.timer.Stop()
@@ -106,9 +107,9 @@ func (t *timer) checkExpiration() bool {
// The timer is enabled, but it may have expired early. Check if that's
// the case, and if so, reset the runtime timer to the correct time.
- now := time.Now()
+ now := t.clock.NowMonotonic()
if now.Before(t.target) {
- t.runtimeTarget = t.target
+ t.clockTarget = t.target
t.timer.Reset(t.target.Sub(now))
return false
}
@@ -134,11 +135,11 @@ func (t *timer) enabled() bool {
// enable enables the timer, programming the runtime timer if necessary.
func (t *timer) enable(d time.Duration) {
- t.target = time.Now().Add(d)
+ t.target = t.clock.NowMonotonic().Add(d)
// Check if we need to set the runtime timer.
- if t.state == timerStateDisabled || t.target.Before(t.runtimeTarget) {
- t.runtimeTarget = t.target
+ if t.state == timerStateDisabled || t.target.Before(t.clockTarget) {
+ t.clockTarget = t.target
t.timer.Reset(d)
}
diff --git a/pkg/tcpip/transport/tcp/timer_test.go b/pkg/tcpip/transport/tcp/timer_test.go
index dbd6dff54..479752de7 100644
--- a/pkg/tcpip/transport/tcp/timer_test.go
+++ b/pkg/tcpip/transport/tcp/timer_test.go
@@ -19,6 +19,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
)
func TestCleanup(t *testing.T) {
@@ -27,9 +28,11 @@ func TestCleanup(t *testing.T) {
isAssertedTimeoutSeconds = timerDurationSeconds + 1
)
+ clock := faketime.NewManualClock()
+
tmr := timer{}
w := sleep.Waker{}
- tmr.init(&w)
+ tmr.init(clock, &w)
tmr.enable(timerDurationSeconds * time.Second)
tmr.cleanup()
@@ -39,7 +42,7 @@ func TestCleanup(t *testing.T) {
// The waker should not be asserted.
for i := 0; i < isAssertedTimeoutSeconds; i++ {
- time.Sleep(time.Second)
+ clock.Advance(time.Second)
if w.IsAsserted() {
t.Fatalf("waker asserted unexpectedly")
}
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index dd5c910ae..cdc344ab7 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -49,6 +49,7 @@ go_test(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index f7dd50d35..def9d7186 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -17,6 +17,7 @@ package udp
import (
"io"
"sync/atomic"
+ "time"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -34,25 +35,26 @@ type udpPacket struct {
destinationAddress tcpip.FullAddress
packetInfo tcpip.IPPacketInfo
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
- timestamp int64
+ receivedAt time.Time `state:".(int64)"`
// tos stores either the receiveTOS or receiveTClass value.
tos uint8
}
// EndpointState represents the state of a UDP endpoint.
-type EndpointState uint32
+type EndpointState tcpip.EndpointState
// Endpoint states. Note that are represented in a netstack-specific manner and
// may not be meaningful externally. Specifically, they need to be translated to
// Linux's representation for these states if presented to userspace.
const (
- StateInitial EndpointState = iota
+ _ EndpointState = iota
+ StateInitial
StateBound
StateConnected
StateClosed
)
-// String implements fmt.Stringer.String.
+// String implements fmt.Stringer.
func (s EndpointState) String() string {
switch s {
case StateInitial:
@@ -98,7 +100,7 @@ type endpoint struct {
mu sync.RWMutex `state:"nosave"`
// state must be read/set using the EndpointState()/setEndpointState()
// methods.
- state EndpointState
+ state uint32
route *stack.Route `state:"manual"`
dstPort uint16
ttl uint8
@@ -176,7 +178,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
// Linux defaults to TTL=1.
multicastTTL: 1,
multicastMemberships: make(map[multicastMembership]struct{}),
- state: StateInitial,
+ state: uint32(StateInitial),
uniqueID: s.UniqueID(),
}
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
@@ -204,15 +206,15 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
//
// Precondition: e.mu must be held to call this method.
func (e *endpoint) setEndpointState(state EndpointState) {
- atomic.StoreUint32((*uint32)(&e.state), uint32(state))
+ atomic.StoreUint32(&e.state, uint32(state))
}
// EndpointState() returns the current state of the endpoint.
func (e *endpoint) EndpointState() EndpointState {
- return EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
+ return EndpointState(atomic.LoadUint32(&e.state))
}
-// UniqueID implements stack.TransportEndpoint.UniqueID.
+// UniqueID implements stack.TransportEndpoint.
func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
}
@@ -226,14 +228,14 @@ func (e *endpoint) LastError() tcpip.Error {
return err
}
-// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
+// UpdateLastError implements tcpip.SocketOptionsHandler.
func (e *endpoint) UpdateLastError(err tcpip.Error) {
e.lastErrorMu.Lock()
e.lastError = err
e.lastErrorMu.Unlock()
}
-// Abort implements stack.TransportEndpoint.Abort.
+// Abort implements stack.TransportEndpoint.
func (e *endpoint) Abort() {
e.Close()
}
@@ -289,10 +291,10 @@ func (e *endpoint) Close() {
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
}
-// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
-func (e *endpoint) ModerateRecvBuf(copied int) {}
+// ModerateRecvBuf implements tcpip.Endpoint.
+func (*endpoint) ModerateRecvBuf(int) {}
-// Read implements tcpip.Endpoint.Read.
+// Read implements tcpip.Endpoint.
func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
if err := e.LastError(); err != nil {
return tcpip.ReadResult{}, err
@@ -320,7 +322,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// Control Messages
cm := tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: p.timestamp,
+ Timestamp: p.receivedAt.UnixNano(),
}
if e.ops.GetReceiveTOS() {
cm.HasTOS = true
@@ -581,21 +583,21 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
return int64(len(v)), nil
}
-// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet.
+// OnReuseAddressSet implements tcpip.SocketOptionsHandler.
func (e *endpoint) OnReuseAddressSet(v bool) {
e.mu.Lock()
e.portFlags.MostRecent = v
e.mu.Unlock()
}
-// OnReusePortSet implements tcpip.SocketOptionsHandler.OnReusePortSet.
+// OnReusePortSet implements tcpip.SocketOptionsHandler.
func (e *endpoint) OnReusePortSet(v bool) {
e.mu.Lock()
e.portFlags.LoadBalanced = v
e.mu.Unlock()
}
-// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+// SetSockOptInt implements tcpip.Endpoint.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
switch opt {
case tcpip.MTUDiscoverOption:
@@ -629,11 +631,14 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
return nil
}
+var _ tcpip.SocketOptionsHandler = (*endpoint)(nil)
+
+// HasNIC implements tcpip.SocketOptionsHandler.
func (e *endpoint) HasNIC(id int32) bool {
- return id == 0 || e.stack.HasNIC(tcpip.NICID(id))
+ return e.stack.HasNIC(tcpip.NICID(id))
}
-// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
+// SetSockOpt implements tcpip.Endpoint.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
switch v := opt.(type) {
case *tcpip.MulticastInterfaceOption:
@@ -749,7 +754,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
return nil
}
-// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+// GetSockOptInt implements tcpip.Endpoint.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
case tcpip.IPv4TOSOption:
@@ -795,14 +800,14 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
}
}
-// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+// GetSockOpt implements tcpip.Endpoint.
func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
switch o := opt.(type) {
case *tcpip.MulticastInterfaceOption:
e.mu.Lock()
*o = tcpip.MulticastInterfaceOption{
- e.multicastNICID,
- e.multicastAddr,
+ NIC: e.multicastNICID,
+ InterfaceAddr: e.multicastAddr,
}
e.mu.Unlock()
@@ -872,7 +877,7 @@ func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddres
return unwrapped, netProto, nil
}
-// Disconnect implements tcpip.Endpoint.Disconnect.
+// Disconnect implements tcpip.Endpoint.
func (e *endpoint) Disconnect() tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -1075,7 +1080,7 @@ func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id
BindToDevice: bindToDevice,
Dest: tcpip.FullAddress{},
}
- port, err := e.stack.ReservePort(portRes, nil /* testPort */)
+ port, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */)
if err != nil {
return id, bindToDevice, err
}
@@ -1301,12 +1306,12 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
senderAddress: tcpip.FullAddress{
NIC: pkt.NICID,
Addr: id.RemoteAddress,
- Port: header.UDP(hdr).SourcePort(),
+ Port: hdr.SourcePort(),
},
destinationAddress: tcpip.FullAddress{
NIC: pkt.NICID,
Addr: id.LocalAddress,
- Port: header.UDP(hdr).DestinationPort(),
+ Port: hdr.DestinationPort(),
},
data: pkt.Data().ExtractVV(),
}
@@ -1328,7 +1333,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
packet.packetInfo.LocalAddr = localAddr
packet.packetInfo.DestinationAddr = localAddr
packet.packetInfo.NIC = pkt.NICID
- packet.timestamp = e.stack.Clock().NowNanoseconds()
+ packet.receivedAt = e.stack.Clock().Now()
e.rcvMu.Unlock()
@@ -1386,7 +1391,7 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB
}
}
-// State implements tcpip.Endpoint.State.
+// State implements tcpip.Endpoint.
func (e *endpoint) State() uint32 {
return uint32(e.EndpointState())
}
@@ -1405,19 +1410,19 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
return &e.stats
}
-// Wait implements tcpip.Endpoint.Wait.
+// Wait implements tcpip.Endpoint.
func (*endpoint) Wait() {}
func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr)
}
-// SetOwner implements tcpip.Endpoint.SetOwner.
+// SetOwner implements tcpip.Endpoint.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-// SocketOptions implements tcpip.Endpoint.SocketOptions.
+// SocketOptions implements tcpip.Endpoint.
func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
return &e.ops
}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 4aba68b21..1f638c3f6 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -15,26 +15,38 @@
package udp
import (
+ "time"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+// saveReceivedAt is invoked by stateify.
+func (p *udpPacket) saveReceivedAt() int64 {
+ return p.receivedAt.UnixNano()
+}
+
+// loadReceivedAt is invoked by stateify.
+func (p *udpPacket) loadReceivedAt(nsec int64) {
+ p.receivedAt = time.Unix(0, nsec)
+}
+
// saveData saves udpPacket.data field.
-func (u *udpPacket) saveData() buffer.VectorisedView {
- // We cannot save u.data directly as u.data.views may alias to u.views,
+func (p *udpPacket) saveData() buffer.VectorisedView {
+ // We cannot save p.data directly as p.data.views may alias to p.views,
// which is not allowed by state framework (in-struct pointer).
- return u.data.Clone(nil)
+ return p.data.Clone(nil)
}
// loadData loads udpPacket.data field.
-func (u *udpPacket) loadData(data buffer.VectorisedView) {
- // NOTE: We cannot do the u.data = data.Clone(u.views[:]) optimization
+func (p *udpPacket) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
// here because data.views is not guaranteed to be loaded by now. Plus,
// data.views will be allocated anyway so there really is little point
- // of utilizing u.views for data.views.
- u.data = data
+ // of utilizing p.views for data.views.
+ p.data = data
}
// afterLoad is invoked by stateify.
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index 705ad1f64..7c357cb09 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -90,7 +90,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
ep.RegisterNICID = r.pkt.NICID
ep.boundPortFlags = ep.portFlags
- ep.state = StateConnected
+ ep.state = uint32(StateConnected)
ep.rcvMu.Lock()
ep.rcvReady = true
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index dc2e3f493..4008cacf2 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -16,17 +16,16 @@ package udp_test
import (
"bytes"
- "context"
"fmt"
"io/ioutil"
"math/rand"
"testing"
- "time"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
@@ -298,16 +297,18 @@ type testContext struct {
func newDualTestContext(t *testing.T, mtu uint32) *testContext {
t.Helper()
- return newDualTestContextWithOptions(t, mtu, stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
- HandleLocal: true,
- })
+ return newDualTestContextWithHandleLocal(t, mtu, true)
}
-func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Options) *testContext {
+func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal bool) *testContext {
t.Helper()
+ options := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
+ HandleLocal: handleLocal,
+ Clock: &faketime.NullClock{},
+ }
s := stack.New(options)
ep := channel.New(256, mtu, "")
wep := stack.LinkEndpoint(ep)
@@ -378,9 +379,7 @@ func (c *testContext) createEndpointForFlow(flow testFlow) {
func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte {
c.t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
- defer cancel()
- p, ok := c.linkEP.ReadContext(ctx)
+ p, ok := c.linkEP.Read()
if !ok {
c.t.Fatalf("Packet wasn't written out")
return nil
@@ -534,7 +533,9 @@ func newMinPayload(minSize int) []byte {
func TestBindToDeviceOption(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}})
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: &faketime.NullClock{},
+ })
ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
@@ -606,7 +607,7 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
case <-ch:
res, err = c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true})
- case <-time.After(300 * time.Millisecond):
+ default:
if packetShouldBeDropped {
return // expected to time out
}
@@ -820,11 +821,7 @@ func TestV4ReadSelfSource(t *testing.T) {
{"NoHandleLocal", true, &tcpip.ErrWouldBlock{}, 1},
} {
t.Run(tt.name, func(t *testing.T) {
- c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- HandleLocal: tt.handleLocal,
- })
+ c := newDualTestContextWithHandleLocal(t, defaultMTU, tt.handleLocal)
defer c.cleanup()
c.createEndpointForFlow(unicastV4)
@@ -1034,17 +1031,17 @@ func testWriteAndVerifyInternal(c *testContext, flow testFlow, setDest bool, che
payload := testWriteNoVerify(c, flow, setDest)
// Received the packet and check the payload.
b := c.getPacketAndVerify(flow, checkers...)
- var udp header.UDP
+ var udpH header.UDP
if flow.isV4() {
- udp = header.UDP(header.IPv4(b).Payload())
+ udpH = header.IPv4(b).Payload()
} else {
- udp = header.UDP(header.IPv6(b).Payload())
+ udpH = header.IPv6(b).Payload()
}
- if !bytes.Equal(payload, udp.Payload()) {
- c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
+ if !bytes.Equal(payload, udpH.Payload()) {
+ c.t.Fatalf("Bad payload: got %x, want %x", udpH.Payload(), payload)
}
- return udp.SourcePort()
+ return udpH.SourcePort()
}
func testDualWrite(c *testContext) uint16 {
@@ -1198,7 +1195,7 @@ func TestWriteOnConnectedInvalidPort(t *testing.T) {
r.Reset(payload)
n, err := c.ep.Write(&r, writeOpts)
if err != nil {
- c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err)
+ c.t.Fatalf("c.ep.Write(...) = %s, want nil", err)
}
if got, want := n, int64(len(payload)); got != want {
c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want)
@@ -1462,7 +1459,7 @@ func TestReadRecvOriginalDstAddr(t *testing.T) {
name: "IPv4 unicast",
proto: header.IPv4ProtocolNumber,
flow: unicastV4,
- expectedOriginalDstAddr: tcpip.FullAddress{1, stackAddr, stackPort},
+ expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackAddr, Port: stackPort},
},
{
name: "IPv4 multicast",
@@ -1474,7 +1471,7 @@ func TestReadRecvOriginalDstAddr(t *testing.T) {
// behaviour. We still include the test so that once the bug is
// resolved, this test will start to fail and the individual tasked
// with fixing this bug knows to also fix this test :).
- expectedOriginalDstAddr: tcpip.FullAddress{1, multicastAddr, stackPort},
+ expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastAddr, Port: stackPort},
},
{
name: "IPv4 broadcast",
@@ -1486,13 +1483,13 @@ func TestReadRecvOriginalDstAddr(t *testing.T) {
// behaviour. We still include the test so that once the bug is
// resolved, this test will start to fail and the individual tasked
// with fixing this bug knows to also fix this test :).
- expectedOriginalDstAddr: tcpip.FullAddress{1, broadcastAddr, stackPort},
+ expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: broadcastAddr, Port: stackPort},
},
{
name: "IPv6 unicast",
proto: header.IPv6ProtocolNumber,
flow: unicastV6,
- expectedOriginalDstAddr: tcpip.FullAddress{1, stackV6Addr, stackPort},
+ expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackV6Addr, Port: stackPort},
},
{
name: "IPv6 multicast",
@@ -1504,7 +1501,7 @@ func TestReadRecvOriginalDstAddr(t *testing.T) {
// behaviour. We still include the test so that once the bug is
// resolved, this test will start to fail and the individual tasked
// with fixing this bug knows to also fix this test :).
- expectedOriginalDstAddr: tcpip.FullAddress{1, multicastV6Addr, stackPort},
+ expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastV6Addr, Port: stackPort},
},
}
@@ -1614,6 +1611,7 @@ func TestTTL(t *testing.T) {
}
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{p},
+ Clock: &faketime.NullClock{},
})
ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil)
wantTTL = ep.DefaultTTL()
@@ -1759,7 +1757,7 @@ func TestReceiveTosTClass(t *testing.T) {
c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false)
}
- want := true
+ const want = true
optionSetter(want)
got := optionGetter()
@@ -1889,18 +1887,14 @@ func TestV4UnknownDestination(t *testing.T) {
}
}
if !tc.icmpRequired {
- ctx, cancel := context.WithTimeout(context.Background(), time.Second)
- defer cancel()
- if p, ok := c.linkEP.ReadContext(ctx); ok {
+ if p, ok := c.linkEP.Read(); ok {
t.Fatalf("unexpected packet received: %+v", p)
}
return
}
// ICMP required.
- ctx, cancel := context.WithTimeout(context.Background(), time.Second)
- defer cancel()
- p, ok := c.linkEP.ReadContext(ctx)
+ p, ok := c.linkEP.Read()
if !ok {
t.Fatalf("packet wasn't written out")
return
@@ -1987,18 +1981,14 @@ func TestV6UnknownDestination(t *testing.T) {
}
}
if !tc.icmpRequired {
- ctx, cancel := context.WithTimeout(context.Background(), time.Second)
- defer cancel()
- if p, ok := c.linkEP.ReadContext(ctx); ok {
+ if p, ok := c.linkEP.Read(); ok {
t.Fatalf("unexpected packet received: %+v", p)
}
return
}
// ICMP required.
- ctx, cancel := context.WithTimeout(context.Background(), time.Second)
- defer cancel()
- p, ok := c.linkEP.ReadContext(ctx)
+ p, ok := c.linkEP.Read()
if !ok {
t.Fatalf("packet wasn't written out")
return
@@ -2115,8 +2105,8 @@ func TestShortHeader(t *testing.T) {
Data: buf.ToVectorisedView(),
}))
- if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want {
- t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want)
+ if got, want := c.s.Stats().NICs.MalformedL4RcvdPackets.Value(), uint64(1); got != want {
+ t.Errorf("got c.s.Stats().NIC.MalformedL4RcvdPackets.Value() = %d, want = %d", got, want)
}
}
@@ -2124,25 +2114,27 @@ func TestShortHeader(t *testing.T) {
// global and endpoint stats are incremented.
func TestBadChecksumErrors(t *testing.T) {
for _, flow := range []testFlow{unicastV4, unicastV6} {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
+ t.Run(flow.String(), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
- c.createEndpoint(flow.sockProto())
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
+ c.createEndpoint(flow.sockProto())
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
- payload := newPayload()
- c.injectPacket(flow, payload, true /* badChecksum */)
+ payload := newPayload()
+ c.injectPacket(flow, payload, true /* badChecksum */)
- const want = 1
- if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
- t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
- }
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+ })
}
}
@@ -2484,6 +2476,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: &faketime.NullClock{},
})
e := channel.New(0, defaultMTU, "")
if err := s.CreateNIC(nicID1, e); err != nil {
diff --git a/pkg/test/testutil/BUILD b/pkg/test/testutil/BUILD
index a789c246e..7ff13cf12 100644
--- a/pkg/test/testutil/BUILD
+++ b/pkg/test/testutil/BUILD
@@ -12,6 +12,7 @@ go_library(
],
visibility = ["//:sandbox"],
deps = [
+ "//pkg/sentry/watchdog",
"//pkg/sync",
"//runsc/config",
"//runsc/specutils",
diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go
index 663c83679..f6a3e34c7 100644
--- a/pkg/test/testutil/testutil.go
+++ b/pkg/test/testutil/testutil.go
@@ -42,6 +42,7 @@ import (
"github.com/cenkalti/backoff"
specs "github.com/opencontainers/runtime-spec/specs-go"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/sentry/watchdog"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/runsc/config"
"gvisor.dev/gvisor/runsc/specutils"
@@ -184,6 +185,7 @@ func TestConfig(t *testing.T) *config.Config {
conf.Network = config.NetworkNone
conf.Strace = true
conf.TestOnlyAllowRunAsCurrentUserWithoutChroot = true
+ conf.WatchdogAction = watchdog.Panic
return conf
}
diff --git a/pkg/usermem/BUILD b/pkg/usermem/BUILD
index 3dba36f12..d7decd78a 100644
--- a/pkg/usermem/BUILD
+++ b/pkg/usermem/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = [
"bytes_io.go",
"bytes_io_unsafe.go",
+ "marshal.go",
"usermem.go",
],
visibility = ["//:sandbox"],
diff --git a/pkg/usermem/marshal.go b/pkg/usermem/marshal.go
new file mode 100644
index 000000000..5b5a662dc
--- /dev/null
+++ b/pkg/usermem/marshal.go
@@ -0,0 +1,43 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package usermem
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
+)
+
+// IOCopyContext wraps an object implementing hostarch.IO to implement
+// marshal.CopyContext.
+type IOCopyContext struct {
+ Ctx context.Context
+ IO IO
+ Opts IOOpts
+}
+
+// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer.
+func (i *IOCopyContext) CopyScratchBuffer(size int) []byte {
+ return make([]byte, size)
+}
+
+// CopyOutBytes implements marshal.CopyContext.CopyOutBytes.
+func (i *IOCopyContext) CopyOutBytes(addr hostarch.Addr, b []byte) (int, error) {
+ return i.IO.CopyOut(i.Ctx, addr, b, i.Opts)
+}
+
+// CopyInBytes implements marshal.CopyContext.CopyInBytes.
+func (i *IOCopyContext) CopyInBytes(addr hostarch.Addr, b []byte) (int, error) {
+ return i.IO.CopyIn(i.Ctx, addr, b, i.Opts)
+}