summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/BUILD5
-rw-r--r--pkg/abi/linux/signalfd.go45
-rw-r--r--pkg/amutex/BUILD3
-rw-r--r--pkg/atomicbitops/BUILD3
-rw-r--r--pkg/binary/BUILD3
-rw-r--r--pkg/bits/BUILD3
-rw-r--r--pkg/bpf/BUILD4
-rw-r--r--pkg/compressio/BUILD3
-rw-r--r--pkg/cpuid/BUILD4
-rw-r--r--pkg/cpuid/cpuid.go203
-rw-r--r--pkg/cpuid/cpuid_test.go24
-rw-r--r--pkg/eventchannel/BUILD3
-rw-r--r--pkg/fd/BUILD6
-rw-r--r--pkg/fd/fd.go8
-rw-r--r--pkg/fdchannel/BUILD3
-rw-r--r--pkg/flipcall/BUILD4
-rw-r--r--pkg/flipcall/ctrl_futex.go32
-rw-r--r--pkg/flipcall/flipcall.go72
-rw-r--r--pkg/flipcall/flipcall_example_test.go4
-rw-r--r--pkg/flipcall/flipcall_test.go204
-rw-r--r--pkg/flipcall/flipcall_unsafe.go28
-rw-r--r--pkg/flipcall/futex_linux.go23
-rw-r--r--pkg/fspath/BUILD3
-rw-r--r--pkg/gate/BUILD3
-rw-r--r--pkg/ilist/BUILD3
-rw-r--r--pkg/linewriter/BUILD3
-rw-r--r--pkg/log/BUILD3
-rw-r--r--pkg/metric/BUILD10
-rw-r--r--pkg/p9/BUILD6
-rw-r--r--pkg/p9/client.go273
-rw-r--r--pkg/p9/client_test.go50
-rw-r--r--pkg/p9/handlers.go53
-rw-r--r--pkg/p9/messages.go103
-rw-r--r--pkg/p9/p9.go2
-rw-r--r--pkg/p9/p9test/BUILD6
-rw-r--r--pkg/p9/p9test/client_test.go95
-rw-r--r--pkg/p9/p9test/p9test.go4
-rw-r--r--pkg/p9/server.go178
-rw-r--r--pkg/p9/transport.go5
-rw-r--r--pkg/p9/transport_flipcall.go263
-rw-r--r--pkg/p9/transport_test.go4
-rw-r--r--pkg/p9/version.go9
-rw-r--r--pkg/procid/BUILD3
-rw-r--r--pkg/refs/BUILD4
-rw-r--r--pkg/refs/refcounter.go6
-rw-r--r--pkg/seccomp/BUILD4
-rw-r--r--pkg/seccomp/seccomp_unsafe.go2
-rw-r--r--pkg/secio/BUILD3
-rw-r--r--pkg/segment/test/BUILD3
-rw-r--r--pkg/sentry/BUILD2
-rw-r--r--pkg/sentry/arch/BUILD7
-rw-r--r--pkg/sentry/control/BUILD3
-rw-r--r--pkg/sentry/device/BUILD4
-rw-r--r--pkg/sentry/fs/BUILD4
-rw-r--r--pkg/sentry/fs/dirent.go4
-rw-r--r--pkg/sentry/fs/dirent_refs_test.go2
-rw-r--r--pkg/sentry/fs/fdpipe/BUILD4
-rw-r--r--pkg/sentry/fs/file.go23
-rw-r--r--pkg/sentry/fs/file_operations.go9
-rw-r--r--pkg/sentry/fs/file_overlay.go9
-rw-r--r--pkg/sentry/fs/fsutil/BUILD4
-rw-r--r--pkg/sentry/fs/fsutil/file.go6
-rw-r--r--pkg/sentry/fs/fsutil/host_mappable.go2
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached.go74
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached_test.go14
-rw-r--r--pkg/sentry/fs/gofer/BUILD4
-rw-r--r--pkg/sentry/fs/gofer/fs.go22
-rw-r--r--pkg/sentry/fs/gofer/inode.go15
-rw-r--r--pkg/sentry/fs/gofer/session.go27
-rw-r--r--pkg/sentry/fs/host/BUILD4
-rw-r--r--pkg/sentry/fs/host/inode.go10
-rw-r--r--pkg/sentry/fs/host/tty.go15
-rw-r--r--pkg/sentry/fs/inotify.go5
-rw-r--r--pkg/sentry/fs/lock/BUILD4
-rw-r--r--pkg/sentry/fs/mounts.go3
-rw-r--r--pkg/sentry/fs/proc/BUILD5
-rw-r--r--pkg/sentry/fs/proc/net.go201
-rw-r--r--pkg/sentry/fs/proc/seqfile/BUILD4
-rw-r--r--pkg/sentry/fs/ramfs/BUILD4
-rw-r--r--pkg/sentry/fs/splice.go162
-rw-r--r--pkg/sentry/fs/timerfd/timerfd.go3
-rw-r--r--pkg/sentry/fs/tmpfs/BUILD4
-rw-r--r--pkg/sentry/fs/tty/BUILD4
-rw-r--r--pkg/sentry/fsimpl/ext/BUILD6
-rw-r--r--pkg/sentry/fsimpl/ext/benchmark/BUILD2
-rw-r--r--pkg/sentry/fsimpl/ext/directory.go8
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/BUILD4
-rw-r--r--pkg/sentry/fsimpl/ext/ext_test.go4
-rw-r--r--pkg/sentry/fsimpl/memfs/BUILD3
-rw-r--r--pkg/sentry/fsimpl/memfs/directory.go24
-rw-r--r--pkg/sentry/fsimpl/proc/BUILD3
-rw-r--r--pkg/sentry/hostcpu/BUILD4
-rw-r--r--pkg/sentry/hostcpu/getcpu_arm64.s28
-rw-r--r--pkg/sentry/kernel/BUILD10
-rw-r--r--pkg/sentry/kernel/epoll/BUILD4
-rw-r--r--pkg/sentry/kernel/eventfd/BUILD4
-rw-r--r--pkg/sentry/kernel/futex/BUILD4
-rw-r--r--pkg/sentry/kernel/kernel.go129
-rw-r--r--pkg/sentry/kernel/memevent/BUILD7
-rw-r--r--pkg/sentry/kernel/pipe/BUILD4
-rw-r--r--pkg/sentry/kernel/pipe/buffer.go25
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go82
-rw-r--r--pkg/sentry/kernel/pipe/reader_writer.go76
-rw-r--r--pkg/sentry/kernel/posixtimer.go8
-rw-r--r--pkg/sentry/kernel/sched/BUILD3
-rw-r--r--pkg/sentry/kernel/semaphore/BUILD4
-rw-r--r--pkg/sentry/kernel/sessions.go8
-rw-r--r--pkg/sentry/kernel/signalfd/BUILD22
-rw-r--r--pkg/sentry/kernel/signalfd/signalfd.go137
-rw-r--r--pkg/sentry/kernel/task.go8
-rw-r--r--pkg/sentry/kernel/task_identity.go4
-rw-r--r--pkg/sentry/kernel/task_sched.go33
-rw-r--r--pkg/sentry/kernel/task_signals.go18
-rw-r--r--pkg/sentry/kernel/thread_group.go3
-rw-r--r--pkg/sentry/kernel/time/time.go27
-rw-r--r--pkg/sentry/limits/BUILD4
-rw-r--r--pkg/sentry/loader/elf.go20
-rw-r--r--pkg/sentry/loader/loader.go3
-rw-r--r--pkg/sentry/memmap/BUILD4
-rw-r--r--pkg/sentry/mm/BUILD4
-rw-r--r--pkg/sentry/pgalloc/BUILD4
-rw-r--r--pkg/sentry/platform/interrupt/BUILD3
-rw-r--r--pkg/sentry/platform/kvm/BUILD4
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_unsafe.go16
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go20
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux.go4
-rw-r--r--pkg/sentry/platform/ring0/pagetables/BUILD3
-rw-r--r--pkg/sentry/platform/safecopy/BUILD3
-rw-r--r--pkg/sentry/safemem/BUILD3
-rw-r--r--pkg/sentry/sighandling/sighandling_unsafe.go2
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go227
-rw-r--r--pkg/sentry/socket/epsocket/provider.go2
-rw-r--r--pkg/sentry/socket/netlink/port/BUILD4
-rw-r--r--pkg/sentry/socket/rpcinet/BUILD9
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go82
-rw-r--r--pkg/sentry/strace/BUILD7
-rw-r--r--pkg/sentry/strace/linux64.go1
-rw-r--r--pkg/sentry/syscalls/linux/BUILD1
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go5
-rw-r--r--pkg/sentry/syscalls/linux/sys_read.go33
-rw-r--r--pkg/sentry/syscalls/linux/sys_signal.go77
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_splice.go115
-rw-r--r--pkg/sentry/syscalls/linux/sys_time.go39
-rw-r--r--pkg/sentry/syscalls/linux/sys_utsname.go6
-rw-r--r--pkg/sentry/time/BUILD3
-rw-r--r--pkg/sentry/unimpl/BUILD7
-rw-r--r--pkg/sentry/usage/memory.go5
-rw-r--r--pkg/sentry/usermem/BUILD4
-rw-r--r--pkg/sentry/usermem/usermem.go8
-rw-r--r--pkg/sentry/vfs/BUILD3
-rw-r--r--pkg/sentry/vfs/file_description.go7
-rw-r--r--pkg/sleep/BUILD3
-rw-r--r--pkg/state/BUILD3
-rw-r--r--pkg/state/statefile/BUILD3
-rw-r--r--pkg/syserror/BUILD3
-rw-r--r--pkg/tcpip/BUILD4
-rw-r--r--pkg/tcpip/adapters/gonet/BUILD3
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go5
-rw-r--r--pkg/tcpip/buffer/BUILD4
-rw-r--r--pkg/tcpip/checker/checker.go100
-rw-r--r--pkg/tcpip/hash/jenkins/BUILD3
-rw-r--r--pkg/tcpip/header/BUILD4
-rw-r--r--pkg/tcpip/header/icmpv4.go71
-rw-r--r--pkg/tcpip/header/icmpv6.go88
-rw-r--r--pkg/tcpip/header/ipv4.go33
-rw-r--r--pkg/tcpip/header/ipv6.go55
-rw-r--r--pkg/tcpip/header/udp.go5
-rw-r--r--pkg/tcpip/link/channel/channel.go9
-rw-r--r--pkg/tcpip/link/fdbased/BUILD7
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go50
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go3
-rw-r--r--pkg/tcpip/link/fdbased/mmap.go179
-rw-r--r--pkg/tcpip/link/fdbased/mmap_amd64.go194
-rw-r--r--pkg/tcpip/link/fdbased/mmap_stub.go23
-rw-r--r--pkg/tcpip/link/fdbased/mmap_unsafe.go (renamed from pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go)2
-rw-r--r--pkg/tcpip/link/loopback/loopback.go7
-rw-r--r--pkg/tcpip/link/muxed/BUILD3
-rw-r--r--pkg/tcpip/link/muxed/injectable.go12
-rw-r--r--pkg/tcpip/link/muxed/injectable_test.go4
-rw-r--r--pkg/tcpip/link/rawfile/BUILD5
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_arm64.s42
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go (renamed from pkg/tcpip/link/rawfile/blockingpoll_unsafe.go)4
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go (renamed from pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go)8
-rw-r--r--pkg/tcpip/link/sharedmem/BUILD3
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/BUILD3
-rw-r--r--pkg/tcpip/link/sharedmem/queue/BUILD3
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go11
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go4
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go21
-rw-r--r--pkg/tcpip/link/waitable/BUILD3
-rw-r--r--pkg/tcpip/link/waitable/waitable.go10
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go9
-rw-r--r--pkg/tcpip/network/BUILD2
-rw-r--r--pkg/tcpip/network/arp/BUILD3
-rw-r--r--pkg/tcpip/network/arp/arp.go20
-rw-r--r--pkg/tcpip/network/arp/arp_test.go15
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD4
-rw-r--r--pkg/tcpip/network/ip_test.go25
-rw-r--r--pkg/tcpip/network/ipv4/BUILD3
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go4
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go44
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go31
-rw-r--r--pkg/tcpip/network/ipv6/BUILD10
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go61
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go47
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go24
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go266
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go181
-rw-r--r--pkg/tcpip/ports/BUILD3
-rw-r--r--pkg/tcpip/ports/ports.go161
-rw-r--r--pkg/tcpip/ports/ports_test.go169
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go9
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go9
-rw-r--r--pkg/tcpip/stack/BUILD33
-rw-r--r--pkg/tcpip/stack/icmp_rate_limit.go41
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go253
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go79
-rw-r--r--pkg/tcpip/stack/nic.go441
-rw-r--r--pkg/tcpip/stack/registration.go75
-rw-r--r--pkg/tcpip/stack/route.go15
-rw-r--r--pkg/tcpip/stack/stack.go261
-rw-r--r--pkg/tcpip/stack/stack_test.go921
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go227
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go352
-rw-r--r--pkg/tcpip/stack/transport_test.go75
-rw-r--r--pkg/tcpip/tcpip.go106
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go58
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go38
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go37
-rw-r--r--pkg/tcpip/transport/raw/protocol.go9
-rw-r--r--pkg/tcpip/transport/tcp/BUILD5
-rw-r--r--pkg/tcpip/transport/tcp/accept.go32
-rw-r--r--pkg/tcpip/transport/tcp/connect.go19
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go285
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go24
-rw-r--r--pkg/tcpip/transport/tcp/snd.go9
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go10
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go392
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go52
-rw-r--r--pkg/tcpip/transport/tcpconntrack/BUILD3
-rw-r--r--pkg/tcpip/transport/udp/BUILD5
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go151
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go4
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go4
-rw-r--r--pkg/tcpip/transport/udp/protocol.go113
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go290
-rw-r--r--pkg/tmutex/BUILD3
-rw-r--r--pkg/unet/BUILD3
-rw-r--r--pkg/urpc/BUILD3
-rw-r--r--pkg/waiter/BUILD4
252 files changed, 7854 insertions, 2582 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index ba233b93f..f45934466 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -2,9 +2,11 @@
# Linux kernel. It should be used instead of syscall or golang.org/x/sys/unix
# when the host OS may not be Linux.
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "linux",
@@ -44,6 +46,7 @@ go_library(
"sem.go",
"shm.go",
"signal.go",
+ "signalfd.go",
"socket.go",
"splice.go",
"tcp.go",
diff --git a/pkg/abi/linux/signalfd.go b/pkg/abi/linux/signalfd.go
new file mode 100644
index 000000000..85fad9956
--- /dev/null
+++ b/pkg/abi/linux/signalfd.go
@@ -0,0 +1,45 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+const (
+ // SFD_NONBLOCK is a signalfd(2) flag.
+ SFD_NONBLOCK = 00004000
+
+ // SFD_CLOEXEC is a signalfd(2) flag.
+ SFD_CLOEXEC = 02000000
+)
+
+// SignalfdSiginfo is the siginfo encoding for signalfds.
+type SignalfdSiginfo struct {
+ Signo uint32
+ Errno int32
+ Code int32
+ PID uint32
+ UID uint32
+ FD int32
+ TID uint32
+ Band uint32
+ Overrun uint32
+ TrapNo uint32
+ Status int32
+ Int int32
+ Ptr uint64
+ UTime uint64
+ STime uint64
+ Addr uint64
+ AddrLSB uint16
+ _ [48]uint8
+}
diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD
index 39d253b98..6bc486b62 100644
--- a/pkg/amutex/BUILD
+++ b/pkg/amutex/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD
index 47ab65346..5f59866fa 100644
--- a/pkg/atomicbitops/BUILD
+++ b/pkg/atomicbitops/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/binary/BUILD b/pkg/binary/BUILD
index 09d6c2c1f..543fb54bf 100644
--- a/pkg/binary/BUILD
+++ b/pkg/binary/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/bits/BUILD b/pkg/bits/BUILD
index 0c2dde4f8..51967b811 100644
--- a/pkg/bits/BUILD
+++ b/pkg/bits/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/bpf/BUILD b/pkg/bpf/BUILD
index b692aa3b1..8d31e068c 100644
--- a/pkg/bpf/BUILD
+++ b/pkg/bpf/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "bpf",
diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD
index cdec96df1..a0b21d4bd 100644
--- a/pkg/compressio/BUILD
+++ b/pkg/compressio/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/cpuid/BUILD b/pkg/cpuid/BUILD
index 830e19e07..32422f9e2 100644
--- a/pkg/cpuid/BUILD
+++ b/pkg/cpuid/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "cpuid",
diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go
index 3fabaf445..5d61dc2ff 100644
--- a/pkg/cpuid/cpuid.go
+++ b/pkg/cpuid/cpuid.go
@@ -418,6 +418,73 @@ var x86FeatureParseOnlyStrings = map[Feature]string{
X86FeaturePREFETCHWT1: "prefetchwt1",
}
+// intelCacheDescriptors describe the caches and TLBs on the system. They are
+// returned in the registers for eax=2. Intel only.
+type intelCacheDescriptor uint8
+
+// Valid cache/TLB descriptors. All descriptors can be found in Intel SDM Vol.
+// 2, Ch. 3.2, "CPUID", Table 3-12 "Encoding of CPUID Leaf 2 Descriptors".
+const (
+ intelNullDescriptor intelCacheDescriptor = 0
+ intelNoTLBDescriptor intelCacheDescriptor = 0xfe
+ intelNoCacheDescriptor intelCacheDescriptor = 0xff
+
+ // Most descriptors omitted for brevity as they are currently unused.
+)
+
+// CacheType describes the type of a cache, as returned in eax[4:0] for eax=4.
+type CacheType uint8
+
+const (
+ // cacheNull indicates that there are no more entries.
+ cacheNull CacheType = iota
+
+ // CacheData is a data cache.
+ CacheData
+
+ // CacheInstruction is an instruction cache.
+ CacheInstruction
+
+ // CacheUnified is a unified instruction and data cache.
+ CacheUnified
+)
+
+// Cache describes the parameters of a single cache on the system.
+//
+// +stateify savable
+type Cache struct {
+ // Level is the hierarchical level of this cache (L1, L2, etc).
+ Level uint32
+
+ // Type is the type of cache.
+ Type CacheType
+
+ // FullyAssociative indicates that entries may be placed in any block.
+ FullyAssociative bool
+
+ // Partitions is the number of physical partitions in the cache.
+ Partitions uint32
+
+ // Ways is the number of ways of associativity in the cache.
+ Ways uint32
+
+ // Sets is the number of sets in the cache.
+ Sets uint32
+
+ // InvalidateHierarchical indicates that WBINVD/INVD from threads
+ // sharing this cache acts upon lower level caches for threads sharing
+ // this cache.
+ InvalidateHierarchical bool
+
+ // Inclusive indicates that this cache is inclusive of lower cache
+ // levels.
+ Inclusive bool
+
+ // DirectMapped indicates that this cache is directly mapped from
+ // address, rather than using a hash function.
+ DirectMapped bool
+}
+
// Just a way to wrap cpuid function numbers.
type cpuidFunction uint32
@@ -494,7 +561,7 @@ func (f Feature) flagString(cpuinfoOnly bool) string {
return ""
}
-// FeatureSet is a set of Features for a cpu.
+// FeatureSet is a set of Features for a CPU.
//
// +stateify savable
type FeatureSet struct {
@@ -521,6 +588,15 @@ type FeatureSet struct {
// SteppingID is part of the processor signature.
SteppingID uint8
+
+ // Caches describes the caches on the CPU.
+ Caches []Cache
+
+ // CacheLine is the size of a cache line in bytes.
+ //
+ // All caches use the same line size. This is not enforced in the CPUID
+ // encoding, but is true on all known x86 processors.
+ CacheLine uint32
}
// FlagsString prints out supported CPU flags. If cpuinfoOnly is true, it is
@@ -557,22 +633,27 @@ func (fs FeatureSet) CPUInfo(cpu uint) string {
fmt.Fprintln(&b, "wp\t\t: yes")
fmt.Fprintf(&b, "flags\t\t: %s\n", fs.FlagsString(true))
fmt.Fprintf(&b, "bogomips\t: %.02f\n", cpuFreqMHz) // It's bogus anyway.
- fmt.Fprintf(&b, "clflush size\t: %d\n", 64)
- fmt.Fprintf(&b, "cache_alignment\t: %d\n", 64)
+ fmt.Fprintf(&b, "clflush size\t: %d\n", fs.CacheLine)
+ fmt.Fprintf(&b, "cache_alignment\t: %d\n", fs.CacheLine)
fmt.Fprintf(&b, "address sizes\t: %d bits physical, %d bits virtual\n", 46, 48)
fmt.Fprintln(&b, "power management:") // This is always here, but can be blank.
fmt.Fprintln(&b, "") // The /proc/cpuinfo file ends with an extra newline.
return b.String()
}
+const (
+ amdVendorID = "AuthenticAMD"
+ intelVendorID = "GenuineIntel"
+)
+
// AMD returns true if fs describes an AMD CPU.
func (fs *FeatureSet) AMD() bool {
- return fs.VendorID == "AuthenticAMD"
+ return fs.VendorID == amdVendorID
}
// Intel returns true if fs describes an Intel CPU.
func (fs *FeatureSet) Intel() bool {
- return fs.VendorID == "GenuineIntel"
+ return fs.VendorID == intelVendorID
}
// ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a
@@ -589,9 +670,18 @@ func (e ErrIncompatible) Error() string {
// CheckHostCompatible returns nil if fs is a subset of the host feature set.
func (fs *FeatureSet) CheckHostCompatible() error {
hfs := HostFeatureSet()
+
if diff := fs.Subtract(hfs); diff != nil {
return ErrIncompatible{fmt.Sprintf("CPU feature set %v incompatible with host feature set %v (missing: %v)", fs.FlagsString(false), hfs.FlagsString(false), diff)}
}
+
+ // The size of a cache line must match, as it is critical to correctly
+ // utilizing CLFLUSH. Other cache properties are allowed to change, as
+ // they are not important to correctness.
+ if fs.CacheLine != hfs.CacheLine {
+ return ErrIncompatible{fmt.Sprintf("CPU cache line size %d incompatible with host cache line size %d", fs.CacheLine, hfs.CacheLine)}
+ }
+
return nil
}
@@ -732,14 +822,6 @@ func (fs *FeatureSet) HasFeature(feature Feature) bool {
return fs.Set[feature]
}
-// IsSubset returns true if the FeatureSet is a subset of the FeatureSet passed in.
-// This is useful if you want to see if a FeatureSet is compatible with another
-// FeatureSet, since you can only run with a given FeatureSet if it's a subset of
-// the host's.
-func (fs *FeatureSet) IsSubset(other *FeatureSet) bool {
- return fs.Subtract(other) == nil
-}
-
// Subtract returns the features present in fs that are not present in other.
// If all features in fs are present in other, Subtract returns nil.
func (fs *FeatureSet) Subtract(other *FeatureSet) (diff map[Feature]bool) {
@@ -755,17 +837,6 @@ func (fs *FeatureSet) Subtract(other *FeatureSet) (diff map[Feature]bool) {
return
}
-// TakeFeatureIntersection will set the features in `fs` to the intersection of
-// the features in `fs` and `other` (effectively clearing any feature bits on
-// `fs` that are not also set in `other`).
-func (fs *FeatureSet) TakeFeatureIntersection(other *FeatureSet) {
- for f := range fs.Set {
- if !other.Set[f] {
- delete(fs.Set, f)
- }
- }
-}
-
// EmulateID emulates a cpuid instruction based on the feature set.
func (fs *FeatureSet) EmulateID(origAx, origCx uint32) (ax, bx, cx, dx uint32) {
switch cpuidFunction(origAx) {
@@ -773,9 +844,8 @@ func (fs *FeatureSet) EmulateID(origAx, origCx uint32) (ax, bx, cx, dx uint32) {
ax = uint32(xSaveInfo) // 0xd (xSaveInfo) is the highest function we support.
bx, dx, cx = fs.vendorIDRegs()
case featureInfo:
- // clflush line size (ebx bits[15:8]) hardcoded as 8. This
- // means cache lines of size 64 bytes.
- bx = 8 << 8
+ // CLFLUSH line size is encoded in quadwords. Other fields in bx unsupported.
+ bx = (fs.CacheLine / 8) << 8
cx = fs.blockMask(block(0))
dx = fs.blockMask(block(1))
ax = fs.signature()
@@ -789,10 +859,46 @@ func (fs *FeatureSet) EmulateID(origAx, origCx uint32) (ax, bx, cx, dx uint32) {
// will always return 01H. Software should ignore this value
// and not interpret it as an informational descriptor." - SDM
//
- // We do not support exposing cache information, but we do set
- // this fixed field because some language runtimes (dlang) get
- // confused by ax = 0 and will loop infinitely.
- ax = 1
+ // We only support reporting cache parameters via
+ // intelDeterministicCacheParams; report as much here.
+ //
+ // We do not support exposing TLB information at all.
+ ax = 1 | (uint32(intelNoCacheDescriptor) << 8)
+ case intelDeterministicCacheParams:
+ if !fs.Intel() {
+ // Reserved on non-Intel.
+ return 0, 0, 0, 0
+ }
+
+ // cx is the index of the cache to describe.
+ if int(origCx) >= len(fs.Caches) {
+ return uint32(cacheNull), 0, 0, 0
+ }
+ c := fs.Caches[origCx]
+
+ ax = uint32(c.Type)
+ ax |= c.Level << 5
+ ax |= 1 << 8 // Always claim the cache is "self-initializing".
+ if c.FullyAssociative {
+ ax |= 1 << 9
+ }
+ // Processor topology not supported.
+
+ bx = fs.CacheLine - 1
+ bx |= (c.Partitions - 1) << 12
+ bx |= (c.Ways - 1) << 22
+
+ cx = c.Sets - 1
+
+ if !c.InvalidateHierarchical {
+ dx |= 1
+ }
+ if c.Inclusive {
+ dx |= 1 << 1
+ }
+ if !c.DirectMapped {
+ dx |= 1 << 2
+ }
case xSaveInfo:
if !fs.UseXsave() {
return 0, 0, 0, 0
@@ -845,10 +951,41 @@ func HostFeatureSet() *FeatureSet {
vendorID := vendorIDFromRegs(bx, cx, dx)
// eax=1 gets basic features in ecx:edx.
- ax, _, cx, dx := HostID(1, 0)
+ ax, bx, cx, dx := HostID(1, 0)
featureBlock0 := cx
featureBlock1 := dx
ef, em, pt, f, m, sid := signatureSplit(ax)
+ cacheLine := 8 * (bx >> 8) & 0xff
+
+ // eax=4, ecx=i gets details about cache index i. Only supported on Intel.
+ var caches []Cache
+ if vendorID == intelVendorID {
+ // ecx selects the cache index until a null type is returned.
+ for i := uint32(0); ; i++ {
+ ax, bx, cx, dx := HostID(4, i)
+ t := CacheType(ax & 0xf)
+ if t == cacheNull {
+ break
+ }
+
+ lineSize := (bx & 0xfff) + 1
+ if lineSize != cacheLine {
+ panic(fmt.Sprintf("Mismatched cache line size: %d vs %d", lineSize, cacheLine))
+ }
+
+ caches = append(caches, Cache{
+ Type: t,
+ Level: (ax >> 5) & 0x7,
+ FullyAssociative: ((ax >> 9) & 1) == 1,
+ Partitions: ((bx >> 12) & 0x3ff) + 1,
+ Ways: ((bx >> 22) & 0x3ff) + 1,
+ Sets: cx + 1,
+ InvalidateHierarchical: (dx & 1) == 0,
+ Inclusive: ((dx >> 1) & 1) == 1,
+ DirectMapped: ((dx >> 2) & 1) == 0,
+ })
+ }
+ }
// eax=7, ecx=0 gets extended features in ecx:ebx.
_, bx, cx, _ = HostID(7, 0)
@@ -883,6 +1020,8 @@ func HostFeatureSet() *FeatureSet {
Family: f,
Model: m,
SteppingID: sid,
+ CacheLine: cacheLine,
+ Caches: caches,
}
}
diff --git a/pkg/cpuid/cpuid_test.go b/pkg/cpuid/cpuid_test.go
index 6ae14d2da..a707ebb55 100644
--- a/pkg/cpuid/cpuid_test.go
+++ b/pkg/cpuid/cpuid_test.go
@@ -57,24 +57,13 @@ var justFPUandPAE = &FeatureSet{
X86FeaturePAE: true,
}}
-func TestIsSubset(t *testing.T) {
- if !justFPU.IsSubset(justFPUandPAE) {
- t.Errorf("Got %v is not subset of %v, want IsSubset being true", justFPU, justFPUandPAE)
+func TestSubtract(t *testing.T) {
+ if diff := justFPU.Subtract(justFPUandPAE); diff != nil {
+ t.Errorf("Got %v is not subset of %v, want diff (%v) to be nil", justFPU, justFPUandPAE, diff)
}
- if justFPUandPAE.IsSubset(justFPU) {
- t.Errorf("Got %v is a subset of %v, want IsSubset being false", justFPU, justFPUandPAE)
- }
-}
-
-func TestTakeFeatureIntersection(t *testing.T) {
- testFeatures := HostFeatureSet()
- testFeatures.TakeFeatureIntersection(justFPU)
- if !testFeatures.IsSubset(justFPU) {
- t.Errorf("Got more features than expected after intersecting host features with justFPU: %v, want %v", testFeatures.Set, justFPU.Set)
- }
- if !testFeatures.HasFeature(X86FeatureFPU) {
- t.Errorf("Got no features in testFeatures after intersecting, want %v", X86FeatureFPU)
+ if justFPUandPAE.Subtract(justFPU) == nil {
+ t.Errorf("Got %v is a subset of %v, want diff to be nil", justFPU, justFPUandPAE)
}
}
@@ -83,7 +72,7 @@ func TestTakeFeatureIntersection(t *testing.T) {
// if HostFeatureSet gives back junk bits.
func TestHostFeatureSet(t *testing.T) {
hostFeatures := HostFeatureSet()
- if !justFPUandPAE.IsSubset(hostFeatures) {
+ if justFPUandPAE.Subtract(hostFeatures) != nil {
t.Errorf("Got invalid feature set %v from HostFeatureSet()", hostFeatures)
}
}
@@ -175,6 +164,7 @@ func TestEmulateIDBasicFeatures(t *testing.T) {
testFeatures := newEmptyFeatureSet()
testFeatures.Add(X86FeatureCLFSH)
testFeatures.Add(X86FeatureAVX)
+ testFeatures.CacheLine = 64
ax, bx, cx, dx := testFeatures.EmulateID(1, 0)
ECXAVXBit := uint32(1 << uint(X86FeatureAVX))
diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD
index 9961baaa9..71f2abc83 100644
--- a/pkg/eventchannel/BUILD
+++ b/pkg/eventchannel/BUILD
@@ -1,5 +1,6 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/fd/BUILD b/pkg/fd/BUILD
index 785c685a0..c7f549428 100644
--- a/pkg/fd/BUILD
+++ b/pkg/fd/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -7,6 +8,9 @@ go_library(
srcs = ["fd.go"],
importpath = "gvisor.dev/gvisor/pkg/fd",
visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/unet",
+ ],
)
go_test(
diff --git a/pkg/fd/fd.go b/pkg/fd/fd.go
index 83bcfe220..7691b477b 100644
--- a/pkg/fd/fd.go
+++ b/pkg/fd/fd.go
@@ -22,6 +22,8 @@ import (
"runtime"
"sync/atomic"
"syscall"
+
+ "gvisor.dev/gvisor/pkg/unet"
)
// ReadWriter implements io.ReadWriter, io.ReaderAt, and io.WriterAt for fd. It
@@ -185,6 +187,12 @@ func OpenAt(dir *FD, path string, flags int, mode uint32) (*FD, error) {
return New(f), nil
}
+// DialUnix connects to a Unix Domain Socket and return the file descriptor.
+func DialUnix(path string) (*FD, error) {
+ socket, err := unet.Connect(path, false)
+ return New(socket.FD()), err
+}
+
// Close closes the file descriptor contained in the FD.
//
// Close is safe to call multiple times, but will return an error after the
diff --git a/pkg/fdchannel/BUILD b/pkg/fdchannel/BUILD
index e54e7371c..56495cbd9 100644
--- a/pkg/fdchannel/BUILD
+++ b/pkg/fdchannel/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD
index bd1d614b6..5643d5f26 100644
--- a/pkg/flipcall/BUILD
+++ b/pkg/flipcall/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -18,6 +19,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/log",
"//pkg/memutil",
+ "//third_party/gvsync",
],
)
diff --git a/pkg/flipcall/ctrl_futex.go b/pkg/flipcall/ctrl_futex.go
index 865b6f640..8390915a2 100644
--- a/pkg/flipcall/ctrl_futex.go
+++ b/pkg/flipcall/ctrl_futex.go
@@ -82,6 +82,7 @@ func (ep *Endpoint) ctrlWaitFirst() error {
*ep.dataLen() = w.Len()
// Return control to the client.
+ raceBecomeInactive()
if err := ep.futexSwitchToPeer(); err != nil {
return err
}
@@ -121,7 +122,16 @@ func (ep *Endpoint) enterFutexWait() error {
}
func (ep *Endpoint) exitFutexWait() {
- atomic.AddInt32(&ep.ctrl.state, -epsBlocked)
+ switch eps := atomic.AddInt32(&ep.ctrl.state, -epsBlocked); eps {
+ case 0:
+ return
+ case epsShutdown:
+ // ep.ctrlShutdown() was called while we were blocked, so we are
+ // repsonsible for indicating connection shutdown.
+ ep.shutdownConn()
+ default:
+ panic(fmt.Sprintf("invalid flipcall.Endpoint.ctrl.state after flipcall.Endpoint.exitFutexWait(): %v", eps+epsBlocked))
+ }
}
func (ep *Endpoint) ctrlShutdown() {
@@ -142,5 +152,25 @@ func (ep *Endpoint) ctrlShutdown() {
break
}
}
+ } else {
+ // There is no blocked thread, so we are responsible for indicating
+ // connection shutdown.
+ ep.shutdownConn()
+ }
+}
+
+func (ep *Endpoint) shutdownConn() {
+ switch cs := atomic.SwapUint32(ep.connState(), csShutdown); cs {
+ case ep.activeState:
+ if err := ep.futexWakeConnState(1); err != nil {
+ log.Warningf("failed to FUTEX_WAKE peer Endpoint for shutdown: %v", err)
+ }
+ case ep.inactiveState:
+ // The peer is currently active and will detect shutdown when it tries
+ // to update the connection state.
+ case csShutdown:
+ // The peer also called Endpoint.Shutdown().
+ default:
+ log.Warningf("unexpected connection state before Endpoint.shutdownConn(): %v", cs)
}
}
diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go
index 5c9212c33..386cee42c 100644
--- a/pkg/flipcall/flipcall.go
+++ b/pkg/flipcall/flipcall.go
@@ -42,11 +42,6 @@ type Endpoint struct {
// dataCap is immutable.
dataCap uint32
- // shutdown is non-zero if Endpoint.Shutdown() has been called, or if the
- // Endpoint has acknowledged shutdown initiated by the peer. shutdown is
- // accessed using atomic memory operations.
- shutdown uint32
-
// activeState is csClientActive if this is a client Endpoint and
// csServerActive if this is a server Endpoint.
activeState uint32
@@ -55,9 +50,27 @@ type Endpoint struct {
// csClientActive if this is a server Endpoint.
inactiveState uint32
+ // shutdown is non-zero if Endpoint.Shutdown() has been called, or if the
+ // Endpoint has acknowledged shutdown initiated by the peer. shutdown is
+ // accessed using atomic memory operations.
+ shutdown uint32
+
ctrl endpointControlImpl
}
+// EndpointSide indicates which side of a connection an Endpoint belongs to.
+type EndpointSide int
+
+const (
+ // ClientSide indicates that an Endpoint is a client (initially-active;
+ // first method call should be Connect).
+ ClientSide EndpointSide = iota
+
+ // ServerSide indicates that an Endpoint is a server (initially-inactive;
+ // first method call should be RecvFirst.)
+ ServerSide
+)
+
// Init must be called on zero-value Endpoints before first use. If it
// succeeds, ep.Destroy() must be called once the Endpoint is no longer in use.
//
@@ -65,7 +78,17 @@ type Endpoint struct {
// Endpoint. FD may differ between Endpoints if they are in different
// processes, but must represent the same file. The packet window must
// initially be filled with zero bytes.
-func (ep *Endpoint) Init(pwd PacketWindowDescriptor, opts ...EndpointOption) error {
+func (ep *Endpoint) Init(side EndpointSide, pwd PacketWindowDescriptor, opts ...EndpointOption) error {
+ switch side {
+ case ClientSide:
+ ep.activeState = csClientActive
+ ep.inactiveState = csServerActive
+ case ServerSide:
+ ep.activeState = csServerActive
+ ep.inactiveState = csClientActive
+ default:
+ return fmt.Errorf("invalid EndpointSide: %v", side)
+ }
if pwd.Length < pageSize {
return fmt.Errorf("packet window size (%d) less than minimum (%d)", pwd.Length, pageSize)
}
@@ -78,9 +101,6 @@ func (ep *Endpoint) Init(pwd PacketWindowDescriptor, opts ...EndpointOption) err
}
ep.packet = m
ep.dataCap = uint32(pwd.Length) - uint32(PacketHeaderBytes)
- // These will be overwritten by ep.Connect() for client Endpoints.
- ep.activeState = csServerActive
- ep.inactiveState = csClientActive
if err := ep.ctrlInit(opts...); err != nil {
ep.unmapPacket()
return err
@@ -90,9 +110,9 @@ func (ep *Endpoint) Init(pwd PacketWindowDescriptor, opts ...EndpointOption) err
// NewEndpoint is a convenience function that returns an initialized Endpoint
// allocated on the heap.
-func NewEndpoint(pwd PacketWindowDescriptor, opts ...EndpointOption) (*Endpoint, error) {
+func NewEndpoint(side EndpointSide, pwd PacketWindowDescriptor, opts ...EndpointOption) (*Endpoint, error) {
var ep Endpoint
- if err := ep.Init(pwd, opts...); err != nil {
+ if err := ep.Init(side, pwd, opts...); err != nil {
return nil, err
}
return &ep, nil
@@ -115,9 +135,9 @@ func (ep *Endpoint) unmapPacket() {
}
// Shutdown causes concurrent and future calls to ep.Connect(), ep.SendRecv(),
-// ep.RecvFirst(), and ep.SendLast() to unblock and return errors. It does not
-// wait for concurrent calls to return. The effect of Shutdown on the peer
-// Endpoint is unspecified. Successive calls to Shutdown have no effect.
+// ep.RecvFirst(), and ep.SendLast(), as well as the same calls in the peer
+// Endpoint, to unblock and return errors. It does not wait for concurrent
+// calls to return. Successive calls to Shutdown have no effect.
//
// Shutdown is the only Endpoint method that may be called concurrently with
// other methods on the same Endpoint.
@@ -152,28 +172,31 @@ const (
// The client is, by definition, initially active, so this must be 0.
csClientActive = 0
csServerActive = 1
+ csShutdown = 2
)
-// Connect designates ep as a client Endpoint and blocks until the peer
-// Endpoint has called Endpoint.RecvFirst().
+// Connect blocks until the peer Endpoint has called Endpoint.RecvFirst().
//
-// Preconditions: ep.Connect(), ep.RecvFirst(), ep.SendRecv(), and
-// ep.SendLast() have never been called.
+// Preconditions: ep is a client Endpoint. ep.Connect(), ep.RecvFirst(),
+// ep.SendRecv(), and ep.SendLast() have never been called.
func (ep *Endpoint) Connect() error {
- ep.activeState = csClientActive
- ep.inactiveState = csServerActive
- return ep.ctrlConnect()
+ err := ep.ctrlConnect()
+ if err == nil {
+ raceBecomeActive()
+ }
+ return err
}
// RecvFirst blocks until the peer Endpoint calls Endpoint.SendRecv(), then
// returns the datagram length specified by that call.
//
-// Preconditions: ep.SendRecv(), ep.RecvFirst(), and ep.SendLast() have never
-// been called.
+// Preconditions: ep is a server Endpoint. ep.SendRecv(), ep.RecvFirst(), and
+// ep.SendLast() have never been called.
func (ep *Endpoint) RecvFirst() (uint32, error) {
if err := ep.ctrlWaitFirst(); err != nil {
return 0, err
}
+ raceBecomeActive()
recvDataLen := atomic.LoadUint32(ep.dataLen())
if recvDataLen > ep.dataCap {
return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, ep.dataCap)
@@ -200,9 +223,11 @@ func (ep *Endpoint) SendRecv(dataLen uint32) (uint32, error) {
// after ep.ctrlRoundTrip(), so if the peer is mutating it concurrently then
// they can only shoot themselves in the foot.
*ep.dataLen() = dataLen
+ raceBecomeInactive()
if err := ep.ctrlRoundTrip(); err != nil {
return 0, err
}
+ raceBecomeActive()
recvDataLen := atomic.LoadUint32(ep.dataLen())
if recvDataLen > ep.dataCap {
return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, ep.dataCap)
@@ -222,6 +247,7 @@ func (ep *Endpoint) SendLast(dataLen uint32) error {
panic(fmt.Sprintf("attempting to send packet with datagram length %d (maximum %d)", dataLen, ep.dataCap))
}
*ep.dataLen() = dataLen
+ raceBecomeInactive()
if err := ep.ctrlWakeLast(); err != nil {
return err
}
diff --git a/pkg/flipcall/flipcall_example_test.go b/pkg/flipcall/flipcall_example_test.go
index edb6a8bef..8d88b845d 100644
--- a/pkg/flipcall/flipcall_example_test.go
+++ b/pkg/flipcall/flipcall_example_test.go
@@ -38,12 +38,12 @@ func Example() {
panic(err)
}
var clientEP Endpoint
- if err := clientEP.Init(pwd); err != nil {
+ if err := clientEP.Init(ClientSide, pwd); err != nil {
panic(err)
}
defer clientEP.Destroy()
var serverEP Endpoint
- if err := serverEP.Init(pwd); err != nil {
+ if err := serverEP.Init(ServerSide, pwd); err != nil {
panic(err)
}
defer serverEP.Destroy()
diff --git a/pkg/flipcall/flipcall_test.go b/pkg/flipcall/flipcall_test.go
index da9d736ab..168a487ec 100644
--- a/pkg/flipcall/flipcall_test.go
+++ b/pkg/flipcall/flipcall_test.go
@@ -39,11 +39,11 @@ func newTestConnectionWithOptions(tb testing.TB, clientOpts, serverOpts []Endpoi
c.pwa.Destroy()
tb.Fatalf("PacketWindowAllocator.Allocate() failed: %v", err)
}
- if err := c.clientEP.Init(pwd, clientOpts...); err != nil {
+ if err := c.clientEP.Init(ClientSide, pwd, clientOpts...); err != nil {
c.pwa.Destroy()
tb.Fatalf("failed to create client Endpoint: %v", err)
}
- if err := c.serverEP.Init(pwd, serverOpts...); err != nil {
+ if err := c.serverEP.Init(ServerSide, pwd, serverOpts...); err != nil {
c.pwa.Destroy()
c.clientEP.Destroy()
tb.Fatalf("failed to create server Endpoint: %v", err)
@@ -62,17 +62,30 @@ func (c *testConnection) destroy() {
}
func testSendRecv(t *testing.T, c *testConnection) {
+ // This shared variable is used to confirm that synchronization between
+ // flipcall endpoints is visible to the Go race detector.
+ state := 0
var serverRun sync.WaitGroup
serverRun.Add(1)
go func() {
defer serverRun.Done()
t.Logf("server Endpoint waiting for packet 1")
if _, err := c.serverEP.RecvFirst(); err != nil {
- t.Fatalf("server Endpoint.RecvFirst() failed: %v", err)
+ t.Errorf("server Endpoint.RecvFirst() failed: %v", err)
+ return
+ }
+ state++
+ if state != 2 {
+ t.Errorf("shared state counter: got %d, wanted 2", state)
}
t.Logf("server Endpoint got packet 1, sending packet 2 and waiting for packet 3")
if _, err := c.serverEP.SendRecv(0); err != nil {
- t.Fatalf("server Endpoint.SendRecv() failed: %v", err)
+ t.Errorf("server Endpoint.SendRecv() failed: %v", err)
+ return
+ }
+ state++
+ if state != 4 {
+ t.Errorf("shared state counter: got %d, wanted 4", state)
}
t.Logf("server Endpoint got packet 3")
}()
@@ -87,10 +100,18 @@ func testSendRecv(t *testing.T, c *testConnection) {
if err := c.clientEP.Connect(); err != nil {
t.Fatalf("client Endpoint.Connect() failed: %v", err)
}
+ state++
+ if state != 1 {
+ t.Errorf("shared state counter: got %d, wanted 1", state)
+ }
t.Logf("client Endpoint sending packet 1 and waiting for packet 2")
if _, err := c.clientEP.SendRecv(0); err != nil {
t.Fatalf("client Endpoint.SendRecv() failed: %v", err)
}
+ state++
+ if state != 3 {
+ t.Errorf("shared state counter: got %d, wanted 3", state)
+ }
t.Logf("client Endpoint got packet 2, sending packet 3")
if err := c.clientEP.SendLast(0); err != nil {
t.Fatalf("client Endpoint.SendLast() failed: %v", err)
@@ -105,7 +126,30 @@ func TestSendRecv(t *testing.T) {
testSendRecv(t, c)
}
-func testShutdownConnect(t *testing.T, c *testConnection) {
+func testShutdownBeforeConnect(t *testing.T, c *testConnection, remoteShutdown bool) {
+ if remoteShutdown {
+ c.serverEP.Shutdown()
+ } else {
+ c.clientEP.Shutdown()
+ }
+ if err := c.clientEP.Connect(); err == nil {
+ t.Errorf("client Endpoint.Connect() succeeded unexpectedly")
+ }
+}
+
+func TestShutdownBeforeConnectLocal(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownBeforeConnect(t, c, false)
+}
+
+func TestShutdownBeforeConnectRemote(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownBeforeConnect(t, c, true)
+}
+
+func testShutdownDuringConnect(t *testing.T, c *testConnection, remoteShutdown bool) {
var clientRun sync.WaitGroup
clientRun.Add(1)
go func() {
@@ -115,44 +159,86 @@ func testShutdownConnect(t *testing.T, c *testConnection) {
}
}()
time.Sleep(time.Second) // to allow c.clientEP.Connect() to block
- c.clientEP.Shutdown()
+ if remoteShutdown {
+ c.serverEP.Shutdown()
+ } else {
+ c.clientEP.Shutdown()
+ }
clientRun.Wait()
}
-func TestShutdownConnect(t *testing.T) {
+func TestShutdownDuringConnectLocal(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringConnect(t, c, false)
+}
+
+func TestShutdownDuringConnectRemote(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringConnect(t, c, true)
+}
+
+func testShutdownBeforeRecvFirst(t *testing.T, c *testConnection, remoteShutdown bool) {
+ if remoteShutdown {
+ c.clientEP.Shutdown()
+ } else {
+ c.serverEP.Shutdown()
+ }
+ if _, err := c.serverEP.RecvFirst(); err == nil {
+ t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly")
+ }
+}
+
+func TestShutdownBeforeRecvFirstLocal(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownBeforeRecvFirst(t, c, false)
+}
+
+func TestShutdownBeforeRecvFirstRemote(t *testing.T) {
c := newTestConnection(t)
defer c.destroy()
- testShutdownConnect(t, c)
+ testShutdownBeforeRecvFirst(t, c, true)
}
-func testShutdownRecvFirstBeforeConnect(t *testing.T, c *testConnection) {
+func testShutdownDuringRecvFirstBeforeConnect(t *testing.T, c *testConnection, remoteShutdown bool) {
var serverRun sync.WaitGroup
serverRun.Add(1)
go func() {
defer serverRun.Done()
- _, err := c.serverEP.RecvFirst()
- if err == nil {
+ if _, err := c.serverEP.RecvFirst(); err == nil {
t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly")
}
}()
time.Sleep(time.Second) // to allow c.serverEP.RecvFirst() to block
- c.serverEP.Shutdown()
+ if remoteShutdown {
+ c.clientEP.Shutdown()
+ } else {
+ c.serverEP.Shutdown()
+ }
serverRun.Wait()
}
-func TestShutdownRecvFirstBeforeConnect(t *testing.T) {
+func TestShutdownDuringRecvFirstBeforeConnectLocal(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringRecvFirstBeforeConnect(t, c, false)
+}
+
+func TestShutdownDuringRecvFirstBeforeConnectRemote(t *testing.T) {
c := newTestConnection(t)
defer c.destroy()
- testShutdownRecvFirstBeforeConnect(t, c)
+ testShutdownDuringRecvFirstBeforeConnect(t, c, true)
}
-func testShutdownRecvFirstAfterConnect(t *testing.T, c *testConnection) {
+func testShutdownDuringRecvFirstAfterConnect(t *testing.T, c *testConnection, remoteShutdown bool) {
var serverRun sync.WaitGroup
serverRun.Add(1)
go func() {
defer serverRun.Done()
if _, err := c.serverEP.RecvFirst(); err == nil {
- t.Fatalf("server Endpoint.RecvFirst() succeeded unexpectedly")
+ t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly")
}
}()
defer func() {
@@ -164,23 +250,75 @@ func testShutdownRecvFirstAfterConnect(t *testing.T, c *testConnection) {
if err := c.clientEP.Connect(); err != nil {
t.Fatalf("client Endpoint.Connect() failed: %v", err)
}
- c.serverEP.Shutdown()
+ if remoteShutdown {
+ c.clientEP.Shutdown()
+ } else {
+ c.serverEP.Shutdown()
+ }
serverRun.Wait()
}
-func TestShutdownRecvFirstAfterConnect(t *testing.T) {
+func TestShutdownDuringRecvFirstAfterConnectLocal(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringRecvFirstAfterConnect(t, c, false)
+}
+
+func TestShutdownDuringRecvFirstAfterConnectRemote(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringRecvFirstAfterConnect(t, c, true)
+}
+
+func testShutdownDuringClientSendRecv(t *testing.T, c *testConnection, remoteShutdown bool) {
+ var serverRun sync.WaitGroup
+ serverRun.Add(1)
+ go func() {
+ defer serverRun.Done()
+ if _, err := c.serverEP.RecvFirst(); err != nil {
+ t.Errorf("server Endpoint.RecvFirst() failed: %v", err)
+ }
+ // At this point, the client must be blocked in c.clientEP.SendRecv().
+ if remoteShutdown {
+ c.serverEP.Shutdown()
+ } else {
+ c.clientEP.Shutdown()
+ }
+ }()
+ defer func() {
+ // Ensure that the server goroutine is cleaned up before
+ // c.serverEP.Destroy(), even if the test fails.
+ c.serverEP.Shutdown()
+ serverRun.Wait()
+ }()
+ if err := c.clientEP.Connect(); err != nil {
+ t.Fatalf("client Endpoint.Connect() failed: %v", err)
+ }
+ if _, err := c.clientEP.SendRecv(0); err == nil {
+ t.Errorf("client Endpoint.SendRecv() succeeded unexpectedly")
+ }
+}
+
+func TestShutdownDuringClientSendRecvLocal(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringClientSendRecv(t, c, false)
+}
+
+func TestShutdownDuringClientSendRecvRemote(t *testing.T) {
c := newTestConnection(t)
defer c.destroy()
- testShutdownRecvFirstAfterConnect(t, c)
+ testShutdownDuringClientSendRecv(t, c, true)
}
-func testShutdownSendRecv(t *testing.T, c *testConnection) {
+func testShutdownDuringServerSendRecv(t *testing.T, c *testConnection, remoteShutdown bool) {
var serverRun sync.WaitGroup
serverRun.Add(1)
go func() {
defer serverRun.Done()
if _, err := c.serverEP.RecvFirst(); err != nil {
- t.Fatalf("server Endpoint.RecvFirst() failed: %v", err)
+ t.Errorf("server Endpoint.RecvFirst() failed: %v", err)
+ return
}
if _, err := c.serverEP.SendRecv(0); err == nil {
t.Errorf("server Endpoint.SendRecv() succeeded unexpectedly")
@@ -199,14 +337,24 @@ func testShutdownSendRecv(t *testing.T, c *testConnection) {
t.Fatalf("client Endpoint.SendRecv() failed: %v", err)
}
time.Sleep(time.Second) // to allow serverEP.SendRecv() to block
- c.serverEP.Shutdown()
+ if remoteShutdown {
+ c.clientEP.Shutdown()
+ } else {
+ c.serverEP.Shutdown()
+ }
serverRun.Wait()
}
-func TestShutdownSendRecv(t *testing.T) {
+func TestShutdownDuringServerSendRecvLocal(t *testing.T) {
c := newTestConnection(t)
defer c.destroy()
- testShutdownSendRecv(t, c)
+ testShutdownDuringServerSendRecv(t, c, false)
+}
+
+func TestShutdownDuringServerSendRecvRemote(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringServerSendRecv(t, c, true)
}
func benchmarkSendRecv(b *testing.B, c *testConnection) {
@@ -218,15 +366,17 @@ func benchmarkSendRecv(b *testing.B, c *testConnection) {
return
}
if _, err := c.serverEP.RecvFirst(); err != nil {
- b.Fatalf("server Endpoint.RecvFirst() failed: %v", err)
+ b.Errorf("server Endpoint.RecvFirst() failed: %v", err)
+ return
}
for i := 1; i < b.N; i++ {
if _, err := c.serverEP.SendRecv(0); err != nil {
- b.Fatalf("server Endpoint.SendRecv() failed: %v", err)
+ b.Errorf("server Endpoint.SendRecv() failed: %v", err)
+ return
}
}
if err := c.serverEP.SendLast(0); err != nil {
- b.Fatalf("server Endpoint.SendLast() failed: %v", err)
+ b.Errorf("server Endpoint.SendLast() failed: %v", err)
}
}()
defer func() {
diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go
index 7c8977893..a37952637 100644
--- a/pkg/flipcall/flipcall_unsafe.go
+++ b/pkg/flipcall/flipcall_unsafe.go
@@ -17,17 +17,19 @@ package flipcall
import (
"reflect"
"unsafe"
+
+ "gvisor.dev/gvisor/third_party/gvsync"
)
-// Packets consist of an 8-byte header followed by an arbitrarily-sized
+// Packets consist of a 16-byte header followed by an arbitrarily-sized
// datagram. The header consists of:
//
// - A 4-byte native-endian connection state.
//
// - A 4-byte native-endian datagram length in bytes.
+//
+// - 8 reserved bytes.
const (
- sizeofUint32 = unsafe.Sizeof(uint32(0))
-
// PacketHeaderBytes is the size of a flipcall packet header in bytes. The
// maximum datagram size supported by a flipcall connection is equal to the
// length of the packet window minus PacketHeaderBytes.
@@ -35,7 +37,7 @@ const (
// PacketHeaderBytes is exported to support its use in constant
// expressions. Non-constant expressions may prefer to use
// PacketWindowLengthForDataCap().
- PacketHeaderBytes = 2 * sizeofUint32
+ PacketHeaderBytes = 16
)
func (ep *Endpoint) connState() *uint32 {
@@ -43,7 +45,7 @@ func (ep *Endpoint) connState() *uint32 {
}
func (ep *Endpoint) dataLen() *uint32 {
- return (*uint32)((unsafe.Pointer)(ep.packet + sizeofUint32))
+ return (*uint32)((unsafe.Pointer)(ep.packet + 4))
}
// Data returns the datagram part of ep's packet window as a byte slice.
@@ -67,3 +69,19 @@ func (ep *Endpoint) Data() []byte {
bsReflect.Cap = int(ep.dataCap)
return bs
}
+
+// ioSync is a dummy variable used to indicate synchronization to the Go race
+// detector. Compare syscall.ioSync.
+var ioSync int64
+
+func raceBecomeActive() {
+ if gvsync.RaceEnabled {
+ gvsync.RaceAcquire((unsafe.Pointer)(&ioSync))
+ }
+}
+
+func raceBecomeInactive() {
+ if gvsync.RaceEnabled {
+ gvsync.RaceReleaseMerge((unsafe.Pointer)(&ioSync))
+ }
+}
diff --git a/pkg/flipcall/futex_linux.go b/pkg/flipcall/futex_linux.go
index e7dd812b3..b127a2bbb 100644
--- a/pkg/flipcall/futex_linux.go
+++ b/pkg/flipcall/futex_linux.go
@@ -59,7 +59,12 @@ func (ep *Endpoint) futexConnect(req *ctrlHandshakeRequest) (ctrlHandshakeRespon
func (ep *Endpoint) futexSwitchToPeer() error {
// Update connection state to indicate that the peer should be active.
if !atomic.CompareAndSwapUint32(ep.connState(), ep.activeState, ep.inactiveState) {
- return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", atomic.LoadUint32(ep.connState()))
+ switch cs := atomic.LoadUint32(ep.connState()); cs {
+ case csShutdown:
+ return shutdownError{}
+ default:
+ return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", cs)
+ }
}
// Wake the peer's Endpoint.futexSwitchFromPeer().
@@ -75,16 +80,18 @@ func (ep *Endpoint) futexSwitchFromPeer() error {
case ep.activeState:
return nil
case ep.inactiveState:
- // Continue to FUTEX_WAIT.
+ if ep.isShutdownLocally() {
+ return shutdownError{}
+ }
+ if err := ep.futexWaitConnState(ep.inactiveState); err != nil {
+ return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err)
+ }
+ continue
+ case csShutdown:
+ return shutdownError{}
default:
return fmt.Errorf("unexpected connection state before FUTEX_WAIT: %v", cs)
}
- if ep.isShutdownLocally() {
- return shutdownError{}
- }
- if err := ep.futexWaitConnState(ep.inactiveState); err != nil {
- return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err)
- }
}
}
diff --git a/pkg/fspath/BUILD b/pkg/fspath/BUILD
index 11716af81..0c5f50397 100644
--- a/pkg/fspath/BUILD
+++ b/pkg/fspath/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(
default_visibility = ["//visibility:public"],
diff --git a/pkg/gate/BUILD b/pkg/gate/BUILD
index e6a8dbd02..4b9321711 100644
--- a/pkg/gate/BUILD
+++ b/pkg/gate/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/ilist/BUILD b/pkg/ilist/BUILD
index 8f3defa25..34d2673ef 100644
--- a/pkg/ilist/BUILD
+++ b/pkg/ilist/BUILD
@@ -1,5 +1,6 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
diff --git a/pkg/linewriter/BUILD b/pkg/linewriter/BUILD
index c8e923a74..a5d980d14 100644
--- a/pkg/linewriter/BUILD
+++ b/pkg/linewriter/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/log/BUILD b/pkg/log/BUILD
index 12615240c..fc5f5779b 100644
--- a/pkg/log/BUILD
+++ b/pkg/log/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD
index 3b8a691f4..dd6ca6d39 100644
--- a/pkg/metric/BUILD
+++ b/pkg/metric/BUILD
@@ -1,5 +1,7 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -21,6 +23,12 @@ proto_library(
visibility = ["//:sandbox"],
)
+cc_proto_library(
+ name = "metric_cc_proto",
+ visibility = ["//:sandbox"],
+ deps = [":metric_proto"],
+)
+
go_proto_library(
name = "metric_go_proto",
importpath = "gvisor.dev/gvisor/pkg/metric/metric_go_proto",
diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD
index c6737bf97..f32244c69 100644
--- a/pkg/p9/BUILD
+++ b/pkg/p9/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(
default_visibility = ["//visibility:public"],
@@ -19,11 +20,14 @@ go_library(
"pool.go",
"server.go",
"transport.go",
+ "transport_flipcall.go",
"version.go",
],
importpath = "gvisor.dev/gvisor/pkg/p9",
deps = [
"//pkg/fd",
+ "//pkg/fdchannel",
+ "//pkg/flipcall",
"//pkg/log",
"//pkg/unet",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
index 7dc20aeef..2412aa5e1 100644
--- a/pkg/p9/client.go
+++ b/pkg/p9/client.go
@@ -20,6 +20,8 @@ import (
"sync"
"syscall"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/flipcall"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/unet"
)
@@ -77,6 +79,47 @@ type Client struct {
// fidPool is the collection of available fids.
fidPool pool
+ // messageSize is the maximum total size of a message.
+ messageSize uint32
+
+ // payloadSize is the maximum payload size of a read or write.
+ //
+ // For large reads and writes this means that the read or write is
+ // broken up into buffer-size/payloadSize requests.
+ payloadSize uint32
+
+ // version is the agreed upon version X of 9P2000.L.Google.X.
+ // version 0 implies 9P2000.L.
+ version uint32
+
+ // closedWg is marked as done when the Client.watch() goroutine, which is
+ // responsible for closing channels and the socket fd, returns.
+ closedWg sync.WaitGroup
+
+ // sendRecv is the transport function.
+ //
+ // This is determined dynamically based on whether or not the server
+ // supports flipcall channels (preferred as it is faster and more
+ // efficient, and does not require tags).
+ sendRecv func(message, message) error
+
+ // -- below corresponds to sendRecvChannel --
+
+ // channelsMu protects channels.
+ channelsMu sync.Mutex
+
+ // channelsWg counts the number of channels for which channel.active ==
+ // true.
+ channelsWg sync.WaitGroup
+
+ // channels is the set of all initialized channels.
+ channels []*channel
+
+ // availableChannels is a FIFO of inactive channels.
+ availableChannels []*channel
+
+ // -- below corresponds to sendRecvLegacy --
+
// pending is the set of pending messages.
pending map[Tag]*response
pendingMu sync.Mutex
@@ -89,25 +132,12 @@ type Client struct {
// Whoever writes to this channel is permitted to call recv. When
// finished calling recv, this channel should be emptied.
recvr chan bool
-
- // messageSize is the maximum total size of a message.
- messageSize uint32
-
- // payloadSize is the maximum payload size of a read or write
- // request. For large reads and writes this means that the
- // read or write is broken up into buffer-size/payloadSize
- // requests.
- payloadSize uint32
-
- // version is the agreed upon version X of 9P2000.L.Google.X.
- // version 0 implies 9P2000.L.
- version uint32
}
// NewClient creates a new client. It performs a Tversion exchange with
// the server to assert that messageSize is ok to use.
//
-// You should not use the same socket for multiple clients.
+// If NewClient succeeds, ownership of socket is transferred to the new Client.
func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client, error) {
// Need at least one byte of payload.
if messageSize <= msgRegistry.largestFixedSize {
@@ -138,8 +168,15 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client
return nil, ErrBadVersionString
}
for {
+ // Always exchange the version using the legacy version of the
+ // protocol. If the protocol supports flipcall, then we switch
+ // our sendRecv function to use that functionality. Otherwise,
+ // we stick to sendRecvLegacy.
rversion := Rversion{}
- err := c.sendRecv(&Tversion{Version: versionString(requested), MSize: messageSize}, &rversion)
+ err := c.sendRecvLegacy(&Tversion{
+ Version: versionString(requested),
+ MSize: messageSize,
+ }, &rversion)
// The server told us to try again with a lower version.
if err == syscall.EAGAIN {
@@ -165,9 +202,155 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client
c.version = version
break
}
+
+ // Can we switch to use the more advanced channels and create
+ // independent channels for communication? Prefer it if possible.
+ if versionSupportsFlipcall(c.version) {
+ // Attempt to initialize IPC-based communication.
+ for i := 0; i < channelsPerClient; i++ {
+ if err := c.openChannel(i); err != nil {
+ log.Warningf("error opening flipcall channel: %v", err)
+ break // Stop.
+ }
+ }
+ if len(c.channels) >= 1 {
+ // At least one channel created.
+ c.sendRecv = c.sendRecvChannel
+ } else {
+ // Channel setup failed; fallback.
+ c.sendRecv = c.sendRecvLegacy
+ }
+ } else {
+ // No channels available: use the legacy mechanism.
+ c.sendRecv = c.sendRecvLegacy
+ }
+
+ // Ensure that the socket and channels are closed when the socket is shut
+ // down.
+ c.closedWg.Add(1)
+ go c.watch(socket) // S/R-SAFE: not relevant.
+
return c, nil
}
+// watch watches the given socket and releases resources on hangup events.
+//
+// This is intended to be called as a goroutine.
+func (c *Client) watch(socket *unet.Socket) {
+ defer c.closedWg.Done()
+
+ events := []unix.PollFd{
+ unix.PollFd{
+ Fd: int32(socket.FD()),
+ Events: unix.POLLHUP | unix.POLLRDHUP,
+ },
+ }
+
+ // Wait for a shutdown event.
+ for {
+ n, err := unix.Ppoll(events, nil, nil)
+ if err == syscall.EINTR || err == syscall.EAGAIN {
+ continue
+ }
+ if err != nil {
+ log.Warningf("p9.Client.watch(): %v", err)
+ break
+ }
+ if n != 1 {
+ log.Warningf("p9.Client.watch(): got %d events, wanted 1", n)
+ }
+ break
+ }
+
+ // Set availableChannels to nil so that future calls to c.sendRecvChannel()
+ // don't attempt to activate a channel, and concurrent calls to
+ // c.sendRecvChannel() don't mark released channels as available.
+ c.channelsMu.Lock()
+ c.availableChannels = nil
+
+ // Shut down all active channels.
+ for _, ch := range c.channels {
+ if ch.active {
+ log.Debugf("shutting down active channel@%p...", ch)
+ ch.Shutdown()
+ }
+ }
+ c.channelsMu.Unlock()
+
+ // Wait for active channels to become inactive.
+ c.channelsWg.Wait()
+
+ // Close all channels.
+ c.channelsMu.Lock()
+ for _, ch := range c.channels {
+ ch.Close()
+ }
+ c.channelsMu.Unlock()
+
+ // Close the main socket.
+ c.socket.Close()
+}
+
+// openChannel attempts to open a client channel.
+//
+// Note that this function returns naked errors which should not be propagated
+// directly to a caller. It is expected that the errors will be logged and a
+// fallback path will be used instead.
+func (c *Client) openChannel(id int) error {
+ var (
+ rchannel0 Rchannel
+ rchannel1 Rchannel
+ res = new(channel)
+ )
+
+ // Open the data channel.
+ if err := c.sendRecvLegacy(&Tchannel{
+ ID: uint32(id),
+ Control: 0,
+ }, &rchannel0); err != nil {
+ return fmt.Errorf("error handling Tchannel message: %v", err)
+ }
+ if rchannel0.FilePayload() == nil {
+ return fmt.Errorf("missing file descriptor on primary channel")
+ }
+
+ // We don't need to hold this.
+ defer rchannel0.FilePayload().Close()
+
+ // Open the channel for file descriptors.
+ if err := c.sendRecvLegacy(&Tchannel{
+ ID: uint32(id),
+ Control: 1,
+ }, &rchannel1); err != nil {
+ return err
+ }
+ if rchannel1.FilePayload() == nil {
+ return fmt.Errorf("missing file descriptor on file descriptor channel")
+ }
+
+ // Construct the endpoints.
+ res.desc = flipcall.PacketWindowDescriptor{
+ FD: rchannel0.FilePayload().FD(),
+ Offset: int64(rchannel0.Offset),
+ Length: int(rchannel0.Length),
+ }
+ if err := res.data.Init(flipcall.ClientSide, res.desc); err != nil {
+ rchannel1.FilePayload().Close()
+ return err
+ }
+
+ // The fds channel owns the control payload, and it will be closed when
+ // the channel object is closed.
+ res.fds.Init(rchannel1.FilePayload().Release())
+
+ // Save the channel.
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+ c.channels = append(c.channels, res)
+ c.availableChannels = append(c.availableChannels, res)
+ return nil
+}
+
// handleOne handles a single incoming message.
//
// This should only be called with the token from recvr. Note that the received
@@ -247,10 +430,10 @@ func (c *Client) waitAndRecv(done chan error) error {
}
}
-// sendRecv performs a roundtrip message exchange.
+// sendRecvLegacy performs a roundtrip message exchange.
//
// This is called by internal functions.
-func (c *Client) sendRecv(t message, r message) error {
+func (c *Client) sendRecvLegacy(t message, r message) error {
tag, ok := c.tagPool.Get()
if !ok {
return ErrOutOfTags
@@ -296,12 +479,62 @@ func (c *Client) sendRecv(t message, r message) error {
return nil
}
+// sendRecvChannel uses channels to send a message.
+func (c *Client) sendRecvChannel(t message, r message) error {
+ // Acquire an available channel.
+ c.channelsMu.Lock()
+ if len(c.availableChannels) == 0 {
+ c.channelsMu.Unlock()
+ return c.sendRecvLegacy(t, r)
+ }
+ idx := len(c.availableChannels) - 1
+ ch := c.availableChannels[idx]
+ c.availableChannels = c.availableChannels[:idx]
+ ch.active = true
+ c.channelsWg.Add(1)
+ c.channelsMu.Unlock()
+
+ // Ensure that it's connected.
+ if !ch.connected {
+ ch.connected = true
+ if err := ch.data.Connect(); err != nil {
+ // The channel is unusable, so don't return it to
+ // c.availableChannels. However, we still have to mark it as
+ // inactive so c.watch() doesn't wait for it.
+ c.channelsMu.Lock()
+ ch.active = false
+ c.channelsMu.Unlock()
+ c.channelsWg.Done()
+ return err
+ }
+ }
+
+ // Send the message.
+ err := ch.sendRecv(c, t, r)
+
+ // Release the channel.
+ c.channelsMu.Lock()
+ ch.active = false
+ // If c.availableChannels is nil, c.watch() has fired and we should not
+ // mark this channel as available.
+ if c.availableChannels != nil {
+ c.availableChannels = append(c.availableChannels, ch)
+ }
+ c.channelsMu.Unlock()
+ c.channelsWg.Done()
+
+ return err
+}
+
// Version returns the negotiated 9P2000.L.Google version number.
func (c *Client) Version() uint32 {
return c.version
}
-// Close closes the underlying socket.
-func (c *Client) Close() error {
- return c.socket.Close()
+// Close closes the underlying socket and channels.
+func (c *Client) Close() {
+ // unet.Socket.Shutdown() has no effect if unet.Socket.Close() has already
+ // been called (by c.watch()).
+ c.socket.Shutdown()
+ c.closedWg.Wait()
}
diff --git a/pkg/p9/client_test.go b/pkg/p9/client_test.go
index 87b2dd61e..29a0afadf 100644
--- a/pkg/p9/client_test.go
+++ b/pkg/p9/client_test.go
@@ -35,23 +35,23 @@ func TestVersion(t *testing.T) {
go s.Handle(serverSocket)
// NewClient does a Tversion exchange, so this is our test for success.
- c, err := NewClient(clientSocket, 1024*1024 /* 1M message size */, HighestVersionString())
+ c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString())
if err != nil {
t.Fatalf("got %v, expected nil", err)
}
// Check a bogus version string.
- if err := c.sendRecv(&Tversion{Version: "notokay", MSize: 1024 * 1024}, &Rversion{}); err != syscall.EINVAL {
+ if err := c.sendRecv(&Tversion{Version: "notokay", MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EINVAL {
t.Errorf("got %v expected %v", err, syscall.EINVAL)
}
// Check a bogus version number.
- if err := c.sendRecv(&Tversion{Version: "9P1000.L", MSize: 1024 * 1024}, &Rversion{}); err != syscall.EINVAL {
+ if err := c.sendRecv(&Tversion{Version: "9P1000.L", MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EINVAL {
t.Errorf("got %v expected %v", err, syscall.EINVAL)
}
// Check a too high version number.
- if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion + 1), MSize: 1024 * 1024}, &Rversion{}); err != syscall.EAGAIN {
+ if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion + 1), MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EAGAIN {
t.Errorf("got %v expected %v", err, syscall.EAGAIN)
}
@@ -60,3 +60,45 @@ func TestVersion(t *testing.T) {
t.Errorf("got %v expected %v", err, syscall.EINVAL)
}
}
+
+func benchmarkSendRecv(b *testing.B, fn func(c *Client) func(message, message) error) {
+ // See above.
+ serverSocket, clientSocket, err := unet.SocketPair(false)
+ if err != nil {
+ b.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer clientSocket.Close()
+
+ // See above.
+ s := NewServer(nil)
+ go s.Handle(serverSocket)
+
+ // See above.
+ c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString())
+ if err != nil {
+ b.Fatalf("got %v, expected nil", err)
+ }
+
+ // Initialize messages.
+ sendRecv := fn(c)
+ tversion := &Tversion{
+ Version: versionString(highestSupportedVersion),
+ MSize: DefaultMessageSize,
+ }
+ rversion := new(Rversion)
+
+ // Run in a loop.
+ for i := 0; i < b.N; i++ {
+ if err := sendRecv(tversion, rversion); err != nil {
+ b.Fatalf("got unexpected err: %v", err)
+ }
+ }
+}
+
+func BenchmarkSendRecvLegacy(b *testing.B) {
+ benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvLegacy })
+}
+
+func BenchmarkSendRecvChannel(b *testing.B) {
+ benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvChannel })
+}
diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go
index 999b4f684..ba9a55d6d 100644
--- a/pkg/p9/handlers.go
+++ b/pkg/p9/handlers.go
@@ -305,7 +305,9 @@ func (t *Tlopen) handle(cs *connState) message {
ref.opened = true
ref.openFlags = t.Flags
- return &Rlopen{QID: qid, IoUnit: ioUnit, File: osFile}
+ rlopen := &Rlopen{QID: qid, IoUnit: ioUnit}
+ rlopen.SetFilePayload(osFile)
+ return rlopen
}
func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) {
@@ -364,7 +366,9 @@ func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) {
// Replace the FID reference.
cs.InsertFID(t.FID, newRef)
- return &Rlcreate{Rlopen: Rlopen{QID: qid, IoUnit: ioUnit, File: osFile}}, nil
+ rlcreate := &Rlcreate{Rlopen: Rlopen{QID: qid, IoUnit: ioUnit}}
+ rlcreate.SetFilePayload(osFile)
+ return rlcreate, nil
}
// handle implements handler.handle.
@@ -1287,5 +1291,48 @@ func (t *Tlconnect) handle(cs *connState) message {
return newErr(err)
}
- return &Rlconnect{File: osFile}
+ rlconnect := &Rlconnect{}
+ rlconnect.SetFilePayload(osFile)
+ return rlconnect
+}
+
+// handle implements handler.handle.
+func (t *Tchannel) handle(cs *connState) message {
+ // Ensure that channels are enabled.
+ if err := cs.initializeChannels(); err != nil {
+ return newErr(err)
+ }
+
+ // Lookup the given channel.
+ ch := cs.lookupChannel(t.ID)
+ if ch == nil {
+ return newErr(syscall.ENOSYS)
+ }
+
+ // Return the payload. Note that we need to duplicate the file
+ // descriptor for the channel allocator, because sending is a
+ // destructive operation between sendRecvLegacy (and now the newer
+ // channel send operations). Same goes for the client FD.
+ rchannel := &Rchannel{
+ Offset: uint64(ch.desc.Offset),
+ Length: uint64(ch.desc.Length),
+ }
+ switch t.Control {
+ case 0:
+ // Open the main data channel.
+ mfd, err := syscall.Dup(int(cs.channelAlloc.FD()))
+ if err != nil {
+ return newErr(err)
+ }
+ rchannel.SetFilePayload(fd.New(mfd))
+ case 1:
+ cfd, err := syscall.Dup(ch.client.FD())
+ if err != nil {
+ return newErr(err)
+ }
+ rchannel.SetFilePayload(fd.New(cfd))
+ default:
+ return newErr(syscall.EINVAL)
+ }
+ return rchannel
}
diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go
index fd9eb1c5d..ffdd7e8c6 100644
--- a/pkg/p9/messages.go
+++ b/pkg/p9/messages.go
@@ -64,6 +64,21 @@ type filer interface {
SetFilePayload(*fd.FD)
}
+// filePayload embeds a File object.
+type filePayload struct {
+ File *fd.FD
+}
+
+// FilePayload returns the file payload.
+func (f *filePayload) FilePayload() *fd.FD {
+ return f.File
+}
+
+// SetFilePayload sets the received file.
+func (f *filePayload) SetFilePayload(file *fd.FD) {
+ f.File = file
+}
+
// Tversion is a version request.
type Tversion struct {
// MSize is the message size to use.
@@ -524,10 +539,7 @@ type Rlopen struct {
// IoUnit is the recommended I/O unit.
IoUnit uint32
- // File may be attached via the socket.
- //
- // This is an extension specific to this package.
- File *fd.FD
+ filePayload
}
// Decode implements encoder.Decode.
@@ -547,16 +559,6 @@ func (*Rlopen) Type() MsgType {
return MsgRlopen
}
-// FilePayload returns the file payload.
-func (r *Rlopen) FilePayload() *fd.FD {
- return r.File
-}
-
-// SetFilePayload sets the received file.
-func (r *Rlopen) SetFilePayload(file *fd.FD) {
- r.File = file
-}
-
// String implements fmt.Stringer.
func (r *Rlopen) String() string {
return fmt.Sprintf("Rlopen{QID: %s, IoUnit: %d, File: %v}", r.QID, r.IoUnit, r.File)
@@ -2171,8 +2173,7 @@ func (t *Tlconnect) String() string {
// Rlconnect is a connect response.
type Rlconnect struct {
- // File is a host socket.
- File *fd.FD
+ filePayload
}
// Decode implements encoder.Decode.
@@ -2186,19 +2187,71 @@ func (*Rlconnect) Type() MsgType {
return MsgRlconnect
}
-// FilePayload returns the file payload.
-func (r *Rlconnect) FilePayload() *fd.FD {
- return r.File
+// String implements fmt.Stringer.
+func (r *Rlconnect) String() string {
+ return fmt.Sprintf("Rlconnect{File: %v}", r.File)
}
-// SetFilePayload sets the received file.
-func (r *Rlconnect) SetFilePayload(file *fd.FD) {
- r.File = file
+// Tchannel creates a new channel.
+type Tchannel struct {
+ // ID is the channel ID.
+ ID uint32
+
+ // Control is 0 if the Rchannel response should provide the flipcall
+ // component of the channel, and 1 if the Rchannel response should
+ // provide the fdchannel component of the channel.
+ Control uint32
+}
+
+// Decode implements encoder.Decode.
+func (t *Tchannel) Decode(b *buffer) {
+ t.ID = b.Read32()
+ t.Control = b.Read32()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tchannel) Encode(b *buffer) {
+ b.Write32(t.ID)
+ b.Write32(t.Control)
+}
+
+// Type implements message.Type.
+func (*Tchannel) Type() MsgType {
+ return MsgTchannel
}
// String implements fmt.Stringer.
-func (r *Rlconnect) String() string {
- return fmt.Sprintf("Rlconnect{File: %v}", r.File)
+func (t *Tchannel) String() string {
+ return fmt.Sprintf("Tchannel{ID: %d, Control: %d}", t.ID, t.Control)
+}
+
+// Rchannel is the channel response.
+type Rchannel struct {
+ Offset uint64
+ Length uint64
+ filePayload
+}
+
+// Decode implements encoder.Decode.
+func (r *Rchannel) Decode(b *buffer) {
+ r.Offset = b.Read64()
+ r.Length = b.Read64()
+}
+
+// Encode implements encoder.Encode.
+func (r *Rchannel) Encode(b *buffer) {
+ b.Write64(r.Offset)
+ b.Write64(r.Length)
+}
+
+// Type implements message.Type.
+func (*Rchannel) Type() MsgType {
+ return MsgRchannel
+}
+
+// String implements fmt.Stringer.
+func (r *Rchannel) String() string {
+ return fmt.Sprintf("Rchannel{Offset: %d, Length: %d}", r.Offset, r.Length)
}
const maxCacheSize = 3
@@ -2356,4 +2409,6 @@ func init() {
msgRegistry.register(MsgRlconnect, func() message { return &Rlconnect{} })
msgRegistry.register(MsgTallocate, func() message { return &Tallocate{} })
msgRegistry.register(MsgRallocate, func() message { return &Rallocate{} })
+ msgRegistry.register(MsgTchannel, func() message { return &Tchannel{} })
+ msgRegistry.register(MsgRchannel, func() message { return &Rchannel{} })
}
diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go
index e12831dbd..25530adca 100644
--- a/pkg/p9/p9.go
+++ b/pkg/p9/p9.go
@@ -378,6 +378,8 @@ const (
MsgRlconnect = 137
MsgTallocate = 138
MsgRallocate = 139
+ MsgTchannel = 250
+ MsgRchannel = 251
)
// QIDType represents the file type for QIDs.
diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD
index 6e939a49a..28707c0ca 100644
--- a/pkg/p9/p9test/BUILD
+++ b/pkg/p9/p9test/BUILD
@@ -1,5 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test")
package(licenses = ["notice"])
@@ -77,7 +77,7 @@ go_library(
go_test(
name = "client_test",
- size = "small",
+ size = "medium",
srcs = ["client_test.go"],
embed = [":p9test"],
deps = [
diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go
index fe649c2e8..8bbdb2488 100644
--- a/pkg/p9/p9test/client_test.go
+++ b/pkg/p9/p9test/client_test.go
@@ -2127,3 +2127,98 @@ func TestConcurrency(t *testing.T) {
}
}
}
+
+func TestReadWriteConcurrent(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ const (
+ instances = 10
+ iterations = 10000
+ dataSize = 1024
+ )
+ var (
+ dataSets [instances][dataSize]byte
+ backends [instances]*Mock
+ files [instances]p9.File
+ )
+
+ // Walk to the file normally.
+ for i := 0; i < instances; i++ {
+ _, backends[i], files[i] = walkHelper(h, "file", root)
+ defer files[i].Close()
+ }
+
+ // Open the files.
+ for i := 0; i < instances; i++ {
+ backends[i].EXPECT().Open(p9.ReadWrite)
+ if _, _, _, err := files[i].Open(p9.ReadWrite); err != nil {
+ t.Fatalf("open got %v, wanted nil", err)
+ }
+ }
+
+ // Initialize random data for each instance.
+ for i := 0; i < instances; i++ {
+ if _, err := rand.Read(dataSets[i][:]); err != nil {
+ t.Fatalf("error initializing dataSet#%d, got %v", i, err)
+ }
+ }
+
+ // Define our random read/write mechanism.
+ randRead := func(h *Harness, backend *Mock, f p9.File, data, test []byte) {
+ // Prepare the backend.
+ backend.EXPECT().ReadAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) {
+ if n := copy(p, data); n != len(data) {
+ // Note that we have to assert the result here, as the Return statement
+ // below cannot be dynamic: it will be bound before this call is made.
+ h.t.Errorf("wanted length %d, got %d", len(data), n)
+ }
+ }).Return(len(data), nil)
+
+ // Execute the read.
+ if n, err := f.ReadAt(test, 0); n != len(test) || err != nil {
+ t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(test), n, err)
+ return // No sense doing check below.
+ }
+ if !bytes.Equal(test, data) {
+ t.Errorf("data integrity failed during read") // Not as expected.
+ }
+ }
+ randWrite := func(h *Harness, backend *Mock, f p9.File, data []byte) {
+ // Prepare the backend.
+ backend.EXPECT().WriteAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) {
+ if !bytes.Equal(p, data) {
+ h.t.Errorf("data integrity failed during write") // Not as expected.
+ }
+ }).Return(len(data), nil)
+
+ // Execute the write.
+ if n, err := f.WriteAt(data, 0); n != len(data) || err != nil {
+ t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(data), n, err)
+ }
+ }
+ randReadWrite := func(n int, h *Harness, backend *Mock, f p9.File, data []byte) {
+ test := make([]byte, len(data))
+ for i := 0; i < n; i++ {
+ if rand.Intn(2) == 0 {
+ randRead(h, backend, f, data, test)
+ } else {
+ randWrite(h, backend, f, data)
+ }
+ }
+ }
+
+ // Start reading and writing.
+ var wg sync.WaitGroup
+ for i := 0; i < instances; i++ {
+ wg.Add(1)
+ go func(i int) {
+ defer wg.Done()
+ randReadWrite(iterations, h, backends[i], files[i], dataSets[i][:])
+ }(i)
+ }
+ wg.Wait()
+}
diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go
index 95846e5f7..4d3271b37 100644
--- a/pkg/p9/p9test/p9test.go
+++ b/pkg/p9/p9test/p9test.go
@@ -279,7 +279,7 @@ func (h *Harness) NewSocket() Generator {
// Finish completes all checks and shuts down the server.
func (h *Harness) Finish() {
- h.clientSocket.Close()
+ h.clientSocket.Shutdown()
h.wg.Wait()
h.mockCtrl.Finish()
}
@@ -315,7 +315,7 @@ func NewHarness(t *testing.T) (*Harness, *p9.Client) {
}()
// Create the client.
- client, err := p9.NewClient(clientSocket, 1024, p9.HighestVersionString())
+ client, err := p9.NewClient(clientSocket, p9.DefaultMessageSize, p9.HighestVersionString())
if err != nil {
serverSocket.Close()
clientSocket.Close()
diff --git a/pkg/p9/server.go b/pkg/p9/server.go
index b294efbb0..69c886a5d 100644
--- a/pkg/p9/server.go
+++ b/pkg/p9/server.go
@@ -21,6 +21,9 @@ import (
"sync/atomic"
"syscall"
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/fdchannel"
+ "gvisor.dev/gvisor/pkg/flipcall"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/unet"
)
@@ -45,7 +48,6 @@ type Server struct {
}
// NewServer returns a new server.
-//
func NewServer(attacher Attacher) *Server {
return &Server{
attacher: attacher,
@@ -85,6 +87,8 @@ type connState struct {
// version 0 implies 9P2000.L.
version uint32
+ // -- below relates to the legacy handler --
+
// recvOkay indicates that a receive may start.
recvOkay chan bool
@@ -93,6 +97,20 @@ type connState struct {
// sendDone is signalled when a send is finished.
sendDone chan error
+
+ // -- below relates to the flipcall handler --
+
+ // channelMu protects below.
+ channelMu sync.Mutex
+
+ // channelWg represents active workers.
+ channelWg sync.WaitGroup
+
+ // channelAlloc allocates channel memory.
+ channelAlloc *flipcall.PacketWindowAllocator
+
+ // channels are the set of initialized channels.
+ channels []*channel
}
// fidRef wraps a node and tracks references.
@@ -386,6 +404,99 @@ func (cs *connState) WaitTag(t Tag) {
<-ch
}
+// initializeChannels initializes all channels.
+//
+// This is a no-op if channels are already initialized.
+func (cs *connState) initializeChannels() (err error) {
+ cs.channelMu.Lock()
+ defer cs.channelMu.Unlock()
+
+ // Initialize our channel allocator.
+ if cs.channelAlloc == nil {
+ alloc, err := flipcall.NewPacketWindowAllocator()
+ if err != nil {
+ return err
+ }
+ cs.channelAlloc = alloc
+ }
+
+ // Create all the channels.
+ for len(cs.channels) < channelsPerClient {
+ res := &channel{
+ done: make(chan struct{}),
+ }
+
+ res.desc, err = cs.channelAlloc.Allocate(channelSize)
+ if err != nil {
+ return err
+ }
+ if err := res.data.Init(flipcall.ServerSide, res.desc); err != nil {
+ return err
+ }
+
+ socks, err := fdchannel.NewConnectedSockets()
+ if err != nil {
+ res.data.Destroy() // Cleanup.
+ return err
+ }
+ res.fds.Init(socks[0])
+ res.client = fd.New(socks[1])
+
+ cs.channels = append(cs.channels, res)
+
+ // Start servicing the channel.
+ //
+ // When we call stop, we will close all the channels and these
+ // routines should finish. We need the wait group to ensure
+ // that active handlers are actually finished before cleanup.
+ cs.channelWg.Add(1)
+ go func() { // S/R-SAFE: Server side.
+ defer cs.channelWg.Done()
+ res.service(cs)
+ }()
+ }
+
+ return nil
+}
+
+// lookupChannel looks up the channel with given id.
+//
+// The function returns nil if no such channel is available.
+func (cs *connState) lookupChannel(id uint32) *channel {
+ cs.channelMu.Lock()
+ defer cs.channelMu.Unlock()
+ if id >= uint32(len(cs.channels)) {
+ return nil
+ }
+ return cs.channels[id]
+}
+
+// handle handles a single message.
+func (cs *connState) handle(m message) (r message) {
+ defer func() {
+ if r == nil {
+ // Don't allow a panic to propagate.
+ recover()
+
+ // Include a useful log message.
+ log.Warningf("panic in handler: %s", debug.Stack())
+
+ // Wrap in an EFAULT error; we don't really have a
+ // better way to describe this kind of error. It will
+ // usually manifest as a result of the test framework.
+ r = newErr(syscall.EFAULT)
+ }
+ }()
+ if handler, ok := m.(handler); ok {
+ // Call the message handler.
+ r = handler.handle(cs)
+ } else {
+ // Produce an ENOSYS error.
+ r = newErr(syscall.ENOSYS)
+ }
+ return
+}
+
// handleRequest handles a single request.
//
// The recvDone channel is signaled when recv is done (with a error if
@@ -428,41 +539,20 @@ func (cs *connState) handleRequest() {
}
// Handle the message.
- var r message // r is the response.
- defer func() {
- if r == nil {
- // Don't allow a panic to propagate.
- recover()
+ r := cs.handle(m)
- // Include a useful log message.
- log.Warningf("panic in handler: %s", debug.Stack())
+ // Clear the tag before sending. That's because as soon as this hits
+ // the wire, the client can legally send the same tag.
+ cs.ClearTag(tag)
- // Wrap in an EFAULT error; we don't really have a
- // better way to describe this kind of error. It will
- // usually manifest as a result of the test framework.
- r = newErr(syscall.EFAULT)
- }
+ // Send back the result.
+ cs.sendMu.Lock()
+ err = send(cs.conn, tag, r)
+ cs.sendMu.Unlock()
+ cs.sendDone <- err
- // Clear the tag before sending. That's because as soon as this
- // hits the wire, the client can legally send another message
- // with the same tag.
- cs.ClearTag(tag)
-
- // Send back the result.
- cs.sendMu.Lock()
- err = send(cs.conn, tag, r)
- cs.sendMu.Unlock()
- cs.sendDone <- err
- }()
- if handler, ok := m.(handler); ok {
- // Call the message handler.
- r = handler.handle(cs)
- } else {
- // Produce an ENOSYS error.
- r = newErr(syscall.ENOSYS)
- }
+ // Return the message to the cache.
msgRegistry.put(m)
- m = nil // 'm' should not be touched after this point.
}
func (cs *connState) handleRequests() {
@@ -477,7 +567,27 @@ func (cs *connState) stop() {
close(cs.recvDone)
close(cs.sendDone)
- for _, fidRef := range cs.fids {
+ // Free the channels.
+ cs.channelMu.Lock()
+ for _, ch := range cs.channels {
+ ch.Shutdown()
+ }
+ cs.channelWg.Wait()
+ for _, ch := range cs.channels {
+ ch.Close()
+ }
+ cs.channels = nil // Clear.
+ cs.channelMu.Unlock()
+
+ // Free the channel memory.
+ if cs.channelAlloc != nil {
+ cs.channelAlloc.Destroy()
+ }
+
+ // Close all remaining fids.
+ for fid, fidRef := range cs.fids {
+ delete(cs.fids, fid)
+
// Drop final reference in the FID table. Note this should
// always close the file, since we've ensured that there are no
// handlers running via the wait for Pending => 0 below.
@@ -510,7 +620,7 @@ func (cs *connState) service() error {
for i := 0; i < pending; i++ {
<-cs.sendDone
}
- return err
+ return nil
}
// This handler is now pending.
diff --git a/pkg/p9/transport.go b/pkg/p9/transport.go
index 5648df589..6e8b4bbcd 100644
--- a/pkg/p9/transport.go
+++ b/pkg/p9/transport.go
@@ -54,7 +54,10 @@ const (
headerLength uint32 = 7
// maximumLength is the largest possible message.
- maximumLength uint32 = 4 * 1024 * 1024
+ maximumLength uint32 = 1 << 20
+
+ // DefaultMessageSize is a sensible default.
+ DefaultMessageSize uint32 = 64 << 10
// initialBufferLength is the initial data buffer we allocate.
initialBufferLength uint32 = 64
diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go
new file mode 100644
index 000000000..7cdf4ecc3
--- /dev/null
+++ b/pkg/p9/transport_flipcall.go
@@ -0,0 +1,263 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "runtime"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/fdchannel"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+// channelsPerClient is the number of channels to create per client.
+//
+// While the client and server will generally agree on this number, in reality
+// it's completely up to the server. We simply define a minimum of 2, and a
+// maximum of 4, and select the number of available processes as a tie-breaker.
+// Note that we don't want the number of channels to be too large, because each
+// will account for channelSize memory used, which can be large.
+var channelsPerClient = func() int {
+ n := runtime.NumCPU()
+ if n < 2 {
+ return 2
+ }
+ if n > 4 {
+ return 4
+ }
+ return n
+}()
+
+// channelSize is the channel size to create.
+//
+// We simply ensure that this is larger than the largest possible message size,
+// plus the flipcall packet header, plus the two bytes we write below.
+const channelSize = int(2 + flipcall.PacketHeaderBytes + 2 + maximumLength)
+
+// channel is a fast IPC channel.
+//
+// The same object is used by both the server and client implementations. In
+// general, the client will use only the send and recv methods.
+type channel struct {
+ desc flipcall.PacketWindowDescriptor
+ data flipcall.Endpoint
+ fds fdchannel.Endpoint
+ buf buffer
+
+ // -- client only --
+ connected bool
+ active bool
+
+ // -- server only --
+ client *fd.FD
+ done chan struct{}
+}
+
+// reset resets the channel buffer.
+func (ch *channel) reset(sz uint32) {
+ ch.buf.data = ch.data.Data()[:sz]
+}
+
+// service services the channel.
+func (ch *channel) service(cs *connState) error {
+ rsz, err := ch.data.RecvFirst()
+ if err != nil {
+ return err
+ }
+ for rsz > 0 {
+ m, err := ch.recv(nil, rsz)
+ if err != nil {
+ return err
+ }
+ r := cs.handle(m)
+ msgRegistry.put(m)
+ rsz, err = ch.send(r)
+ if err != nil {
+ return err
+ }
+ }
+ return nil // Done.
+}
+
+// Shutdown shuts down the channel.
+//
+// This must be called before Close.
+func (ch *channel) Shutdown() {
+ ch.data.Shutdown()
+}
+
+// Close closes the channel.
+//
+// This must only be called once, and cannot return an error. Note that
+// synchronization for this method is provided at a high-level, depending on
+// whether it is the client or server. This cannot be called while there are
+// active callers in either service or sendRecv.
+//
+// Precondition: the channel should be shutdown.
+func (ch *channel) Close() error {
+ // Close all backing transports.
+ ch.fds.Destroy()
+ ch.data.Destroy()
+ if ch.client != nil {
+ ch.client.Close()
+ }
+ return nil
+}
+
+// send sends the given message.
+//
+// The return value is the size of the received response. Not that in the
+// server case, this is the size of the next request.
+func (ch *channel) send(m message) (uint32, error) {
+ if log.IsLogging(log.Debug) {
+ log.Debugf("send [channel @%p] %s", ch, m.String())
+ }
+
+ // Send any file payload.
+ sentFD := false
+ if filer, ok := m.(filer); ok {
+ if f := filer.FilePayload(); f != nil {
+ if err := ch.fds.SendFD(f.FD()); err != nil {
+ return 0, syscall.EIO // Map everything to EIO.
+ }
+ f.Close() // Per sendRecvLegacy.
+ sentFD = true // To mark below.
+ }
+ }
+
+ // Encode the message.
+ //
+ // Note that IPC itself encodes the length of messages, so we don't
+ // need to encode a standard 9P header. We write only the message type.
+ ch.reset(0)
+
+ ch.buf.WriteMsgType(m.Type())
+ if sentFD {
+ ch.buf.Write8(1) // Incoming FD.
+ } else {
+ ch.buf.Write8(0) // No incoming FD.
+ }
+ m.Encode(&ch.buf)
+ ssz := uint32(len(ch.buf.data)) // Updated below.
+
+ // Is there a payload?
+ if payloader, ok := m.(payloader); ok {
+ p := payloader.Payload()
+ copy(ch.data.Data()[ssz:], p)
+ ssz += uint32(len(p))
+ }
+
+ // Perform the one-shot communication.
+ n, err := ch.data.SendRecv(ssz)
+ if err != nil {
+ if n > 0 {
+ return n, nil
+ }
+ return 0, syscall.EIO // See above.
+ }
+
+ return n, nil
+}
+
+// recv decodes a message that exists on the channel.
+//
+// If the passed r is non-nil, then the type must match or an error will be
+// generated. If the passed r is nil, then a new message will be created and
+// returned.
+func (ch *channel) recv(r message, rsz uint32) (message, error) {
+ // Decode the response from the inline buffer.
+ ch.reset(rsz)
+ t := ch.buf.ReadMsgType()
+ hasFD := ch.buf.Read8() != 0
+ if t == MsgRlerror {
+ // Change the message type. We check for this special case
+ // after decoding below, and transform into an error.
+ r = &Rlerror{}
+ } else if r == nil {
+ nr, err := msgRegistry.get(0, t)
+ if err != nil {
+ return nil, err
+ }
+ r = nr // New message.
+ } else if t != r.Type() {
+ // Not an error and not the expected response; propagate.
+ return nil, &ErrBadResponse{Got: t, Want: r.Type()}
+ }
+
+ // Is there a payload? Copy from the latter portion.
+ if payloader, ok := r.(payloader); ok {
+ fs := payloader.FixedSize()
+ p := payloader.Payload()
+ payloadData := ch.buf.data[fs:]
+ if len(p) < len(payloadData) {
+ p = make([]byte, len(payloadData))
+ copy(p, payloadData)
+ payloader.SetPayload(p)
+ } else if n := copy(p, payloadData); n < len(p) {
+ payloader.SetPayload(p[:n])
+ }
+ ch.buf.data = ch.buf.data[:fs]
+ }
+
+ r.Decode(&ch.buf)
+ if ch.buf.isOverrun() {
+ // Nothing valid was available.
+ log.Debugf("recv [got %d bytes, needed more]", rsz)
+ return nil, ErrNoValidMessage
+ }
+
+ // Read any FD result.
+ if hasFD {
+ if rfd, err := ch.fds.RecvFDNonblock(); err == nil {
+ f := fd.New(rfd)
+ if filer, ok := r.(filer); ok {
+ // Set the payload.
+ filer.SetFilePayload(f)
+ } else {
+ // Don't want the FD.
+ f.Close()
+ }
+ } else {
+ // The header bit was set but nothing came in.
+ log.Warningf("expected FD, got err: %v", err)
+ }
+ }
+
+ // Log a message.
+ if log.IsLogging(log.Debug) {
+ log.Debugf("recv [channel @%p] %s", ch, r.String())
+ }
+
+ // Convert errors appropriately; see above.
+ if rlerr, ok := r.(*Rlerror); ok {
+ return nil, syscall.Errno(rlerr.Error)
+ }
+
+ return r, nil
+}
+
+// sendRecv sends the given message over the channel.
+//
+// This is used by the client.
+func (ch *channel) sendRecv(c *Client, m, r message) error {
+ rsz, err := ch.send(m)
+ if err != nil {
+ return err
+ }
+ _, err = ch.recv(r, rsz)
+ return err
+}
diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go
index cdb3bc841..2f50ff3ea 100644
--- a/pkg/p9/transport_test.go
+++ b/pkg/p9/transport_test.go
@@ -124,7 +124,9 @@ func TestSendRecvWithFile(t *testing.T) {
t.Fatalf("unable to create file: %v", err)
}
- if err := send(client, Tag(1), &Rlopen{File: f}); err != nil {
+ rlopen := &Rlopen{}
+ rlopen.SetFilePayload(f)
+ if err := send(client, Tag(1), rlopen); err != nil {
t.Fatalf("send got err %v expected nil", err)
}
diff --git a/pkg/p9/version.go b/pkg/p9/version.go
index c2a2885ae..f1ffdd23a 100644
--- a/pkg/p9/version.go
+++ b/pkg/p9/version.go
@@ -26,7 +26,7 @@ const (
//
// Clients are expected to start requesting this version number and
// to continuously decrement it until a Tversion request succeeds.
- highestSupportedVersion uint32 = 7
+ highestSupportedVersion uint32 = 8
// lowestSupportedVersion is the lowest supported version X in a
// version string of the format 9P2000.L.Google.X.
@@ -148,3 +148,10 @@ func VersionSupportsMultiUser(v uint32) bool {
func versionSupportsTallocate(v uint32) bool {
return v >= 7
}
+
+// versionSupportsFlipcall returns true if version v supports IPC channels from
+// the flipcall package. Note that these must be negotiated, but this version
+// string indicates that such a facility exists.
+func versionSupportsFlipcall(v uint32) bool {
+ return v >= 8
+}
diff --git a/pkg/procid/BUILD b/pkg/procid/BUILD
index 697e7a2f4..078f084b2 100644
--- a/pkg/procid/BUILD
+++ b/pkg/procid/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD
index 9c08452fc..827385139 100644
--- a/pkg/refs/BUILD
+++ b/pkg/refs/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "weak_ref_list",
diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go
index 828e9b5c1..ad69e0757 100644
--- a/pkg/refs/refcounter.go
+++ b/pkg/refs/refcounter.go
@@ -215,8 +215,8 @@ type AtomicRefCount struct {
type LeakMode uint32
const (
- // uninitializedLeakChecking indicates that the leak checker has not yet been initialized.
- uninitializedLeakChecking LeakMode = iota
+ // UninitializedLeakChecking indicates that the leak checker has not yet been initialized.
+ UninitializedLeakChecking LeakMode = iota
// NoLeakChecking indicates that no effort should be made to check for
// leaks.
@@ -318,7 +318,7 @@ func (r *AtomicRefCount) finalize() {
switch LeakMode(atomic.LoadUint32(&leakMode)) {
case NoLeakChecking:
return
- case uninitializedLeakChecking:
+ case UninitializedLeakChecking:
note = "(Leak checker uninitialized): "
}
if n := r.ReadRefs(); n != 0 {
diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD
index d1024e49d..af94e944d 100644
--- a/pkg/seccomp/BUILD
+++ b/pkg/seccomp/BUILD
@@ -1,5 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
-load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_embed_data")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_embed_data", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/seccomp/seccomp_unsafe.go b/pkg/seccomp/seccomp_unsafe.go
index 0a3d92854..be328db12 100644
--- a/pkg/seccomp/seccomp_unsafe.go
+++ b/pkg/seccomp/seccomp_unsafe.go
@@ -35,7 +35,7 @@ type sockFprog struct {
//go:nosplit
func SetFilter(instrs []linux.BPFInstruction) syscall.Errno {
// PR_SET_NO_NEW_PRIVS is required in order to enable seccomp. See seccomp(2) for details.
- if _, _, errno := syscall.RawSyscall(syscall.SYS_PRCTL, linux.PR_SET_NO_NEW_PRIVS, 1, 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PRCTL, linux.PR_SET_NO_NEW_PRIVS, 1, 0, 0, 0, 0); errno != 0 {
return errno
}
diff --git a/pkg/secio/BUILD b/pkg/secio/BUILD
index f38fb39f3..22abdc69f 100644
--- a/pkg/secio/BUILD
+++ b/pkg/secio/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/segment/test/BUILD b/pkg/segment/test/BUILD
index 694486296..12d7c77d2 100644
--- a/pkg/segment/test/BUILD
+++ b/pkg/segment/test/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(
default_visibility = ["//visibility:private"],
diff --git a/pkg/sentry/BUILD b/pkg/sentry/BUILD
index 53989301f..2d6379c86 100644
--- a/pkg/sentry/BUILD
+++ b/pkg/sentry/BUILD
@@ -8,5 +8,7 @@ package_group(
packages = [
"//pkg/sentry/...",
"//runsc/...",
+ # Code generated by go_marshal relies on go_marshal libraries.
+ "//tools/go_marshal/...",
],
)
diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD
index 7aace2d7b..c71cff9f3 100644
--- a/pkg/sentry/arch/BUILD
+++ b/pkg/sentry/arch/BUILD
@@ -1,4 +1,5 @@
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -42,6 +43,12 @@ proto_library(
visibility = ["//visibility:public"],
)
+cc_proto_library(
+ name = "registers_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":registers_proto"],
+)
+
go_proto_library(
name = "registers_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto",
diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD
index bf802d1b6..5522cecd0 100644
--- a/pkg/sentry/control/BUILD
+++ b/pkg/sentry/control/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/device/BUILD b/pkg/sentry/device/BUILD
index 7e8918722..0c86197f7 100644
--- a/pkg/sentry/device/BUILD
+++ b/pkg/sentry/device/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "device",
diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD
index d7259b47b..3119a61b6 100644
--- a/pkg/sentry/fs/BUILD
+++ b/pkg/sentry/fs/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "fs",
diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go
index fbca06761..3cb73bd78 100644
--- a/pkg/sentry/fs/dirent.go
+++ b/pkg/sentry/fs/dirent.go
@@ -1126,7 +1126,7 @@ func (d *Dirent) unmount(ctx context.Context, replacement *Dirent) error {
// Remove removes the given file or symlink. The root dirent is used to
// resolve name, and must not be nil.
-func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string) error {
+func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string, dirPath bool) error {
// Check the root.
if root == nil {
panic("Dirent.Remove: root must not be nil")
@@ -1151,6 +1151,8 @@ func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string) error {
// Remove cannot remove directories.
if IsDir(child.Inode.StableAttr) {
return syscall.EISDIR
+ } else if dirPath {
+ return syscall.ENOTDIR
}
// Remove cannot remove a mount point.
diff --git a/pkg/sentry/fs/dirent_refs_test.go b/pkg/sentry/fs/dirent_refs_test.go
index 884e3ff06..47bc72a88 100644
--- a/pkg/sentry/fs/dirent_refs_test.go
+++ b/pkg/sentry/fs/dirent_refs_test.go
@@ -343,7 +343,7 @@ func TestRemoveExtraRefs(t *testing.T) {
}
d := f.Dirent
- if err := test.root.Remove(contexttest.Context(t), test.root, name); err != nil {
+ if err := test.root.Remove(contexttest.Context(t), test.root, name, false /* dirPath */); err != nil {
t.Fatalf("root.Remove(root, %q) failed: %v", name, err)
}
diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD
index bf00b9c09..b9bd9ed17 100644
--- a/pkg/sentry/fs/fdpipe/BUILD
+++ b/pkg/sentry/fs/fdpipe/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "fdpipe",
diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go
index bb8117f89..c0a6e884b 100644
--- a/pkg/sentry/fs/file.go
+++ b/pkg/sentry/fs/file.go
@@ -515,6 +515,11 @@ type lockedReader struct {
// File is the file to read from.
File *File
+
+ // Offset is the offset to start at.
+ //
+ // This applies only to Read, not ReadAt.
+ Offset int64
}
// Read implements io.Reader.Read.
@@ -522,7 +527,8 @@ func (r *lockedReader) Read(buf []byte) (int, error) {
if r.Ctx.Interrupted() {
return 0, syserror.ErrInterrupted
}
- n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), r.File.offset)
+ n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), r.Offset)
+ r.Offset += n
return int(n), err
}
@@ -544,11 +550,21 @@ type lockedWriter struct {
// File is the file to write to.
File *File
+
+ // Offset is the offset to start at.
+ //
+ // This applies only to Write, not WriteAt.
+ Offset int64
}
// Write implements io.Writer.Write.
func (w *lockedWriter) Write(buf []byte) (int, error) {
- return w.WriteAt(buf, w.File.offset)
+ if w.Ctx.Interrupted() {
+ return 0, syserror.ErrInterrupted
+ }
+ n, err := w.WriteAt(buf, w.Offset)
+ w.Offset += int64(n)
+ return int(n), err
}
// WriteAt implements io.Writer.WriteAt.
@@ -562,6 +578,9 @@ func (w *lockedWriter) WriteAt(buf []byte, offset int64) (int, error) {
// io.Copy, since our own Write interface does not have this same
// contract. Enforce that here.
for written < len(buf) {
+ if w.Ctx.Interrupted() {
+ return written, syserror.ErrInterrupted
+ }
var n int64
n, err = w.File.FileOperations.Write(w.Ctx, w.File, usermem.BytesIOSequence(buf[written:]), offset+int64(written))
if n > 0 {
diff --git a/pkg/sentry/fs/file_operations.go b/pkg/sentry/fs/file_operations.go
index d86f5bf45..b88303f17 100644
--- a/pkg/sentry/fs/file_operations.go
+++ b/pkg/sentry/fs/file_operations.go
@@ -15,6 +15,8 @@
package fs
import (
+ "io"
+
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/memmap"
@@ -105,8 +107,11 @@ type FileOperations interface {
// on the destination, following by a buffered copy with standard Read
// and Write operations.
//
+ // If dup is set, the data should be duplicated into the destination
+ // and retained.
+ //
// The same preconditions as Read apply.
- WriteTo(ctx context.Context, file *File, dst *File, opts SpliceOpts) (int64, error)
+ WriteTo(ctx context.Context, file *File, dst io.Writer, count int64, dup bool) (int64, error)
// Write writes src to file at offset and returns the number of bytes
// written which must be greater than or equal to 0. Like Read, file
@@ -126,7 +131,7 @@ type FileOperations interface {
// source. See WriteTo for details regarding how this is called.
//
// The same preconditions as Write apply; FileFlags.Write must be set.
- ReadFrom(ctx context.Context, file *File, src *File, opts SpliceOpts) (int64, error)
+ ReadFrom(ctx context.Context, file *File, src io.Reader, count int64) (int64, error)
// Fsync writes buffered modifications of file and/or flushes in-flight
// operations to backing storage based on syncType. The range to sync is
diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go
index 9820f0b13..225e40186 100644
--- a/pkg/sentry/fs/file_overlay.go
+++ b/pkg/sentry/fs/file_overlay.go
@@ -15,6 +15,7 @@
package fs
import (
+ "io"
"sync"
"gvisor.dev/gvisor/pkg/refs"
@@ -268,9 +269,9 @@ func (f *overlayFileOperations) Read(ctx context.Context, file *File, dst userme
}
// WriteTo implements FileOperations.WriteTo.
-func (f *overlayFileOperations) WriteTo(ctx context.Context, file *File, dst *File, opts SpliceOpts) (n int64, err error) {
+func (f *overlayFileOperations) WriteTo(ctx context.Context, file *File, dst io.Writer, count int64, dup bool) (n int64, err error) {
err = f.onTop(ctx, file, func(file *File, ops FileOperations) error {
- n, err = ops.WriteTo(ctx, file, dst, opts)
+ n, err = ops.WriteTo(ctx, file, dst, count, dup)
return err // Will overwrite itself.
})
return
@@ -285,9 +286,9 @@ func (f *overlayFileOperations) Write(ctx context.Context, file *File, src userm
}
// ReadFrom implements FileOperations.ReadFrom.
-func (f *overlayFileOperations) ReadFrom(ctx context.Context, file *File, src *File, opts SpliceOpts) (n int64, err error) {
+func (f *overlayFileOperations) ReadFrom(ctx context.Context, file *File, src io.Reader, count int64) (n int64, err error) {
// See above; f.upper must be non-nil.
- return f.upper.FileOperations.ReadFrom(ctx, f.upper, src, opts)
+ return f.upper.FileOperations.ReadFrom(ctx, f.upper, src, count)
}
// Fsync implements FileOperations.Fsync.
diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD
index 6499f87ac..b4ac83dc4 100644
--- a/pkg/sentry/fs/fsutil/BUILD
+++ b/pkg/sentry/fs/fsutil/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "dirty_set_impl",
diff --git a/pkg/sentry/fs/fsutil/file.go b/pkg/sentry/fs/fsutil/file.go
index 626b9126a..fc5b3b1a1 100644
--- a/pkg/sentry/fs/fsutil/file.go
+++ b/pkg/sentry/fs/fsutil/file.go
@@ -15,6 +15,8 @@
package fsutil
import (
+ "io"
+
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -228,12 +230,12 @@ func (FileNoIoctl) Ioctl(context.Context, *fs.File, usermem.IO, arch.SyscallArgu
type FileNoSplice struct{}
// WriteTo implements fs.FileOperations.WriteTo.
-func (FileNoSplice) WriteTo(context.Context, *fs.File, *fs.File, fs.SpliceOpts) (int64, error) {
+func (FileNoSplice) WriteTo(context.Context, *fs.File, io.Writer, int64, bool) (int64, error) {
return 0, syserror.ENOSYS
}
// ReadFrom implements fs.FileOperations.ReadFrom.
-func (FileNoSplice) ReadFrom(context.Context, *fs.File, *fs.File, fs.SpliceOpts) (int64, error) {
+func (FileNoSplice) ReadFrom(context.Context, *fs.File, io.Reader, int64) (int64, error) {
return 0, syserror.ENOSYS
}
diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go
index d2495cb83..693625ddc 100644
--- a/pkg/sentry/fs/fsutil/host_mappable.go
+++ b/pkg/sentry/fs/fsutil/host_mappable.go
@@ -144,7 +144,7 @@ func (h *HostMappable) Truncate(ctx context.Context, newSize int64) error {
mask := fs.AttrMask{Size: true}
attr := fs.UnstableAttr{Size: newSize}
- if err := h.backingFile.SetMaskedAttributes(ctx, mask, attr); err != nil {
+ if err := h.backingFile.SetMaskedAttributes(ctx, mask, attr, false); err != nil {
return err
}
diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go
index e70bc28fb..dd80757dc 100644
--- a/pkg/sentry/fs/fsutil/inode_cached.go
+++ b/pkg/sentry/fs/fsutil/inode_cached.go
@@ -66,10 +66,8 @@ type CachingInodeOperations struct {
// mfp is used to allocate memory that caches backingFile's contents.
mfp pgalloc.MemoryFileProvider
- // forcePageCache indicates the sentry page cache should be used regardless
- // of whether the platform supports host mapped I/O or not. This must not be
- // modified after inode creation.
- forcePageCache bool
+ // opts contains options. opts is immutable.
+ opts CachingInodeOperationsOptions
attrMu sync.Mutex `state:"nosave"`
@@ -116,6 +114,20 @@ type CachingInodeOperations struct {
refs frameRefSet
}
+// CachingInodeOperationsOptions configures a CachingInodeOperations.
+//
+// +stateify savable
+type CachingInodeOperationsOptions struct {
+ // If ForcePageCache is true, use the sentry page cache even if a host file
+ // descriptor is available.
+ ForcePageCache bool
+
+ // If LimitHostFDTranslation is true, apply maxFillRange() constraints to
+ // host file descriptor mappings returned by
+ // CachingInodeOperations.Translate().
+ LimitHostFDTranslation bool
+}
+
// CachedFileObject is a file that may require caching.
type CachedFileObject interface {
// ReadToBlocksAt reads up to dsts.NumBytes() bytes from the file to dsts,
@@ -128,12 +140,16 @@ type CachedFileObject interface {
// WriteFromBlocksAt may return a partial write without an error.
WriteFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)
- // SetMaskedAttributes sets the attributes in attr that are true in mask
- // on the backing file.
+ // SetMaskedAttributes sets the attributes in attr that are true in
+ // mask on the backing file. If the mask contains only ATime or MTime
+ // and the CachedFileObject has an FD to the file, then this operation
+ // is a noop unless forceSetTimestamps is true. This avoids an extra
+ // RPC to the gofer in the open-read/write-close case, when the
+ // timestamps on the file will be updated by the host kernel for us.
//
// SetMaskedAttributes may be called at any point, regardless of whether
// the file was opened.
- SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error
+ SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, forceSetTimestamps bool) error
// Allocate allows the caller to reserve disk space for the inode.
// It's equivalent to fallocate(2) with 'mode=0'.
@@ -159,7 +175,7 @@ type CachedFileObject interface {
// NewCachingInodeOperations returns a new CachingInodeOperations backed by
// a CachedFileObject and its initial unstable attributes.
-func NewCachingInodeOperations(ctx context.Context, backingFile CachedFileObject, uattr fs.UnstableAttr, forcePageCache bool) *CachingInodeOperations {
+func NewCachingInodeOperations(ctx context.Context, backingFile CachedFileObject, uattr fs.UnstableAttr, opts CachingInodeOperationsOptions) *CachingInodeOperations {
mfp := pgalloc.MemoryFileProviderFromContext(ctx)
if mfp == nil {
panic(fmt.Sprintf("context.Context %T lacks non-nil value for key %T", ctx, pgalloc.CtxMemoryFileProvider))
@@ -167,7 +183,7 @@ func NewCachingInodeOperations(ctx context.Context, backingFile CachedFileObject
return &CachingInodeOperations{
backingFile: backingFile,
mfp: mfp,
- forcePageCache: forcePageCache,
+ opts: opts,
attr: uattr,
hostFileMapper: NewHostFileMapper(),
}
@@ -212,7 +228,7 @@ func (c *CachingInodeOperations) SetPermissions(ctx context.Context, inode *fs.I
now := ktime.NowFromContext(ctx)
masked := fs.AttrMask{Perms: true}
- if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Perms: perms}); err != nil {
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Perms: perms}, false); err != nil {
return false
}
c.attr.Perms = perms
@@ -234,7 +250,7 @@ func (c *CachingInodeOperations) SetOwner(ctx context.Context, inode *fs.Inode,
UID: owner.UID.Ok(),
GID: owner.GID.Ok(),
}
- if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Owner: owner}); err != nil {
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Owner: owner}, false); err != nil {
return err
}
if owner.UID.Ok() {
@@ -270,7 +286,9 @@ func (c *CachingInodeOperations) SetTimestamps(ctx context.Context, inode *fs.In
AccessTime: !ts.ATimeOmit,
ModificationTime: !ts.MTimeOmit,
}
- if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{AccessTime: ts.ATime, ModificationTime: ts.MTime}); err != nil {
+ // Call SetMaskedAttributes with forceSetTimestamps = true to make sure
+ // the timestamp is updated.
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{AccessTime: ts.ATime, ModificationTime: ts.MTime}, true); err != nil {
return err
}
if !ts.ATimeOmit {
@@ -293,7 +311,7 @@ func (c *CachingInodeOperations) Truncate(ctx context.Context, inode *fs.Inode,
now := ktime.NowFromContext(ctx)
masked := fs.AttrMask{Size: true}
attr := fs.UnstableAttr{Size: size}
- if err := c.backingFile.SetMaskedAttributes(ctx, masked, attr); err != nil {
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, attr, false); err != nil {
c.dataMu.Unlock()
return err
}
@@ -382,7 +400,7 @@ func (c *CachingInodeOperations) WriteOut(ctx context.Context, inode *fs.Inode)
c.dirtyAttr.Size = false
// Write out cached attributes.
- if err := c.backingFile.SetMaskedAttributes(ctx, c.dirtyAttr, c.attr); err != nil {
+ if err := c.backingFile.SetMaskedAttributes(ctx, c.dirtyAttr, c.attr, false); err != nil {
c.attrMu.Unlock()
return err
}
@@ -763,7 +781,7 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error
// and memory mappings, and false if c.cache may contain data cached from
// c.backingFile.
func (c *CachingInodeOperations) useHostPageCache() bool {
- return !c.forcePageCache && c.backingFile.FD() >= 0
+ return !c.opts.ForcePageCache && c.backingFile.FD() >= 0
}
// AddMapping implements memmap.Mappable.AddMapping.
@@ -784,11 +802,6 @@ func (c *CachingInodeOperations) AddMapping(ctx context.Context, ms memmap.Mappi
mf.MarkUnevictable(c, pgalloc.EvictableRange{r.Start, r.End})
}
}
- if c.useHostPageCache() && !usage.IncrementalMappedAccounting {
- for _, r := range mapped {
- usage.MemoryAccounting.Inc(r.Length(), usage.Mapped)
- }
- }
c.mapsMu.Unlock()
return nil
}
@@ -802,11 +815,6 @@ func (c *CachingInodeOperations) RemoveMapping(ctx context.Context, ms memmap.Ma
c.hostFileMapper.DecRefOn(r)
}
if c.useHostPageCache() {
- if !usage.IncrementalMappedAccounting {
- for _, r := range unmapped {
- usage.MemoryAccounting.Dec(r.Length(), usage.Mapped)
- }
- }
c.mapsMu.Unlock()
return
}
@@ -835,11 +843,15 @@ func (c *CachingInodeOperations) CopyMapping(ctx context.Context, ms memmap.Mapp
func (c *CachingInodeOperations) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
// Hot path. Avoid defer.
if c.useHostPageCache() {
+ mr := optional
+ if c.opts.LimitHostFDTranslation {
+ mr = maxFillRange(required, optional)
+ }
return []memmap.Translation{
{
- Source: optional,
+ Source: mr,
File: c,
- Offset: optional.Start,
+ Offset: mr.Start,
Perms: usermem.AnyAccess,
},
}, nil
@@ -985,9 +997,7 @@ func (c *CachingInodeOperations) IncRef(fr platform.FileRange) {
seg, gap = seg.NextNonEmpty()
case gap.Ok() && gap.Start() < fr.End:
newRange := gap.Range().Intersect(fr)
- if usage.IncrementalMappedAccounting {
- usage.MemoryAccounting.Inc(newRange.Length(), usage.Mapped)
- }
+ usage.MemoryAccounting.Inc(newRange.Length(), usage.Mapped)
seg, gap = c.refs.InsertWithoutMerging(gap, newRange, 1).NextNonEmpty()
default:
c.refs.MergeAdjacent(fr)
@@ -1008,9 +1018,7 @@ func (c *CachingInodeOperations) DecRef(fr platform.FileRange) {
for seg.Ok() && seg.Start() < fr.End {
seg = c.refs.Isolate(seg, fr)
if old := seg.Value(); old == 1 {
- if usage.IncrementalMappedAccounting {
- usage.MemoryAccounting.Dec(seg.Range().Length(), usage.Mapped)
- }
+ usage.MemoryAccounting.Dec(seg.Range().Length(), usage.Mapped)
seg = c.refs.Remove(seg).NextSegment()
} else {
seg.SetValue(old - 1)
diff --git a/pkg/sentry/fs/fsutil/inode_cached_test.go b/pkg/sentry/fs/fsutil/inode_cached_test.go
index dc19255ed..129f314c8 100644
--- a/pkg/sentry/fs/fsutil/inode_cached_test.go
+++ b/pkg/sentry/fs/fsutil/inode_cached_test.go
@@ -39,7 +39,7 @@ func (noopBackingFile) WriteFromBlocksAt(ctx context.Context, srcs safemem.Block
return srcs.NumBytes(), nil
}
-func (noopBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr) error {
+func (noopBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr, bool) error {
return nil
}
@@ -61,7 +61,7 @@ func TestSetPermissions(t *testing.T) {
uattr := fs.WithCurrentTime(ctx, fs.UnstableAttr{
Perms: fs.FilePermsFromMode(0444),
})
- iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, false /*forcePageCache*/)
+ iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{})
defer iops.Release()
perms := fs.FilePermsFromMode(0777)
@@ -150,7 +150,7 @@ func TestSetTimestamps(t *testing.T) {
ModificationTime: epoch,
StatusChangeTime: epoch,
}
- iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, false /*forcePageCache*/)
+ iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{})
defer iops.Release()
if err := iops.SetTimestamps(ctx, nil, test.ts); err != nil {
@@ -188,7 +188,7 @@ func TestTruncate(t *testing.T) {
uattr := fs.UnstableAttr{
Size: 0,
}
- iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, false /*forcePageCache*/)
+ iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{})
defer iops.Release()
if err := iops.Truncate(ctx, nil, uattr.Size); err != nil {
@@ -230,7 +230,7 @@ func (f *sliceBackingFile) WriteFromBlocksAt(ctx context.Context, srcs safemem.B
return w.WriteFromBlocks(srcs)
}
-func (*sliceBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr) error {
+func (*sliceBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr, bool) error {
return nil
}
@@ -280,7 +280,7 @@ func TestRead(t *testing.T) {
uattr := fs.UnstableAttr{
Size: int64(len(buf)),
}
- iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, false /*forcePageCache*/)
+ iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, CachingInodeOperationsOptions{})
defer iops.Release()
// Expect the cache to be initially empty.
@@ -336,7 +336,7 @@ func TestWrite(t *testing.T) {
uattr := fs.UnstableAttr{
Size: int64(len(buf)),
}
- iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, false /*forcePageCache*/)
+ iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, CachingInodeOperationsOptions{})
defer iops.Release()
// Expect the cache to be initially empty.
diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD
index 6b993928c..2b71ca0e1 100644
--- a/pkg/sentry/fs/gofer/BUILD
+++ b/pkg/sentry/fs/gofer/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "gofer",
diff --git a/pkg/sentry/fs/gofer/fs.go b/pkg/sentry/fs/gofer/fs.go
index 69999dc28..8f8ab5d29 100644
--- a/pkg/sentry/fs/gofer/fs.go
+++ b/pkg/sentry/fs/gofer/fs.go
@@ -54,6 +54,10 @@ const (
// sandbox using files backed by the gofer. If set to false, unix sockets
// cannot be bound to gofer files without an overlay on top.
privateUnixSocketKey = "privateunixsocket"
+
+ // If present, sets CachingInodeOperationsOptions.LimitHostFDTranslation to
+ // true.
+ limitHostFDTranslationKey = "limit_host_fd_translation"
)
// defaultAname is the default attach name.
@@ -134,12 +138,13 @@ func (f *filesystem) Mount(ctx context.Context, device string, flags fs.MountSou
// opts are parsed 9p mount options.
type opts struct {
- fd int
- aname string
- policy cachePolicy
- msize uint32
- version string
- privateunixsocket bool
+ fd int
+ aname string
+ policy cachePolicy
+ msize uint32
+ version string
+ privateunixsocket bool
+ limitHostFDTranslation bool
}
// options parses mount(2) data into structured options.
@@ -237,6 +242,11 @@ func options(data string) (opts, error) {
delete(options, privateUnixSocketKey)
}
+ if _, ok := options[limitHostFDTranslationKey]; ok {
+ o.limitHostFDTranslation = true
+ delete(options, limitHostFDTranslationKey)
+ }
+
// Fail to attach if the caller wanted us to do something that we
// don't support.
if len(options) > 0 {
diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go
index 95b064aea..d918d6620 100644
--- a/pkg/sentry/fs/gofer/inode.go
+++ b/pkg/sentry/fs/gofer/inode.go
@@ -215,8 +215,8 @@ func (i *inodeFileState) WriteFromBlocksAt(ctx context.Context, srcs safemem.Blo
}
// SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes.
-func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error {
- if i.skipSetAttr(mask) {
+func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, forceSetTimestamps bool) error {
+ if i.skipSetAttr(mask, forceSetTimestamps) {
return nil
}
as, ans := attr.AccessTime.Unix()
@@ -251,13 +251,14 @@ func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMa
// when:
// - Mask is empty
// - Mask contains only attributes that cannot be set in the gofer
-// - Mask contains only atime and/or mtime, and host FD exists
+// - forceSetTimestamps is false and mask contains only atime and/or mtime
+// and host FD exists
//
// Updates to atime and mtime can be skipped because cached value will be
// "close enough" to host value, given that operation went directly to host FD.
// Skipping atime updates is particularly important to reduce the number of
// operations sent to the Gofer for readonly files.
-func (i *inodeFileState) skipSetAttr(mask fs.AttrMask) bool {
+func (i *inodeFileState) skipSetAttr(mask fs.AttrMask, forceSetTimestamps bool) bool {
// First remove attributes that cannot be updated.
cpy := mask
cpy.Type = false
@@ -277,6 +278,12 @@ func (i *inodeFileState) skipSetAttr(mask fs.AttrMask) bool {
return false
}
+ // If forceSetTimestamps was passed, then we cannot skip.
+ if forceSetTimestamps {
+ return false
+ }
+
+ // Skip if we have a host FD.
i.handlesMu.RLock()
defer i.handlesMu.RUnlock()
return (i.readHandles != nil && i.readHandles.Host != nil) ||
diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go
index 69d08a627..50da865c1 100644
--- a/pkg/sentry/fs/gofer/session.go
+++ b/pkg/sentry/fs/gofer/session.go
@@ -117,6 +117,11 @@ type session struct {
// Flags provided to the mount.
superBlockFlags fs.MountSourceFlags `state:"wait"`
+ // limitHostFDTranslation is the value used for
+ // CachingInodeOperationsOptions.LimitHostFDTranslation for all
+ // CachingInodeOperations created by the session.
+ limitHostFDTranslation bool
+
// connID is a unique identifier for the session connection.
connID string `state:"wait"`
@@ -218,8 +223,11 @@ func newInodeOperations(ctx context.Context, s *session, file contextFile, qid p
uattr := unstable(ctx, valid, attr, s.mounter, s.client)
return sattr, &inodeOperations{
- fileState: fileState,
- cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, s.superBlockFlags.ForcePageCache),
+ fileState: fileState,
+ cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, fsutil.CachingInodeOperationsOptions{
+ ForcePageCache: s.superBlockFlags.ForcePageCache,
+ LimitHostFDTranslation: s.limitHostFDTranslation,
+ }),
}
}
@@ -242,13 +250,14 @@ func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockF
// Construct the session.
s := session{
- connID: dev,
- msize: o.msize,
- version: o.version,
- cachePolicy: o.policy,
- aname: o.aname,
- superBlockFlags: superBlockFlags,
- mounter: mounter,
+ connID: dev,
+ msize: o.msize,
+ version: o.version,
+ cachePolicy: o.policy,
+ aname: o.aname,
+ superBlockFlags: superBlockFlags,
+ limitHostFDTranslation: o.limitHostFDTranslation,
+ mounter: mounter,
}
s.EnableLeakCheck("gofer.session")
diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD
index b1080fb1a..3e532332e 100644
--- a/pkg/sentry/fs/host/BUILD
+++ b/pkg/sentry/fs/host/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "host",
diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go
index 679d8321a..a6e4a09e3 100644
--- a/pkg/sentry/fs/host/inode.go
+++ b/pkg/sentry/fs/host/inode.go
@@ -114,7 +114,7 @@ func (i *inodeFileState) WriteFromBlocksAt(ctx context.Context, srcs safemem.Blo
}
// SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes.
-func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error {
+func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, _ bool) error {
if mask.Empty() {
return nil
}
@@ -163,7 +163,7 @@ func (i *inodeFileState) unstableAttr(ctx context.Context) (fs.UnstableAttr, err
return unstableAttr(i.mops, &s), nil
}
-// SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes.
+// Allocate implements fsutil.CachedFileObject.Allocate.
func (i *inodeFileState) Allocate(_ context.Context, offset, length int64) error {
return syscall.Fallocate(i.FD(), 0, offset, length)
}
@@ -200,8 +200,10 @@ func newInode(ctx context.Context, msrc *fs.MountSource, fd int, saveable bool,
// Build the fs.InodeOperations.
uattr := unstableAttr(msrc.MountSourceOperations.(*superOperations), &s)
iops := &inodeOperations{
- fileState: fileState,
- cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, msrc.Flags.ForcePageCache),
+ fileState: fileState,
+ cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, fsutil.CachingInodeOperationsOptions{
+ ForcePageCache: msrc.Flags.ForcePageCache,
+ }),
}
// Return the fs.Inode.
diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go
index 2526412a4..90331e3b2 100644
--- a/pkg/sentry/fs/host/tty.go
+++ b/pkg/sentry/fs/host/tty.go
@@ -43,12 +43,15 @@ type TTYFileOperations struct {
// fgProcessGroup is the foreground process group that is currently
// connected to this TTY.
fgProcessGroup *kernel.ProcessGroup
+
+ termios linux.KernelTermios
}
// newTTYFile returns a new fs.File that wraps a TTY FD.
func newTTYFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags, iops *inodeOperations) *fs.File {
return fs.NewFile(ctx, dirent, flags, &TTYFileOperations{
fileOperations: fileOperations{iops: iops},
+ termios: linux.DefaultSlaveTermios,
})
}
@@ -97,9 +100,12 @@ func (t *TTYFileOperations) Write(ctx context.Context, file *fs.File, src userme
t.mu.Lock()
defer t.mu.Unlock()
- // Are we allowed to do the write?
- if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
- return 0, err
+ // Check whether TOSTOP is enabled. This corresponds to the check in
+ // drivers/tty/n_tty.c:n_tty_write().
+ if t.termios.LEnabled(linux.TOSTOP) {
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ return 0, err
+ }
}
return t.fileOperations.Write(ctx, file, src, offset)
}
@@ -144,6 +150,9 @@ func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO
return 0, err
}
err := ioctlSetTermios(fd, ioctl, &termios)
+ if err == nil {
+ t.termios.FromTermios(termios)
+ }
return 0, err
case linux.TIOCGPGRP:
diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go
index c7f4e2d13..ba3e0233d 100644
--- a/pkg/sentry/fs/inotify.go
+++ b/pkg/sentry/fs/inotify.go
@@ -15,6 +15,7 @@
package fs
import (
+ "io"
"sync"
"sync/atomic"
@@ -172,7 +173,7 @@ func (i *Inotify) Read(ctx context.Context, _ *File, dst usermem.IOSequence, _ i
}
// WriteTo implements FileOperations.WriteTo.
-func (*Inotify) WriteTo(context.Context, *File, *File, SpliceOpts) (int64, error) {
+func (*Inotify) WriteTo(context.Context, *File, io.Writer, int64, bool) (int64, error) {
return 0, syserror.ENOSYS
}
@@ -182,7 +183,7 @@ func (*Inotify) Fsync(context.Context, *File, int64, int64, SyncType) error {
}
// ReadFrom implements FileOperations.ReadFrom.
-func (*Inotify) ReadFrom(context.Context, *File, *File, SpliceOpts) (int64, error) {
+func (*Inotify) ReadFrom(context.Context, *File, io.Reader, int64) (int64, error) {
return 0, syserror.ENOSYS
}
diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD
index 08d7c0c57..5a7a5b8cd 100644
--- a/pkg/sentry/fs/lock/BUILD
+++ b/pkg/sentry/fs/lock/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "lock_range",
diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go
index 9b713e785..ac0398bd9 100644
--- a/pkg/sentry/fs/mounts.go
+++ b/pkg/sentry/fs/mounts.go
@@ -171,8 +171,6 @@ type MountNamespace struct {
// NewMountNamespace returns a new MountNamespace, with the provided node at the
// root, and the given cache size. A root must always be provided.
func NewMountNamespace(ctx context.Context, root *Inode) (*MountNamespace, error) {
- creds := auth.CredentialsFromContext(ctx)
-
// Set the root dirent and id on the root mount. The reference returned from
// NewDirent will be donated to the MountNamespace constructed below.
d := NewDirent(ctx, root, "/")
@@ -181,6 +179,7 @@ func NewMountNamespace(ctx context.Context, root *Inode) (*MountNamespace, error
d: newRootMount(1, d),
}
+ creds := auth.CredentialsFromContext(ctx)
mns := MountNamespace{
userns: creds.UserNamespace,
root: d,
diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD
index 70ed854a8..1c93e8886 100644
--- a/pkg/sentry/fs/proc/BUILD
+++ b/pkg/sentry/fs/proc/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "proc",
@@ -31,7 +33,6 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/log",
"//pkg/sentry/context",
"//pkg/sentry/fs",
diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go
index 9adb23608..5e28982c5 100644
--- a/pkg/sentry/fs/proc/net.go
+++ b/pkg/sentry/fs/proc/net.go
@@ -17,10 +17,10 @@ package proc
import (
"bytes"
"fmt"
+ "io"
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -28,9 +28,11 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
)
// newNet creates a new proc net entry.
@@ -57,9 +59,8 @@ func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSo
"ptype": newStaticProcInode(ctx, msrc, []byte("Type Device Function")),
"route": newStaticProcInode(ctx, msrc, []byte("Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT")),
"tcp": seqfile.NewSeqFileInode(ctx, &netTCP{k: k}, msrc),
- "udp": newStaticProcInode(ctx, msrc, []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode ref pointer drops")),
-
- "unix": seqfile.NewSeqFileInode(ctx, &netUnix{k: k}, msrc),
+ "udp": seqfile.NewSeqFileInode(ctx, &netUDP{k: k}, msrc),
+ "unix": seqfile.NewSeqFileInode(ctx, &netUnix{k: k}, msrc),
}
if s.SupportsIPv6() {
@@ -216,7 +217,7 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s
for _, se := range n.k.ListSockets() {
s := se.Sock.Get()
if s == nil {
- log.Debugf("Couldn't resolve weakref %v in socket table, racing with destruction?", se.Sock)
+ log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID)
continue
}
sfile := s.(*fs.File)
@@ -297,6 +298,42 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s
return data, 0
}
+func networkToHost16(n uint16) uint16 {
+ // n is in network byte order, so is big-endian. The most-significant byte
+ // should be stored in the lower address.
+ //
+ // We manually inline binary.BigEndian.Uint16() because Go does not support
+ // non-primitive consts, so binary.BigEndian is a (mutable) var, so calls to
+ // binary.BigEndian.Uint16() require a read of binary.BigEndian and an
+ // interface method call, defeating inlining.
+ buf := [2]byte{byte(n >> 8 & 0xff), byte(n & 0xff)}
+ return usermem.ByteOrder.Uint16(buf[:])
+}
+
+func writeInetAddr(w io.Writer, a linux.SockAddrInet) {
+ // linux.SockAddrInet.Port is stored in the network byte order and is
+ // printed like a number in host byte order. Note that all numbers in host
+ // byte order are printed with the most-significant byte first when
+ // formatted with %X. See get_tcp4_sock() and udp4_format_sock() in Linux.
+ port := networkToHost16(a.Port)
+
+ // linux.SockAddrInet.Addr is stored as a byte slice in big-endian order
+ // (i.e. most-significant byte in index 0). Linux represents this as a
+ // __be32 which is a typedef for an unsigned int, and is printed with
+ // %X. This means that for a little-endian machine, Linux prints the
+ // least-significant byte of the address first. To emulate this, we first
+ // invert the byte order for the address using usermem.ByteOrder.Uint32,
+ // which makes it have the equivalent encoding to a __be32 on a little
+ // endian machine. Note that this operation is a no-op on a big endian
+ // machine. Then similar to Linux, we format it with %X, which will print
+ // the most-significant byte of the __be32 address first, which is now
+ // actually the least-significant byte of the original address in
+ // linux.SockAddrInet.Addr on little endian machines, due to the conversion.
+ addr := usermem.ByteOrder.Uint32(a.Addr[:])
+
+ fmt.Fprintf(w, "%08X:%04X ", addr, port)
+}
+
// netTCP implements seqfile.SeqSource for /proc/net/tcp.
//
// +stateify savable
@@ -311,6 +348,9 @@ func (*netTCP) NeedsUpdate(generation int64) bool {
// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ // t may be nil here if our caller is not part of a task goroutine. This can
+ // happen for example if we're here for "sentryctl cat". When t is nil,
+ // degrade gracefully and retrieve what we can.
t := kernel.TaskFromContext(ctx)
if h != nil {
@@ -321,7 +361,7 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
for _, se := range n.k.ListSockets() {
s := se.Sock.Get()
if s == nil {
- log.Debugf("Couldn't resolve weakref %+v in socket table, racing with destruction?", se.Sock)
+ log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID)
continue
}
sfile := s.(*fs.File)
@@ -343,27 +383,23 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
// Field: sl; entry number.
fmt.Fprintf(&buf, "%4d: ", se.ID)
- portBuf := make([]byte, 2)
-
// Field: local_adddress.
var localAddr linux.SockAddrInet
- if local, _, err := sops.GetSockName(t); err == nil {
- localAddr = *local.(*linux.SockAddrInet)
+ if t != nil {
+ if local, _, err := sops.GetSockName(t); err == nil {
+ localAddr = *local.(*linux.SockAddrInet)
+ }
}
- binary.LittleEndian.PutUint16(portBuf, localAddr.Port)
- fmt.Fprintf(&buf, "%08X:%04X ",
- binary.LittleEndian.Uint32(localAddr.Addr[:]),
- portBuf)
+ writeInetAddr(&buf, localAddr)
// Field: rem_address.
var remoteAddr linux.SockAddrInet
- if remote, _, err := sops.GetPeerName(t); err == nil {
- remoteAddr = *remote.(*linux.SockAddrInet)
+ if t != nil {
+ if remote, _, err := sops.GetPeerName(t); err == nil {
+ remoteAddr = *remote.(*linux.SockAddrInet)
+ }
}
- binary.LittleEndian.PutUint16(portBuf, remoteAddr.Port)
- fmt.Fprintf(&buf, "%08X:%04X ",
- binary.LittleEndian.Uint32(remoteAddr.Addr[:]),
- portBuf)
+ writeInetAddr(&buf, remoteAddr)
// Field: state; socket state.
fmt.Fprintf(&buf, "%02X ", sops.State())
@@ -386,7 +422,8 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
log.Warningf("Failed to retrieve unstable attr for socket file: %v", err)
fmt.Fprintf(&buf, "%5d ", 0)
} else {
- fmt.Fprintf(&buf, "%5d ", uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()))
+ creds := auth.CredentialsFromContext(ctx)
+ fmt.Fprintf(&buf, "%5d ", uint32(uattr.Owner.UID.In(creds.UserNamespace).OrOverflow()))
}
// Field: timeout; number of unanswered 0-window probes.
@@ -438,3 +475,125 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
}
return data, 0
}
+
+// netUDP implements seqfile.SeqSource for /proc/net/udp.
+//
+// +stateify savable
+type netUDP struct {
+ k *kernel.Kernel
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (*netUDP) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ // t may be nil here if our caller is not part of a task goroutine. This can
+ // happen for example if we're here for "sentryctl cat". When t is nil,
+ // degrade gracefully and retrieve what we can.
+ t := kernel.TaskFromContext(ctx)
+
+ if h != nil {
+ return nil, 0
+ }
+
+ var buf bytes.Buffer
+ for _, se := range n.k.ListSockets() {
+ s := se.Sock.Get()
+ if s == nil {
+ log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID)
+ continue
+ }
+ sfile := s.(*fs.File)
+ sops, ok := sfile.FileOperations.(socket.Socket)
+ if !ok {
+ panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile))
+ }
+ if family, stype, _ := sops.Type(); family != linux.AF_INET || stype != linux.SOCK_DGRAM {
+ s.DecRef()
+ // Not udp4 socket.
+ continue
+ }
+
+ // For Linux's implementation, see net/ipv4/udp.c:udp4_format_sock().
+
+ // Field: sl; entry number.
+ fmt.Fprintf(&buf, "%5d: ", se.ID)
+
+ // Field: local_adddress.
+ var localAddr linux.SockAddrInet
+ if t != nil {
+ if local, _, err := sops.GetSockName(t); err == nil {
+ localAddr = *local.(*linux.SockAddrInet)
+ }
+ }
+ writeInetAddr(&buf, localAddr)
+
+ // Field: rem_address.
+ var remoteAddr linux.SockAddrInet
+ if t != nil {
+ if remote, _, err := sops.GetPeerName(t); err == nil {
+ remoteAddr = *remote.(*linux.SockAddrInet)
+ }
+ }
+ writeInetAddr(&buf, remoteAddr)
+
+ // Field: state; socket state.
+ fmt.Fprintf(&buf, "%02X ", sops.State())
+
+ // Field: tx_queue, rx_queue; number of packets in the transmit and
+ // receive queue. Unimplemented.
+ fmt.Fprintf(&buf, "%08X:%08X ", 0, 0)
+
+ // Field: tr, tm->when. Always 0 for UDP.
+ fmt.Fprintf(&buf, "%02X:%08X ", 0, 0)
+
+ // Field: retrnsmt. Always 0 for UDP.
+ fmt.Fprintf(&buf, "%08X ", 0)
+
+ // Field: uid.
+ uattr, err := sfile.Dirent.Inode.UnstableAttr(ctx)
+ if err != nil {
+ log.Warningf("Failed to retrieve unstable attr for socket file: %v", err)
+ fmt.Fprintf(&buf, "%5d ", 0)
+ } else {
+ creds := auth.CredentialsFromContext(ctx)
+ fmt.Fprintf(&buf, "%5d ", uint32(uattr.Owner.UID.In(creds.UserNamespace).OrOverflow()))
+ }
+
+ // Field: timeout. Always 0 for UDP.
+ fmt.Fprintf(&buf, "%8d ", 0)
+
+ // Field: inode.
+ fmt.Fprintf(&buf, "%8d ", sfile.InodeID())
+
+ // Field: ref; reference count on the socket inode. Don't count the ref
+ // we obtain while deferencing the weakref to this socket.
+ fmt.Fprintf(&buf, "%d ", sfile.ReadRefs()-1)
+
+ // Field: Socket struct address. Redacted due to the same reason as
+ // the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData.
+ fmt.Fprintf(&buf, "%#016p ", (*socket.Socket)(nil))
+
+ // Field: drops; number of dropped packets. Unimplemented.
+ fmt.Fprintf(&buf, "%d", 0)
+
+ fmt.Fprintf(&buf, "\n")
+
+ s.DecRef()
+ }
+
+ data := []seqfile.SeqData{
+ {
+ Buf: []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode ref pointer drops \n"),
+ Handle: n,
+ },
+ {
+ Buf: buf.Bytes(),
+ Handle: n,
+ },
+ }
+ return data, 0
+}
diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD
index 20c3eefc8..76433c7d0 100644
--- a/pkg/sentry/fs/proc/seqfile/BUILD
+++ b/pkg/sentry/fs/proc/seqfile/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "seqfile",
diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD
index 516efcc4c..d0f351e5a 100644
--- a/pkg/sentry/fs/ramfs/BUILD
+++ b/pkg/sentry/fs/ramfs/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "ramfs",
diff --git a/pkg/sentry/fs/splice.go b/pkg/sentry/fs/splice.go
index eed1c2854..311798811 100644
--- a/pkg/sentry/fs/splice.go
+++ b/pkg/sentry/fs/splice.go
@@ -18,7 +18,6 @@ import (
"io"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/secio"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -33,146 +32,131 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64,
}
// Check whether or not the objects being sliced are stream-oriented
- // (i.e. pipes or sockets). If yes, we elide checks and offset locks.
- srcPipe := IsPipe(src.Dirent.Inode.StableAttr) || IsSocket(src.Dirent.Inode.StableAttr)
- dstPipe := IsPipe(dst.Dirent.Inode.StableAttr) || IsSocket(dst.Dirent.Inode.StableAttr)
+ // (i.e. pipes or sockets). For all stream-oriented files and files
+ // where a specific offiset is not request, we acquire the file mutex.
+ // This has two important side effects. First, it provides the standard
+ // protection against concurrent writes that would mutate the offset.
+ // Second, it prevents Splice deadlocks. Only internal anonymous files
+ // implement the ReadFrom and WriteTo methods directly, and since such
+ // anonymous files are referred to by a unique fs.File object, we know
+ // that the file mutex takes strict precedence over internal locks.
+ // Since we enforce lock ordering here, we can't deadlock by using
+ // using a file in two different splice operations simultaneously.
+ srcPipe := !IsRegular(src.Dirent.Inode.StableAttr)
+ dstPipe := !IsRegular(dst.Dirent.Inode.StableAttr)
+ dstAppend := !dstPipe && dst.Flags().Append
+ srcLock := srcPipe || !opts.SrcOffset
+ dstLock := dstPipe || !opts.DstOffset || dstAppend
- if !dstPipe && !opts.DstOffset && !srcPipe && !opts.SrcOffset {
+ switch {
+ case srcLock && dstLock:
switch {
case dst.UniqueID < src.UniqueID:
// Acquire dst first.
if !dst.mu.Lock(ctx) {
return 0, syserror.ErrInterrupted
}
- defer dst.mu.Unlock()
if !src.mu.Lock(ctx) {
+ dst.mu.Unlock()
return 0, syserror.ErrInterrupted
}
- defer src.mu.Unlock()
case dst.UniqueID > src.UniqueID:
// Acquire src first.
if !src.mu.Lock(ctx) {
return 0, syserror.ErrInterrupted
}
- defer src.mu.Unlock()
if !dst.mu.Lock(ctx) {
+ src.mu.Unlock()
return 0, syserror.ErrInterrupted
}
- defer dst.mu.Unlock()
case dst.UniqueID == src.UniqueID:
// Acquire only one lock; it's the same file. This is a
// bit of a edge case, but presumably it's possible.
if !dst.mu.Lock(ctx) {
return 0, syserror.ErrInterrupted
}
- defer dst.mu.Unlock()
+ srcLock = false // Only need one unlock.
}
// Use both offsets (locked).
opts.DstStart = dst.offset
opts.SrcStart = src.offset
- } else if !dstPipe && !opts.DstOffset {
+ case dstLock:
// Acquire only dst.
if !dst.mu.Lock(ctx) {
return 0, syserror.ErrInterrupted
}
- defer dst.mu.Unlock()
opts.DstStart = dst.offset // Safe: locked.
- } else if !srcPipe && !opts.SrcOffset {
+ case srcLock:
// Acquire only src.
if !src.mu.Lock(ctx) {
return 0, syserror.ErrInterrupted
}
- defer src.mu.Unlock()
opts.SrcStart = src.offset // Safe: locked.
}
- // Check append-only mode and the limit.
- if !dstPipe {
+ var err error
+ if dstAppend {
unlock := dst.Dirent.Inode.lockAppendMu(dst.Flags().Append)
defer unlock()
- if dst.Flags().Append {
- if opts.DstOffset {
- // We need to acquire the lock.
- if !dst.mu.Lock(ctx) {
- return 0, syserror.ErrInterrupted
- }
- defer dst.mu.Unlock()
- }
- // Figure out the appropriate offset to use.
- if err := dst.offsetForAppend(ctx, &opts.DstStart); err != nil {
- return 0, err
- }
- }
+ // Figure out the appropriate offset to use.
+ err = dst.offsetForAppend(ctx, &opts.DstStart)
+ }
+ if err == nil && !dstPipe {
// Enforce file limits.
limit, ok := dst.checkLimit(ctx, opts.DstStart)
switch {
case ok && limit == 0:
- return 0, syserror.ErrExceedsFileSizeLimit
+ err = syserror.ErrExceedsFileSizeLimit
case ok && limit < opts.Length:
opts.Length = limit // Cap the write.
}
}
+ if err != nil {
+ if dstLock {
+ dst.mu.Unlock()
+ }
+ if srcLock {
+ src.mu.Unlock()
+ }
+ return 0, err
+ }
- // Attempt to do a WriteTo; this is likely the most efficient.
- //
- // The underlying implementation may be able to donate buffers.
- newOpts := SpliceOpts{
- Length: opts.Length,
- SrcStart: opts.SrcStart,
- SrcOffset: !srcPipe,
- Dup: opts.Dup,
- DstStart: opts.DstStart,
- DstOffset: !dstPipe,
+ // Construct readers and writers for the splice. This is used to
+ // provide a safer locking path for the WriteTo/ReadFrom operations
+ // (since they will otherwise go through public interface methods which
+ // conflict with locking done above), and simplifies the fallback path.
+ w := &lockedWriter{
+ Ctx: ctx,
+ File: dst,
+ Offset: opts.DstStart,
}
- n, err := src.FileOperations.WriteTo(ctx, src, dst, newOpts)
- if n == 0 && err != nil {
- // Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also
- // be more efficient than a copy if buffers are cached or readily
- // available. (It's unlikely that they can actually be donate
- n, err = dst.FileOperations.ReadFrom(ctx, dst, src, newOpts)
+ r := &lockedReader{
+ Ctx: ctx,
+ File: src,
+ Offset: opts.SrcStart,
}
- if n == 0 && err != nil {
- // If we've failed up to here, and at least one of the sources
- // is a pipe or socket, then we can't properly support dup.
- // Return an error indicating that this operation is not
- // supported.
- if (srcPipe || dstPipe) && newOpts.Dup {
- return 0, syserror.EINVAL
- }
- // We failed to splice the files. But that's fine; we just fall
- // back to a slow path in this case. This copies without doing
- // any mode changes, so should still be more efficient.
- var (
- r io.Reader
- w io.Writer
- )
- fw := &lockedWriter{
- Ctx: ctx,
- File: dst,
- }
- if newOpts.DstOffset {
- // Use the provided offset.
- w = secio.NewOffsetWriter(fw, newOpts.DstStart)
- } else {
- // Writes will proceed with no offset.
- w = fw
- }
- fr := &lockedReader{
- Ctx: ctx,
- File: src,
- }
- if newOpts.SrcOffset {
- // Limit to the given offset and length.
- r = io.NewSectionReader(fr, opts.SrcStart, opts.Length)
- } else {
- // Limit just to the given length.
- r = &io.LimitedReader{fr, opts.Length}
- }
+ // Attempt to do a WriteTo; this is likely the most efficient.
+ n, err := src.FileOperations.WriteTo(ctx, src, w, opts.Length, opts.Dup)
+ if n == 0 && err == syserror.ENOSYS && !opts.Dup {
+ // Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also be
+ // more efficient than a copy if buffers are cached or readily
+ // available. (It's unlikely that they can actually be donated).
+ n, err = dst.FileOperations.ReadFrom(ctx, dst, r, opts.Length)
+ }
- // Copy between the two.
- n, err = io.Copy(w, r)
+ // Support one last fallback option, but only if at least one of
+ // the source and destination are regular files. This is because
+ // if we block at some point, we could lose data. If the source is
+ // not a pipe then reading is not destructive; if the destination
+ // is a regular file, then it is guaranteed not to block writing.
+ if n == 0 && err == syserror.ENOSYS && !opts.Dup && (!dstPipe || !srcPipe) {
+ // Fallback to an in-kernel copy.
+ n, err = io.Copy(w, &io.LimitedReader{
+ R: r,
+ N: opts.Length,
+ })
}
// Update offsets, if required.
@@ -185,5 +169,13 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64,
}
}
+ // Drop locks.
+ if dstLock {
+ dst.mu.Unlock()
+ }
+ if srcLock {
+ src.mu.Unlock()
+ }
+
return n, err
}
diff --git a/pkg/sentry/fs/timerfd/timerfd.go b/pkg/sentry/fs/timerfd/timerfd.go
index 59403d9db..f8bf663bb 100644
--- a/pkg/sentry/fs/timerfd/timerfd.go
+++ b/pkg/sentry/fs/timerfd/timerfd.go
@@ -141,9 +141,10 @@ func (t *TimerOperations) Write(context.Context, *fs.File, usermem.IOSequence, i
}
// Notify implements ktime.TimerListener.Notify.
-func (t *TimerOperations) Notify(exp uint64) {
+func (t *TimerOperations) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) {
atomic.AddUint64(&t.val, exp)
t.events.Notify(waiter.EventIn)
+ return ktime.Setting{}, false
}
// Destroy implements ktime.TimerListener.Destroy.
diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD
index 8f7eb5757..11b680929 100644
--- a/pkg/sentry/fs/tmpfs/BUILD
+++ b/pkg/sentry/fs/tmpfs/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "tmpfs",
diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD
index 291164986..25811f668 100644
--- a/pkg/sentry/fs/tty/BUILD
+++ b/pkg/sentry/fs/tty/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "tty",
diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD
index a41101339..b0c286b7a 100644
--- a/pkg/sentry/fsimpl/ext/BUILD
+++ b/pkg/sentry/fsimpl/ext/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
go_template_instance(
@@ -79,7 +81,7 @@ go_test(
"//pkg/sentry/usermem",
"//pkg/sentry/vfs",
"//pkg/syserror",
- "//runsc/test/testutil",
+ "//runsc/testutil",
"@com_github_google_go-cmp//cmp:go_default_library",
"@com_github_google_go-cmp//cmp/cmpopts:go_default_library",
],
diff --git a/pkg/sentry/fsimpl/ext/benchmark/BUILD b/pkg/sentry/fsimpl/ext/benchmark/BUILD
index 9fddb4c4c..bfc46dfa6 100644
--- a/pkg/sentry/fsimpl/ext/benchmark/BUILD
+++ b/pkg/sentry/fsimpl/ext/benchmark/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_test")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go
index b51f3e18d..0b471d121 100644
--- a/pkg/sentry/fsimpl/ext/directory.go
+++ b/pkg/sentry/fsimpl/ext/directory.go
@@ -190,10 +190,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
}
if !cb.Handle(vfs.Dirent{
- Name: child.diskDirent.FileName(),
- Type: fs.ToDirentType(childType),
- Ino: uint64(child.diskDirent.Inode()),
- Off: fd.off,
+ Name: child.diskDirent.FileName(),
+ Type: fs.ToDirentType(childType),
+ Ino: uint64(child.diskDirent.Inode()),
+ NextOff: fd.off + 1,
}) {
dir.childList.InsertBefore(child, fd.iter)
return nil
diff --git a/pkg/sentry/fsimpl/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD
index 907d35b7e..2d50e30aa 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/BUILD
+++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "disklayout",
diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go
index 49b57a2d6..1aa2bd6a4 100644
--- a/pkg/sentry/fsimpl/ext/ext_test.go
+++ b/pkg/sentry/fsimpl/ext/ext_test.go
@@ -33,7 +33,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/testutil"
)
const (
@@ -584,7 +584,7 @@ func TestIterDirents(t *testing.T) {
// Ignore the inode number and offset of dirents because those are likely to
// change as the underlying image changes.
cmpIgnoreFields := cmp.FilterPath(func(p cmp.Path) bool {
- return p.String() == "Ino" || p.String() == "Off"
+ return p.String() == "Ino" || p.String() == "NextOff"
}, cmp.Ignore())
if diff := cmp.Diff(cb.dirents, test.want, cmpIgnoreFields); diff != "" {
t.Errorf("dirents mismatch (-want +got):\n%s", diff)
diff --git a/pkg/sentry/fsimpl/memfs/BUILD b/pkg/sentry/fsimpl/memfs/BUILD
index d2450e810..7e364c5fd 100644
--- a/pkg/sentry/fsimpl/memfs/BUILD
+++ b/pkg/sentry/fsimpl/memfs/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/fsimpl/memfs/directory.go b/pkg/sentry/fsimpl/memfs/directory.go
index c52dc781c..c620227c9 100644
--- a/pkg/sentry/fsimpl/memfs/directory.go
+++ b/pkg/sentry/fsimpl/memfs/directory.go
@@ -75,10 +75,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
if fd.off == 0 {
if !cb.Handle(vfs.Dirent{
- Name: ".",
- Type: linux.DT_DIR,
- Ino: vfsd.Impl().(*dentry).inode.ino,
- Off: 0,
+ Name: ".",
+ Type: linux.DT_DIR,
+ Ino: vfsd.Impl().(*dentry).inode.ino,
+ NextOff: 1,
}) {
return nil
}
@@ -87,10 +87,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
if fd.off == 1 {
parentInode := vfsd.ParentOrSelf().Impl().(*dentry).inode
if !cb.Handle(vfs.Dirent{
- Name: "..",
- Type: parentInode.direntType(),
- Ino: parentInode.ino,
- Off: 1,
+ Name: "..",
+ Type: parentInode.direntType(),
+ Ino: parentInode.ino,
+ NextOff: 2,
}) {
return nil
}
@@ -112,10 +112,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
// Skip other directoryFD iterators.
if child.inode != nil {
if !cb.Handle(vfs.Dirent{
- Name: child.vfsd.Name(),
- Type: child.inode.direntType(),
- Ino: child.inode.ino,
- Off: fd.off,
+ Name: child.vfsd.Name(),
+ Type: child.inode.direntType(),
+ Ino: child.inode.ino,
+ NextOff: fd.off + 1,
}) {
dir.childList.InsertBefore(child, fd.iter)
return nil
diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD
index 3d8a4deaf..ade6ac946 100644
--- a/pkg/sentry/fsimpl/proc/BUILD
+++ b/pkg/sentry/fsimpl/proc/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/hostcpu/BUILD b/pkg/sentry/hostcpu/BUILD
index f989f2f8b..359468ccc 100644
--- a/pkg/sentry/hostcpu/BUILD
+++ b/pkg/sentry/hostcpu/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -6,6 +7,7 @@ go_library(
name = "hostcpu",
srcs = [
"getcpu_amd64.s",
+ "getcpu_arm64.s",
"hostcpu.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/hostcpu",
diff --git a/pkg/sentry/hostcpu/getcpu_arm64.s b/pkg/sentry/hostcpu/getcpu_arm64.s
new file mode 100644
index 000000000..caf9abb89
--- /dev/null
+++ b/pkg/sentry/hostcpu/getcpu_arm64.s
@@ -0,0 +1,28 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// GetCPU makes the getcpu(unsigned *cpu, unsigned *node, NULL) syscall for
+// the lack of an optimazed way of getting the current CPU number on arm64.
+
+// func GetCPU() (cpu uint32)
+TEXT ·GetCPU(SB), NOSPLIT, $0-4
+ MOVW ZR, cpu+0(FP)
+ MOVD $cpu+0(FP), R0
+ MOVD $0x0, R1 // unused
+ MOVD $0x0, R2 // unused
+ MOVD $0xA8, R8 // SYS_GETCPU
+ SVC
+ RET
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index 41bee9a22..aba2414d4 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -1,9 +1,11 @@
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "pending_signals_list",
@@ -83,6 +85,12 @@ proto_library(
deps = ["//pkg/sentry/arch:registers_proto"],
)
+cc_proto_library(
+ name = "uncaught_signal_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":uncaught_signal_proto"],
+)
+
go_proto_library(
name = "uncaught_signal_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto",
diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD
index f46c43128..65427b112 100644
--- a/pkg/sentry/kernel/epoll/BUILD
+++ b/pkg/sentry/kernel/epoll/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "epoll_list",
diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD
index 1c5f979d4..983ca67ed 100644
--- a/pkg/sentry/kernel/eventfd/BUILD
+++ b/pkg/sentry/kernel/eventfd/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "eventfd",
diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD
index 6a31dc044..41f44999c 100644
--- a/pkg/sentry/kernel/futex/BUILD
+++ b/pkg/sentry/kernel/futex/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "atomicptr_bucket",
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 8c1f79ab5..3cda03891 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -24,6 +24,7 @@
// TaskSet.mu
// SignalHandlers.mu
// Task.mu
+// runningTasksMu
//
// Locking SignalHandlers.mu in multiple SignalHandlers requires locking
// TaskSet.mu exclusively first. Locking Task.mu in multiple Tasks at the same
@@ -135,6 +136,22 @@ type Kernel struct {
// syslog is the kernel log.
syslog syslog
+ // runningTasksMu synchronizes disable/enable of cpuClockTicker when
+ // the kernel is idle (runningTasks == 0).
+ //
+ // runningTasksMu is used to exclude critical sections when the timer
+ // disables itself and when the first active task enables the timer,
+ // ensuring that tasks always see a valid cpuClock value.
+ runningTasksMu sync.Mutex `state:"nosave"`
+
+ // runningTasks is the total count of tasks currently in
+ // TaskGoroutineRunningSys or TaskGoroutineRunningApp. i.e., they are
+ // not blocked or stopped.
+ //
+ // runningTasks must be accessed atomically. Increments from 0 to 1 are
+ // further protected by runningTasksMu (see incRunningTasks).
+ runningTasks int64
+
// cpuClock is incremented every linux.ClockTick. cpuClock is used to
// measure task CPU usage, since sampling monotonicClock twice on every
// syscall turns out to be unreasonably expensive. This is similar to how
@@ -150,6 +167,22 @@ type Kernel struct {
// cpuClockTicker increments cpuClock.
cpuClockTicker *ktime.Timer `state:"nosave"`
+ // cpuClockTickerDisabled indicates that cpuClockTicker has been
+ // disabled because no tasks are running.
+ //
+ // cpuClockTickerDisabled is protected by runningTasksMu.
+ cpuClockTickerDisabled bool
+
+ // cpuClockTickerSetting is the ktime.Setting of cpuClockTicker at the
+ // point it was disabled. It is cached here to avoid a lock ordering
+ // violation with cpuClockTicker.mu when runningTaskMu is held.
+ //
+ // cpuClockTickerSetting is only valid when cpuClockTickerDisabled is
+ // true.
+ //
+ // cpuClockTickerSetting is protected by runningTasksMu.
+ cpuClockTickerSetting ktime.Setting
+
// fdMapUids is an ever-increasing counter for generating FDTable uids.
//
// fdMapUids is mutable, and is accessed using atomic memory operations.
@@ -912,6 +945,102 @@ func (k *Kernel) resumeTimeLocked() {
}
}
+func (k *Kernel) incRunningTasks() {
+ for {
+ tasks := atomic.LoadInt64(&k.runningTasks)
+ if tasks != 0 {
+ // Standard case. Simply increment.
+ if !atomic.CompareAndSwapInt64(&k.runningTasks, tasks, tasks+1) {
+ continue
+ }
+ return
+ }
+
+ // Transition from 0 -> 1. Synchronize with other transitions and timer.
+ k.runningTasksMu.Lock()
+ tasks = atomic.LoadInt64(&k.runningTasks)
+ if tasks != 0 {
+ // We're no longer the first task, no need to
+ // re-enable.
+ atomic.AddInt64(&k.runningTasks, 1)
+ k.runningTasksMu.Unlock()
+ return
+ }
+
+ if !k.cpuClockTickerDisabled {
+ // Timer was never disabled.
+ atomic.StoreInt64(&k.runningTasks, 1)
+ k.runningTasksMu.Unlock()
+ return
+ }
+
+ // We need to update cpuClock for all of the ticks missed while we
+ // slept, and then re-enable the timer.
+ //
+ // The Notify in Swap isn't sufficient. kernelCPUClockTicker.Notify
+ // always increments cpuClock by 1 regardless of the number of
+ // expirations as a heuristic to avoid over-accounting in cases of CPU
+ // throttling.
+ //
+ // We want to cover the normal case, when all time should be accounted,
+ // so we increment for all expirations. Throttling is less concerning
+ // here because the ticker is only disabled from Notify. This means
+ // that Notify must schedule and compensate for the throttled period
+ // before the timer is disabled. Throttling while the timer is disabled
+ // doesn't matter, as nothing is running or reading cpuClock anyways.
+ //
+ // S/R also adds complication, as there are two cases. Recall that
+ // monotonicClock will jump forward on restore.
+ //
+ // 1. If the ticker is enabled during save, then on Restore Notify is
+ // called with many expirations, covering the time jump, but cpuClock
+ // is only incremented by 1.
+ //
+ // 2. If the ticker is disabled during save, then after Restore the
+ // first wakeup will call this function and cpuClock will be
+ // incremented by the number of expirations across the S/R.
+ //
+ // These cause very different value of cpuClock. But again, since
+ // nothing was running while the ticker was disabled, those differences
+ // don't matter.
+ setting, exp := k.cpuClockTickerSetting.At(k.monotonicClock.Now())
+ if exp > 0 {
+ atomic.AddUint64(&k.cpuClock, exp)
+ }
+
+ // Now that cpuClock is updated it is safe to allow other tasks to
+ // transition to running.
+ atomic.StoreInt64(&k.runningTasks, 1)
+
+ // N.B. we must unlock before calling Swap to maintain lock ordering.
+ //
+ // cpuClockTickerDisabled need not wait until after Swap to become
+ // true. It is sufficient that the timer *will* be enabled.
+ k.cpuClockTickerDisabled = false
+ k.runningTasksMu.Unlock()
+
+ // This won't call Notify (unless it's been ClockTick since setting.At
+ // above). This means we skip the thread group work in Notify. However,
+ // since nothing was running while we were disabled, none of the timers
+ // could have expired.
+ k.cpuClockTicker.Swap(setting)
+
+ return
+ }
+}
+
+func (k *Kernel) decRunningTasks() {
+ tasks := atomic.AddInt64(&k.runningTasks, -1)
+ if tasks < 0 {
+ panic(fmt.Sprintf("Invalid running count %d", tasks))
+ }
+
+ // Nothing to do. The next CPU clock tick will disable the timer if
+ // there is still nothing running. This provides approximately one tick
+ // of slack in which we can switch back and forth between idle and
+ // active without an expensive transition.
+}
+
// WaitExited blocks until all tasks in k have exited.
func (k *Kernel) WaitExited() {
k.tasks.liveGoroutines.Wait()
diff --git a/pkg/sentry/kernel/memevent/BUILD b/pkg/sentry/kernel/memevent/BUILD
index ebcfaa619..d7a7d1169 100644
--- a/pkg/sentry/kernel/memevent/BUILD
+++ b/pkg/sentry/kernel/memevent/BUILD
@@ -1,5 +1,6 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -24,6 +25,12 @@ proto_library(
visibility = ["//visibility:public"],
)
+cc_proto_library(
+ name = "memory_events_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":memory_events_proto"],
+)
+
go_proto_library(
name = "memory_events_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/memevent/memory_events_go_proto",
diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD
index 4d15cca85..2ce8952e2 100644
--- a/pkg/sentry/kernel/pipe/BUILD
+++ b/pkg/sentry/kernel/pipe/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "buffer_list",
diff --git a/pkg/sentry/kernel/pipe/buffer.go b/pkg/sentry/kernel/pipe/buffer.go
index 69ef2a720..95bee2d37 100644
--- a/pkg/sentry/kernel/pipe/buffer.go
+++ b/pkg/sentry/kernel/pipe/buffer.go
@@ -15,6 +15,7 @@
package pipe
import (
+ "io"
"sync"
"gvisor.dev/gvisor/pkg/sentry/safemem"
@@ -67,6 +68,17 @@ func (b *buffer) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
return n, err
}
+// WriteFromReader writes to the buffer from an io.Reader.
+func (b *buffer) WriteFromReader(r io.Reader, count int64) (int64, error) {
+ dst := b.data[b.write:]
+ if count < int64(len(dst)) {
+ dst = b.data[b.write:][:count]
+ }
+ n, err := r.Read(dst)
+ b.write += n
+ return int64(n), err
+}
+
// ReadToBlocks implements safemem.Reader.ReadToBlocks.
func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
src := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(b.data[b.read:b.write]))
@@ -75,6 +87,19 @@ func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
return n, err
}
+// ReadToWriter reads from the buffer into an io.Writer.
+func (b *buffer) ReadToWriter(w io.Writer, count int64, dup bool) (int64, error) {
+ src := b.data[b.read:b.write]
+ if count < int64(len(src)) {
+ src = b.data[b.read:][:count]
+ }
+ n, err := w.Write(src)
+ if !dup {
+ b.read += n
+ }
+ return int64(n), err
+}
+
// bufferPool is a pool for buffers.
var bufferPool = sync.Pool{
New: func() interface{} {
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index 247e2928e..93b50669f 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -23,7 +23,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -173,13 +172,24 @@ func (p *Pipe) Open(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) *fs.F
}
}
+type readOps struct {
+ // left returns the bytes remaining.
+ left func() int64
+
+ // limit limits subsequence reads.
+ limit func(int64)
+
+ // read performs the actual read operation.
+ read func(*buffer) (int64, error)
+}
+
// read reads data from the pipe into dst and returns the number of bytes
// read, or returns ErrWouldBlock if the pipe is empty.
//
// Precondition: this pipe must have readers.
-func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) {
+func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) {
// Don't block for a zero-length read even if the pipe is empty.
- if dst.NumBytes() == 0 {
+ if ops.left() == 0 {
return 0, nil
}
@@ -196,12 +206,12 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error)
}
// Limit how much we consume.
- if dst.NumBytes() > p.size {
- dst = dst.TakeFirst64(p.size)
+ if ops.left() > p.size {
+ ops.limit(p.size)
}
done := int64(0)
- for dst.NumBytes() > 0 {
+ for ops.left() > 0 {
// Pop the first buffer.
first := p.data.Front()
if first == nil {
@@ -209,10 +219,9 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error)
}
// Copy user data.
- n, err := dst.CopyOutFrom(ctx, first)
+ n, err := ops.read(first)
done += int64(n)
p.size -= n
- dst = dst.DropFirst64(n)
// Empty buffer?
if first.Empty() {
@@ -230,12 +239,57 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error)
return done, nil
}
+// dup duplicates all data from this pipe into the given writer.
+//
+// There is no blocking behavior implemented here. The writer may propagate
+// some blocking error. All the writes must be complete writes.
+func (p *Pipe) dup(ctx context.Context, ops readOps) (int64, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ // Is the pipe empty?
+ if p.size == 0 {
+ if !p.HasWriters() {
+ // See above.
+ return 0, nil
+ }
+ return 0, syserror.ErrWouldBlock
+ }
+
+ // Limit how much we consume.
+ if ops.left() > p.size {
+ ops.limit(p.size)
+ }
+
+ done := int64(0)
+ for buf := p.data.Front(); buf != nil; buf = buf.Next() {
+ n, err := ops.read(buf)
+ done += n
+ if err != nil {
+ return done, err
+ }
+ }
+
+ return done, nil
+}
+
+type writeOps struct {
+ // left returns the bytes remaining.
+ left func() int64
+
+ // limit should limit subsequent writes.
+ limit func(int64)
+
+ // write should write to the provided buffer.
+ write func(*buffer) (int64, error)
+}
+
// write writes data from sv into the pipe and returns the number of bytes
// written. If no bytes are written because the pipe is full (or has less than
// atomicIOBytes free capacity), write returns ErrWouldBlock.
//
// Precondition: this pipe must have writers.
-func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error) {
+func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) {
p.mu.Lock()
defer p.mu.Unlock()
@@ -246,17 +300,16 @@ func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error)
// POSIX requires that a write smaller than atomicIOBytes (PIPE_BUF) be
// atomic, but requires no atomicity for writes larger than this.
- wanted := src.NumBytes()
+ wanted := ops.left()
if avail := p.max - p.size; wanted > avail {
if wanted <= p.atomicIOBytes {
return 0, syserror.ErrWouldBlock
}
- // Limit to the available capacity.
- src = src.TakeFirst64(avail)
+ ops.limit(avail)
}
done := int64(0)
- for src.NumBytes() > 0 {
+ for ops.left() > 0 {
// Need a new buffer?
last := p.data.Back()
if last == nil || last.Full() {
@@ -266,10 +319,9 @@ func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error)
}
// Copy user data.
- n, err := src.CopyInTo(ctx, last)
+ n, err := ops.write(last)
done += int64(n)
p.size += n
- src = src.DropFirst64(n)
// Handle errors.
if err != nil {
diff --git a/pkg/sentry/kernel/pipe/reader_writer.go b/pkg/sentry/kernel/pipe/reader_writer.go
index f69dbf27b..7c307f013 100644
--- a/pkg/sentry/kernel/pipe/reader_writer.go
+++ b/pkg/sentry/kernel/pipe/reader_writer.go
@@ -15,6 +15,7 @@
package pipe
import (
+ "io"
"math"
"syscall"
@@ -55,7 +56,45 @@ func (rw *ReaderWriter) Release() {
// Read implements fs.FileOperations.Read.
func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
- n, err := rw.Pipe.read(ctx, dst)
+ n, err := rw.Pipe.read(ctx, readOps{
+ left: func() int64 {
+ return dst.NumBytes()
+ },
+ limit: func(l int64) {
+ dst = dst.TakeFirst64(l)
+ },
+ read: func(buf *buffer) (int64, error) {
+ n, err := dst.CopyOutFrom(ctx, buf)
+ dst = dst.DropFirst64(n)
+ return n, err
+ },
+ })
+ if n > 0 {
+ rw.Pipe.Notify(waiter.EventOut)
+ }
+ return n, err
+}
+
+// WriteTo implements fs.FileOperations.WriteTo.
+func (rw *ReaderWriter) WriteTo(ctx context.Context, _ *fs.File, w io.Writer, count int64, dup bool) (int64, error) {
+ ops := readOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ read: func(buf *buffer) (int64, error) {
+ n, err := buf.ReadToWriter(w, count, dup)
+ count -= n
+ return n, err
+ },
+ }
+ if dup {
+ // There is no notification for dup operations.
+ return rw.Pipe.dup(ctx, ops)
+ }
+ n, err := rw.Pipe.read(ctx, ops)
if n > 0 {
rw.Pipe.Notify(waiter.EventOut)
}
@@ -64,7 +103,40 @@ func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequ
// Write implements fs.FileOperations.Write.
func (rw *ReaderWriter) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
- n, err := rw.Pipe.write(ctx, src)
+ n, err := rw.Pipe.write(ctx, writeOps{
+ left: func() int64 {
+ return src.NumBytes()
+ },
+ limit: func(l int64) {
+ src = src.TakeFirst64(l)
+ },
+ write: func(buf *buffer) (int64, error) {
+ n, err := src.CopyInTo(ctx, buf)
+ src = src.DropFirst64(n)
+ return n, err
+ },
+ })
+ if n > 0 {
+ rw.Pipe.Notify(waiter.EventIn)
+ }
+ return n, err
+}
+
+// ReadFrom implements fs.FileOperations.WriteTo.
+func (rw *ReaderWriter) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
+ n, err := rw.Pipe.write(ctx, writeOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ write: func(buf *buffer) (int64, error) {
+ n, err := buf.WriteFromReader(r, count)
+ count -= n
+ return n, err
+ },
+ })
if n > 0 {
rw.Pipe.Notify(waiter.EventIn)
}
diff --git a/pkg/sentry/kernel/posixtimer.go b/pkg/sentry/kernel/posixtimer.go
index c5d095af7..2e861a5a8 100644
--- a/pkg/sentry/kernel/posixtimer.go
+++ b/pkg/sentry/kernel/posixtimer.go
@@ -117,9 +117,9 @@ func (it *IntervalTimer) signalRejectedLocked() {
}
// Notify implements ktime.TimerListener.Notify.
-func (it *IntervalTimer) Notify(exp uint64) {
+func (it *IntervalTimer) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) {
if it.target == nil {
- return
+ return ktime.Setting{}, false
}
it.target.tg.pidns.owner.mu.RLock()
@@ -129,7 +129,7 @@ func (it *IntervalTimer) Notify(exp uint64) {
if it.sigpending {
it.overrunCur += exp
- return
+ return ktime.Setting{}, false
}
// sigpending must be set before sendSignalTimerLocked() so that it can be
@@ -148,6 +148,8 @@ func (it *IntervalTimer) Notify(exp uint64) {
if err := it.target.sendSignalTimerLocked(si, it.group, it); err != nil {
it.signalRejectedLocked()
}
+
+ return ktime.Setting{}, false
}
// Destroy implements ktime.TimerListener.Destroy. Users of Timer should call
diff --git a/pkg/sentry/kernel/sched/BUILD b/pkg/sentry/kernel/sched/BUILD
index 1725b8562..98ea7a0d8 100644
--- a/pkg/sentry/kernel/sched/BUILD
+++ b/pkg/sentry/kernel/sched/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD
index 36edf10f3..80e5e5da3 100644
--- a/pkg/sentry/kernel/semaphore/BUILD
+++ b/pkg/sentry/kernel/semaphore/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "waiter_list",
diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go
index e5f297478..047b5214d 100644
--- a/pkg/sentry/kernel/sessions.go
+++ b/pkg/sentry/kernel/sessions.go
@@ -328,8 +328,14 @@ func (tg *ThreadGroup) createSession() error {
childTG.processGroup.incRefWithParent(pg)
childTG.processGroup.decRefWithParent(oldParentPG)
})
- tg.processGroup.decRefWithParent(oldParentPG)
+ // If tg.processGroup is an orphan, decRefWithParent will lock
+ // the signal mutex of each thread group in tg.processGroup.
+ // However, tg's signal mutex may already be locked at this
+ // point. We change tg's process group before calling
+ // decRefWithParent to avoid locking tg's signal mutex twice.
+ oldPG := tg.processGroup
tg.processGroup = pg
+ oldPG.decRefWithParent(oldParentPG)
} else {
// The current process group may be nil only in the case of an
// unparented thread group (i.e. the init process). This would
diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD
new file mode 100644
index 000000000..50b69d154
--- /dev/null
+++ b/pkg/sentry/kernel/signalfd/BUILD
@@ -0,0 +1,22 @@
+package(licenses = ["notice"])
+
+load("//tools/go_stateify:defs.bzl", "go_library")
+
+go_library(
+ name = "signalfd",
+ srcs = ["signalfd.go"],
+ importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/signalfd",
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/sentry/context",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/anon",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/usermem",
+ "//pkg/syserror",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go
new file mode 100644
index 000000000..06fd5ec88
--- /dev/null
+++ b/pkg/sentry/kernel/signalfd/signalfd.go
@@ -0,0 +1,137 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package signalfd provides an implementation of signal file descriptors.
+package signalfd
+
+import (
+ "sync"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/anon"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// SignalOperations represent a file with signalfd semantics.
+//
+// +stateify savable
+type SignalOperations struct {
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoWrite `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ // target is the original task target.
+ //
+ // The semantics here are a bit broken. Linux will always use current
+ // for all reads, regardless of where the signalfd originated. We can't
+ // do exactly that because we need to plumb the context through
+ // EventRegister in order to support proper blocking behavior. This
+ // will undoubtedly become very complicated quickly.
+ target *kernel.Task
+
+ // mu protects below.
+ mu sync.Mutex `state:"nosave"`
+
+ // mask is the signal mask. Protected by mu.
+ mask linux.SignalSet
+}
+
+// New creates a new signalfd object with the supplied mask.
+func New(ctx context.Context, mask linux.SignalSet) (*fs.File, error) {
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ // No task context? Not valid.
+ return nil, syserror.EINVAL
+ }
+ // name matches fs/signalfd.c:signalfd4.
+ dirent := fs.NewDirent(ctx, anon.NewInode(ctx), "anon_inode:[signalfd]")
+ return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &SignalOperations{
+ target: t,
+ mask: mask,
+ }), nil
+}
+
+// Release implements fs.FileOperations.Release.
+func (s *SignalOperations) Release() {}
+
+// Mask returns the signal mask.
+func (s *SignalOperations) Mask() linux.SignalSet {
+ s.mu.Lock()
+ mask := s.mask
+ s.mu.Unlock()
+ return mask
+}
+
+// SetMask sets the signal mask.
+func (s *SignalOperations) SetMask(mask linux.SignalSet) {
+ s.mu.Lock()
+ s.mask = mask
+ s.mu.Unlock()
+}
+
+// Read implements fs.FileOperations.Read.
+func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ // Attempt to dequeue relevant signals.
+ info, err := s.target.Sigtimedwait(s.Mask(), 0)
+ if err != nil {
+ // There must be no signal available.
+ return 0, syserror.ErrWouldBlock
+ }
+
+ // Copy out the signal info using the specified format.
+ var buf [128]byte
+ binary.Marshal(buf[:0], usermem.ByteOrder, &linux.SignalfdSiginfo{
+ Signo: uint32(info.Signo),
+ Errno: info.Errno,
+ Code: info.Code,
+ PID: uint32(info.Pid()),
+ UID: uint32(info.Uid()),
+ Status: info.Status(),
+ Overrun: uint32(info.Overrun()),
+ Addr: info.Addr(),
+ })
+ n, err := dst.CopyOut(ctx, buf[:])
+ return int64(n), err
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *SignalOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return mask & waiter.EventIn
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *SignalOperations) EventRegister(entry *waiter.Entry, _ waiter.EventMask) {
+ // Register for the signal set; ignore the passed events.
+ s.target.SignalRegister(entry, waiter.EventMask(s.Mask()))
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *SignalOperations) EventUnregister(entry *waiter.Entry) {
+ // Unregister the original entry.
+ s.target.SignalUnregister(entry)
+}
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index e91f82bb3..c82ef5486 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -35,6 +35,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
"gvisor.dev/gvisor/third_party/gvsync"
)
@@ -133,6 +134,13 @@ type Task struct {
// signalStack is exclusive to the task goroutine.
signalStack arch.SignalStack
+ // signalQueue is a set of registered waiters for signal-related events.
+ //
+ // signalQueue is protected by the signalMutex. Note that the task does
+ // not implement all queue methods, specifically the readiness checks.
+ // The task only broadcast a notification on signal delivery.
+ signalQueue waiter.Queue `state:"zerovalue"`
+
// If groupStopPending is true, the task should participate in a group
// stop in the interrupt path.
//
diff --git a/pkg/sentry/kernel/task_identity.go b/pkg/sentry/kernel/task_identity.go
index 78ff14b20..ce3e6ef28 100644
--- a/pkg/sentry/kernel/task_identity.go
+++ b/pkg/sentry/kernel/task_identity.go
@@ -465,8 +465,8 @@ func (t *Task) SetKeepCaps(k bool) {
// disables the features we don't support anyway, is always set. This
// drastically simplifies this function.
//
-// - We don't implement AT_SECURE, because no_new_privs always being set means
-// that the conditions that require AT_SECURE never arise. (Compare Linux's
+// - We don't set AT_SECURE = 1, because no_new_privs always being set means
+// that the conditions that require AT_SECURE = 1 never arise. (Compare Linux's
// security/commoncap.c:cap_bprm_set_creds() and cap_bprm_secureexec().)
//
// - We don't check for CAP_SYS_ADMIN in prctl(PR_SET_SECCOMP), since
diff --git a/pkg/sentry/kernel/task_sched.go b/pkg/sentry/kernel/task_sched.go
index e76c069b0..8b148db35 100644
--- a/pkg/sentry/kernel/task_sched.go
+++ b/pkg/sentry/kernel/task_sched.go
@@ -126,12 +126,22 @@ func (t *Task) accountTaskGoroutineEnter(state TaskGoroutineState) {
t.gosched.Timestamp = now
t.gosched.State = state
t.goschedSeq.EndWrite()
+
+ if state != TaskGoroutineRunningApp {
+ // Task is blocking/stopping.
+ t.k.decRunningTasks()
+ }
}
// Preconditions: The caller must be running on the task goroutine, and leaving
// a state indicated by a previous call to
// t.accountTaskGoroutineEnter(state).
func (t *Task) accountTaskGoroutineLeave(state TaskGoroutineState) {
+ if state != TaskGoroutineRunningApp {
+ // Task is unblocking/continuing.
+ t.k.incRunningTasks()
+ }
+
now := t.k.CPUClockNow()
if t.gosched.State != state {
panic(fmt.Sprintf("Task goroutine switching from state %v (expected %v) to %v", t.gosched.State, state, TaskGoroutineRunningSys))
@@ -330,7 +340,7 @@ func newKernelCPUClockTicker(k *Kernel) *kernelCPUClockTicker {
}
// Notify implements ktime.TimerListener.Notify.
-func (ticker *kernelCPUClockTicker) Notify(exp uint64) {
+func (ticker *kernelCPUClockTicker) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) {
// Only increment cpuClock by 1 regardless of the number of expirations.
// This approximately compensates for cases where thread throttling or bad
// Go runtime scheduling prevents the kernelCPUClockTicker goroutine, and
@@ -426,6 +436,27 @@ func (ticker *kernelCPUClockTicker) Notify(exp uint64) {
tgs[i] = nil
}
ticker.tgs = tgs[:0]
+
+ // If nothing is running, we can disable the timer.
+ tasks := atomic.LoadInt64(&ticker.k.runningTasks)
+ if tasks == 0 {
+ ticker.k.runningTasksMu.Lock()
+ defer ticker.k.runningTasksMu.Unlock()
+ tasks := atomic.LoadInt64(&ticker.k.runningTasks)
+ if tasks != 0 {
+ // Raced with a 0 -> 1 transition.
+ return setting, false
+ }
+
+ // Stop the timer. We must cache the current setting so the
+ // kernel can access it without violating the lock order.
+ ticker.k.cpuClockTickerSetting = setting
+ ticker.k.cpuClockTickerDisabled = true
+ setting.Enabled = false
+ return setting, true
+ }
+
+ return setting, false
}
// Destroy implements ktime.TimerListener.Destroy.
diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go
index 266959a07..39cd1340d 100644
--- a/pkg/sentry/kernel/task_signals.go
+++ b/pkg/sentry/kernel/task_signals.go
@@ -28,6 +28,7 @@ import (
ucspb "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/waiter"
)
// SignalAction is an internal signal action.
@@ -497,6 +498,9 @@ func (tg *ThreadGroup) applySignalSideEffectsLocked(sig linux.Signal) {
//
// Preconditions: The signal mutex must be locked.
func (t *Task) canReceiveSignalLocked(sig linux.Signal) bool {
+ // Notify that the signal is queued.
+ t.signalQueue.Notify(waiter.EventMask(linux.MakeSignalSet(sig)))
+
// - Do not choose tasks that are blocking the signal.
if linux.SignalSetOf(sig)&t.signalMask != 0 {
return false
@@ -1108,3 +1112,17 @@ func (*runInterruptAfterSignalDeliveryStop) execute(t *Task) taskRunState {
t.tg.signalHandlers.mu.Unlock()
return t.deliverSignal(info, act)
}
+
+// SignalRegister registers a waiter for pending signals.
+func (t *Task) SignalRegister(e *waiter.Entry, mask waiter.EventMask) {
+ t.tg.signalHandlers.mu.Lock()
+ t.signalQueue.EventRegister(e, mask)
+ t.tg.signalHandlers.mu.Unlock()
+}
+
+// SignalUnregister unregisters a waiter for pending signals.
+func (t *Task) SignalUnregister(e *waiter.Entry) {
+ t.tg.signalHandlers.mu.Lock()
+ t.signalQueue.EventUnregister(e)
+ t.tg.signalHandlers.mu.Unlock()
+}
diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go
index 0eef24bfb..72568d296 100644
--- a/pkg/sentry/kernel/thread_group.go
+++ b/pkg/sentry/kernel/thread_group.go
@@ -511,8 +511,9 @@ type itimerRealListener struct {
}
// Notify implements ktime.TimerListener.Notify.
-func (l *itimerRealListener) Notify(exp uint64) {
+func (l *itimerRealListener) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) {
l.tg.SendSignal(SignalInfoPriv(linux.SIGALRM))
+ return ktime.Setting{}, false
}
// Destroy implements ktime.TimerListener.Destroy.
diff --git a/pkg/sentry/kernel/time/time.go b/pkg/sentry/kernel/time/time.go
index aa6c75d25..107394183 100644
--- a/pkg/sentry/kernel/time/time.go
+++ b/pkg/sentry/kernel/time/time.go
@@ -280,13 +280,16 @@ func (ClockEventsQueue) Readiness(mask waiter.EventMask) waiter.EventMask {
// A TimerListener receives expirations from a Timer.
type TimerListener interface {
// Notify is called when its associated Timer expires. exp is the number of
- // expirations.
+ // expirations. setting is the next timer Setting.
//
// Notify is called with the associated Timer's mutex locked, so Notify
// must not take any locks that precede Timer.mu in lock order.
//
+ // If Notify returns true, the timer will use the returned setting
+ // rather than the passed one.
+ //
// Preconditions: exp > 0.
- Notify(exp uint64)
+ Notify(exp uint64, setting Setting) (newSetting Setting, update bool)
// Destroy is called when the timer is destroyed.
Destroy()
@@ -533,7 +536,9 @@ func (t *Timer) Tick() {
s, exp := t.setting.At(now)
t.setting = s
if exp > 0 {
- t.listener.Notify(exp)
+ if newS, ok := t.listener.Notify(exp, t.setting); ok {
+ t.setting = newS
+ }
}
t.resetKickerLocked(now)
}
@@ -588,7 +593,9 @@ func (t *Timer) Get() (Time, Setting) {
s, exp := t.setting.At(now)
t.setting = s
if exp > 0 {
- t.listener.Notify(exp)
+ if newS, ok := t.listener.Notify(exp, t.setting); ok {
+ t.setting = newS
+ }
}
t.resetKickerLocked(now)
return now, s
@@ -620,7 +627,9 @@ func (t *Timer) SwapAnd(s Setting, f func()) (Time, Setting) {
}
oldS, oldExp := t.setting.At(now)
if oldExp > 0 {
- t.listener.Notify(oldExp)
+ t.listener.Notify(oldExp, oldS)
+ // N.B. The returned Setting doesn't matter because we're about
+ // to overwrite.
}
if f != nil {
f()
@@ -628,7 +637,9 @@ func (t *Timer) SwapAnd(s Setting, f func()) (Time, Setting) {
newS, newExp := s.At(now)
t.setting = newS
if newExp > 0 {
- t.listener.Notify(newExp)
+ if newS, ok := t.listener.Notify(newExp, t.setting); ok {
+ t.setting = newS
+ }
}
t.resetKickerLocked(now)
return now, oldS
@@ -683,11 +694,13 @@ func NewChannelNotifier() (TimerListener, <-chan struct{}) {
}
// Notify implements ktime.TimerListener.Notify.
-func (c *ChannelNotifier) Notify(uint64) {
+func (c *ChannelNotifier) Notify(uint64, Setting) (Setting, bool) {
select {
case c.tchan <- struct{}{}:
default:
}
+
+ return Setting{}, false
}
// Destroy implements ktime.TimerListener.Destroy and will close the channel.
diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD
index 40025d62d..59649c770 100644
--- a/pkg/sentry/limits/BUILD
+++ b/pkg/sentry/limits/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "limits",
diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go
index bc5b841fb..2d9251e92 100644
--- a/pkg/sentry/loader/elf.go
+++ b/pkg/sentry/loader/elf.go
@@ -323,18 +323,22 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f *fs.File, phdr *elf.
return syserror.ENOEXEC
}
+ // N.B. Linux uses vm_brk_flags to map these pages, which only
+ // honors the X bit, always mapping at least RW. ignoring These
+ // pages are not included in the final brk region.
+ prot := usermem.ReadWrite
+ if phdr.Flags&elf.PF_X == elf.PF_X {
+ prot.Execute = true
+ }
+
if _, err := m.MMap(ctx, memmap.MMapOpts{
Length: uint64(anonSize),
Addr: anonAddr,
// Fixed without Unmap will fail the mmap if something is
// already at addr.
- Fixed: true,
- Private: true,
- // N.B. Linux uses vm_brk to map these pages, ignoring
- // the segment protections, instead always mapping RW.
- // These pages are not included in the final brk
- // region.
- Perms: usermem.ReadWrite,
+ Fixed: true,
+ Private: true,
+ Perms: prot,
MaxPerms: usermem.AnyAccess,
}); err != nil {
ctx.Infof("Error mapping PT_LOAD segment %v anonymous memory: %v", phdr, err)
@@ -464,7 +468,7 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, info el
// base address big enough to fit all segments, so we first create a
// mapping for the total size just to find a region that is big enough.
//
- // It is safe to unmap it immediately with racing with another mapping
+ // It is safe to unmap it immediately without racing with another mapping
// because we are the only one in control of the MemoryManager.
//
// Note that the vaddr of the first PT_LOAD segment is ignored when
diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go
index f6f1ae762..089d1635b 100644
--- a/pkg/sentry/loader/loader.go
+++ b/pkg/sentry/loader/loader.go
@@ -308,6 +308,9 @@ func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, r
arch.AuxEntry{linux.AT_EUID, usermem.Addr(c.EffectiveKUID.In(c.UserNamespace).OrOverflow())},
arch.AuxEntry{linux.AT_GID, usermem.Addr(c.RealKGID.In(c.UserNamespace).OrOverflow())},
arch.AuxEntry{linux.AT_EGID, usermem.Addr(c.EffectiveKGID.In(c.UserNamespace).OrOverflow())},
+ // The conditions that require AT_SECURE = 1 never arise. See
+ // kernel.Task.updateCredsForExecLocked.
+ arch.AuxEntry{linux.AT_SECURE, 0},
arch.AuxEntry{linux.AT_CLKTCK, linux.CLOCKS_PER_SEC},
arch.AuxEntry{linux.AT_EXECFN, execfn},
arch.AuxEntry{linux.AT_RANDOM, random},
diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD
index 29c14ec56..9687e7e76 100644
--- a/pkg/sentry/memmap/BUILD
+++ b/pkg/sentry/memmap/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "mappable_range",
diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD
index 072745a08..b35c8c673 100644
--- a/pkg/sentry/mm/BUILD
+++ b/pkg/sentry/mm/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "file_refcount_set",
diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD
index 858f895f2..3fd904c67 100644
--- a/pkg/sentry/pgalloc/BUILD
+++ b/pkg/sentry/pgalloc/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "evictable_range",
diff --git a/pkg/sentry/platform/interrupt/BUILD b/pkg/sentry/platform/interrupt/BUILD
index eeb634644..b6d008dbe 100644
--- a/pkg/sentry/platform/interrupt/BUILD
+++ b/pkg/sentry/platform/interrupt/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index ad8b95744..31fa48ec5 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -54,6 +55,7 @@ go_test(
],
embed = [":kvm"],
tags = [
+ "manual",
"nogotsan",
"requires-kvm",
],
diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
index 47957bb3b..72c7ec564 100644
--- a/pkg/sentry/platform/ptrace/ptrace_unsafe.go
+++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
@@ -154,3 +154,19 @@ func (t *thread) clone() (*thread, error) {
cpu: ^uint32(0),
}, nil
}
+
+// getEventMessage retrieves a message about the ptrace event that just happened.
+func (t *thread) getEventMessage() (uintptr, error) {
+ var msg uintptr
+ _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_GETEVENTMSG,
+ uintptr(t.tid),
+ 0,
+ uintptr(unsafe.Pointer(&msg)),
+ 0, 0)
+ if errno != 0 {
+ return msg, errno
+ }
+ return msg, nil
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index 6bf7cd097..9f0ecfbe4 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -267,7 +267,7 @@ func (s *subprocess) newThread() *thread {
// attach attaches to the thread.
func (t *thread) attach() {
- if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_ATTACH, uintptr(t.tid), 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_ATTACH, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("unable to attach: %v", errno))
}
@@ -355,7 +355,8 @@ func (t *thread) wait(outcome waitOutcome) syscall.Signal {
}
if stopSig == syscall.SIGTRAP {
if status.TrapCause() == syscall.PTRACE_EVENT_EXIT {
- t.dumpAndPanic("wait failed: the process exited")
+ msg, err := t.getEventMessage()
+ t.dumpAndPanic(fmt.Sprintf("wait failed: the process %d:%d exited: %x (err %v)", t.tgid, t.tid, msg, err))
}
// Re-encode the trap cause the way it's expected.
return stopSig | syscall.Signal(status.TrapCause()<<8)
@@ -416,7 +417,7 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) {
for {
// Execute the syscall instruction.
- if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
}
@@ -426,12 +427,15 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) {
break
} else {
// Some other signal caused a thread stop; ignore.
+ if sig != syscall.SIGSTOP && sig != syscall.SIGCHLD {
+ log.Warningf("The thread %d:%d has been interrupted by %d", t.tgid, t.tid, sig)
+ }
continue
}
}
// Complete the actual system call.
- if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
}
@@ -522,17 +526,17 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
for {
// Start running until the next system call.
if isSingleStepping(regs) {
- if _, _, errno := syscall.RawSyscall(
+ if _, _, errno := syscall.RawSyscall6(
syscall.SYS_PTRACE,
syscall.PTRACE_SYSEMU_SINGLESTEP,
- uintptr(t.tid), 0); errno != 0 {
+ uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace sysemu failed: %v", errno))
}
} else {
- if _, _, errno := syscall.RawSyscall(
+ if _, _, errno := syscall.RawSyscall6(
syscall.SYS_PTRACE,
syscall.PTRACE_SYSEMU,
- uintptr(t.tid), 0); errno != 0 {
+ uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace sysemu failed: %v", errno))
}
}
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go
index f09b0b3d0..c075b5f91 100644
--- a/pkg/sentry/platform/ptrace/subprocess_linux.go
+++ b/pkg/sentry/platform/ptrace/subprocess_linux.go
@@ -53,7 +53,7 @@ func probeSeccomp() bool {
for {
// Attempt an emulation.
- if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_SYSEMU, uintptr(t.tid), 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
}
@@ -266,7 +266,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro
// Enable cpuid-faulting; this may fail on older kernels or hardware,
// so we just disregard the result. Host CPUID will be enabled.
- syscall.RawSyscall(syscall.SYS_ARCH_PRCTL, linux.ARCH_SET_CPUID, 0, 0)
+ syscall.RawSyscall6(syscall.SYS_ARCH_PRCTL, linux.ARCH_SET_CPUID, 0, 0, 0, 0, 0)
// Call the stub; should not return.
stubCall(stubStart, ppid)
diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD
index 3b95af617..ea090b686 100644
--- a/pkg/sentry/platform/ring0/pagetables/BUILD
+++ b/pkg/sentry/platform/ring0/pagetables/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/platform/safecopy/BUILD b/pkg/sentry/platform/safecopy/BUILD
index 924d8a6d6..6769cd0a5 100644
--- a/pkg/sentry/platform/safecopy/BUILD
+++ b/pkg/sentry/platform/safecopy/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/safemem/BUILD b/pkg/sentry/safemem/BUILD
index fd6dc8e6e..884020f7b 100644
--- a/pkg/sentry/safemem/BUILD
+++ b/pkg/sentry/safemem/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sentry/sighandling/sighandling_unsafe.go
index eace3766d..c303435d5 100644
--- a/pkg/sentry/sighandling/sighandling_unsafe.go
+++ b/pkg/sentry/sighandling/sighandling_unsafe.go
@@ -23,7 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
)
-// TODO(b/34161764): Move to pkg/abi/linux along with definitions in
+// FIXME(gvisor.dev/issue/214): Move to pkg/abi/linux along with definitions in
// pkg/sentry/arch.
type sigaction struct {
handler uintptr
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index 635042263..5812085fa 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -26,13 +26,16 @@ package epsocket
import (
"bytes"
+ "io"
"math"
+ "reflect"
"sync"
"syscall"
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
@@ -52,6 +55,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -205,6 +209,10 @@ type commonEndpoint interface {
// transport.Endpoint.SetSockOpt.
SetSockOpt(interface{}) *tcpip.Error
+ // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and
+ // transport.Endpoint.SetSockOptInt.
+ SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error
+
// GetSockOpt implements tcpip.Endpoint.GetSockOpt and
// transport.Endpoint.GetSockOpt.
GetSockOpt(interface{}) *tcpip.Error
@@ -224,7 +232,6 @@ type SocketOperations struct {
fsutil.FileNoopFlush `state:"nosave"`
fsutil.FileNoFsync `state:"nosave"`
fsutil.FileNoMMap `state:"nosave"`
- fsutil.FileNoSplice `state:"nosave"`
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
socket.SendReceiveTimeout
*waiter.Queue
@@ -409,17 +416,60 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
return int64(n), nil
}
-// ioSequencePayload implements tcpip.Payload. It copies user memory bytes on demand
-// based on the requested size.
+// WriteTo implements fs.FileOperations.WriteTo.
+func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) {
+ s.readMu.Lock()
+
+ // Copy as much data as possible.
+ done := int64(0)
+ for count > 0 {
+ // This may return a blocking error.
+ if err := s.fetchReadView(); err != nil {
+ s.readMu.Unlock()
+ return done, err.ToError()
+ }
+
+ // Write to the underlying file.
+ n, err := dst.Write(s.readView)
+ done += int64(n)
+ count -= int64(n)
+ if dup {
+ // That's all we support for dup. This is generally
+ // supported by any Linux system calls, but the
+ // expectation is that now a caller will call read to
+ // actually remove these bytes from the socket.
+ break
+ }
+
+ // Drop that part of the view.
+ s.readView.TrimFront(n)
+ if err != nil {
+ s.readMu.Unlock()
+ return done, err
+ }
+ }
+
+ s.readMu.Unlock()
+ return done, nil
+}
+
+// ioSequencePayload implements tcpip.Payload.
+//
+// t copies user memory bytes on demand based on the requested size.
type ioSequencePayload struct {
ctx context.Context
src usermem.IOSequence
}
-// Get implements tcpip.Payload.
-func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) {
- if size > i.Size() {
- size = i.Size()
+// FullPayload implements tcpip.Payloader.FullPayload
+func (i *ioSequencePayload) FullPayload() ([]byte, *tcpip.Error) {
+ return i.Payload(int(i.src.NumBytes()))
+}
+
+// Payload implements tcpip.Payloader.Payload.
+func (i *ioSequencePayload) Payload(size int) ([]byte, *tcpip.Error) {
+ if max := int(i.src.NumBytes()); size > max {
+ size = max
}
v := buffer.NewView(size)
if _, err := i.src.CopyIn(i.ctx, v); err != nil {
@@ -428,11 +478,6 @@ func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) {
return v, nil
}
-// Size implements tcpip.Payload.
-func (i *ioSequencePayload) Size() int {
- return int(i.src.NumBytes())
-}
-
// DropFirst drops the first n bytes from underlying src.
func (i *ioSequencePayload) DropFirst(n int) {
i.src = i.src.DropFirst(int(n))
@@ -466,6 +511,78 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
return int64(n), nil
}
+// readerPayload implements tcpip.Payloader.
+//
+// It allocates a view and reads from a reader on-demand, based on available
+// capacity in the endpoint.
+type readerPayload struct {
+ ctx context.Context
+ r io.Reader
+ count int64
+ err error
+}
+
+// FullPayload implements tcpip.Payloader.FullPayload.
+func (r *readerPayload) FullPayload() ([]byte, *tcpip.Error) {
+ return r.Payload(int(r.count))
+}
+
+// Payload implements tcpip.Payloader.Payload.
+func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) {
+ if size > int(r.count) {
+ size = int(r.count)
+ }
+ v := buffer.NewView(size)
+ n, err := r.r.Read(v)
+ if n > 0 {
+ // We ignore the error here. It may re-occur on subsequent
+ // reads, but for now we can enqueue some amount of data.
+ r.count -= int64(n)
+ return v[:n], nil
+ }
+ if err == syserror.ErrWouldBlock {
+ return nil, tcpip.ErrWouldBlock
+ } else if err != nil {
+ r.err = err // Save for propation.
+ return nil, tcpip.ErrBadAddress
+ }
+
+ // There is no data and no error. Return an error, which will propagate
+ // r.err, which will be nil. This is the desired result: (0, nil).
+ return nil, tcpip.ErrBadAddress
+}
+
+// ReadFrom implements fs.FileOperations.ReadFrom.
+func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
+ f := &readerPayload{ctx: ctx, r: r, count: count}
+ n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{
+ // Reads may be destructive but should be very fast,
+ // so we can't release the lock while copying data.
+ Atomic: true,
+ })
+ if err == tcpip.ErrWouldBlock {
+ return 0, syserror.ErrWouldBlock
+ }
+
+ if resCh != nil {
+ t := ctx.(*kernel.Task)
+ if err := t.Block(resCh); err != nil {
+ return 0, syserr.FromError(err).ToError()
+ }
+
+ n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{
+ Atomic: true, // See above.
+ })
+ }
+ if err == tcpip.ErrWouldBlock {
+ return n, syserror.ErrWouldBlock
+ } else if err != nil {
+ return int64(n), f.err // Propagate error.
+ }
+
+ return int64(n), nil
+}
+
// Readiness returns a mask of ready events for socket s.
func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
r := s.Endpoint.Readiness(mask)
@@ -774,8 +891,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- var size tcpip.SendBufferSizeOption
- if err := ep.GetSockOpt(&size); err != nil {
+ size, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -790,8 +907,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- var size tcpip.ReceiveBufferSizeOption
- if err := ep.GetSockOpt(&size); err != nil {
+ size, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -825,6 +942,19 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return int32(v), nil
+ case linux.SO_BINDTODEVICE:
+ var v tcpip.BindToDeviceOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ if len(v) == 0 {
+ return []byte{}, nil
+ }
+ if outLen < linux.IFNAMSIZ {
+ return nil, syserr.ErrInvalidArgument
+ }
+ return append([]byte(v), 0), nil
+
case linux.SO_BROADCAST:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
@@ -1162,7 +1292,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.SendBufferSizeOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.SendBufferSizeOption, int(v)))
case linux.SO_RCVBUF:
if len(optVal) < sizeOfInt32 {
@@ -1170,7 +1300,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, int(v)))
case linux.SO_REUSEADDR:
if len(optVal) < sizeOfInt32 {
@@ -1188,6 +1318,13 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
v := usermem.ByteOrder.Uint32(optVal)
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v)))
+ case linux.SO_BINDTODEVICE:
+ n := bytes.IndexByte(optVal, 0)
+ if n == -1 {
+ n = len(optVal)
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n])))
+
case linux.SO_BROADCAST:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -2057,7 +2194,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
n, _, err = s.Endpoint.Write(v, opts)
}
dontWait := flags&linux.MSG_DONTWAIT != 0
- if err == nil && (n >= int64(v.Size()) || dontWait) {
+ if err == nil && (n >= v.src.NumBytes() || dontWait) {
// Complete write.
return int(n), nil
}
@@ -2082,7 +2219,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
return 0, syserr.TranslateNetstackError(err)
}
- if err == nil && v.Size() == 0 || err != nil && err != tcpip.ErrWouldBlock {
+ if err == nil && v.src.NumBytes() == 0 || err != nil && err != tcpip.ErrWouldBlock {
return int(total), nil
}
@@ -2101,7 +2238,8 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
// SIOCGSTAMP is implemented by epsocket rather than all commonEndpoint
// sockets.
// TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP.
- if int(args[1].Int()) == syscall.SIOCGSTAMP {
+ switch args[1].Int() {
+ case syscall.SIOCGSTAMP:
s.readMu.Lock()
defer s.readMu.Unlock()
if !s.timestampValid {
@@ -2113,6 +2251,25 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
AddressSpaceActive: true,
})
return 0, err
+
+ case linux.TIOCINQ:
+ v, terr := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
+ if terr != nil {
+ return 0, syserr.TranslateNetstackError(terr).ToError()
+ }
+
+ // Add bytes removed from the endpoint but not yet sent to the caller.
+ v += len(s.readView)
+
+ if v > math.MaxInt32 {
+ v = math.MaxInt32
+ }
+
+ // Copy result to user-space.
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
}
return Ioctl(ctx, s.Endpoint, io, args)
@@ -2184,9 +2341,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc
return 0, err
case linux.TIOCOUTQ:
- var v tcpip.SendQueueSizeOption
- if err := ep.GetSockOpt(&v); err != nil {
- return 0, syserr.TranslateNetstackError(err).ToError()
+ v, terr := ep.GetSockOptInt(tcpip.SendQueueSizeOption)
+ if terr != nil {
+ return 0, syserr.TranslateNetstackError(terr).ToError()
}
if v > math.MaxInt32 {
@@ -2421,7 +2578,8 @@ func (s *SocketOperations) State() uint32 {
return 0
}
- if !s.isPacketBased() {
+ switch {
+ case s.skType == linux.SOCK_STREAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_TCP:
// TCP socket.
switch tcp.EndpointState(s.Endpoint.State()) {
case tcp.StateEstablished:
@@ -2450,9 +2608,26 @@ func (s *SocketOperations) State() uint32 {
// Internal or unknown state.
return 0
}
+ case s.skType == linux.SOCK_DGRAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_UDP:
+ // UDP socket.
+ switch udp.EndpointState(s.Endpoint.State()) {
+ case udp.StateInitial, udp.StateBound, udp.StateClosed:
+ return linux.TCP_CLOSE
+ case udp.StateConnected:
+ return linux.TCP_ESTABLISHED
+ default:
+ return 0
+ }
+ case s.skType == linux.SOCK_DGRAM && s.protocol == syscall.IPPROTO_ICMP || s.protocol == syscall.IPPROTO_ICMPV6:
+ // TODO(b/112063468): Export states for ICMP sockets.
+ case s.skType == linux.SOCK_RAW:
+ // TODO(b/112063468): Export states for raw sockets.
+ default:
+ // Unknown transport protocol, how did we make this socket?
+ log.Warningf("Unknown transport protocol for an existing socket: family=%v, type=%v, protocol=%v, internal type %v", s.family, s.skType, s.protocol, reflect.TypeOf(s.Endpoint).Elem())
+ return 0
}
- // TODO(b/112063468): Export states for UDP, ICMP, and raw sockets.
return 0
}
diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go
index 421f93dc4..0a9dfa6c3 100644
--- a/pkg/sentry/socket/epsocket/provider.go
+++ b/pkg/sentry/socket/epsocket/provider.go
@@ -65,7 +65,7 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in
// Raw sockets require CAP_NET_RAW.
creds := auth.CredentialsFromContext(ctx)
if !creds.HasCapability(linux.CAP_NET_RAW) {
- return 0, true, syserr.ErrPermissionDenied
+ return 0, true, syserr.ErrNotPermitted
}
switch protocol {
diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD
index 9e2e12799..445080aa4 100644
--- a/pkg/sentry/socket/netlink/port/BUILD
+++ b/pkg/sentry/socket/netlink/port/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "port",
diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD
index 5061dcbde..3a6baa308 100644
--- a/pkg/sentry/socket/rpcinet/BUILD
+++ b/pkg/sentry/socket/rpcinet/BUILD
@@ -1,5 +1,6 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -49,6 +50,14 @@ proto_library(
],
)
+cc_proto_library(
+ name = "syscall_rpc_cc_proto",
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [":syscall_rpc_proto"],
+)
+
go_proto_library(
name = "syscall_rpc_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto",
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 2b0ad6395..1867b3a5c 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -175,6 +175,10 @@ type Endpoint interface {
// types.
SetSockOpt(opt interface{}) *tcpip.Error
+ // SetSockOptInt sets a socket option for simple cases when a value has
+ // the int type.
+ SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error
+
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// tcpip.*Option types.
GetSockOpt(opt interface{}) *tcpip.Error
@@ -838,6 +842,10 @@ func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
+func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return nil
+}
+
func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
@@ -853,65 +861,63 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
return -1, tcpip.ErrQueueSizeNotSupported
}
return v, nil
- default:
- return -1, tcpip.ErrUnknownProtocolOption
- }
-}
-
-// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch o := opt.(type) {
- case tcpip.ErrorOption:
- return nil
- case *tcpip.SendQueueSizeOption:
+ case tcpip.SendQueueSizeOption:
e.Lock()
if !e.Connected() {
e.Unlock()
- return tcpip.ErrNotConnected
+ return -1, tcpip.ErrNotConnected
}
- qs := tcpip.SendQueueSizeOption(e.connected.SendQueuedSize())
+ v := e.connected.SendQueuedSize()
e.Unlock()
- if qs < 0 {
- return tcpip.ErrQueueSizeNotSupported
- }
- *o = qs
- return nil
-
- case *tcpip.PasscredOption:
- if e.Passcred() {
- *o = tcpip.PasscredOption(1)
- } else {
- *o = tcpip.PasscredOption(0)
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
}
- return nil
+ return int(v), nil
- case *tcpip.SendBufferSizeOption:
+ case tcpip.SendBufferSizeOption:
e.Lock()
if !e.Connected() {
e.Unlock()
- return tcpip.ErrNotConnected
+ return -1, tcpip.ErrNotConnected
}
- qs := tcpip.SendBufferSizeOption(e.connected.SendMaxQueueSize())
+ v := e.connected.SendMaxQueueSize()
e.Unlock()
- if qs < 0 {
- return tcpip.ErrQueueSizeNotSupported
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
}
- *o = qs
- return nil
+ return int(v), nil
- case *tcpip.ReceiveBufferSizeOption:
+ case tcpip.ReceiveBufferSizeOption:
e.Lock()
if e.receiver == nil {
e.Unlock()
- return tcpip.ErrNotConnected
+ return -1, tcpip.ErrNotConnected
}
- qs := tcpip.ReceiveBufferSizeOption(e.receiver.RecvMaxQueueSize())
+ v := e.receiver.RecvMaxQueueSize()
e.Unlock()
- if qs < 0 {
- return tcpip.ErrQueueSizeNotSupported
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
+ }
+ return int(v), nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ case *tcpip.PasscredOption:
+ if e.Passcred() {
+ *o = tcpip.PasscredOption(1)
+ } else {
+ *o = tcpip.PasscredOption(0)
}
- *o = qs
return nil
case *tcpip.KeepaliveEnabledOption:
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index 445d25010..7d7b42eba 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -1,5 +1,6 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -44,6 +45,12 @@ proto_library(
visibility = ["//visibility:public"],
)
+cc_proto_library(
+ name = "strace_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":strace_proto"],
+)
+
go_proto_library(
name = "strace_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/strace/strace_go_proto",
diff --git a/pkg/sentry/strace/linux64.go b/pkg/sentry/strace/linux64.go
index 3650fd6e1..5d57b75af 100644
--- a/pkg/sentry/strace/linux64.go
+++ b/pkg/sentry/strace/linux64.go
@@ -335,4 +335,5 @@ var linuxAMD64 = SyscallMap{
315: makeSyscallInfo("sched_getattr", Hex, Hex, Hex),
316: makeSyscallInfo("renameat2", FD, Path, Hex, Path, Hex),
317: makeSyscallInfo("seccomp", Hex, Hex, Hex),
+ 332: makeSyscallInfo("statx", FD, Path, Hex, Hex, Hex),
}
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index 33a40b9c6..e76ee27d2 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -74,6 +74,7 @@ go_library(
"//pkg/sentry/kernel/pipe",
"//pkg/sentry/kernel/sched",
"//pkg/sentry/kernel/shm",
+ "//pkg/sentry/kernel/signalfd",
"//pkg/sentry/kernel/time",
"//pkg/sentry/limits",
"//pkg/sentry/memmap",
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index 2f77d587b..72c383537 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -19,4 +19,4 @@ const (
_LINUX_SYSNAME = "Linux"
_LINUX_RELEASE = "4.4"
_LINUX_VERSION = "#1 SMP Sun Jan 10 15:06:54 PST 2016"
-)
+) \ No newline at end of file
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index 2e00a91ce..b9a8e3e21 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -1423,9 +1423,6 @@ func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error {
if err != nil {
return err
}
- if dirPath {
- return syserror.ENOENT
- }
return fileOpAt(t, dirFD, path, func(root *fs.Dirent, d *fs.Dirent, name string, _ uint) error {
if !fs.IsDir(d.Inode.StableAttr) {
@@ -1436,7 +1433,7 @@ func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error {
return err
}
- return d.Remove(t, root, name)
+ return d.Remove(t, root, name, dirPath)
})
}
diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go
index 3ab54271c..cd31e0649 100644
--- a/pkg/sentry/syscalls/linux/sys_read.go
+++ b/pkg/sentry/syscalls/linux/sys_read.go
@@ -72,6 +72,39 @@ func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "read", file)
}
+// Readahead implements readahead(2).
+func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ offset := args[1].Int64()
+ size := args[2].SizeT()
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the file is readable.
+ if !file.Flags().Read {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Check that the size is valid.
+ if int(size) < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Return EINVAL; if the underlying file type does not support readahead,
+ // then Linux will return EINVAL to indicate as much. In the future, we
+ // may extend this function to actually support readahead hints.
+ return 0, nil, syserror.EINVAL
+}
+
// Pread64 implements linux syscall pread64(2).
func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go
index 0104a94c0..fb6efd5d8 100644
--- a/pkg/sentry/syscalls/linux/sys_signal.go
+++ b/pkg/sentry/syscalls/linux/sys_signal.go
@@ -20,7 +20,10 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/signalfd"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -506,3 +509,77 @@ func RestartSyscall(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
t.Debugf("Restart block missing in restart_syscall(2). Did ptrace inject a return value of ERESTART_RESTARTBLOCK?")
return 0, nil, syserror.EINTR
}
+
+// sharedSignalfd is shared between the two calls.
+func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize uint, flags int32) (uintptr, *kernel.SyscallControl, error) {
+ // Copy in the signal mask.
+ mask, err := copyInSigSet(t, sigset, sigsetsize)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Always check for valid flags, even if not creating.
+ if flags&^(linux.SFD_NONBLOCK|linux.SFD_CLOEXEC) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Is this a change to an existing signalfd?
+ //
+ // The spec indicates that this should adjust the mask.
+ if fd != -1 {
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Is this a signalfd?
+ if s, ok := file.FileOperations.(*signalfd.SignalOperations); ok {
+ s.SetMask(mask)
+ return 0, nil, nil
+ }
+
+ // Not a signalfd.
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Create a new file.
+ file, err := signalfd.New(t, mask)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+
+ // Set appropriate flags.
+ file.SetFlags(fs.SettableFileFlags{
+ NonBlocking: flags&linux.SFD_NONBLOCK != 0,
+ })
+
+ // Create a new descriptor.
+ fd, err = t.NewFDFrom(0, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.SFD_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Done.
+ return uintptr(fd), nil, nil
+}
+
+// Signalfd implements the linux syscall signalfd(2).
+func Signalfd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ sigset := args[1].Pointer()
+ sigsetsize := args[2].SizeT()
+ return sharedSignalfd(t, fd, sigset, sigsetsize, 0)
+}
+
+// Signalfd4 implements the linux syscall signalfd4(2).
+func Signalfd4(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ sigset := args[1].Pointer()
+ sigsetsize := args[2].SizeT()
+ flags := args[3].Int()
+ return sharedSignalfd(t, fd, sigset, sigsetsize, flags)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 3bac4d90d..b5a72ce63 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -531,7 +531,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, syserror.ENOTSOCK
}
- if optLen <= 0 {
+ if optLen < 0 {
return 0, nil, syserror.EINVAL
}
if optLen > maxOptLen {
diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go
index 17e3dde1f..9f705ebca 100644
--- a/pkg/sentry/syscalls/linux/sys_splice.go
+++ b/pkg/sentry/syscalls/linux/sys_splice.go
@@ -29,9 +29,8 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB
total int64
n int64
err error
- ch chan struct{}
- inW bool
- outW bool
+ inCh chan struct{}
+ outCh chan struct{}
)
for opts.Length > 0 {
n, err = fs.Splice(t, outFile, inFile, opts)
@@ -43,35 +42,33 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB
break
}
- // Are we a registered waiter?
- if ch == nil {
- ch = make(chan struct{}, 1)
- }
- if !inW && !inFile.Flags().NonBlocking {
- w, _ := waiter.NewChannelEntry(ch)
- inFile.EventRegister(&w, EventMaskRead)
- defer inFile.EventUnregister(&w)
- inW = true // Registered.
- } else if !outW && !outFile.Flags().NonBlocking {
- w, _ := waiter.NewChannelEntry(ch)
- outFile.EventRegister(&w, EventMaskWrite)
- defer outFile.EventUnregister(&w)
- outW = true // Registered.
- }
-
- // Was anything registered? If no, everything is non-blocking.
- if !inW && !outW {
- break
- }
-
- if (!inW || inFile.Readiness(EventMaskRead) != 0) && (!outW || outFile.Readiness(EventMaskWrite) != 0) {
- // Something became ready, try again without blocking.
- continue
+ // Note that the blocking behavior here is a bit different than the
+ // normal pattern. Because we need to have both data to read and data
+ // to write simultaneously, we actually explicitly block on both of
+ // these cases in turn before returning to the splice operation.
+ if inFile.Readiness(EventMaskRead) == 0 {
+ if inCh == nil {
+ inCh = make(chan struct{}, 1)
+ inW, _ := waiter.NewChannelEntry(inCh)
+ inFile.EventRegister(&inW, EventMaskRead)
+ defer inFile.EventUnregister(&inW)
+ continue // Need to refresh readiness.
+ }
+ if err = t.Block(inCh); err != nil {
+ break
+ }
}
-
- // Block until there's data.
- if err = t.Block(ch); err != nil {
- break
+ if outFile.Readiness(EventMaskWrite) == 0 {
+ if outCh == nil {
+ outCh = make(chan struct{}, 1)
+ outW, _ := waiter.NewChannelEntry(outCh)
+ outFile.EventRegister(&outW, EventMaskWrite)
+ defer outFile.EventUnregister(&outW)
+ continue // Need to refresh readiness.
+ }
+ if err = t.Block(outCh); err != nil {
+ break
+ }
}
}
@@ -91,22 +88,29 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
// Get files.
+ inFile := t.GetFile(inFD)
+ if inFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer inFile.DecRef()
+
+ if !inFile.Flags().Read {
+ return 0, nil, syserror.EBADF
+ }
+
outFile := t.GetFile(outFD)
if outFile == nil {
return 0, nil, syserror.EBADF
}
defer outFile.DecRef()
- inFile := t.GetFile(inFD)
- if inFile == nil {
+ if !outFile.Flags().Write {
return 0, nil, syserror.EBADF
}
- defer inFile.DecRef()
- // Verify that the outfile Append flag is not set. Note that fs.Splice
- // itself validates that the output file is writable.
+ // Verify that the outfile Append flag is not set.
if outFile.Flags().Append {
- return 0, nil, syserror.EBADF
+ return 0, nil, syserror.EINVAL
}
// Verify that we have a regular infile. This is a requirement; the
@@ -142,7 +146,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
Length: count,
SrcOffset: true,
SrcStart: offset,
- }, false)
+ }, outFile.Flags().NonBlocking)
// Copy out the new offset.
if _, err := t.CopyOut(offsetAddr, n+offset); err != nil {
@@ -152,7 +156,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Send data using splice.
n, err = doSplice(t, outFile, inFile, fs.SpliceOpts{
Length: count,
- }, false)
+ }, outFile.Flags().NonBlocking)
}
// We can only pass a single file to handleIOError, so pick inFile
@@ -174,12 +178,6 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, syserror.EINVAL
}
- // Only non-blocking is meaningful. Note that unlike in Linux, this
- // flag is applied consistently. We will have either fully blocking or
- // non-blocking behavior below, regardless of the underlying files
- // being spliced to. It's unclear if this is a bug or not yet.
- nonBlocking := (flags & linux.SPLICE_F_NONBLOCK) != 0
-
// Get files.
outFile := t.GetFile(outFD)
if outFile == nil {
@@ -193,6 +191,13 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
}
defer inFile.DecRef()
+ // The operation is non-blocking if anything is non-blocking.
+ //
+ // N.B. This is a rather simplistic heuristic that avoids some
+ // poor edge case behavior since the exact semantics here are
+ // underspecified and vary between versions of Linux itself.
+ nonBlock := inFile.Flags().NonBlocking || outFile.Flags().NonBlocking || (flags&linux.SPLICE_F_NONBLOCK != 0)
+
// Construct our options.
//
// Note that exactly one of the underlying buffers must be a pipe. We
@@ -240,17 +245,17 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if inOffset != 0 || outOffset != 0 {
return 0, nil, syserror.ESPIPE
}
- default:
- return 0, nil, syserror.EINVAL
- }
- // We may not refer to the same pipe; otherwise it's a continuous loop.
- if inFile.Dirent.Inode.StableAttr.InodeID == outFile.Dirent.Inode.StableAttr.InodeID {
+ // We may not refer to the same pipe; otherwise it's a continuous loop.
+ if inFile.Dirent.Inode.StableAttr.InodeID == outFile.Dirent.Inode.StableAttr.InodeID {
+ return 0, nil, syserror.EINVAL
+ }
+ default:
return 0, nil, syserror.EINVAL
}
// Splice data.
- n, err := doSplice(t, outFile, inFile, opts, nonBlocking)
+ n, err := doSplice(t, outFile, inFile, opts, nonBlock)
// See above; inFile is chosen arbitrarily here.
return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "splice", inFile)
@@ -268,9 +273,6 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
return 0, nil, syserror.EINVAL
}
- // Only non-blocking is meaningful.
- nonBlocking := (flags & linux.SPLICE_F_NONBLOCK) != 0
-
// Get files.
outFile := t.GetFile(outFD)
if outFile == nil {
@@ -294,11 +296,14 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
return 0, nil, syserror.EINVAL
}
+ // The operation is non-blocking if anything is non-blocking.
+ nonBlock := inFile.Flags().NonBlocking || outFile.Flags().NonBlocking || (flags&linux.SPLICE_F_NONBLOCK != 0)
+
// Splice data.
n, err := doSplice(t, outFile, inFile, fs.SpliceOpts{
Length: count,
Dup: true,
- }, nonBlocking)
+ }, nonBlock)
// See above; inFile is chosen arbitrarily here.
return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "tee", inFile)
diff --git a/pkg/sentry/syscalls/linux/sys_time.go b/pkg/sentry/syscalls/linux/sys_time.go
index 4b3f043a2..b887fa9d7 100644
--- a/pkg/sentry/syscalls/linux/sys_time.go
+++ b/pkg/sentry/syscalls/linux/sys_time.go
@@ -15,6 +15,7 @@
package linux
import (
+ "fmt"
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -228,41 +229,35 @@ func clockNanosleepFor(t *kernel.Task, c ktime.Clock, dur time.Duration, rem use
timer.Destroy()
- var remaining time.Duration
- // Did we just block for the entire duration?
- if err == syserror.ETIMEDOUT {
- remaining = 0
- } else {
- remaining = dur - after.Sub(start)
+ switch err {
+ case syserror.ETIMEDOUT:
+ // Slept for entire timeout.
+ return nil
+ case syserror.ErrInterrupted:
+ // Interrupted.
+ remaining := dur - after.Sub(start)
if remaining < 0 {
remaining = time.Duration(0)
}
- }
- // Copy out remaining time.
- if err != nil && rem != usermem.Addr(0) {
- timeleft := linux.NsecToTimespec(remaining.Nanoseconds())
- if err := copyTimespecOut(t, rem, &timeleft); err != nil {
- return err
+ // Copy out remaining time.
+ if rem != 0 {
+ timeleft := linux.NsecToTimespec(remaining.Nanoseconds())
+ if err := copyTimespecOut(t, rem, &timeleft); err != nil {
+ return err
+ }
}
- }
-
- // Did we just block for the entire duration?
- if err == syserror.ETIMEDOUT {
- return nil
- }
- // If interrupted, arrange for a restart with the remaining duration.
- if err == syserror.ErrInterrupted {
+ // Arrange for a restart with the remaining duration.
t.SetSyscallRestartBlock(&clockNanosleepRestartBlock{
c: c,
duration: remaining,
rem: rem,
})
return kernel.ERESTART_RESTARTBLOCK
+ default:
+ panic(fmt.Sprintf("Impossible BlockWithTimer error %v", err))
}
-
- return err
}
// Nanosleep implements linux syscall Nanosleep(2).
diff --git a/pkg/sentry/syscalls/linux/sys_utsname.go b/pkg/sentry/syscalls/linux/sys_utsname.go
index 271ace08e..748e8dd8d 100644
--- a/pkg/sentry/syscalls/linux/sys_utsname.go
+++ b/pkg/sentry/syscalls/linux/sys_utsname.go
@@ -79,11 +79,11 @@ func Sethostname(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
return 0, nil, syserror.EINVAL
}
- name, err := t.CopyInString(nameAddr, int(size))
- if err != nil {
+ name := make([]byte, size)
+ if _, err := t.CopyInBytes(nameAddr, name); err != nil {
return 0, nil, err
}
- utsns.SetHostName(name)
+ utsns.SetHostName(string(name))
return 0, nil, nil
}
diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD
index 8aa6a3017..beb43ba13 100644
--- a/pkg/sentry/time/BUILD
+++ b/pkg/sentry/time/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/unimpl/BUILD b/pkg/sentry/unimpl/BUILD
index b69603da3..fc7614fff 100644
--- a/pkg/sentry/unimpl/BUILD
+++ b/pkg/sentry/unimpl/BUILD
@@ -1,5 +1,6 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -10,6 +11,12 @@ proto_library(
deps = ["//pkg/sentry/arch:registers_proto"],
)
+cc_proto_library(
+ name = "unimplemented_syscall_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":unimplemented_syscall_proto"],
+)
+
go_proto_library(
name = "unimplemented_syscall_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto",
diff --git a/pkg/sentry/usage/memory.go b/pkg/sentry/usage/memory.go
index f4326706a..d6ef644d8 100644
--- a/pkg/sentry/usage/memory.go
+++ b/pkg/sentry/usage/memory.go
@@ -277,8 +277,3 @@ func TotalMemory(memSize, used uint64) uint64 {
}
return memSize
}
-
-// IncrementalMappedAccounting controls whether host mapped memory is accounted
-// incrementally during map translation. This may be modified during early
-// initialization, and is read-only afterward.
-var IncrementalMappedAccounting = false
diff --git a/pkg/sentry/usermem/BUILD b/pkg/sentry/usermem/BUILD
index a5b4206bb..cc5d25762 100644
--- a/pkg/sentry/usermem/BUILD
+++ b/pkg/sentry/usermem/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "addr_range",
diff --git a/pkg/sentry/usermem/usermem.go b/pkg/sentry/usermem/usermem.go
index 6eced660a..7b1f312b1 100644
--- a/pkg/sentry/usermem/usermem.go
+++ b/pkg/sentry/usermem/usermem.go
@@ -16,6 +16,7 @@
package usermem
import (
+ "bytes"
"errors"
"io"
"strconv"
@@ -270,11 +271,10 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt
n, err := uio.CopyIn(ctx, addr, buf[done:done+readlen], opts)
// Look for the terminating zero byte, which may have occurred before
// hitting err.
- for i, c := range buf[done : done+n] {
- if c == 0 {
- return stringFromImmutableBytes(buf[:done+i]), nil
- }
+ if i := bytes.IndexByte(buf[done:done+n], byte(0)); i >= 0 {
+ return stringFromImmutableBytes(buf[:done+i]), nil
}
+
done += n
if err != nil {
return stringFromImmutableBytes(buf[:done]), err
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index 0f247bf77..eff4b44f6 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 86bde7fb3..7eb2b2821 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -199,8 +199,11 @@ type Dirent struct {
// Ino is the inode number.
Ino uint64
- // Off is this Dirent's offset.
- Off int64
+ // NextOff is the offset of the *next* Dirent in the directory; that is,
+ // FileDescription.Seek(NextOff, SEEK_SET) (as called by seekdir(3)) will
+ // cause the next call to FileDescription.IterDirents() to yield the next
+ // Dirent. (The offset of the first Dirent in a directory is always 0.)
+ NextOff int64
}
// IterDirentsCallback receives Dirents from FileDescriptionImpl.IterDirents.
diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD
index 00665c939..bdca80d37 100644
--- a/pkg/sleep/BUILD
+++ b/pkg/sleep/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/state/BUILD b/pkg/state/BUILD
index c0f3c658d..329904457 100644
--- a/pkg/state/BUILD
+++ b/pkg/state/BUILD
@@ -1,5 +1,6 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD
index e70f4a79f..8a865d229 100644
--- a/pkg/state/statefile/BUILD
+++ b/pkg/state/statefile/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/syserror/BUILD b/pkg/syserror/BUILD
index b149f9e02..bd3f9fd28 100644
--- a/pkg/syserror/BUILD
+++ b/pkg/syserror/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index df37c7d5a..3fd9e3134 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "tcpip",
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD
index 0d2637ee4..78df5a0b1 100644
--- a/pkg/tcpip/adapters/gonet/BUILD
+++ b/pkg/tcpip/adapters/gonet/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index 672f026b2..8ced960bb 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -60,7 +60,10 @@ func TestTimeouts(t *testing.T) {
func newLoopbackStack() (*stack.Stack, *tcpip.Error) {
// Create the stack and add a NIC.
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName, udp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()},
+ })
if err := s.CreateNIC(NICID, loopback.New()); err != nil {
return nil, err
diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD
index 3301967fb..b4e8d6810 100644
--- a/pkg/tcpip/buffer/BUILD
+++ b/pkg/tcpip/buffer/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "buffer",
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index afcabd51d..096ad71ab 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -586,3 +586,103 @@ func Payload(want []byte) TransportChecker {
}
}
}
+
+// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 and
+// potentially additional ICMPv4 header fields.
+func ICMPv4(checkers ...TransportChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ last := h[len(h)-1]
+
+ if p := last.TransportProtocol(); p != header.ICMPv4ProtocolNumber {
+ t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv4ProtocolNumber)
+ }
+
+ icmp := header.ICMPv4(last.Payload())
+ for _, f := range checkers {
+ f(t, icmp)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+}
+
+// ICMPv4Type creates a checker that checks the ICMPv4 Type field.
+func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
+ }
+ if got := icmpv4.Type(); got != want {
+ t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
+ }
+ }
+}
+
+// ICMPv4Code creates a checker that checks the ICMPv4 Code field.
+func ICMPv4Code(want byte) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
+ }
+ if got := icmpv4.Code(); got != want {
+ t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
+ }
+ }
+}
+
+// ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and
+// potentially additional ICMPv6 header fields.
+func ICMPv6(checkers ...TransportChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ last := h[len(h)-1]
+
+ if p := last.TransportProtocol(); p != header.ICMPv6ProtocolNumber {
+ t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv6ProtocolNumber)
+ }
+
+ icmp := header.ICMPv6(last.Payload())
+ for _, f := range checkers {
+ f(t, icmp)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+}
+
+// ICMPv6Type creates a checker that checks the ICMPv6 Type field.
+func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+ icmpv6, ok := h.(header.ICMPv6)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
+ }
+ if got := icmpv6.Type(); got != want {
+ t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
+ }
+ }
+}
+
+// ICMPv6Code creates a checker that checks the ICMPv6 Code field.
+func ICMPv6Code(want byte) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+ icmpv6, ok := h.(header.ICMPv6)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
+ }
+ if got := icmpv6.Code(); got != want {
+ t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
+ }
+ }
+}
diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD
index 29b30be9c..0c5c20cea 100644
--- a/pkg/tcpip/hash/jenkins/BUILD
+++ b/pkg/tcpip/hash/jenkins/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index 76ef02f13..b558350c3 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "header",
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
index c52c0d851..0cac6c0a5 100644
--- a/pkg/tcpip/header/icmpv4.go
+++ b/pkg/tcpip/header/icmpv4.go
@@ -18,6 +18,7 @@ import (
"encoding/binary"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
)
// ICMPv4 represents an ICMPv4 header stored in a byte array.
@@ -25,13 +26,29 @@ type ICMPv4 []byte
const (
// ICMPv4PayloadOffset defines the start of ICMP payload.
- ICMPv4PayloadOffset = 4
+ ICMPv4PayloadOffset = 8
// ICMPv4MinimumSize is the minimum size of a valid ICMP packet.
ICMPv4MinimumSize = 8
// ICMPv4ProtocolNumber is the ICMP transport protocol number.
ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1
+
+ // icmpv4ChecksumOffset is the offset of the checksum field
+ // in an ICMPv4 message.
+ icmpv4ChecksumOffset = 2
+
+ // icmpv4MTUOffset is the offset of the MTU field
+ // in a ICMPv4FragmentationNeeded message.
+ icmpv4MTUOffset = 6
+
+ // icmpv4IdentOffset is the offset of the ident field
+ // in a ICMPv4EchoRequest/Reply message.
+ icmpv4IdentOffset = 4
+
+ // icmpv4SequenceOffset is the offset of the sequence field
+ // in a ICMPv4EchoRequest/Reply message.
+ icmpv4SequenceOffset = 6
)
// ICMPv4Type is the ICMP type field described in RFC 792.
@@ -72,12 +89,12 @@ func (b ICMPv4) SetCode(c byte) { b[1] = c }
// Checksum is the ICMP checksum field.
func (b ICMPv4) Checksum() uint16 {
- return binary.BigEndian.Uint16(b[2:])
+ return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:])
}
// SetChecksum sets the ICMP checksum field.
func (b ICMPv4) SetChecksum(checksum uint16) {
- binary.BigEndian.PutUint16(b[2:], checksum)
+ binary.BigEndian.PutUint16(b[icmpv4ChecksumOffset:], checksum)
}
// SourcePort implements Transport.SourcePort.
@@ -102,3 +119,51 @@ func (ICMPv4) SetDestinationPort(uint16) {
func (b ICMPv4) Payload() []byte {
return b[ICMPv4PayloadOffset:]
}
+
+// MTU retrieves the MTU field from an ICMPv4 message.
+func (b ICMPv4) MTU() uint16 {
+ return binary.BigEndian.Uint16(b[icmpv4MTUOffset:])
+}
+
+// SetMTU sets the MTU field from an ICMPv4 message.
+func (b ICMPv4) SetMTU(mtu uint16) {
+ binary.BigEndian.PutUint16(b[icmpv4MTUOffset:], mtu)
+}
+
+// Ident retrieves the Ident field from an ICMPv4 message.
+func (b ICMPv4) Ident() uint16 {
+ return binary.BigEndian.Uint16(b[icmpv4IdentOffset:])
+}
+
+// SetIdent sets the Ident field from an ICMPv4 message.
+func (b ICMPv4) SetIdent(ident uint16) {
+ binary.BigEndian.PutUint16(b[icmpv4IdentOffset:], ident)
+}
+
+// Sequence retrieves the Sequence field from an ICMPv4 message.
+func (b ICMPv4) Sequence() uint16 {
+ return binary.BigEndian.Uint16(b[icmpv4SequenceOffset:])
+}
+
+// SetSequence sets the Sequence field from an ICMPv4 message.
+func (b ICMPv4) SetSequence(sequence uint16) {
+ binary.BigEndian.PutUint16(b[icmpv4SequenceOffset:], sequence)
+}
+
+// ICMPv4Checksum calculates the ICMP checksum over the provided ICMP header,
+// and payload.
+func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 {
+ // Calculate the IPv6 pseudo-header upper-layer checksum.
+ xsum := uint16(0)
+ for _, v := range vv.Views() {
+ xsum = Checksum(v, xsum)
+ }
+
+ // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum.
+ h2, h3 := h[2], h[3]
+ h[2], h[3] = 0, 0
+ xsum = ^Checksum(h, xsum)
+ h[2], h[3] = h2, h3
+
+ return xsum
+}
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index 3cc57e234..1125a7d14 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -18,6 +18,7 @@ import (
"encoding/binary"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
)
// ICMPv6 represents an ICMPv6 header stored in a byte array.
@@ -25,14 +26,18 @@ type ICMPv6 []byte
const (
// ICMPv6MinimumSize is the minimum size of a valid ICMP packet.
- ICMPv6MinimumSize = 4
+ ICMPv6MinimumSize = 8
+
+ // ICMPv6PayloadOffset is the offset of the payload in an
+ // ICMP packet.
+ ICMPv6PayloadOffset = 8
// ICMPv6ProtocolNumber is the ICMP transport protocol number.
ICMPv6ProtocolNumber tcpip.TransportProtocolNumber = 58
// ICMPv6NeighborSolicitMinimumSize is the minimum size of a
// neighbor solicitation packet.
- ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 4 + 16
+ ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 16
// ICMPv6NeighborAdvertSize is size of a neighbor advertisement.
ICMPv6NeighborAdvertSize = 32
@@ -42,11 +47,27 @@ const (
// ICMPv6DstUnreachableMinimumSize is the minimum size of a valid ICMP
// destination unreachable packet.
- ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize + 4
+ ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize
// ICMPv6PacketTooBigMinimumSize is the minimum size of a valid ICMP
// packet-too-big packet.
- ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize + 4
+ ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize
+
+ // icmpv6ChecksumOffset is the offset of the checksum field
+ // in an ICMPv6 message.
+ icmpv6ChecksumOffset = 2
+
+ // icmpv6MTUOffset is the offset of the MTU field in an ICMPv6
+ // PacketTooBig message.
+ icmpv6MTUOffset = 4
+
+ // icmpv6IdentOffset is the offset of the ident field
+ // in a ICMPv6 Echo Request/Reply message.
+ icmpv6IdentOffset = 4
+
+ // icmpv6SequenceOffset is the offset of the sequence field
+ // in a ICMPv6 Echo Request/Reply message.
+ icmpv6SequenceOffset = 6
)
// ICMPv6Type is the ICMP type field described in RFC 4443 and friends.
@@ -89,12 +110,12 @@ func (b ICMPv6) SetCode(c byte) { b[1] = c }
// Checksum is the ICMP checksum field.
func (b ICMPv6) Checksum() uint16 {
- return binary.BigEndian.Uint16(b[2:])
+ return binary.BigEndian.Uint16(b[icmpv6ChecksumOffset:])
}
// SetChecksum calculates and sets the ICMP checksum field.
func (b ICMPv6) SetChecksum(checksum uint16) {
- binary.BigEndian.PutUint16(b[2:], checksum)
+ binary.BigEndian.PutUint16(b[icmpv6ChecksumOffset:], checksum)
}
// SourcePort implements Transport.SourcePort.
@@ -115,7 +136,60 @@ func (ICMPv6) SetSourcePort(uint16) {
func (ICMPv6) SetDestinationPort(uint16) {
}
+// MTU retrieves the MTU field from an ICMPv6 message.
+func (b ICMPv6) MTU() uint32 {
+ return binary.BigEndian.Uint32(b[icmpv6MTUOffset:])
+}
+
+// SetMTU sets the MTU field from an ICMPv6 message.
+func (b ICMPv6) SetMTU(mtu uint32) {
+ binary.BigEndian.PutUint32(b[icmpv6MTUOffset:], mtu)
+}
+
+// Ident retrieves the Ident field from an ICMPv6 message.
+func (b ICMPv6) Ident() uint16 {
+ return binary.BigEndian.Uint16(b[icmpv6IdentOffset:])
+}
+
+// SetIdent sets the Ident field from an ICMPv6 message.
+func (b ICMPv6) SetIdent(ident uint16) {
+ binary.BigEndian.PutUint16(b[icmpv6IdentOffset:], ident)
+}
+
+// Sequence retrieves the Sequence field from an ICMPv6 message.
+func (b ICMPv6) Sequence() uint16 {
+ return binary.BigEndian.Uint16(b[icmpv6SequenceOffset:])
+}
+
+// SetSequence sets the Sequence field from an ICMPv6 message.
+func (b ICMPv6) SetSequence(sequence uint16) {
+ binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence)
+}
+
// Payload implements Transport.Payload.
func (b ICMPv6) Payload() []byte {
- return b[ICMPv6MinimumSize:]
+ return b[ICMPv6PayloadOffset:]
+}
+
+// ICMPv6Checksum calculates the ICMP checksum over the provided ICMP header,
+// IPv6 src/dst addresses and the payload.
+func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 {
+ // Calculate the IPv6 pseudo-header upper-layer checksum.
+ xsum := Checksum([]byte(src), 0)
+ xsum = Checksum([]byte(dst), xsum)
+ var upperLayerLength [4]byte
+ binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size()))
+ xsum = Checksum(upperLayerLength[:], xsum)
+ xsum = Checksum([]byte{0, 0, 0, uint8(ICMPv6ProtocolNumber)}, xsum)
+ for _, v := range vv.Views() {
+ xsum = Checksum(v, xsum)
+ }
+
+ // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum.
+ h2, h3 := h[2], h[3]
+ h[2], h[3] = 0, 0
+ xsum = ^Checksum(h, xsum)
+ h[2], h[3] = h2, h3
+
+ return xsum
}
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index 17fc9c68e..554632a64 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -21,16 +21,18 @@ import (
)
const (
- versIHL = 0
- tos = 1
- totalLen = 2
- id = 4
- flagsFO = 6
- ttl = 8
- protocol = 9
- checksum = 10
- srcAddr = 12
- dstAddr = 16
+ versIHL = 0
+ tos = 1
+ // IPv4TotalLenOffset is the offset of the total length field in the
+ // IPv4 header.
+ IPv4TotalLenOffset = 2
+ id = 4
+ flagsFO = 6
+ ttl = 8
+ protocol = 9
+ checksum = 10
+ srcAddr = 12
+ dstAddr = 16
)
// IPv4Fields contains the fields of an IPv4 packet. It is used to describe the
@@ -103,6 +105,11 @@ const (
// IPv4Any is the non-routable IPv4 "any" meta address.
IPv4Any tcpip.Address = "\x00\x00\x00\x00"
+
+ // IPv4MinimumProcessableDatagramSize is the minimum size of an IP
+ // packet that every IPv4 capable host must be able to
+ // process/reassemble.
+ IPv4MinimumProcessableDatagramSize = 576
)
// Flags that may be set in an IPv4 packet.
@@ -163,7 +170,7 @@ func (b IPv4) FragmentOffset() uint16 {
// TotalLength returns the "total length" field of the ipv4 header.
func (b IPv4) TotalLength() uint16 {
- return binary.BigEndian.Uint16(b[totalLen:])
+ return binary.BigEndian.Uint16(b[IPv4TotalLenOffset:])
}
// Checksum returns the checksum field of the ipv4 header.
@@ -209,7 +216,7 @@ func (b IPv4) SetTOS(v uint8, _ uint32) {
// SetTotalLength sets the "total length" field of the ipv4 header.
func (b IPv4) SetTotalLength(totalLength uint16) {
- binary.BigEndian.PutUint16(b[totalLen:], totalLength)
+ binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength)
}
// SetChecksum sets the checksum field of the ipv4 header.
@@ -265,7 +272,7 @@ func (b IPv4) Encode(i *IPv4Fields) {
// packets are produced.
func (b IPv4) EncodePartial(partialChecksum, totalLength uint16) {
b.SetTotalLength(totalLength)
- checksum := Checksum(b[totalLen:totalLen+2], partialChecksum)
+ checksum := Checksum(b[IPv4TotalLenOffset:IPv4TotalLenOffset+2], partialChecksum)
b.SetChecksum(^checksum)
}
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 31be42ce0..9d3abc0e4 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -22,12 +22,14 @@ import (
)
const (
- versTCFL = 0
- payloadLen = 4
- nextHdr = 6
- hopLimit = 7
- v6SrcAddr = 8
- v6DstAddr = 24
+ versTCFL = 0
+ // IPv6PayloadLenOffset is the offset of the PayloadLength field in
+ // IPv6 header.
+ IPv6PayloadLenOffset = 4
+ nextHdr = 6
+ hopLimit = 7
+ v6SrcAddr = 8
+ v6DstAddr = v6SrcAddr + IPv6AddressSize
)
// IPv6Fields contains the fields of an IPv6 packet. It is used to describe the
@@ -74,6 +76,13 @@ const (
// IPv6Version is the version of the ipv6 protocol.
IPv6Version = 6
+ // IPv6AllNodesMulticastAddress is a link-local multicast group that
+ // all IPv6 nodes MUST join, as per RFC 4291, section 2.8. Packets
+ // destined to this address will reach all nodes on a link.
+ //
+ // The address is ff02::1.
+ IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+
// IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460,
// section 5.
IPv6MinimumMTU = 1280
@@ -94,7 +103,7 @@ var IPv6EmptySubnet = func() tcpip.Subnet {
// PayloadLength returns the value of the "payload length" field of the ipv6
// header.
func (b IPv6) PayloadLength() uint16 {
- return binary.BigEndian.Uint16(b[payloadLen:])
+ return binary.BigEndian.Uint16(b[IPv6PayloadLenOffset:])
}
// HopLimit returns the value of the "hop limit" field of the ipv6 header.
@@ -119,13 +128,13 @@ func (b IPv6) Payload() []byte {
// SourceAddress returns the "source address" field of the ipv6 header.
func (b IPv6) SourceAddress() tcpip.Address {
- return tcpip.Address(b[v6SrcAddr : v6SrcAddr+IPv6AddressSize])
+ return tcpip.Address(b[v6SrcAddr:][:IPv6AddressSize])
}
// DestinationAddress returns the "destination address" field of the ipv6
// header.
func (b IPv6) DestinationAddress() tcpip.Address {
- return tcpip.Address(b[v6DstAddr : v6DstAddr+IPv6AddressSize])
+ return tcpip.Address(b[v6DstAddr:][:IPv6AddressSize])
}
// Checksum implements Network.Checksum. Given that IPv6 doesn't have a
@@ -148,18 +157,18 @@ func (b IPv6) SetTOS(t uint8, l uint32) {
// SetPayloadLength sets the "payload length" field of the ipv6 header.
func (b IPv6) SetPayloadLength(payloadLength uint16) {
- binary.BigEndian.PutUint16(b[payloadLen:], payloadLength)
+ binary.BigEndian.PutUint16(b[IPv6PayloadLenOffset:], payloadLength)
}
// SetSourceAddress sets the "source address" field of the ipv6 header.
func (b IPv6) SetSourceAddress(addr tcpip.Address) {
- copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], addr)
+ copy(b[v6SrcAddr:][:IPv6AddressSize], addr)
}
// SetDestinationAddress sets the "destination address" field of the ipv6
// header.
func (b IPv6) SetDestinationAddress(addr tcpip.Address) {
- copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], addr)
+ copy(b[v6DstAddr:][:IPv6AddressSize], addr)
}
// SetNextHeader sets the value of the "next header" field of the ipv6 header.
@@ -178,8 +187,8 @@ func (b IPv6) Encode(i *IPv6Fields) {
b.SetPayloadLength(i.PayloadLength)
b[nextHdr] = i.NextHeader
b[hopLimit] = i.HopLimit
- copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], i.SrcAddr)
- copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], i.DstAddr)
+ b.SetSourceAddress(i.SrcAddr)
+ b.SetDestinationAddress(i.DstAddr)
}
// IsValid performs basic validation on the packet.
@@ -219,6 +228,24 @@ func IsV6MulticastAddress(addr tcpip.Address) bool {
return addr[0] == 0xff
}
+// IsV6UnicastAddress determines if the provided address is a valid IPv6
+// unicast (and specified) address. That is, IsV6UnicastAddress returns
+// true if addr contains IPv6AddressSize bytes, is not the unspecified
+// address and is not a multicast address.
+func IsV6UnicastAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv6AddressSize {
+ return false
+ }
+
+ // Must not be unspecified
+ if addr == IPv6Any {
+ return false
+ }
+
+ // Return if not a multicast.
+ return addr[0] != 0xff
+}
+
// SolicitedNodeAddr computes the solicited-node multicast address. This is
// used for NDP. Described in RFC 4291. The argument must be a full-length IPv6
// address.
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
index c1f454805..74412c894 100644
--- a/pkg/tcpip/header/udp.go
+++ b/pkg/tcpip/header/udp.go
@@ -27,6 +27,11 @@ const (
udpChecksum = 6
)
+const (
+ // UDPMaximumPacketSize is the largest possible UDP packet.
+ UDPMaximumPacketSize = 0xffff
+)
+
// UDPFields contains the fields of a UDP packet. It is used to describe the
// fields of a packet that needs to be encoded.
type UDPFields struct {
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index c40744b8e..18adb2085 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -44,14 +44,12 @@ type Endpoint struct {
}
// New creates a new channel endpoint.
-func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) (tcpip.LinkEndpointID, *Endpoint) {
- e := &Endpoint{
+func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint {
+ return &Endpoint{
C: make(chan PacketInfo, size),
mtu: mtu,
linkAddr: linkAddr,
}
-
- return stack.RegisterLinkEndpoint(e), e
}
// Drain removes all outbound packets from the channel and counts them.
@@ -135,3 +133,6 @@ func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, hdr buffer.Prepen
return nil
}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (*Endpoint) Wait() {}
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
index d786d8fdf..8fa9e3984 100644
--- a/pkg/tcpip/link/fdbased/BUILD
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -8,8 +9,8 @@ go_library(
"endpoint.go",
"endpoint_unsafe.go",
"mmap.go",
- "mmap_amd64.go",
- "mmap_amd64_unsafe.go",
+ "mmap_stub.go",
+ "mmap_unsafe.go",
"packet_dispatchers.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased",
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index 77f988b9f..7636418b1 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -41,6 +41,7 @@ package fdbased
import (
"fmt"
+ "sync"
"syscall"
"golang.org/x/sys/unix"
@@ -81,6 +82,19 @@ const (
PacketMMap
)
+func (p PacketDispatchMode) String() string {
+ switch p {
+ case Readv:
+ return "Readv"
+ case RecvMMsg:
+ return "RecvMMsg"
+ case PacketMMap:
+ return "PacketMMap"
+ default:
+ return fmt.Sprintf("unknown packet dispatch mode %v", p)
+ }
+}
+
type endpoint struct {
// fds is the set of file descriptors each identifying one inbound/outbound
// channel. The endpoint will dispatch from all inbound channels as well as
@@ -114,6 +128,9 @@ type endpoint struct {
// gsoMaxSize is the maximum GSO packet size. It is zero if GSO is
// disabled.
gsoMaxSize uint32
+
+ // wg keeps track of running goroutines.
+ wg sync.WaitGroup
}
// Options specify the details about the fd-based endpoint to be created.
@@ -164,8 +181,9 @@ type Options struct {
// New creates a new fd-based endpoint.
//
// Makes fd non-blocking, but does not take ownership of fd, which must remain
-// open for the lifetime of the returned endpoint.
-func New(opts *Options) (tcpip.LinkEndpointID, error) {
+// open for the lifetime of the returned endpoint (until after the endpoint has
+// stopped being using and Wait returns).
+func New(opts *Options) (stack.LinkEndpoint, error) {
caps := stack.LinkEndpointCapabilities(0)
if opts.RXChecksumOffload {
caps |= stack.CapabilityRXChecksumOffload
@@ -190,7 +208,7 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) {
}
if len(opts.FDs) == 0 {
- return 0, fmt.Errorf("opts.FD is empty, at least one FD must be specified")
+ return nil, fmt.Errorf("opts.FD is empty, at least one FD must be specified")
}
e := &endpoint{
@@ -207,12 +225,12 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) {
for i := 0; i < len(e.fds); i++ {
fd := e.fds[i]
if err := syscall.SetNonblock(fd, true); err != nil {
- return 0, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", fd, err)
+ return nil, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", fd, err)
}
isSocket, err := isSocketFD(fd)
if err != nil {
- return 0, err
+ return nil, err
}
if isSocket {
if opts.GSOMaxSize != 0 {
@@ -222,12 +240,12 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) {
}
inboundDispatcher, err := createInboundDispatcher(e, fd, isSocket)
if err != nil {
- return 0, fmt.Errorf("createInboundDispatcher(...) = %v", err)
+ return nil, fmt.Errorf("createInboundDispatcher(...) = %v", err)
}
e.inboundDispatchers = append(e.inboundDispatchers, inboundDispatcher)
}
- return stack.RegisterLinkEndpoint(e), nil
+ return e, nil
}
func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher, error) {
@@ -290,7 +308,11 @@ func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
// saved, they stop sending outgoing packets and all incoming packets
// are rejected.
for i := range e.inboundDispatchers {
- go e.dispatchLoop(e.inboundDispatchers[i]) // S/R-SAFE: See above.
+ e.wg.Add(1)
+ go func(i int) { // S/R-SAFE: See above.
+ e.dispatchLoop(e.inboundDispatchers[i])
+ e.wg.Done()
+ }(i)
}
}
@@ -320,6 +342,12 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return e.addr
}
+// Wait implements stack.LinkEndpoint.Wait. It waits for the endpoint to stop
+// reading from its FD.
+func (e *endpoint) Wait() {
+ e.wg.Wait()
+}
+
// virtioNetHdr is declared in linux/virtio_net.h.
type virtioNetHdr struct {
flags uint8
@@ -435,14 +463,12 @@ func (e *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buf
}
// NewInjectable creates a new fd-based InjectableEndpoint.
-func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabilities) (tcpip.LinkEndpointID, *InjectableEndpoint) {
+func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabilities) *InjectableEndpoint {
syscall.SetNonblock(fd, true)
- e := &InjectableEndpoint{endpoint: endpoint{
+ return &InjectableEndpoint{endpoint: endpoint{
fds: []int{fd},
mtu: mtu,
caps: capabilities,
}}
-
- return stack.RegisterLinkEndpoint(e), e
}
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index e305252d6..04406bc9a 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -68,11 +68,10 @@ func newContext(t *testing.T, opt *Options) *context {
}
opt.FDs = []int{fds[1]}
- epID, err := New(opt)
+ ep, err := New(opt)
if err != nil {
t.Fatalf("Failed to create FD endpoint: %v", err)
}
- ep := stack.FindLinkEndpoint(epID).(*endpoint)
c := &context{
t: t,
diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go
index 2dca173c2..8bfeb97e4 100644
--- a/pkg/tcpip/link/fdbased/mmap.go
+++ b/pkg/tcpip/link/fdbased/mmap.go
@@ -12,12 +12,183 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build !linux !amd64
+// +build linux,amd64 linux,arm64
package fdbased
-// Stubbed out version for non-linux/non-amd64 platforms.
+import (
+ "encoding/binary"
+ "syscall"
-func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
- return nil, nil
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+)
+
+const (
+ tPacketAlignment = uintptr(16)
+ tpStatusKernel = 0
+ tpStatusUser = 1
+ tpStatusCopy = 2
+ tpStatusLosing = 4
+)
+
+// We overallocate the frame size to accommodate space for the
+// TPacketHdr+RawSockAddrLinkLayer+MAC header and any padding.
+//
+// Memory allocated for the ring buffer: tpBlockSize * tpBlockNR = 2 MiB
+//
+// NOTE:
+// Frames need to be aligned at 16 byte boundaries.
+// BlockSize needs to be page aligned.
+//
+// For details see PACKET_MMAP setting constraints in
+// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
+const (
+ tpFrameSize = 65536 + 128
+ tpBlockSize = tpFrameSize * 32
+ tpBlockNR = 1
+ tpFrameNR = (tpBlockSize * tpBlockNR) / tpFrameSize
+)
+
+// tPacketAlign aligns the pointer v at a tPacketAlignment boundary. Direct
+// translation of the TPACKET_ALIGN macro in <linux/if_packet.h>.
+func tPacketAlign(v uintptr) uintptr {
+ return (v + tPacketAlignment - 1) & uintptr(^(tPacketAlignment - 1))
+}
+
+// tPacketReq is the tpacket_req structure as described in
+// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
+type tPacketReq struct {
+ tpBlockSize uint32
+ tpBlockNR uint32
+ tpFrameSize uint32
+ tpFrameNR uint32
+}
+
+// tPacketHdr is tpacket_hdr structure as described in <linux/if_packet.h>
+type tPacketHdr []byte
+
+const (
+ tpStatusOffset = 0
+ tpLenOffset = 8
+ tpSnapLenOffset = 12
+ tpMacOffset = 16
+ tpNetOffset = 18
+ tpSecOffset = 20
+ tpUSecOffset = 24
+)
+
+func (t tPacketHdr) tpLen() uint32 {
+ return binary.LittleEndian.Uint32(t[tpLenOffset:])
+}
+
+func (t tPacketHdr) tpSnapLen() uint32 {
+ return binary.LittleEndian.Uint32(t[tpSnapLenOffset:])
+}
+
+func (t tPacketHdr) tpMac() uint16 {
+ return binary.LittleEndian.Uint16(t[tpMacOffset:])
+}
+
+func (t tPacketHdr) tpNet() uint16 {
+ return binary.LittleEndian.Uint16(t[tpNetOffset:])
+}
+
+func (t tPacketHdr) tpSec() uint32 {
+ return binary.LittleEndian.Uint32(t[tpSecOffset:])
+}
+
+func (t tPacketHdr) tpUSec() uint32 {
+ return binary.LittleEndian.Uint32(t[tpUSecOffset:])
+}
+
+func (t tPacketHdr) Payload() []byte {
+ return t[uint32(t.tpMac()) : uint32(t.tpMac())+t.tpSnapLen()]
+}
+
+// packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets.
+// See: mmap_amd64_unsafe.go for implementation details.
+type packetMMapDispatcher struct {
+ // fd is the file descriptor used to send and receive packets.
+ fd int
+
+ // e is the endpoint this dispatcher is attached to.
+ e *endpoint
+
+ // ringBuffer is only used when PacketMMap dispatcher is used and points
+ // to the start of the mmapped PACKET_RX_RING buffer.
+ ringBuffer []byte
+
+ // ringOffset is the current offset into the ring buffer where the next
+ // inbound packet will be placed by the kernel.
+ ringOffset int
+}
+
+func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) {
+ hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:])
+ for hdr.tpStatus()&tpStatusUser == 0 {
+ event := rawfile.PollEvent{
+ FD: int32(d.fd),
+ Events: unix.POLLIN | unix.POLLERR,
+ }
+ if _, errno := rawfile.BlockingPoll(&event, 1, nil); errno != 0 {
+ if errno == syscall.EINTR {
+ continue
+ }
+ return nil, rawfile.TranslateErrno(errno)
+ }
+ if hdr.tpStatus()&tpStatusCopy != 0 {
+ // This frame is truncated so skip it after flipping the
+ // buffer to the kernel.
+ hdr.setTPStatus(tpStatusKernel)
+ d.ringOffset = (d.ringOffset + 1) % tpFrameNR
+ hdr = (tPacketHdr)(d.ringBuffer[d.ringOffset*tpFrameSize:])
+ continue
+ }
+ }
+
+ // Copy out the packet from the mmapped frame to a locally owned buffer.
+ pkt := make([]byte, hdr.tpSnapLen())
+ copy(pkt, hdr.Payload())
+ // Release packet to kernel.
+ hdr.setTPStatus(tpStatusKernel)
+ d.ringOffset = (d.ringOffset + 1) % tpFrameNR
+ return pkt, nil
+}
+
+// dispatch reads packets from an mmaped ring buffer and dispatches them to the
+// network stack.
+func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
+ pkt, err := d.readMMappedPacket()
+ if err != nil {
+ return false, err
+ }
+ var (
+ p tcpip.NetworkProtocolNumber
+ remote, local tcpip.LinkAddress
+ )
+ if d.e.hdrSize > 0 {
+ eth := header.Ethernet(pkt)
+ p = eth.Type()
+ remote = eth.SourceAddress()
+ local = eth.DestinationAddress()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ switch header.IPVersion(pkt) {
+ case header.IPv4Version:
+ p = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ p = header.IPv6ProtocolNumber
+ default:
+ return true, nil
+ }
+ }
+
+ pkt = pkt[d.e.hdrSize:]
+ d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)}))
+ return true, nil
}
diff --git a/pkg/tcpip/link/fdbased/mmap_amd64.go b/pkg/tcpip/link/fdbased/mmap_amd64.go
deleted file mode 100644
index 029f86a18..000000000
--- a/pkg/tcpip/link/fdbased/mmap_amd64.go
+++ /dev/null
@@ -1,194 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// +build linux,amd64
-
-package fdbased
-
-import (
- "encoding/binary"
- "syscall"
-
- "golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
-)
-
-const (
- tPacketAlignment = uintptr(16)
- tpStatusKernel = 0
- tpStatusUser = 1
- tpStatusCopy = 2
- tpStatusLosing = 4
-)
-
-// We overallocate the frame size to accommodate space for the
-// TPacketHdr+RawSockAddrLinkLayer+MAC header and any padding.
-//
-// Memory allocated for the ring buffer: tpBlockSize * tpBlockNR = 2 MiB
-//
-// NOTE:
-// Frames need to be aligned at 16 byte boundaries.
-// BlockSize needs to be page aligned.
-//
-// For details see PACKET_MMAP setting constraints in
-// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
-const (
- tpFrameSize = 65536 + 128
- tpBlockSize = tpFrameSize * 32
- tpBlockNR = 1
- tpFrameNR = (tpBlockSize * tpBlockNR) / tpFrameSize
-)
-
-// tPacketAlign aligns the pointer v at a tPacketAlignment boundary. Direct
-// translation of the TPACKET_ALIGN macro in <linux/if_packet.h>.
-func tPacketAlign(v uintptr) uintptr {
- return (v + tPacketAlignment - 1) & uintptr(^(tPacketAlignment - 1))
-}
-
-// tPacketReq is the tpacket_req structure as described in
-// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
-type tPacketReq struct {
- tpBlockSize uint32
- tpBlockNR uint32
- tpFrameSize uint32
- tpFrameNR uint32
-}
-
-// tPacketHdr is tpacket_hdr structure as described in <linux/if_packet.h>
-type tPacketHdr []byte
-
-const (
- tpStatusOffset = 0
- tpLenOffset = 8
- tpSnapLenOffset = 12
- tpMacOffset = 16
- tpNetOffset = 18
- tpSecOffset = 20
- tpUSecOffset = 24
-)
-
-func (t tPacketHdr) tpLen() uint32 {
- return binary.LittleEndian.Uint32(t[tpLenOffset:])
-}
-
-func (t tPacketHdr) tpSnapLen() uint32 {
- return binary.LittleEndian.Uint32(t[tpSnapLenOffset:])
-}
-
-func (t tPacketHdr) tpMac() uint16 {
- return binary.LittleEndian.Uint16(t[tpMacOffset:])
-}
-
-func (t tPacketHdr) tpNet() uint16 {
- return binary.LittleEndian.Uint16(t[tpNetOffset:])
-}
-
-func (t tPacketHdr) tpSec() uint32 {
- return binary.LittleEndian.Uint32(t[tpSecOffset:])
-}
-
-func (t tPacketHdr) tpUSec() uint32 {
- return binary.LittleEndian.Uint32(t[tpUSecOffset:])
-}
-
-func (t tPacketHdr) Payload() []byte {
- return t[uint32(t.tpMac()) : uint32(t.tpMac())+t.tpSnapLen()]
-}
-
-// packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets.
-// See: mmap_amd64_unsafe.go for implementation details.
-type packetMMapDispatcher struct {
- // fd is the file descriptor used to send and receive packets.
- fd int
-
- // e is the endpoint this dispatcher is attached to.
- e *endpoint
-
- // ringBuffer is only used when PacketMMap dispatcher is used and points
- // to the start of the mmapped PACKET_RX_RING buffer.
- ringBuffer []byte
-
- // ringOffset is the current offset into the ring buffer where the next
- // inbound packet will be placed by the kernel.
- ringOffset int
-}
-
-func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) {
- hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:])
- for hdr.tpStatus()&tpStatusUser == 0 {
- event := rawfile.PollEvent{
- FD: int32(d.fd),
- Events: unix.POLLIN | unix.POLLERR,
- }
- if _, errno := rawfile.BlockingPoll(&event, 1, nil); errno != 0 {
- if errno == syscall.EINTR {
- continue
- }
- return nil, rawfile.TranslateErrno(errno)
- }
- if hdr.tpStatus()&tpStatusCopy != 0 {
- // This frame is truncated so skip it after flipping the
- // buffer to the kernel.
- hdr.setTPStatus(tpStatusKernel)
- d.ringOffset = (d.ringOffset + 1) % tpFrameNR
- hdr = (tPacketHdr)(d.ringBuffer[d.ringOffset*tpFrameSize:])
- continue
- }
- }
-
- // Copy out the packet from the mmapped frame to a locally owned buffer.
- pkt := make([]byte, hdr.tpSnapLen())
- copy(pkt, hdr.Payload())
- // Release packet to kernel.
- hdr.setTPStatus(tpStatusKernel)
- d.ringOffset = (d.ringOffset + 1) % tpFrameNR
- return pkt, nil
-}
-
-// dispatch reads packets from an mmaped ring buffer and dispatches them to the
-// network stack.
-func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
- pkt, err := d.readMMappedPacket()
- if err != nil {
- return false, err
- }
- var (
- p tcpip.NetworkProtocolNumber
- remote, local tcpip.LinkAddress
- )
- if d.e.hdrSize > 0 {
- eth := header.Ethernet(pkt)
- p = eth.Type()
- remote = eth.SourceAddress()
- local = eth.DestinationAddress()
- } else {
- // We don't get any indication of what the packet is, so try to guess
- // if it's an IPv4 or IPv6 packet.
- switch header.IPVersion(pkt) {
- case header.IPv4Version:
- p = header.IPv4ProtocolNumber
- case header.IPv6Version:
- p = header.IPv6ProtocolNumber
- default:
- return true, nil
- }
- }
-
- pkt = pkt[d.e.hdrSize:]
- d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)}))
- return true, nil
-}
diff --git a/pkg/tcpip/link/fdbased/mmap_stub.go b/pkg/tcpip/link/fdbased/mmap_stub.go
new file mode 100644
index 000000000..67be52d67
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/mmap_stub.go
@@ -0,0 +1,23 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !linux !amd64,!arm64
+
+package fdbased
+
+// Stubbed out version for non-linux/non-amd64/non-arm64 platforms.
+
+func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
+ return nil, nil
+}
diff --git a/pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go b/pkg/tcpip/link/fdbased/mmap_unsafe.go
index 47cb1d1cc..3894185ae 100644
--- a/pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go
+++ b/pkg/tcpip/link/fdbased/mmap_unsafe.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build linux,amd64
+// +build linux,amd64 linux,arm64
package fdbased
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index ab6a53988..b36629d2c 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -32,8 +32,8 @@ type endpoint struct {
// New creates a new loopback endpoint. This link-layer endpoint just turns
// outbound packets into inbound packets.
-func New() tcpip.LinkEndpointID {
- return stack.RegisterLinkEndpoint(&endpoint{})
+func New() stack.LinkEndpoint {
+ return &endpoint{}
}
// Attach implements stack.LinkEndpoint.Attach. It just saves the stack network-
@@ -85,3 +85,6 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependa
return nil
}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (*endpoint) Wait() {}
diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD
index ea12ef1ac..1bab380b0 100644
--- a/pkg/tcpip/link/muxed/BUILD
+++ b/pkg/tcpip/link/muxed/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index a577a3d52..7c946101d 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -104,10 +104,16 @@ func (m *InjectableEndpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *
return endpoint.WriteRawPacket(dest, packet)
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (m *InjectableEndpoint) Wait() {
+ for _, ep := range m.routes {
+ ep.Wait()
+ }
+}
+
// NewInjectableEndpoint creates a new multi-endpoint injectable endpoint.
-func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) (tcpip.LinkEndpointID, *InjectableEndpoint) {
- e := &InjectableEndpoint{
+func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint {
+ return &InjectableEndpoint{
routes: routes,
}
- return stack.RegisterLinkEndpoint(e), e
}
diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go
index 174b9330f..3086fec00 100644
--- a/pkg/tcpip/link/muxed/injectable_test.go
+++ b/pkg/tcpip/link/muxed/injectable_test.go
@@ -87,8 +87,8 @@ func makeTestInjectableEndpoint(t *testing.T) (*InjectableEndpoint, *os.File, tc
if err != nil {
t.Fatal("Failed to create socket pair:", err)
}
- _, underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone)
+ underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone)
routes := map[tcpip.Address]stack.InjectableLinkEndpoint{dstIP: underlyingEndpoint}
- _, endpoint := NewInjectableEndpoint(routes)
+ endpoint := NewInjectableEndpoint(routes)
return endpoint, os.NewFile(uintptr(pair[0]), "test route end"), dstIP
}
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
index 6e3a7a9d7..2e8bc772a 100644
--- a/pkg/tcpip/link/rawfile/BUILD
+++ b/pkg/tcpip/link/rawfile/BUILD
@@ -6,8 +6,9 @@ go_library(
name = "rawfile",
srcs = [
"blockingpoll_amd64.s",
- "blockingpoll_amd64_unsafe.go",
- "blockingpoll_unsafe.go",
+ "blockingpoll_arm64.s",
+ "blockingpoll_noyield_unsafe.go",
+ "blockingpoll_yield_unsafe.go",
"errors.go",
"rawfile_unsafe.go",
],
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_arm64.s b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
new file mode 100644
index 000000000..b62888b93
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// BlockingPoll makes the ppoll() syscall while calling the version of
+// entersyscall that relinquishes the P so that other Gs can run. This is meant
+// to be called in cases when the syscall is expected to block.
+//
+// func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (n int, err syscall.Errno)
+TEXT ·BlockingPoll(SB),NOSPLIT,$0-40
+ BL ·callEntersyscallblock(SB)
+ MOVD fds+0(FP), R0
+ MOVD nfds+8(FP), R1
+ MOVD timeout+16(FP), R2
+ MOVD $0x0, R3 // sigmask parameter which isn't used here
+ MOVD $0x49, R8 // SYS_PPOLL
+ SVC
+ CMP $0xfffffffffffff001, R0
+ BLS ok
+ MOVD $-1, R1
+ MOVD R1, n+24(FP)
+ NEG R0, R0
+ MOVD R0, err+32(FP)
+ BL ·callExitsyscall(SB)
+ RET
+ok:
+ MOVD R0, n+24(FP)
+ MOVD $0, err+32(FP)
+ BL ·callExitsyscall(SB)
+ RET
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go
index 84dc0e918..621ab8d29 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go
+++ b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build linux,!amd64
+// +build linux,!amd64,!arm64
package rawfile
@@ -22,7 +22,7 @@ import (
)
// BlockingPoll is just a stub function that forwards to the ppoll() system call
-// on non-amd64 platforms.
+// on non-amd64 and non-arm64 platforms.
func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno) {
n, _, e := syscall.Syscall6(syscall.SYS_PPOLL, uintptr(unsafe.Pointer(fds)),
uintptr(nfds), uintptr(unsafe.Pointer(timeout)), 0, 0, 0)
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
index 47039a446..dda3b10a6 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go
+++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build linux,amd64
+// +build linux,amd64 linux,arm64
// +build go1.12
// +build !go1.14
@@ -25,6 +25,12 @@ import (
_ "unsafe" // for go:linkname
)
+// BlockingPoll on amd64/arm64 makes the ppoll() syscall while calling the
+// version of entersyscall that relinquishes the P so that other Gs can
+// run. This is meant to be called in cases when the syscall is expected to
+// block. On non amd64/arm64 platforms it just forwards to the ppoll() system
+// call.
+//
//go:noescape
func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno)
diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD
index f2998aa98..0a5ea3dc4 100644
--- a/pkg/tcpip/link/sharedmem/BUILD
+++ b/pkg/tcpip/link/sharedmem/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD
index 94725cb11..330ed5e94 100644
--- a/pkg/tcpip/link/sharedmem/pipe/BUILD
+++ b/pkg/tcpip/link/sharedmem/pipe/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD
index 160a8f864..de1ce043d 100644
--- a/pkg/tcpip/link/sharedmem/queue/BUILD
+++ b/pkg/tcpip/link/sharedmem/queue/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 834ea5c40..9e71d4edf 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -94,7 +94,7 @@ type endpoint struct {
// New creates a new shared-memory-based endpoint. Buffers will be broken up
// into buffers of "bufferSize" bytes.
-func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (tcpip.LinkEndpointID, error) {
+func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (stack.LinkEndpoint, error) {
e := &endpoint{
mtu: mtu,
bufferSize: bufferSize,
@@ -102,15 +102,15 @@ func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (tc
}
if err := e.tx.init(bufferSize, &tx); err != nil {
- return 0, err
+ return nil, err
}
if err := e.rx.init(bufferSize, &rx); err != nil {
e.tx.cleanup()
- return 0, err
+ return nil, err
}
- return stack.RegisterLinkEndpoint(e), nil
+ return e, nil
}
// Close frees all resources associated with the endpoint.
@@ -132,7 +132,8 @@ func (e *endpoint) Close() {
}
}
-// Wait waits until all workers have stopped after a Close() call.
+// Wait implements stack.LinkEndpoint.Wait. It waits until all workers have
+// stopped after a Close() call.
func (e *endpoint) Wait() {
e.completed.Wait()
}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 98036f367..0e9ba0846 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -119,12 +119,12 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress
initQueue(t, &c.txq, &c.txCfg)
initQueue(t, &c.rxq, &c.rxCfg)
- id, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg)
+ ep, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg)
if err != nil {
t.Fatalf("New failed: %v", err)
}
- c.ep = stack.FindLinkEndpoint(id).(*endpoint)
+ c.ep = ep.(*endpoint)
c.ep.Attach(c)
return c
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 36c8c46fc..e401dce44 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -58,10 +58,10 @@ type endpoint struct {
// New creates a new sniffer link-layer endpoint. It wraps around another
// endpoint and logs packets and they traverse the endpoint.
-func New(lower tcpip.LinkEndpointID) tcpip.LinkEndpointID {
- return stack.RegisterLinkEndpoint(&endpoint{
- lower: stack.FindLinkEndpoint(lower),
- })
+func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
+ return &endpoint{
+ lower: lower,
+ }
}
func zoneOffset() (int32, error) {
@@ -102,15 +102,15 @@ func writePCAPHeader(w io.Writer, maxLen uint32) error {
// snapLen is the maximum amount of a packet to be saved. Packets with a length
// less than or equal too snapLen will be saved in their entirety. Longer
// packets will be truncated to snapLen.
-func NewWithFile(lower tcpip.LinkEndpointID, file *os.File, snapLen uint32) (tcpip.LinkEndpointID, error) {
+func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack.LinkEndpoint, error) {
if err := writePCAPHeader(file, snapLen); err != nil {
- return 0, err
+ return nil, err
}
- return stack.RegisterLinkEndpoint(&endpoint{
- lower: stack.FindLinkEndpoint(lower),
+ return &endpoint{
+ lower: lower,
file: file,
maxPCAPLen: snapLen,
- }), nil
+ }, nil
}
// DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is
@@ -240,6 +240,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
return e.lower.WritePacket(r, gso, hdr, payload, protocol)
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (*endpoint) Wait() {}
+
func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) {
// Figure out the network layer info.
var transProto uint8
diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD
index 2597d4b3e..0746dc8ec 100644
--- a/pkg/tcpip/link/waitable/BUILD
+++ b/pkg/tcpip/link/waitable/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index 3b6ac2ff7..5a1791cb5 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -40,11 +40,10 @@ type Endpoint struct {
// New creates a new waitable link-layer endpoint. It wraps around another
// endpoint and allows the caller to block new write/dispatch calls and wait for
// the inflight ones to finish before returning.
-func New(lower tcpip.LinkEndpointID) (tcpip.LinkEndpointID, *Endpoint) {
- e := &Endpoint{
- lower: stack.FindLinkEndpoint(lower),
+func New(lower stack.LinkEndpoint) *Endpoint {
+ return &Endpoint{
+ lower: lower,
}
- return stack.RegisterLinkEndpoint(e), e
}
// DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket.
@@ -121,3 +120,6 @@ func (e *Endpoint) WaitWrite() {
func (e *Endpoint) WaitDispatch() {
e.dispatchGate.Close()
}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (e *Endpoint) Wait() {}
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index 56e18ecb0..ae23c96b7 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -70,9 +70,12 @@ func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.P
return nil
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (*countedEndpoint) Wait() {}
+
func TestWaitWrite(t *testing.T) {
ep := &countedEndpoint{}
- _, wep := New(stack.RegisterLinkEndpoint(ep))
+ wep := New(ep)
// Write and check that it goes through.
wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0)
@@ -97,7 +100,7 @@ func TestWaitWrite(t *testing.T) {
func TestWaitDispatch(t *testing.T) {
ep := &countedEndpoint{}
- _, wep := New(stack.RegisterLinkEndpoint(ep))
+ wep := New(ep)
// Check that attach happens.
wep.Attach(ep)
@@ -139,7 +142,7 @@ func TestOtherMethods(t *testing.T) {
hdrLen: hdrLen,
linkAddr: linkAddr,
}
- _, wep := New(stack.RegisterLinkEndpoint(ep))
+ wep := New(ep)
if v := wep.MTU(); v != mtu {
t.Fatalf("Unexpected mtu: got=%v, want=%v", v, mtu)
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index f36f49453..9d16ff8c9 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_test")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index d95d44f56..df0d3a8c0 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index ea7296e6a..26cf1c528 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -16,9 +16,9 @@
// IPv4 addresses into link-local MAC addresses, and advertises IPv4
// addresses of its stack with the local network.
//
-// To use it in the networking stack, pass arp.ProtocolName as one of the
-// network protocols when calling stack.New. Then add an "arp" address to
-// every NIC on the stack that should respond to ARP requests. That is:
+// To use it in the networking stack, pass arp.NewProtocol() as one of the
+// network protocols when calling stack.New. Then add an "arp" address to every
+// NIC on the stack that should respond to ARP requests. That is:
//
// if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil {
// // handle err
@@ -33,9 +33,6 @@ import (
)
const (
- // ProtocolName is the string representation of the ARP protocol name.
- ProtocolName = "arp"
-
// ProtocolNumber is the ARP protocol number.
ProtocolNumber = header.ARPProtocolNumber
@@ -112,11 +109,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
copy(pkt.HardwareAddressTarget(), h.HardwareAddressSender())
copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender())
e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
- fallthrough // also fill the cache from requests
case header.ARPReply:
- addr := tcpip.Address(h.ProtocolAddressSender())
- linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
- e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr)
}
}
@@ -204,8 +197,7 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
-func init() {
- stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
- return &protocol{}
- })
+// NewProtocol returns an ARP network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ return &protocol{}
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 4c4b54469..88b57ec03 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -44,14 +44,19 @@ type testContext struct {
}
func newTestContext(t *testing.T) *testContext {
- s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{icmp.ProtocolName4}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()},
+ })
const defaultMTU = 65536
- id, linkEP := channel.New(256, defaultMTU, stackLinkAddr)
+ ep := channel.New(256, defaultMTU, stackLinkAddr)
+ wep := stack.LinkEndpoint(ep)
+
if testing.Verbose() {
- id = sniffer.New(id)
+ wep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, wep); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -73,7 +78,7 @@ func newTestContext(t *testing.T) *testContext {
return &testContext{
t: t,
s: s,
- linkEP: linkEP,
+ linkEP: ep,
}
}
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index 118bfc763..c5c7aad86 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "reassembler_list",
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 6bbfcd97f..a9741622e 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -144,6 +144,9 @@ func (*testObject) LinkAddress() tcpip.LinkAddress {
return ""
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (*testObject) Wait() {}
+
// WritePacket is called by network endpoints after producing a packet and
// writing it to the link endpoint. This is used by the test object to verify
// that the produced packet is as expected.
@@ -169,7 +172,10 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prepen
}
func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
- s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ })
s.CreateNIC(1, loopback.New())
s.AddAddress(1, ipv4.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
@@ -182,7 +188,10 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
}
func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
- s := stack.New([]string{ipv6.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ })
s.CreateNIC(1, loopback.New())
s.AddAddress(1, ipv6.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
@@ -319,7 +328,8 @@ func TestIPv4ReceiveControl(t *testing.T) {
icmp := header.ICMPv4(view[header.IPv4MinimumSize:])
icmp.SetType(header.ICMPv4DstUnreachable)
icmp.SetCode(c.code)
- copy(view[header.IPv4MinimumSize+header.ICMPv4PayloadOffset:], []byte{0xde, 0xad, 0xbe, 0xef})
+ icmp.SetIdent(0xdead)
+ icmp.SetSequence(0xbeef)
// Create the inner IPv4 header.
ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:])
@@ -539,7 +549,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
defer ep.Close()
- dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize + 4
+ dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize
if c.fragmentOffset != nil {
dataOffset += header.IPv6FragmentHeaderSize
}
@@ -559,10 +569,11 @@ func TestIPv6ReceiveControl(t *testing.T) {
icmp := header.ICMPv6(view[header.IPv6MinimumSize:])
icmp.SetType(c.typ)
icmp.SetCode(c.code)
- copy(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize:], []byte{0xde, 0xad, 0xbe, 0xef})
+ icmp.SetIdent(0xdead)
+ icmp.SetSequence(0xbeef)
// Create the inner IPv6 header.
- ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:])
+ ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:])
ip.Encode(&header.IPv6Fields{
PayloadLength: 100,
NextHeader: 10,
@@ -574,7 +585,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Build the fragmentation header if needed.
if c.fragmentOffset != nil {
ip.SetNextHeader(header.IPv6FragmentHeader)
- frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:])
+ frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize:])
frag.Encode(&header.IPv6FragmentFields{
NextHeader: 10,
FragmentOffset: *c.fragmentOffset,
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index be84fa63d..58e537aad 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 497164cbb..a25756443 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -15,8 +15,6 @@
package ipv4
import (
- "encoding/binary"
-
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -117,7 +115,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
e.handleControl(stack.ControlPortUnreachable, 0, vv)
case header.ICMPv4FragmentationNeeded:
- mtu := uint32(binary.BigEndian.Uint16(v[header.ICMPv4PayloadOffset+2:]))
+ mtu := uint32(h.MTU())
e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv)
}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index b7a06f525..b7b07a6c1 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -14,9 +14,9 @@
// Package ipv4 contains the implementation of the ipv4 network protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing ipv4.ProtocolName (or "ipv4") as one of the
-// network protocols when calling stack.New(). Then endpoints can be created
-// by passing ipv4.ProtocolNumber as the network protocol number when calling
+// activated on the stack by passing ipv4.NewProtocol() as one of the network
+// protocols when calling stack.New(). Then endpoints can be created by passing
+// ipv4.ProtocolNumber as the network protocol number when calling
// Stack.NewEndpoint().
package ipv4
@@ -32,9 +32,6 @@ import (
)
const (
- // ProtocolName is the string representation of the ipv4 protocol name.
- ProtocolName = "ipv4"
-
// ProtocolNumber is the ipv4 protocol number.
ProtocolNumber = header.IPv4ProtocolNumber
@@ -53,6 +50,7 @@ type endpoint struct {
linkEP stack.LinkEndpoint
dispatcher stack.TransportDispatcher
fragmentation *fragmentation.Fragmentation
+ protocol *protocol
}
// NewEndpoint creates a new ipv4 endpoint.
@@ -64,6 +62,7 @@ func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWi
linkEP: linkEP,
dispatcher: dispatcher,
fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+ protocol: p,
}
return e, nil
@@ -204,7 +203,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
if length > header.IPv4MaximumHeaderSize+8 {
// Packets of 68 bytes or less are required by RFC 791 to not be
// fragmented, so we only assign ids to larger packets.
- id = atomic.AddUint32(&ids[hashRoute(r, protocol)%buckets], 1)
+ id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, protocol, e.protocol.hashIV)%buckets], 1)
}
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
@@ -267,7 +266,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect
if payload.Size() > header.IPv4MaximumHeaderSize+8 {
// Packets of 68 bytes or less are required by RFC 791 to not be
// fragmented, so we only assign ids to larger packets.
- id = atomic.AddUint32(&ids[hashRoute(r, 0 /* protocol */)%buckets], 1)
+ id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)
}
ip.SetID(uint16(id))
}
@@ -325,14 +324,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
// Close cleans up resources associated with the endpoint.
func (e *endpoint) Close() {}
-type protocol struct{}
-
-// NewProtocol creates a new protocol ipv4 protocol descriptor. This is exported
-// only for tests that short-circuit the stack. Regular use of the protocol is
-// done via the stack, which gets a protocol descriptor from the init() function
-// below.
-func NewProtocol() stack.NetworkProtocol {
- return &protocol{}
+type protocol struct {
+ ids []uint32
+ hashIV uint32
}
// Number returns the ipv4 protocol number.
@@ -378,7 +372,7 @@ func calculateMTU(mtu uint32) uint32 {
// hashRoute calculates a hash value for the given route. It uses the source &
// destination address, the transport protocol number, and a random initial
// value (generated once on initialization) to generate the hash.
-func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 {
+func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
t := r.LocalAddress
a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
t = r.RemoteAddress
@@ -386,22 +380,16 @@ func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 {
return hash.Hash3Words(a, b, uint32(protocol), hashIV)
}
-var (
- ids []uint32
- hashIV uint32
-)
-
-func init() {
- ids = make([]uint32, buckets)
+// NewProtocol returns an IPv4 network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ ids := make([]uint32, buckets)
// Randomly initialize hashIV and the ids.
r := hash.RandN32(1 + buckets)
for i := range ids {
ids[i] = r[i]
}
- hashIV = r[buckets]
+ hashIV := r[buckets]
- stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
- return &protocol{}
- })
+ return &protocol{ids: ids, hashIV: hashIV}
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 1b5a55bea..b6641ccc3 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -33,14 +33,17 @@ import (
)
func TestExcludeBroadcast(t *testing.T) {
- s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
const defaultMTU = 65536
- id, _ := channel.New(256, defaultMTU, "")
+ ep := stack.LinkEndpoint(channel.New(256, defaultMTU, ""))
if testing.Verbose() {
- id = sniffer.New(id)
+ ep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -184,15 +187,12 @@ type errorChannel struct {
// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket
// will return successive errors from packetCollectorErrors until the list is
// empty and then return nil each time.
-func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) (tcpip.LinkEndpointID, *errorChannel) {
- _, e := channel.New(size, mtu, linkAddr)
- ec := errorChannel{
- Endpoint: e,
+func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel {
+ return &errorChannel{
+ Endpoint: channel.New(size, mtu, linkAddr),
Ch: make(chan packetInfo, size),
packetCollectorErrors: packetCollectorErrors,
}
-
- return stack.RegisterLinkEndpoint(e), &ec
}
// packetInfo holds all the information about an outbound packet.
@@ -241,10 +241,11 @@ type context struct {
func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context {
// Make the packet and write it.
- s := stack.New([]string{ipv4.ProtocolName}, []string{}, stack.Options{})
- _, linkEP := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
- linkEPId := stack.RegisterLinkEndpoint(linkEP)
- s.CreateNIC(1, linkEPId)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ })
+ ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
+ s.CreateNIC(1, ep)
const (
src = "\x10\x00\x00\x01"
dst = "\x10\x00\x00\x02"
@@ -266,7 +267,7 @@ func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32
}
return context{
Route: r,
- linkEP: linkEP,
+ linkEP: ep,
}
}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index fae7f4507..f06622a8b 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -23,7 +24,11 @@ go_library(
go_test(
name = "ipv6_test",
size = "small",
- srcs = ["icmp_test.go"],
+ srcs = [
+ "icmp_test.go",
+ "ipv6_test.go",
+ "ndp_test.go",
+ ],
embed = [":ipv6"],
deps = [
"//pkg/tcpip",
@@ -33,6 +38,7 @@ go_test(
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/udp",
"//pkg/waiter",
],
)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 5e6a59e91..b4d0295bf 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -15,14 +15,21 @@
package ipv6
import (
- "encoding/binary"
-
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+const (
+ // ndpHopLimit is the expected IP hop limit value of 255 for received
+ // NDP packets, as per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1,
+ // 7.1.2 and 8.1. If the hop limit value is not 255, nodes MUST silently
+ // drop the NDP packet. All outgoing NDP packets must use this value for
+ // its IP hop limit field.
+ ndpHopLimit = 255
+)
+
// handleControl handles the case when an ICMP packet contains the headers of
// the original packet that caused the ICMP one to be sent. This information is
// used to find out which transport endpoint must be notified about the ICMP
@@ -73,6 +80,21 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
}
h := header.ICMPv6(v)
+ // As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and
+ // 8.1, nodes MUST silently drop NDP packets where the Hop Limit field
+ // in the IPv6 header is not set to 255.
+ switch h.Type() {
+ case header.ICMPv6NeighborSolicit,
+ header.ICMPv6NeighborAdvert,
+ header.ICMPv6RouterSolicit,
+ header.ICMPv6RouterAdvert,
+ header.ICMPv6RedirectMsg:
+ if header.IPv6(netHeader).HopLimit() != ndpHopLimit {
+ received.Invalid.Increment()
+ return
+ }
+ }
+
// TODO(b/112892170): Meaningfully handle all ICMP types.
switch h.Type() {
case header.ICMPv6PacketTooBig:
@@ -82,7 +104,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
return
}
vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
- mtu := binary.BigEndian.Uint32(v[header.ICMPv6MinimumSize:])
+ mtu := h.MTU()
e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv)
case header.ICMPv6DstUnreachable:
@@ -100,13 +122,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
case header.ICMPv6NeighborSolicit:
received.NeighborSolicit.Increment()
- e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
-
if len(v) < header.ICMPv6NeighborSolicitMinimumSize {
received.Invalid.Increment()
return
}
- targetAddr := tcpip.Address(v[8:][:16])
+ targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize])
if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 {
// We don't have a useful answer; the best we can do is ignore the request.
return
@@ -132,7 +152,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
r := r.Clone()
defer r.Release()
r.LocalAddress = targetAddr
- pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil {
sent.Dropped.Increment()
@@ -146,7 +166,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
received.Invalid.Increment()
return
}
- targetAddr := tcpip.Address(v[8:][:16])
+ targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize])
e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress)
if targetAddr != r.RemoteAddress {
e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
@@ -164,7 +184,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
copy(pkt, h)
pkt.SetType(header.ICMPv6EchoReply)
- pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, vv))
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, vv))
if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil {
sent.Dropped.Increment()
return
@@ -235,7 +255,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
pkt[icmpV6OptOffset] = ndpOptSrcLinkAddr
pkt[icmpV6LengthOffset] = 1
copy(pkt[icmpV6LengthOffset+1:], linkEP.LinkAddress())
- pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
length := uint16(hdr.UsedLength())
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
@@ -274,24 +294,3 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo
}
return "", false
}
-
-func icmpChecksum(h header.ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 {
- // Calculate the IPv6 pseudo-header upper-layer checksum.
- xsum := header.Checksum([]byte(src), 0)
- xsum = header.Checksum([]byte(dst), xsum)
- var upperLayerLength [4]byte
- binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size()))
- xsum = header.Checksum(upperLayerLength[:], xsum)
- xsum = header.Checksum([]byte{0, 0, 0, uint8(header.ICMPv6ProtocolNumber)}, xsum)
- for _, v := range vv.Views() {
- xsum = header.Checksum(v, xsum)
- }
-
- // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum.
- h2, h3 := h[2], h[3]
- h[2], h[3] = 0, 0
- xsum = ^header.Checksum(h, xsum)
- h[2], h[3] = h2, h3
-
- return xsum
-}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index d0dc72506..01f5a17ec 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -81,10 +81,12 @@ func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.Li
}
func TestICMPCounts(t *testing.T) {
- s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ })
{
- id := stack.RegisterLinkEndpoint(&stubLinkEndpoint{})
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
}
if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
@@ -153,7 +155,7 @@ func TestICMPCounts(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size)
pkt := header.ICMPv6(hdr.Prepend(typ.size))
pkt.SetType(typ.typ)
- pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
handleIPv6Payload(hdr)
}
@@ -206,41 +208,38 @@ func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapab
func newTestContext(t *testing.T) *testContext {
c := &testContext{
- s0: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}),
- s1: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}),
+ s0: stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ }),
+ s1: stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ }),
}
const defaultMTU = 65536
- _, linkEP0 := channel.New(256, defaultMTU, linkAddr0)
- c.linkEP0 = linkEP0
- wrappedEP0 := endpointWithResolutionCapability{LinkEndpoint: linkEP0}
- id0 := stack.RegisterLinkEndpoint(wrappedEP0)
+ c.linkEP0 = channel.New(256, defaultMTU, linkAddr0)
+
+ wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0})
if testing.Verbose() {
- id0 = sniffer.New(id0)
+ wrappedEP0 = sniffer.New(wrappedEP0)
}
- if err := c.s0.CreateNIC(1, id0); err != nil {
+ if err := c.s0.CreateNIC(1, wrappedEP0); err != nil {
t.Fatalf("CreateNIC s0: %v", err)
}
if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
t.Fatalf("AddAddress lladdr0: %v", err)
}
- if err := c.s0.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr0)); err != nil {
- t.Fatalf("AddAddress sn lladdr0: %v", err)
- }
- _, linkEP1 := channel.New(256, defaultMTU, linkAddr1)
- c.linkEP1 = linkEP1
- wrappedEP1 := endpointWithResolutionCapability{LinkEndpoint: linkEP1}
- id1 := stack.RegisterLinkEndpoint(wrappedEP1)
- if err := c.s1.CreateNIC(1, id1); err != nil {
+ c.linkEP1 = channel.New(256, defaultMTU, linkAddr1)
+ wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1})
+ if err := c.s1.CreateNIC(1, wrappedEP1); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil {
t.Fatalf("AddAddress lladdr1: %v", err)
}
- if err := c.s1.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr1)); err != nil {
- t.Fatalf("AddAddress sn lladdr1: %v", err)
- }
subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
if err != nil {
@@ -321,7 +320,7 @@ func TestLinkResolution(t *testing.T) {
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
pkt.SetType(header.ICMPv6EchoRequest)
- pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
payload := tcpip.SlicePayload(hdr.View())
// We can't send our payload directly over the route because that
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 331a8bdaa..7de6a4546 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -14,9 +14,9 @@
// Package ipv6 contains the implementation of the ipv6 network protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing ipv6.ProtocolName (or "ipv6") as one of the
-// network protocols when calling stack.New(). Then endpoints can be created
-// by passing ipv6.ProtocolNumber as the network protocol number when calling
+// activated on the stack by passing ipv6.NewProtocol() as one of the network
+// protocols when calling stack.New(). Then endpoints can be created by passing
+// ipv6.ProtocolNumber as the network protocol number when calling
// Stack.NewEndpoint().
package ipv6
@@ -28,9 +28,6 @@ import (
)
const (
- // ProtocolName is the string representation of the ipv6 protocol name.
- ProtocolName = "ipv6"
-
// ProtocolNumber is the ipv6 protocol number.
ProtocolNumber = header.IPv6ProtocolNumber
@@ -160,14 +157,6 @@ func (*endpoint) Close() {}
type protocol struct{}
-// NewProtocol creates a new protocol ipv6 protocol descriptor. This is exported
-// only for tests that short-circuit the stack. Regular use of the protocol is
-// done via the stack, which gets a protocol descriptor from the init() function
-// below.
-func NewProtocol() stack.NetworkProtocol {
- return &protocol{}
-}
-
// Number returns the ipv6 protocol number.
func (p *protocol) Number() tcpip.NetworkProtocolNumber {
return ProtocolNumber
@@ -221,8 +210,7 @@ func calculateMTU(mtu uint32) uint32 {
return maxPayloadSize
}
-func init() {
- stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
- return &protocol{}
- })
+// NewProtocol returns an IPv6 network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ return &protocol{}
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
new file mode 100644
index 000000000..78c674c2c
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -0,0 +1,266 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ // The least significant 3 bytes are the same as addr2 so both addr2 and
+ // addr3 will have the same solicited-node address.
+ addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02"
+)
+
+// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the
+// expected Neighbor Advertisement received count after receiving the packet.
+func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) {
+ t.Helper()
+
+ // Receive ICMP packet.
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+
+ e.Inject(ProtocolNumber, hdr.View().ToVectorisedView())
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+
+ if got := stats.NeighborAdvert.Value(); got != want {
+ t.Fatalf("got NeighborAdvert = %d, want = %d", got, want)
+ }
+}
+
+// testReceiveICMP tests receiving a UDP packet from src to dst. want is the
+// expected UDP received count after receiving the packet.
+func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) {
+ t.Helper()
+
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ defer ep.Close()
+
+ if err := ep.Bind(tcpip.FullAddress{Addr: dst, Port: 80}); err != nil {
+ t.Fatalf("ep.Bind(...) failed: %v", err)
+ }
+
+ // Receive UDP Packet.
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize)
+ u := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: header.UDPMinimumSize,
+ })
+
+ // UDP pseudo-header checksum.
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize)
+
+ // UDP checksum
+ sum = header.Checksum(header.UDP([]byte{}), sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+
+ e.Inject(ProtocolNumber, hdr.View().ToVectorisedView())
+
+ stat := s.Stats().UDP.PacketsReceived
+
+ if got := stat.Value(); got != want {
+ t.Fatalf("got UDPPacketsReceived = %d, want = %d", got, want)
+ }
+}
+
+// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and
+// UDP packets destined to the IPv6 link-local all-nodes multicast address.
+func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
+ tests := []struct {
+ name string
+ protocolFactory stack.TransportProtocol
+ rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
+ }{
+ {"ICMP", icmp.NewProtocol6(), testReceiveICMP},
+ {"UDP", udp.NewProtocol(), testReceiveUDP},
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
+ })
+ e := channel.New(10, 1280, linkAddr1)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ // Should receive a packet destined to the all-nodes
+ // multicast address.
+ test.rxf(t, s, e, addr1, header.IPv6AllNodesMulticastAddress, 1)
+ })
+ }
+}
+
+// TestReceiveOnSolicitedNodeAddr tests that IPv6 endpoints receive ICMP and UDP
+// packets destined to the IPv6 solicited-node address of an assigned IPv6
+// address.
+func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
+ tests := []struct {
+ name string
+ protocolFactory stack.TransportProtocol
+ rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
+ }{
+ {"ICMP", icmp.NewProtocol6(), testReceiveICMP},
+ {"UDP", udp.NewProtocol(), testReceiveUDP},
+ }
+
+ snmc := header.SolicitedNodeAddr(addr2)
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
+ })
+ e := channel.New(10, 1280, linkAddr1)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ // Should not receive a packet destined to the solicited
+ // node address of addr2/addr3 yet as we haven't added
+ // those addresses.
+ test.rxf(t, s, e, addr1, snmc, 0)
+
+ if err := s.AddAddress(1, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr2, err)
+ }
+
+ // Should receive a packet destined to the solicited
+ // node address of addr2/addr3 now that we have added
+ // added addr2.
+ test.rxf(t, s, e, addr1, snmc, 1)
+
+ if err := s.AddAddress(1, ProtocolNumber, addr3); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr3, err)
+ }
+
+ // Should still receive a packet destined to the
+ // solicited node address of addr2/addr3 now that we
+ // have added addr3.
+ test.rxf(t, s, e, addr1, snmc, 2)
+
+ if err := s.RemoveAddress(1, addr2); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr2, err)
+ }
+
+ // Should still receive a packet destined to the
+ // solicited node address of addr2/addr3 now that we
+ // have removed addr2.
+ test.rxf(t, s, e, addr1, snmc, 3)
+
+ if err := s.RemoveAddress(1, addr3); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr3, err)
+ }
+
+ // Should not receive a packet destined to the solicited
+ // node address of addr2/addr3 yet as both of them got
+ // removed.
+ test.rxf(t, s, e, addr1, snmc, 3)
+ })
+ }
+}
+
+// TestAddIpv6Address tests adding IPv6 addresses.
+func TestAddIpv6Address(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ }{
+ // This test is in response to b/140943433.
+ {
+ "Nil",
+ tcpip.Address([]byte(nil)),
+ },
+ {
+ "ValidUnicast",
+ addr1,
+ },
+ {
+ "ValidLinkLocalUnicast",
+ lladdr0,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, ProtocolNumber, test.addr); err != nil {
+ t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err)
+ }
+
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ }
+ if addr.Address != test.addr {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr.Address, test.addr)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
new file mode 100644
index 000000000..e30791fe3
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -0,0 +1,181 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "strings"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+)
+
+// setupStackAndEndpoint creates a stack with a single NIC with a link-local
+// address llladdr and an IPv6 endpoint to a remote with link-local address
+// rlladdr
+func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) {
+ t.Helper()
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ })
+
+ if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+ if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }},
+ )
+ }
+
+ netProto := s.NetworkProtocolInstance(ProtocolNumber)
+ if netProto == nil {
+ t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
+ }
+
+ ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{rlladdr, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil)
+ if err != nil {
+ t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
+ }
+
+ return s, ep
+}
+
+// TestHopLimitValidation is a test that makes sure that NDP packets are only
+// received if their IP header's hop limit is set to 255.
+func TestHopLimitValidation(t *testing.T) {
+ setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) {
+ t.Helper()
+
+ // Create a stack with the assigned link-local address lladdr0
+ // and an endpoint to lladdr1.
+ s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1)
+
+ r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
+ }
+
+ return s, ep, r
+ }
+
+ handleIPv6Payload := func(hdr buffer.Prependable, hopLimit uint8, ep stack.NetworkEndpoint, r *stack.Route) {
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: hopLimit,
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+ ep.HandlePacket(r, hdr.View().ToVectorisedView())
+ }
+
+ types := []struct {
+ name string
+ typ header.ICMPv6Type
+ size int
+ statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ }{
+ {"RouterSolicit", header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterSolicit
+ }},
+ {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterAdvert
+ }},
+ {"NeighborSolicit", header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborSolicit
+ }},
+ {"NeighborAdvert", header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborAdvert
+ }},
+ {"RedirectMsg", header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RedirectMsg
+ }},
+ }
+
+ for _, typ := range types {
+ t.Run(typ.name, func(t *testing.T) {
+ s, ep, r := setup(t)
+ defer r.Release()
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ typStat := typ.statCounter(stats)
+
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size)
+ pkt := header.ICMPv6(hdr.Prepend(typ.size))
+ pkt.SetType(typ.typ)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ // Should not have received any ICMPv6 packets with
+ // type = typ.typ.
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Receive the NDP packet with an invalid hop limit
+ // value.
+ handleIPv6Payload(hdr, ndpHopLimit-1, ep, &r)
+
+ // Invalid count should have increased.
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+
+ // Rx count of NDP packet of type typ.typ should not
+ // have increased.
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Receive the NDP packet with a valid hop limit value.
+ handleIPv6Payload(hdr, ndpHopLimit, ep, &r)
+
+ // Rx count of NDP packet of type typ.typ should have
+ // increased.
+ if got := typStat.Value(); got != 1 {
+ t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ }
+
+ // Invalid count should not have increased again.
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
index 989058413..11efb4e44 100644
--- a/pkg/tcpip/ports/BUILD
+++ b/pkg/tcpip/ports/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index 315780c0c..30cea8996 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -19,6 +19,7 @@ import (
"math"
"math/rand"
"sync"
+ "sync/atomic"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -27,6 +28,10 @@ const (
// FirstEphemeral is the first ephemeral port.
FirstEphemeral = 16000
+ // numEphemeralPorts it the mnumber of available ephemeral ports to
+ // Netstack.
+ numEphemeralPorts = math.MaxUint16 - FirstEphemeral + 1
+
anyIPAddress tcpip.Address = ""
)
@@ -40,6 +45,13 @@ type portDescriptor struct {
type PortManager struct {
mu sync.RWMutex
allocatedPorts map[portDescriptor]bindAddresses
+
+ // hint is used to pick ports ephemeral ports in a stable order for
+ // a given port offset.
+ //
+ // hint must be accessed using the portHint/incPortHint helpers.
+ // TODO(gvisor.dev/issue/940): S/R this field.
+ hint uint32
}
type portNode struct {
@@ -47,43 +59,76 @@ type portNode struct {
refs int
}
-// bindAddresses is a set of IP addresses.
-type bindAddresses map[tcpip.Address]portNode
+// deviceNode is never empty. When it has no elements, it is removed from the
+// map that references it.
+type deviceNode map[tcpip.NICID]portNode
-// isAvailable checks whether an IP address is available to bind to.
-func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool) bool {
- if addr == anyIPAddress {
- if len(b) == 0 {
- return true
- }
+// isAvailable checks whether binding is possible by device. If not binding to a
+// device, check against all portNodes. If binding to a specific device, check
+// against the unspecified device and the provided device.
+func (d deviceNode) isAvailable(reuse bool, bindToDevice tcpip.NICID) bool {
+ if bindToDevice == 0 {
+ // Trying to binding all devices.
if !reuse {
+ // Can't bind because the (addr,port) is already bound.
return false
}
- for _, n := range b {
- if !n.reuse {
+ for _, p := range d {
+ if !p.reuse {
+ // Can't bind because the (addr,port) was previously bound without reuse.
return false
}
}
return true
}
- // If all addresses for this portDescriptor are already bound, no
- // address is available.
- if n, ok := b[anyIPAddress]; ok {
- if !reuse {
+ if p, ok := d[0]; ok {
+ if !reuse || !p.reuse {
return false
}
- if !n.reuse {
+ }
+
+ if p, ok := d[bindToDevice]; ok {
+ if !reuse || !p.reuse {
return false
}
}
- if n, ok := b[addr]; ok {
- if !reuse {
+ return true
+}
+
+// bindAddresses is a set of IP addresses.
+type bindAddresses map[tcpip.Address]deviceNode
+
+// isAvailable checks whether an IP address is available to bind to. If the
+// address is the "any" address, check all other addresses. Otherwise, just
+// check against the "any" address and the provided address.
+func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice tcpip.NICID) bool {
+ if addr == anyIPAddress {
+ // If binding to the "any" address then check that there are no conflicts
+ // with all addresses.
+ for _, d := range b {
+ if !d.isAvailable(reuse, bindToDevice) {
+ return false
+ }
+ }
+ return true
+ }
+
+ // Check that there is no conflict with the "any" address.
+ if d, ok := b[anyIPAddress]; ok {
+ if !d.isAvailable(reuse, bindToDevice) {
return false
}
- return n.reuse
}
+
+ // Check that this is no conflict with the provided address.
+ if d, ok := b[addr]; ok {
+ if !d.isAvailable(reuse, bindToDevice) {
+ return false
+ }
+ }
+
return true
}
@@ -97,11 +142,40 @@ func NewPortManager() *PortManager {
// is suitable for its needs, and stopping when a port is found or an error
// occurs.
func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
- count := uint16(math.MaxUint16 - FirstEphemeral + 1)
- offset := uint16(rand.Int31n(int32(count)))
+ offset := uint32(rand.Int31n(numEphemeralPorts))
+ return s.pickEphemeralPort(offset, numEphemeralPorts, testPort)
+}
+
+// portHint atomically reads and returns the s.hint value.
+func (s *PortManager) portHint() uint32 {
+ return atomic.LoadUint32(&s.hint)
+}
+
+// incPortHint atomically increments s.hint by 1.
+func (s *PortManager) incPortHint() {
+ atomic.AddUint32(&s.hint, 1)
+}
- for i := uint16(0); i < count; i++ {
- port = FirstEphemeral + (offset+i)%count
+// PickEphemeralPortStable starts at the specified offset + s.portHint and
+// iterates over all ephemeral ports, allowing the caller to decide whether a
+// given port is suitable for its needs and stopping when a port is found or an
+// error occurs.
+func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
+ p, err := s.pickEphemeralPort(s.portHint()+offset, numEphemeralPorts, testPort)
+ if err == nil {
+ s.incPortHint()
+ }
+ return p, err
+
+}
+
+// pickEphemeralPort starts at the offset specified from the FirstEphemeral port
+// and iterates over the number of ports specified by count and allows the
+// caller to decide whether a given port is suitable for its needs, and stopping
+// when a port is found or an error occurs.
+func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
+ for i := uint32(0); i < count; i++ {
+ port = uint16(FirstEphemeral + (offset+i)%count)
ok, err := testPort(port)
if err != nil {
return 0, err
@@ -116,17 +190,17 @@ func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Er
}
// IsPortAvailable tests if the given port is available on all given protocols.
-func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
+func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
s.mu.Lock()
defer s.mu.Unlock()
- return s.isPortAvailableLocked(networks, transport, addr, port, reuse)
+ return s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice)
}
-func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
+func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
for _, network := range networks {
desc := portDescriptor{network, transport, port}
if addrs, ok := s.allocatedPorts[desc]; ok {
- if !addrs.isAvailable(addr, reuse) {
+ if !addrs.isAvailable(addr, reuse, bindToDevice) {
return false
}
}
@@ -138,14 +212,14 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb
// reserved by another endpoint. If port is zero, ReservePort will search for
// an unreserved ephemeral port and reserve it, returning its value in the
// "port" return value.
-func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) (reservedPort uint16, err *tcpip.Error) {
+func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) {
s.mu.Lock()
defer s.mu.Unlock()
// If a port is specified, just try to reserve it for all network
// protocols.
if port != 0 {
- if !s.reserveSpecificPort(networks, transport, addr, port, reuse) {
+ if !s.reserveSpecificPort(networks, transport, addr, port, reuse, bindToDevice) {
return 0, tcpip.ErrPortInUse
}
return port, nil
@@ -153,13 +227,13 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp
// A port wasn't specified, so try to find one.
return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
- return s.reserveSpecificPort(networks, transport, addr, p, reuse), nil
+ return s.reserveSpecificPort(networks, transport, addr, p, reuse, bindToDevice), nil
})
}
// reserveSpecificPort tries to reserve the given port on all given protocols.
-func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
- if !s.isPortAvailableLocked(networks, transport, addr, port, reuse) {
+func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
+ if !s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) {
return false
}
@@ -171,11 +245,16 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
m = make(bindAddresses)
s.allocatedPorts[desc] = m
}
- if n, ok := m[addr]; ok {
+ d, ok := m[addr]
+ if !ok {
+ d = make(deviceNode)
+ m[addr] = d
+ }
+ if n, ok := d[bindToDevice]; ok {
n.refs++
- m[addr] = n
+ d[bindToDevice] = n
} else {
- m[addr] = portNode{reuse: reuse, refs: 1}
+ d[bindToDevice] = portNode{reuse: reuse, refs: 1}
}
}
@@ -184,22 +263,28 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
// ReleasePort releases the reservation on a port/IP combination so that it can
// be reserved by other endpoints.
-func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) {
+func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, bindToDevice tcpip.NICID) {
s.mu.Lock()
defer s.mu.Unlock()
for _, network := range networks {
desc := portDescriptor{network, transport, port}
if m, ok := s.allocatedPorts[desc]; ok {
- n, ok := m[addr]
+ d, ok := m[addr]
+ if !ok {
+ continue
+ }
+ n, ok := d[bindToDevice]
if !ok {
continue
}
n.refs--
+ d[bindToDevice] = n
if n.refs == 0 {
+ delete(d, bindToDevice)
+ }
+ if len(d) == 0 {
delete(m, addr)
- } else {
- m[addr] = n
}
if len(m) == 0 {
delete(s.allocatedPorts, desc)
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
index 689401661..19f4833fc 100644
--- a/pkg/tcpip/ports/ports_test.go
+++ b/pkg/tcpip/ports/ports_test.go
@@ -15,6 +15,7 @@
package ports
import (
+ "math/rand"
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -34,6 +35,7 @@ type portReserveTestAction struct {
want *tcpip.Error
reuse bool
release bool
+ device tcpip.NICID
}
func TestPortReservation(t *testing.T) {
@@ -100,6 +102,112 @@ func TestPortReservation(t *testing.T) {
{port: 24, ip: anyIPAddress, release: true},
{port: 24, ip: anyIPAddress, reuse: false, want: nil},
},
+ }, {
+ tname: "bind twice with device fails",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 3, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 3, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind to device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 1, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 2, want: nil},
+ },
+ }, {
+ tname: "bind to device and then without device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind without device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, want: nil},
+ {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with reuse",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
+ },
+ }, {
+ tname: "binding with reuse and device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "mixing reuse and not reuse by binding to device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 999, want: nil},
+ },
+ }, {
+ tname: "can't bind to 0 after mixing reuse and not reuse",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind and release",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+
+ // Release the bind to device 0 and try again.
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil, release: true},
+ {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil},
+ },
+ }, {
+ tname: "bind twice with reuse once",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "release an unreserved device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil},
+ // The below don't exist.
+ {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil, release: true},
+ {port: 9999, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true},
+ // Release all.
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true},
+ {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil, release: true},
+ },
},
} {
t.Run(test.tname, func(t *testing.T) {
@@ -108,12 +216,12 @@ func TestPortReservation(t *testing.T) {
for _, test := range test.actions {
if test.release {
- pm.ReleasePort(net, fakeTransNumber, test.ip, test.port)
+ pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.device)
continue
}
- gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse)
+ gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse, test.device)
if err != test.want {
- t.Fatalf("ReservePort(.., .., %s, %d, %t) = %v, want %v", test.ip, test.port, test.release, err, test.want)
+ t.Fatalf("ReservePort(.., .., %s, %d, %t, %d) = %v, want %v", test.ip, test.port, test.reuse, test.device, err, test.want)
}
if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) {
t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral)
@@ -125,7 +233,6 @@ func TestPortReservation(t *testing.T) {
}
func TestPickEphemeralPort(t *testing.T) {
- pm := NewPortManager()
customErr := &tcpip.Error{}
for _, test := range []struct {
name string
@@ -169,9 +276,63 @@ func TestPickEphemeralPort(t *testing.T) {
},
} {
t.Run(test.name, func(t *testing.T) {
+ pm := NewPortManager()
if port, err := pm.PickEphemeralPort(test.f); port != test.wantPort || err != test.wantErr {
t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr)
}
})
}
}
+
+func TestPickEphemeralPortStable(t *testing.T) {
+ customErr := &tcpip.Error{}
+ for _, test := range []struct {
+ name string
+ f func(port uint16) (bool, *tcpip.Error)
+ wantErr *tcpip.Error
+ wantPort uint16
+ }{
+ {
+ name: "no-port-available",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ return false, nil
+ },
+ wantErr: tcpip.ErrNoPortAvailable,
+ },
+ {
+ name: "port-tester-error",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ return false, customErr
+ },
+ wantErr: customErr,
+ },
+ {
+ name: "only-port-16042-available",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ if port == FirstEphemeral+42 {
+ return true, nil
+ }
+ return false, nil
+ },
+ wantPort: FirstEphemeral + 42,
+ },
+ {
+ name: "only-port-under-16000-available",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ if port < FirstEphemeral {
+ return true, nil
+ }
+ return false, nil
+ },
+ wantErr: tcpip.ErrNoPortAvailable,
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ pm := NewPortManager()
+ portOffset := uint32(rand.Int31n(int32(numEphemeralPorts)))
+ if port, err := pm.PickEphemeralPortStable(portOffset, test.f); port != test.wantPort || err != test.wantErr {
+ t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index e2021cd15..2239c1e66 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -126,7 +126,10 @@ func main() {
// Create the stack with ipv4 and tcp protocols, then add a tun-based
// NIC and ipv4 address.
- s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
mtu, err := rawfile.GetMTU(tunName)
if err != nil {
@@ -138,11 +141,11 @@ func main() {
log.Fatal(err)
}
- linkID, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu})
+ linkEP, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu})
if err != nil {
log.Fatal(err)
}
- if err := s.CreateNIC(1, sniffer.New(linkID)); err != nil {
+ if err := s.CreateNIC(1, sniffer.New(linkEP)); err != nil {
log.Fatal(err)
}
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index 1716be285..bca73cbb1 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -111,7 +111,10 @@ func main() {
// Create the stack with ip and tcp protocols, then add a tun-based
// NIC and address.
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName, arp.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
mtu, err := rawfile.GetMTU(tunName)
if err != nil {
@@ -128,7 +131,7 @@ func main() {
log.Fatal(err)
}
- linkID, err := fdbased.New(&fdbased.Options{
+ linkEP, err := fdbased.New(&fdbased.Options{
FDs: []int{fd},
MTU: mtu,
EthernetHeader: *tap,
@@ -137,7 +140,7 @@ func main() {
if err != nil {
log.Fatal(err)
}
- if err := s.CreateNIC(1, linkID); err != nil {
+ if err := s.CreateNIC(1, linkEP); err != nil {
log.Fatal(err)
}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 9986b4be3..baf88bfab 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -1,11 +1,28 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+load("//tools/go_stateify:defs.bzl", "go_library")
+
+go_template_instance(
+ name = "linkaddrentry_list",
+ out = "linkaddrentry_list.go",
+ package = "stack",
+ prefix = "linkAddrEntry",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*linkAddrEntry",
+ "Linker": "*linkAddrEntry",
+ },
+)
go_library(
name = "stack",
srcs = [
+ "icmp_rate_limit.go",
"linkaddrcache.go",
+ "linkaddrentry_list.go",
"nic.go",
"registration.go",
"route.go",
@@ -19,6 +36,7 @@ go_library(
],
deps = [
"//pkg/ilist",
+ "//pkg/rand",
"//pkg/sleep",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
@@ -28,6 +46,7 @@ go_library(
"//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/waiter",
+ "@org_golang_x_time//rate:go_default_library",
],
)
@@ -36,6 +55,7 @@ go_test(
size = "small",
srcs = [
"stack_test.go",
+ "transport_demuxer_test.go",
"transport_test.go",
],
deps = [
@@ -46,6 +66,9 @@ go_test(
"//pkg/tcpip/iptables",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/transport/udp",
"//pkg/waiter",
],
)
@@ -60,3 +83,11 @@ go_test(
"//pkg/tcpip",
],
)
+
+filegroup(
+ name = "autogen",
+ srcs = [
+ "linkaddrentry_list.go",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/tcpip/stack/icmp_rate_limit.go b/pkg/tcpip/stack/icmp_rate_limit.go
new file mode 100644
index 000000000..3a20839da
--- /dev/null
+++ b/pkg/tcpip/stack/icmp_rate_limit.go
@@ -0,0 +1,41 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "golang.org/x/time/rate"
+)
+
+const (
+ // icmpLimit is the default maximum number of ICMP messages permitted by this
+ // rate limiter.
+ icmpLimit = 1000
+
+ // icmpBurst is the default number of ICMP messages that can be sent in a single
+ // burst.
+ icmpBurst = 50
+)
+
+// ICMPRateLimiter is a global rate limiter that controls the generation of
+// ICMP messages generated by the stack.
+type ICMPRateLimiter struct {
+ *rate.Limiter
+}
+
+// NewICMPRateLimiter returns a global rate limiter for controlling the rate
+// at which ICMP messages are generated by the stack.
+func NewICMPRateLimiter() *ICMPRateLimiter {
+ return &ICMPRateLimiter{Limiter: rate.NewLimiter(icmpLimit, icmpBurst)}
+}
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index 77bb0ccb9..267df60d1 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -42,10 +42,11 @@ type linkAddrCache struct {
// resolved before failing.
resolutionAttempts int
- mu sync.Mutex
- cache map[tcpip.FullAddress]*linkAddrEntry
- next int // array index of next available entry
- entries [linkAddrCacheSize]linkAddrEntry
+ cache struct {
+ sync.Mutex
+ table map[tcpip.FullAddress]*linkAddrEntry
+ lru linkAddrEntryList
+ }
}
// entryState controls the state of a single entry in the cache.
@@ -60,9 +61,6 @@ const (
// failed means that address resolution timed out and the address
// could not be resolved.
failed
- // expired means that the cache entry has expired and the address must be
- // resolved again.
- expired
)
// String implements Stringer.
@@ -74,8 +72,6 @@ func (s entryState) String() string {
return "ready"
case failed:
return "failed"
- case expired:
- return "expired"
default:
return fmt.Sprintf("unknown(%d)", s)
}
@@ -84,64 +80,46 @@ func (s entryState) String() string {
// A linkAddrEntry is an entry in the linkAddrCache.
// This struct is thread-compatible.
type linkAddrEntry struct {
+ linkAddrEntryEntry
+
addr tcpip.FullAddress
linkAddr tcpip.LinkAddress
expiration time.Time
s entryState
// wakers is a set of waiters for address resolution result. Anytime
- // state transitions out of 'incomplete' these waiters are notified.
+ // state transitions out of incomplete these waiters are notified.
wakers map[*sleep.Waker]struct{}
+ // done is used to allow callers to wait on address resolution. It is nil iff
+ // s is incomplete and resolution is not yet in progress.
done chan struct{}
}
-func (e *linkAddrEntry) state() entryState {
- if e.s != expired && time.Now().After(e.expiration) {
- // Force the transition to ensure waiters are notified.
- e.changeState(expired)
- }
- return e.s
-}
-
-func (e *linkAddrEntry) changeState(ns entryState) {
- if e.s == ns {
- return
- }
-
- // Validate state transition.
- switch e.s {
- case incomplete:
- // All transitions are valid.
- case ready, failed:
- if ns != expired {
- panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns))
- }
- case expired:
- // Terminal state.
- panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns))
- default:
- panic(fmt.Sprintf("invalid state: %s", e.s))
- }
-
+// changeState sets the entry's state to ns, notifying any waiters.
+//
+// The entry's expiration is bumped up to the greater of itself and the passed
+// expiration; the zero value indicates immediate expiration, and is set
+// unconditionally - this is an implementation detail that allows for entries
+// to be reused.
+func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) {
// Notify whoever is waiting on address resolution when transitioning
- // out of 'incomplete'.
- if e.s == incomplete {
+ // out of incomplete.
+ if e.s == incomplete && ns != incomplete {
for w := range e.wakers {
w.Assert()
}
e.wakers = nil
- if e.done != nil {
- close(e.done)
+ if ch := e.done; ch != nil {
+ close(ch)
}
+ e.done = nil
}
- e.s = ns
-}
-func (e *linkAddrEntry) maybeAddWaker(w *sleep.Waker) {
- if w != nil {
- e.wakers[w] = struct{}{}
+ if expiration.IsZero() || expiration.After(e.expiration) {
+ e.expiration = expiration
}
+ e.s = ns
}
func (e *linkAddrEntry) removeWaker(w *sleep.Waker) {
@@ -150,53 +128,54 @@ func (e *linkAddrEntry) removeWaker(w *sleep.Waker) {
// add adds a k -> v mapping to the cache.
func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
- c.mu.Lock()
- defer c.mu.Unlock()
-
- entry, ok := c.cache[k]
- if ok {
- s := entry.state()
- if s != expired && entry.linkAddr == v {
- // Disregard repeated calls.
- return
- }
- // Check if entry is waiting for address resolution.
- if s == incomplete {
- entry.linkAddr = v
- } else {
- // Otherwise create a new entry to replace it.
- entry = c.makeAndAddEntry(k, v)
- }
- } else {
- entry = c.makeAndAddEntry(k, v)
- }
+ // Calculate expiration time before acquiring the lock, since expiration is
+ // relative to the time when information was learned, rather than when it
+ // happened to be inserted into the cache.
+ expiration := time.Now().Add(c.ageLimit)
- entry.changeState(ready)
+ c.cache.Lock()
+ entry := c.getOrCreateEntryLocked(k)
+ entry.linkAddr = v
+
+ entry.changeState(ready, expiration)
+ c.cache.Unlock()
}
-// makeAndAddEntry is a helper function to create and add a new
-// entry to the cache map and evict older entry as needed.
-func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress) *linkAddrEntry {
- // Take over the next entry.
- entry := &c.entries[c.next]
- if c.cache[entry.addr] == entry {
- delete(c.cache, entry.addr)
+// getOrCreateEntryLocked retrieves a cache entry associated with k. The
+// returned entry is always refreshed in the cache (it is reachable via the
+// map, and its place is bumped in LRU).
+//
+// If a matching entry exists in the cache, it is returned. If no matching
+// entry exists and the cache is full, an existing entry is evicted via LRU,
+// reset to state incomplete, and returned. If no matching entry exists and the
+// cache is not full, a new entry with state incomplete is allocated and
+// returned.
+func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEntry {
+ if entry, ok := c.cache.table[k]; ok {
+ c.cache.lru.Remove(entry)
+ c.cache.lru.PushFront(entry)
+ return entry
}
+ var entry *linkAddrEntry
+ if len(c.cache.table) == linkAddrCacheSize {
+ entry = c.cache.lru.Back()
- // Mark the soon-to-be-replaced entry as expired, just in case there is
- // someone waiting for address resolution on it.
- entry.changeState(expired)
+ delete(c.cache.table, entry.addr)
+ c.cache.lru.Remove(entry)
- *entry = linkAddrEntry{
- addr: k,
- linkAddr: v,
- expiration: time.Now().Add(c.ageLimit),
- wakers: make(map[*sleep.Waker]struct{}),
- done: make(chan struct{}),
+ // Wake waiters and mark the soon-to-be-reused entry as expired. Note
+ // that the state passed doesn't matter when the zero time is passed.
+ entry.changeState(failed, time.Time{})
+ } else {
+ entry = new(linkAddrEntry)
}
- c.cache[k] = entry
- c.next = (c.next + 1) % len(c.entries)
+ *entry = linkAddrEntry{
+ addr: k,
+ s: incomplete,
+ }
+ c.cache.table[k] = entry
+ c.cache.lru.PushFront(entry)
return entry
}
@@ -208,43 +187,55 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo
}
}
- c.mu.Lock()
- defer c.mu.Unlock()
- if entry, ok := c.cache[k]; ok {
- switch s := entry.state(); s {
- case expired:
- case ready:
- return entry.linkAddr, nil, nil
- case failed:
- return "", nil, tcpip.ErrNoLinkAddress
- case incomplete:
- // Address resolution is still in progress.
- entry.maybeAddWaker(waker)
- return "", entry.done, tcpip.ErrWouldBlock
- default:
- panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ c.cache.Lock()
+ defer c.cache.Unlock()
+ entry := c.getOrCreateEntryLocked(k)
+ switch s := entry.s; s {
+ case ready, failed:
+ if !time.Now().After(entry.expiration) {
+ // Not expired.
+ switch s {
+ case ready:
+ return entry.linkAddr, nil, nil
+ case failed:
+ return entry.linkAddr, nil, tcpip.ErrNoLinkAddress
+ default:
+ panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ }
}
- }
- if linkRes == nil {
- return "", nil, tcpip.ErrNoLinkAddress
- }
+ entry.changeState(incomplete, time.Time{})
+ fallthrough
+ case incomplete:
+ if waker != nil {
+ if entry.wakers == nil {
+ entry.wakers = make(map[*sleep.Waker]struct{})
+ }
+ entry.wakers[waker] = struct{}{}
+ }
- // Add 'incomplete' entry in the cache to mark that resolution is in progress.
- e := c.makeAndAddEntry(k, "")
- e.maybeAddWaker(waker)
+ if entry.done == nil {
+ // Address resolution needs to be initiated.
+ if linkRes == nil {
+ return entry.linkAddr, nil, tcpip.ErrNoLinkAddress
+ }
- go c.startAddressResolution(k, linkRes, localAddr, linkEP, e.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
+ entry.done = make(chan struct{})
+ go c.startAddressResolution(k, linkRes, localAddr, linkEP, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
+ }
- return "", e.done, tcpip.ErrWouldBlock
+ return entry.linkAddr, entry.done, tcpip.ErrWouldBlock
+ default:
+ panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ }
}
// removeWaker removes a waker previously added through get().
func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) {
- c.mu.Lock()
- defer c.mu.Unlock()
+ c.cache.Lock()
+ defer c.cache.Unlock()
- if entry, ok := c.cache[k]; ok {
+ if entry, ok := c.cache.table[k]; ok {
entry.removeWaker(waker)
}
}
@@ -256,8 +247,8 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP)
select {
- case <-time.After(c.resolutionTimeout):
- if stop := c.checkLinkRequest(k, i); stop {
+ case now := <-time.After(c.resolutionTimeout):
+ if stop := c.checkLinkRequest(now, k, i); stop {
return
}
case <-done:
@@ -269,38 +260,36 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
// checkLinkRequest checks whether previous attempt to resolve address has succeeded
// and mark the entry accordingly, e.g. ready, failed, etc. Return true if request
// can stop, false if another request should be sent.
-func (c *linkAddrCache) checkLinkRequest(k tcpip.FullAddress, attempt int) bool {
- c.mu.Lock()
- defer c.mu.Unlock()
-
- entry, ok := c.cache[k]
+func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool {
+ c.cache.Lock()
+ defer c.cache.Unlock()
+ entry, ok := c.cache.table[k]
if !ok {
// Entry was evicted from the cache.
return true
}
-
- switch s := entry.state(); s {
- case ready, failed, expired:
+ switch s := entry.s; s {
+ case ready, failed:
// Entry was made ready by resolver or failed. Either way we're done.
- return true
case incomplete:
- if attempt+1 >= c.resolutionAttempts {
- // Max number of retries reached, mark entry as failed.
- entry.changeState(failed)
- return true
+ if attempt+1 < c.resolutionAttempts {
+ // No response yet, need to send another ARP request.
+ return false
}
- // No response yet, need to send another ARP request.
- return false
+ // Max number of retries reached, mark entry as failed.
+ entry.changeState(failed, now.Add(c.ageLimit))
default:
panic(fmt.Sprintf("invalid cache entry state: %s", s))
}
+ return true
}
func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache {
- return &linkAddrCache{
+ c := &linkAddrCache{
ageLimit: ageLimit,
resolutionTimeout: resolutionTimeout,
resolutionAttempts: resolutionAttempts,
- cache: make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize),
}
+ c.cache.table = make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize)
+ return c
}
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index 924f4d240..9946b8fe8 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -17,6 +17,7 @@ package stack
import (
"fmt"
"sync"
+ "sync/atomic"
"testing"
"time"
@@ -29,25 +30,34 @@ type testaddr struct {
linkAddr tcpip.LinkAddress
}
-var testaddrs []testaddr
+var testAddrs = func() []testaddr {
+ var addrs []testaddr
+ for i := 0; i < 4*linkAddrCacheSize; i++ {
+ addr := fmt.Sprintf("Addr%06d", i)
+ addrs = append(addrs, testaddr{
+ addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)},
+ linkAddr: tcpip.LinkAddress("Link" + addr),
+ })
+ }
+ return addrs
+}()
type testLinkAddressResolver struct {
- cache *linkAddrCache
- delay time.Duration
+ cache *linkAddrCache
+ delay time.Duration
+ onLinkAddressRequest func()
}
func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
- go func() {
- if r.delay > 0 {
- time.Sleep(r.delay)
- }
- r.fakeRequest(addr)
- }()
+ time.AfterFunc(r.delay, func() { r.fakeRequest(addr) })
+ if f := r.onLinkAddressRequest; f != nil {
+ f()
+ }
return nil
}
func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) {
- for _, ta := range testaddrs {
+ for _, ta := range testAddrs {
if ta.addr.Addr == addr {
r.cache.add(ta.addr, ta.linkAddr)
break
@@ -80,20 +90,10 @@ func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressRe
}
}
-func init() {
- for i := 0; i < 4*linkAddrCacheSize; i++ {
- addr := fmt.Sprintf("Addr%06d", i)
- testaddrs = append(testaddrs, testaddr{
- addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)},
- linkAddr: tcpip.LinkAddress("Link" + addr),
- })
- }
-}
-
func TestCacheOverflow(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
- for i := len(testaddrs) - 1; i >= 0; i-- {
- e := testaddrs[i]
+ for i := len(testAddrs) - 1; i >= 0; i-- {
+ e := testAddrs[i]
c.add(e.addr, e.linkAddr)
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
@@ -105,7 +105,7 @@ func TestCacheOverflow(t *testing.T) {
}
// Expect to find at least half of the most recent entries.
for i := 0; i < linkAddrCacheSize/2; i++ {
- e := testaddrs[i]
+ e := testAddrs[i]
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
@@ -115,8 +115,8 @@ func TestCacheOverflow(t *testing.T) {
}
}
// The earliest entries should no longer be in the cache.
- for i := len(testaddrs) - 1; i >= len(testaddrs)-linkAddrCacheSize; i-- {
- e := testaddrs[i]
+ for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- {
+ e := testAddrs[i]
if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err)
}
@@ -130,7 +130,7 @@ func TestCacheConcurrent(t *testing.T) {
for r := 0; r < 16; r++ {
wg.Add(1)
go func() {
- for _, e := range testaddrs {
+ for _, e := range testAddrs {
c.add(e.addr, e.linkAddr)
c.get(e.addr, nil, "", nil, nil) // make work for gotsan
}
@@ -142,7 +142,7 @@ func TestCacheConcurrent(t *testing.T) {
// All goroutines add in the same order and add more values than
// can fit in the cache, so our eviction strategy requires that
// the last entry be present and the first be missing.
- e := testaddrs[len(testaddrs)-1]
+ e := testAddrs[len(testAddrs)-1]
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
@@ -151,7 +151,7 @@ func TestCacheConcurrent(t *testing.T) {
t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
}
- e = testaddrs[0]
+ e = testAddrs[0]
if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
}
@@ -159,7 +159,7 @@ func TestCacheConcurrent(t *testing.T) {
func TestCacheAgeLimit(t *testing.T) {
c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3)
- e := testaddrs[0]
+ e := testAddrs[0]
c.add(e.addr, e.linkAddr)
time.Sleep(50 * time.Millisecond)
if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
@@ -169,7 +169,7 @@ func TestCacheAgeLimit(t *testing.T) {
func TestCacheReplace(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
- e := testaddrs[0]
+ e := testAddrs[0]
l2 := e.linkAddr + "2"
c.add(e.addr, e.linkAddr)
got, _, err := c.get(e.addr, nil, "", nil, nil)
@@ -193,7 +193,7 @@ func TestCacheReplace(t *testing.T) {
func TestCacheResolution(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1)
linkRes := &testLinkAddressResolver{cache: c}
- for i, ta := range testaddrs {
+ for i, ta := range testAddrs {
got, err := getBlocking(c, ta.addr, linkRes)
if err != nil {
t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err)
@@ -205,7 +205,7 @@ func TestCacheResolution(t *testing.T) {
// Check that after resolved, address stays in the cache and never returns WouldBlock.
for i := 0; i < 10; i++ {
- e := testaddrs[len(testaddrs)-1]
+ e := testAddrs[len(testAddrs)-1]
got, _, err := c.get(e.addr, linkRes, "", nil, nil)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
@@ -220,8 +220,13 @@ func TestCacheResolutionFailed(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5)
linkRes := &testLinkAddressResolver{cache: c}
+ var requestCount uint32
+ linkRes.onLinkAddressRequest = func() {
+ atomic.AddUint32(&requestCount, 1)
+ }
+
// First, sanity check that resolution is working...
- e := testaddrs[0]
+ e := testAddrs[0]
got, err := getBlocking(c, e.addr, linkRes)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
@@ -230,10 +235,16 @@ func TestCacheResolutionFailed(t *testing.T) {
t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
}
+ before := atomic.LoadUint32(&requestCount)
+
e.addr.Addr += "2"
if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
}
+
+ if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want {
+ t.Errorf("got link address request count = %d, want = %d", got, want)
+ }
}
func TestCacheResolutionTimeout(t *testing.T) {
@@ -242,7 +253,7 @@ func TestCacheResolutionTimeout(t *testing.T) {
c := newLinkAddrCache(expiration, 1*time.Millisecond, 3)
linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay}
- e := testaddrs[0]
+ e := testAddrs[0]
if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 4ef85bdfb..f6106f762 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -34,15 +34,13 @@ type NIC struct {
linkEP LinkEndpoint
loopback bool
- demux *transportDemuxer
-
- mu sync.RWMutex
- spoofing bool
- promiscuous bool
- primary map[tcpip.NetworkProtocolNumber]*ilist.List
- endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
- subnets []tcpip.Subnet
- mcastJoins map[NetworkEndpointID]int32
+ mu sync.RWMutex
+ spoofing bool
+ promiscuous bool
+ primary map[tcpip.NetworkProtocolNumber]*ilist.List
+ endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
+ addressRanges []tcpip.Subnet
+ mcastJoins map[NetworkEndpointID]int32
stats NICStats
}
@@ -85,7 +83,6 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback
name: name,
linkEP: ep,
loopback: loopback,
- demux: newTransportDemuxer(stack),
primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
mcastJoins: make(map[NetworkEndpointID]int32),
@@ -102,6 +99,25 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback
}
}
+// enable enables the NIC. enable will attach the link to its LinkEndpoint and
+// join the IPv6 All-Nodes Multicast address (ff02::1).
+func (n *NIC) enable() *tcpip.Error {
+ n.attachLinkEndpoint()
+
+ // Join the IPv6 All-Nodes Multicast group if the stack is configured to
+ // use IPv6. This is required to ensure that this node properly receives
+ // and responds to the various NDP messages that are destined to the
+ // all-nodes multicast address. An example is the Neighbor Advertisement
+ // when we perform Duplicate Address Detection, or Router Advertisement
+ // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861
+ // section 4.2 for more information.
+ if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok {
+ return n.joinGroup(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress)
+ }
+
+ return nil
+}
+
// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it
// to start delivering packets.
func (n *NIC) attachLinkEndpoint() {
@@ -129,37 +145,6 @@ func (n *NIC) setSpoofing(enable bool) {
n.mu.Unlock()
}
-func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) {
- n.mu.RLock()
- defer n.mu.RUnlock()
-
- var r *referencedNetworkEndpoint
-
- // Check for a primary endpoint.
- if list, ok := n.primary[protocol]; ok {
- for e := list.Front(); e != nil; e = e.Next() {
- ref := e.(*referencedNetworkEndpoint)
- if ref.holdsInsertRef && ref.tryIncRef() {
- r = ref
- break
- }
- }
-
- }
-
- if r == nil {
- return tcpip.AddressWithPrefix{}, tcpip.ErrNoLinkAddress
- }
-
- addressWithPrefix := tcpip.AddressWithPrefix{
- Address: r.ep.ID().LocalAddress,
- PrefixLen: r.ep.PrefixLen(),
- }
- r.decRef()
-
- return addressWithPrefix, nil
-}
-
// primaryEndpoint returns the primary endpoint of n for the given network
// protocol.
func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint {
@@ -178,7 +163,7 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN
case header.IPv4Broadcast, header.IPv4Any:
continue
}
- if r.tryIncRef() {
+ if r.isValidForOutgoing() && r.tryIncRef() {
return r
}
}
@@ -197,22 +182,44 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A
// getRefEpOrCreateTemp returns the referenced network endpoint for the given
// protocol and address. If none exists a temporary one may be created if
-// requested.
-func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, allowTemp bool) *referencedNetworkEndpoint {
+// we are in promiscuous mode or spoofing.
+func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, spoofingOrPromiscuous bool) *referencedNetworkEndpoint {
id := NetworkEndpointID{address}
n.mu.RLock()
- if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
- n.mu.RUnlock()
- return ref
+ if ref, ok := n.endpoints[id]; ok {
+ // An endpoint with this id exists, check if it can be used and return it.
+ switch ref.getKind() {
+ case permanentExpired:
+ if !spoofingOrPromiscuous {
+ n.mu.RUnlock()
+ return nil
+ }
+ fallthrough
+ case temporary, permanent:
+ if ref.tryIncRef() {
+ n.mu.RUnlock()
+ return ref
+ }
+ }
}
- // The address was not found, create a temporary one if requested by the
- // caller or if the address is found in the NIC's subnets.
- createTempEP := allowTemp
+ // A usable reference was not found, create a temporary one if requested by
+ // the caller or if the address is found in the NIC's subnets.
+ createTempEP := spoofingOrPromiscuous
if !createTempEP {
- for _, sn := range n.subnets {
+ for _, sn := range n.addressRanges {
+ // Skip the subnet address.
+ if address == sn.ID() {
+ continue
+ }
+ // For now just skip the broadcast address, until we support it.
+ // FIXME(b/137608825): Add support for sending/receiving directed
+ // (subnet) broadcast.
+ if address == sn.Broadcast() {
+ continue
+ }
if sn.Contains(address) {
createTempEP = true
break
@@ -230,34 +237,70 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
// endpoint, create a new "temporary" endpoint. It will only exist while
// there's a route through it.
n.mu.Lock()
- if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
- n.mu.Unlock()
- return ref
+ if ref, ok := n.endpoints[id]; ok {
+ // No need to check the type as we are ok with expired endpoints at this
+ // point.
+ if ref.tryIncRef() {
+ n.mu.Unlock()
+ return ref
+ }
+ // tryIncRef failing means the endpoint is scheduled to be removed once the
+ // lock is released. Remove it here so we can create a new (temporary) one.
+ // The removal logic waiting for the lock handles this case.
+ n.removeEndpointLocked(ref)
}
+ // Add a new temporary endpoint.
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
n.mu.Unlock()
return nil
}
-
ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{
Protocol: protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: address,
PrefixLen: netProto.DefaultPrefixLen(),
},
- }, peb, true)
-
- if ref != nil {
- ref.holdsInsertRef = false
- }
+ }, peb, temporary)
n.mu.Unlock()
return ref
}
-func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) {
+func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) (*referencedNetworkEndpoint, *tcpip.Error) {
+ id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address}
+ if ref, ok := n.endpoints[id]; ok {
+ switch ref.getKind() {
+ case permanent:
+ // The NIC already have a permanent endpoint with that address.
+ return nil, tcpip.ErrDuplicateAddress
+ case permanentExpired, temporary:
+ // Promote the endpoint to become permanent.
+ if ref.tryIncRef() {
+ ref.setKind(permanent)
+ return ref, nil
+ }
+ // tryIncRef failing means the endpoint is scheduled to be removed once
+ // the lock is released. Remove it here so we can create a new
+ // (permanent) one. The removal logic waiting for the lock handles this
+ // case.
+ n.removeEndpointLocked(ref)
+ }
+ }
+ return n.addAddressLocked(protocolAddress, peb, permanent)
+}
+
+func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) {
+ // TODO(b/141022673): Validate IP address before adding them.
+
+ // Sanity check.
+ id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address}
+ if _, ok := n.endpoints[id]; ok {
+ // Endpoint already exists.
+ return nil, tcpip.ErrDuplicateAddress
+ }
+
netProto, ok := n.stack.networkProtocols[protocolAddress.Protocol]
if !ok {
return nil, tcpip.ErrUnknownProtocol
@@ -268,22 +311,12 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
if err != nil {
return nil, err
}
-
- id := *ep.ID()
- if ref, ok := n.endpoints[id]; ok {
- if !replace {
- return nil, tcpip.ErrDuplicateAddress
- }
-
- n.removeEndpointLocked(ref)
- }
-
ref := &referencedNetworkEndpoint{
- refs: 1,
- ep: ep,
- nic: n,
- protocol: protocolAddress.Protocol,
- holdsInsertRef: true,
+ refs: 1,
+ ep: ep,
+ nic: n,
+ protocol: protocolAddress.Protocol,
+ kind: kind,
}
// Set up cache if link address resolution exists for this protocol.
@@ -293,6 +326,15 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
}
}
+ // If we are adding an IPv6 unicast address, join the solicited-node
+ // multicast address.
+ if protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) {
+ snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address)
+ if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil {
+ return nil, err
+ }
+ }
+
n.endpoints[id] = ref
l, ok := n.primary[protocolAddress.Protocol]
@@ -316,18 +358,26 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error {
// Add the endpoint.
n.mu.Lock()
- _, err := n.addAddressLocked(protocolAddress, peb, false)
+ _, err := n.addPermanentAddressLocked(protocolAddress, peb)
n.mu.Unlock()
return err
}
-// Addresses returns the addresses associated with this NIC.
-func (n *NIC) Addresses() []tcpip.ProtocolAddress {
+// AllAddresses returns all addresses (primary and non-primary) associated with
+// this NIC.
+func (n *NIC) AllAddresses() []tcpip.ProtocolAddress {
n.mu.RLock()
defer n.mu.RUnlock()
+
addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints))
for nid, ref := range n.endpoints {
+ // Don't include expired or tempory endpoints to avoid confusion and
+ // prevent the caller from using those.
+ switch ref.getKind() {
+ case permanentExpired, temporary:
+ continue
+ }
addrs = append(addrs, tcpip.ProtocolAddress{
Protocol: ref.protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
@@ -339,45 +389,66 @@ func (n *NIC) Addresses() []tcpip.ProtocolAddress {
return addrs
}
-// AddSubnet adds a new subnet to n, so that it starts accepting packets
-// targeted at the given address and network protocol.
-func (n *NIC) AddSubnet(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) {
+// PrimaryAddresses returns the primary addresses associated with this NIC.
+func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ var addrs []tcpip.ProtocolAddress
+ for proto, list := range n.primary {
+ for e := list.Front(); e != nil; e = e.Next() {
+ ref := e.(*referencedNetworkEndpoint)
+ // Don't include expired or tempory endpoints to avoid confusion and
+ // prevent the caller from using those.
+ switch ref.getKind() {
+ case permanentExpired, temporary:
+ continue
+ }
+
+ addrs = append(addrs, tcpip.ProtocolAddress{
+ Protocol: proto,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: ref.ep.ID().LocalAddress,
+ PrefixLen: ref.ep.PrefixLen(),
+ },
+ })
+ }
+ }
+ return addrs
+}
+
+// AddAddressRange adds a range of addresses to n, so that it starts accepting
+// packets targeted at the given addresses and network protocol. The range is
+// given by a subnet address, and all addresses contained in the subnet are
+// used except for the subnet address itself and the subnet's broadcast
+// address.
+func (n *NIC) AddAddressRange(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) {
n.mu.Lock()
- n.subnets = append(n.subnets, subnet)
+ n.addressRanges = append(n.addressRanges, subnet)
n.mu.Unlock()
}
-// RemoveSubnet removes the given subnet from n.
-func (n *NIC) RemoveSubnet(subnet tcpip.Subnet) {
+// RemoveAddressRange removes the given address range from n.
+func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) {
n.mu.Lock()
// Use the same underlying array.
- tmp := n.subnets[:0]
- for _, sub := range n.subnets {
+ tmp := n.addressRanges[:0]
+ for _, sub := range n.addressRanges {
if sub != subnet {
tmp = append(tmp, sub)
}
}
- n.subnets = tmp
+ n.addressRanges = tmp
n.mu.Unlock()
}
-// ContainsSubnet reports whether this NIC contains the given subnet.
-func (n *NIC) ContainsSubnet(subnet tcpip.Subnet) bool {
- for _, s := range n.Subnets() {
- if s == subnet {
- return true
- }
- }
- return false
-}
-
// Subnets returns the Subnets associated with this NIC.
-func (n *NIC) Subnets() []tcpip.Subnet {
+func (n *NIC) AddressRanges() []tcpip.Subnet {
n.mu.RLock()
defer n.mu.RUnlock()
- sns := make([]tcpip.Subnet, 0, len(n.subnets)+len(n.endpoints))
+ sns := make([]tcpip.Subnet, 0, len(n.addressRanges)+len(n.endpoints))
for nid := range n.endpoints {
sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress))))
if err != nil {
@@ -387,19 +458,22 @@ func (n *NIC) Subnets() []tcpip.Subnet {
}
sns = append(sns, sn)
}
- return append(sns, n.subnets...)
+ return append(sns, n.addressRanges...)
}
func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
id := *r.ep.ID()
- // Nothing to do if the reference has already been replaced with a
- // different one.
+ // Nothing to do if the reference has already been replaced with a different
+ // one. This happens in the case where 1) this endpoint's ref count hit zero
+ // and was waiting (on the lock) to be removed and 2) the same address was
+ // re-added in the meantime by removing this endpoint from the list and
+ // adding a new one.
if n.endpoints[id] != r {
return
}
- if r.holdsInsertRef {
+ if r.getKind() == permanent {
panic("Reference count dropped to zero before being removed")
}
@@ -418,15 +492,28 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
n.mu.Unlock()
}
-func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error {
- r := n.endpoints[NetworkEndpointID{addr}]
- if r == nil || !r.holdsInsertRef {
+func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
+ r, ok := n.endpoints[NetworkEndpointID{addr}]
+ if !ok || r.getKind() != permanent {
return tcpip.ErrBadLocalAddress
}
- r.holdsInsertRef = false
+ r.setKind(permanentExpired)
+ if !r.decRefLocked() {
+ // The endpoint still has references to it.
+ return nil
+ }
- r.decRefLocked()
+ // At this point the endpoint is deleted.
+
+ // If we are removing an IPv6 unicast address, leave the solicited-node
+ // multicast address.
+ if r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) {
+ snmc := header.SolicitedNodeAddr(addr)
+ if err := n.leaveGroupLocked(snmc); err != nil {
+ return err
+ }
+ }
return nil
}
@@ -435,7 +522,7 @@ func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error {
func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
- return n.removeAddressLocked(addr)
+ return n.removePermanentAddressLocked(addr)
}
// joinGroup adds a new endpoint for the given multicast address, if none
@@ -444,6 +531,13 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
n.mu.Lock()
defer n.mu.Unlock()
+ return n.joinGroupLocked(protocol, addr)
+}
+
+// joinGroupLocked adds a new endpoint for the given multicast address, if none
+// exists yet. Otherwise it just increments its count. n MUST be locked before
+// joinGroupLocked is called.
+func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
id := NetworkEndpointID{addr}
joins := n.mcastJoins[id]
if joins == 0 {
@@ -451,13 +545,13 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
if !ok {
return tcpip.ErrUnknownProtocol
}
- if _, err := n.addAddressLocked(tcpip.ProtocolAddress{
+ if _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{
Protocol: protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: addr,
PrefixLen: netProto.DefaultPrefixLen(),
},
- }, NeverPrimaryEndpoint, false); err != nil {
+ }, NeverPrimaryEndpoint); err != nil {
return err
}
}
@@ -471,6 +565,13 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
+ return n.leaveGroupLocked(addr)
+}
+
+// leaveGroupLocked decrements the count for the given multicast address, and
+// when it reaches zero removes the endpoint for this address. n MUST be locked
+// before leaveGroupLocked is called.
+func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
id := NetworkEndpointID{addr}
joins := n.mcastJoins[id]
switch joins {
@@ -479,7 +580,7 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
return tcpip.ErrBadLocalAddress
case 1:
// This is the last one, clean up.
- if err := n.removeAddressLocked(addr); err != nil {
+ if err := n.removePermanentAddressLocked(addr); err != nil {
return err
}
}
@@ -487,6 +588,13 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
return nil
}
+func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, vv buffer.VectorisedView) {
+ r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */)
+ r.RemoteLinkAddress = remotelinkAddr
+ ref.ep.HandlePacket(&r, vv)
+ ref.decRef()
+}
+
// DeliverNetworkPacket finds the appropriate network protocol endpoint and
// hands the packet over for further processing. This function is called when
// the NIC receives a packet from the physical interface.
@@ -514,6 +622,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
src, dst := netProto.ParseAddresses(vv.First())
+ n.stack.AddLinkAddress(n.id, src, remote)
+
// If the packet is destined to the IPv4 Broadcast address, then make a
// route to each IPv4 network endpoint and let each endpoint handle the
// packet.
@@ -521,11 +631,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
// n.endpoints is mutex protected so acquire lock.
n.mu.RLock()
for _, ref := range n.endpoints {
- if ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() {
- r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */)
- r.RemoteLinkAddress = remote
- ref.ep.HandlePacket(&r, vv)
- ref.decRef()
+ if ref.isValidForIncoming() && ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() {
+ handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv)
}
}
n.mu.RUnlock()
@@ -533,10 +640,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
}
if ref := n.getRef(protocol, dst); ref != nil {
- r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */)
- r.RemoteLinkAddress = remote
- ref.ep.HandlePacket(&r, vv)
- ref.decRef()
+ handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv)
return
}
@@ -559,8 +663,9 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
n := r.ref.nic
n.mu.RLock()
ref, ok := n.endpoints[NetworkEndpointID{dst}]
+ ok = ok && ref.isValidForOutgoing() && ref.tryIncRef()
n.mu.RUnlock()
- if ok && ref.tryIncRef() {
+ if ok {
r.RemoteAddress = src
// TODO(b/123449044): Update the source NIC as well.
ref.ep.HandlePacket(&r, vv)
@@ -599,9 +704,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// Raw socket packets are delivered based solely on the transport
// protocol number. We do not inspect the payload to ensure it's
// validly formed.
- if !n.demux.deliverRawPacket(r, protocol, netHeader, vv) {
- n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv)
- }
+ n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv)
if len(vv.First()) < transProto.MinimumPacketSize() {
n.stack.stats.MalformedRcvdPackets.Increment()
@@ -615,9 +718,6 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
}
id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
- if n.demux.deliverPacket(r, protocol, netHeader, vv, id) {
- return
- }
if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) {
return
}
@@ -631,7 +731,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// We could not find an appropriate destination for this packet, so
// deliver it to the global handler.
- if !transProto.HandleUnknownDestinationPacket(r, id, vv) {
+ if !transProto.HandleUnknownDestinationPacket(r, id, netHeader, vv) {
n.stack.stats.MalformedRcvdPackets.Increment()
}
}
@@ -659,10 +759,7 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp
}
id := TransportEndpointID{srcPort, local, dstPort, remote}
- if n.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
- return
- }
- if n.stack.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
+ if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, vv, id) {
return
}
}
@@ -672,9 +769,38 @@ func (n *NIC) ID() tcpip.NICID {
return n.id
}
+// Stack returns the instance of the Stack that owns this NIC.
+func (n *NIC) Stack() *Stack {
+ return n.stack
+}
+
+type networkEndpointKind int32
+
+const (
+ // A permanent endpoint is created by adding a permanent address (vs. a
+ // temporary one) to the NIC. Its reference count is biased by 1 to avoid
+ // removal when no route holds a reference to it. It is removed by explicitly
+ // removing the permanent address from the NIC.
+ permanent networkEndpointKind = iota
+
+ // An expired permanent endoint is a permanent endoint that had its address
+ // removed from the NIC, and it is waiting to be removed once no more routes
+ // hold a reference to it. This is achieved by decreasing its reference count
+ // by 1. If its address is re-added before the endpoint is removed, its type
+ // changes back to permanent and its reference count increases by 1 again.
+ permanentExpired
+
+ // A temporary endpoint is created for spoofing outgoing packets, or when in
+ // promiscuous mode and accepting incoming packets that don't match any
+ // permanent endpoint. Its reference count is not biased by 1 and the
+ // endpoint is removed immediately when no more route holds a reference to
+ // it. A temporary endpoint can be promoted to permanent if its address
+ // is added permanently.
+ temporary
+)
+
type referencedNetworkEndpoint struct {
ilist.Entry
- refs int32
ep NetworkEndpoint
nic *NIC
protocol tcpip.NetworkProtocolNumber
@@ -683,11 +809,34 @@ type referencedNetworkEndpoint struct {
// protocol. Set to nil otherwise.
linkCache LinkAddressCache
- // holdsInsertRef is protected by the NIC's mutex. It indicates whether
- // the reference count is biased by 1 due to the insertion of the
- // endpoint. It is reset to false when RemoveAddress is called on the
- // NIC.
- holdsInsertRef bool
+ // refs is counting references held for this endpoint. When refs hits zero it
+ // triggers the automatic removal of the endpoint from the NIC.
+ refs int32
+
+ // networkEndpointKind must only be accessed using {get,set}Kind().
+ kind networkEndpointKind
+}
+
+func (r *referencedNetworkEndpoint) getKind() networkEndpointKind {
+ return networkEndpointKind(atomic.LoadInt32((*int32)(&r.kind)))
+}
+
+func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) {
+ atomic.StoreInt32((*int32)(&r.kind), int32(kind))
+}
+
+// isValidForOutgoing returns true if the endpoint can be used to send out a
+// packet. It requires the endpoint to not be marked expired (i.e., its address
+// has been removed), or the NIC to be in spoofing mode.
+func (r *referencedNetworkEndpoint) isValidForOutgoing() bool {
+ return r.getKind() != permanentExpired || r.nic.spoofing
+}
+
+// isValidForIncoming returns true if the endpoint can accept an incoming
+// packet. It requires the endpoint to not be marked expired (i.e., its address
+// has been removed), or the NIC to be in promiscuous mode.
+func (r *referencedNetworkEndpoint) isValidForIncoming() bool {
+ return r.getKind() != permanentExpired || r.nic.promiscuous
}
// decRef decrements the ref count and cleans up the endpoint once it reaches
@@ -699,11 +848,14 @@ func (r *referencedNetworkEndpoint) decRef() {
}
// decRefLocked is the same as decRef but assumes that the NIC.mu mutex is
-// locked.
-func (r *referencedNetworkEndpoint) decRefLocked() {
+// locked. Returns true if the endpoint was removed.
+func (r *referencedNetworkEndpoint) decRefLocked() bool {
if atomic.AddInt32(&r.refs, -1) == 0 {
r.nic.removeEndpointLocked(r)
+ return true
}
+
+ return false
}
// incRef increments the ref count. It must only be called when the caller is
@@ -728,3 +880,8 @@ func (r *referencedNetworkEndpoint) tryIncRef() bool {
}
}
}
+
+// stack returns the Stack instance that owns the underlying endpoint.
+func (r *referencedNetworkEndpoint) stack() *Stack {
+ return r.nic.stack
+}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 2037eef9f..80101d4bb 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -15,8 +15,6 @@
package stack
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -109,7 +107,7 @@ type TransportProtocol interface {
//
// The return value indicates whether the packet was well-formed (for
// stats purposes only).
- HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) bool
+ HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
@@ -297,6 +295,15 @@ type LinkEndpoint interface {
// IsAttached returns whether a NetworkDispatcher is attached to the
// endpoint.
IsAttached() bool
+
+ // Wait waits for any worker goroutines owned by the endpoint to stop.
+ //
+ // For now, requesting that an endpoint's worker goroutine(s) stop is
+ // implementation specific.
+ //
+ // Wait will not block if the endpoint hasn't started any goroutines
+ // yet, even if it might later.
+ Wait()
}
// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
@@ -359,14 +366,6 @@ type LinkAddressCache interface {
RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
}
-// TransportProtocolFactory functions are used by the stack to instantiate
-// transport protocols.
-type TransportProtocolFactory func() TransportProtocol
-
-// NetworkProtocolFactory provides methods to be used by the stack to
-// instantiate network protocols.
-type NetworkProtocolFactory func() NetworkProtocol
-
// UnassociatedEndpointFactory produces endpoints for writing packets not
// associated with a particular transport protocol. Such endpoints can be used
// to write arbitrary packets that include the IP header.
@@ -374,60 +373,6 @@ type UnassociatedEndpointFactory interface {
NewUnassociatedRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
}
-var (
- transportProtocols = make(map[string]TransportProtocolFactory)
- networkProtocols = make(map[string]NetworkProtocolFactory)
-
- unassociatedFactory UnassociatedEndpointFactory
-
- linkEPMu sync.RWMutex
- nextLinkEndpointID tcpip.LinkEndpointID = 1
- linkEndpoints = make(map[tcpip.LinkEndpointID]LinkEndpoint)
-)
-
-// RegisterTransportProtocolFactory registers a new transport protocol factory
-// with the stack so that it becomes available to users of the stack. This
-// function is intended to be called by init() functions of the protocols.
-func RegisterTransportProtocolFactory(name string, p TransportProtocolFactory) {
- transportProtocols[name] = p
-}
-
-// RegisterNetworkProtocolFactory registers a new network protocol factory with
-// the stack so that it becomes available to users of the stack. This function
-// is intended to be called by init() functions of the protocols.
-func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) {
- networkProtocols[name] = p
-}
-
-// RegisterUnassociatedFactory registers a factory to produce endpoints not
-// associated with any particular transport protocol. This function is intended
-// to be called by init() functions of the protocols.
-func RegisterUnassociatedFactory(f UnassociatedEndpointFactory) {
- unassociatedFactory = f
-}
-
-// RegisterLinkEndpoint register a link-layer protocol endpoint and returns an
-// ID that can be used to refer to it.
-func RegisterLinkEndpoint(linkEP LinkEndpoint) tcpip.LinkEndpointID {
- linkEPMu.Lock()
- defer linkEPMu.Unlock()
-
- v := nextLinkEndpointID
- nextLinkEndpointID++
-
- linkEndpoints[v] = linkEP
-
- return v
-}
-
-// FindLinkEndpoint finds the link endpoint associated with the given ID.
-func FindLinkEndpoint(id tcpip.LinkEndpointID) LinkEndpoint {
- linkEPMu.RLock()
- defer linkEPMu.RUnlock()
-
- return linkEndpoints[id]
-}
-
// GSOType is the type of GSO segments.
//
// +stateify savable
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 391ab4344..5c8b7977a 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -148,11 +148,15 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
// IsResolutionRequired returns true if Resolve() must be called to resolve
// the link address before the this route can be written to.
func (r *Route) IsResolutionRequired() bool {
- return r.ref.linkCache != nil && r.RemoteLinkAddress == ""
+ return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == ""
}
// WritePacket writes the packet through the given route.
func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
+ if !r.ref.isValidForOutgoing() {
+ return tcpip.ErrInvalidEndpointState
+ }
+
err := r.ref.ep.WritePacket(r, gso, hdr, payload, protocol, ttl, r.loop)
if err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
@@ -166,6 +170,10 @@ func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.Vec
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error {
+ if !r.ref.isValidForOutgoing() {
+ return tcpip.ErrInvalidEndpointState
+ }
+
if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.loop); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return err
@@ -209,3 +217,8 @@ func (r *Route) MakeLoopedRoute() Route {
}
return l
}
+
+// Stack returns the instance of the Stack that owns this route.
+func (r *Route) Stack() *Stack {
+ return r.ref.stack()
+}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index d69162ba1..90c2cf1be 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -17,17 +17,15 @@
//
// For consumers, the only function of interest is New(), everything else is
// provided by the tcpip/public package.
-//
-// For protocol implementers, RegisterTransportProtocolFactory() and
-// RegisterNetworkProtocolFactory() are used to register protocol factories with
-// the stack, which will then be used to instantiate protocol objects when
-// consumers interact with the stack.
package stack
import (
+ "encoding/binary"
"sync"
"time"
+ "golang.org/x/time/rate"
+ "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -350,6 +348,9 @@ type Stack struct {
networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol
linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver
+ // unassociatedFactory creates unassociated endpoints. If nil, raw
+ // endpoints are disabled. It is set during Stack creation and is
+ // immutable.
unassociatedFactory UnassociatedEndpointFactory
demux *transportDemuxer
@@ -358,10 +359,6 @@ type Stack struct {
linkAddrCache *linkAddrCache
- // raw indicates whether raw sockets may be created. It is set during
- // Stack creation and is immutable.
- raw bool
-
mu sync.RWMutex
nics map[tcpip.NICID]*NIC
forwarding bool
@@ -389,10 +386,26 @@ type Stack struct {
// resumableEndpoints is a list of endpoints that need to be resumed if the
// stack is being restored.
resumableEndpoints []ResumableEndpoint
+
+ // icmpRateLimiter is a global rate limiter for all ICMP messages generated
+ // by the stack.
+ icmpRateLimiter *ICMPRateLimiter
+
+ // portSeed is a one-time random value initialized at stack startup
+ // and is used to seed the TCP port picking on active connections
+ //
+ // TODO(gvisor.dev/issues/940): S/R this field.
+ portSeed uint32
}
// Options contains optional Stack configuration.
type Options struct {
+ // NetworkProtocols lists the network protocols to enable.
+ NetworkProtocols []NetworkProtocol
+
+ // TransportProtocols lists the transport protocols to enable.
+ TransportProtocols []TransportProtocol
+
// Clock is an optional clock source used for timestampping packets.
//
// If no Clock is specified, the clock source will be time.Now.
@@ -406,8 +419,9 @@ type Options struct {
// stack (false).
HandleLocal bool
- // Raw indicates whether raw sockets may be created.
- Raw bool
+ // UnassociatedFactory produces unassociated endpoints raw endpoints.
+ // Raw endpoints are enabled only if this is non-nil.
+ UnassociatedFactory UnassociatedEndpointFactory
}
// New allocates a new networking stack with only the requested networking and
@@ -417,7 +431,7 @@ type Options struct {
// SetNetworkProtocolOption/SetTransportProtocolOption methods provided by the
// stack. Please refer to individual protocol implementations as to what options
// are supported.
-func New(network []string, transport []string, opts Options) *Stack {
+func New(opts Options) *Stack {
clock := opts.Clock
if clock == nil {
clock = &tcpip.StdClock{}
@@ -433,16 +447,12 @@ func New(network []string, transport []string, opts Options) *Stack {
clock: clock,
stats: opts.Stats.FillIn(),
handleLocal: opts.HandleLocal,
- raw: opts.Raw,
+ icmpRateLimiter: NewICMPRateLimiter(),
+ portSeed: generateRandUint32(),
}
// Add specified network protocols.
- for _, name := range network {
- netProtoFactory, ok := networkProtocols[name]
- if !ok {
- continue
- }
- netProto := netProtoFactory()
+ for _, netProto := range opts.NetworkProtocols {
s.networkProtocols[netProto.Number()] = netProto
if r, ok := netProto.(LinkAddressResolver); ok {
s.linkAddrResolvers[r.LinkAddressProtocol()] = r
@@ -450,18 +460,14 @@ func New(network []string, transport []string, opts Options) *Stack {
}
// Add specified transport protocols.
- for _, name := range transport {
- transProtoFactory, ok := transportProtocols[name]
- if !ok {
- continue
- }
- transProto := transProtoFactory()
+ for _, transProto := range opts.TransportProtocols {
s.transportProtocols[transProto.Number()] = &transportProtocolState{
proto: transProto,
}
}
- s.unassociatedFactory = unassociatedFactory
+ // Add the factory for unassociated endpoints, if present.
+ s.unassociatedFactory = opts.UnassociatedFactory
// Create the global transport demuxer.
s.demux = newTransportDemuxer(s)
@@ -596,7 +602,7 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp
// protocol. Raw endpoints receive all traffic for a given protocol regardless
// of address.
func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
- if !s.raw {
+ if s.unassociatedFactory == nil {
return nil, tcpip.ErrNotPermitted
}
@@ -614,12 +620,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network
// createNIC creates a NIC with the provided id and link-layer endpoint, and
// optionally enable it.
-func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled, loopback bool) *tcpip.Error {
- ep := FindLinkEndpoint(linkEP)
- if ep == nil {
- return tcpip.ErrBadLinkEndpoint
- }
-
+func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -632,40 +633,40 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpoint
s.nics[id] = n
if enabled {
- n.attachLinkEndpoint()
+ return n.enable()
}
return nil
}
// CreateNIC creates a NIC with the provided id and link-layer endpoint.
-func (s *Stack) CreateNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, "", linkEP, true, false)
+func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
+ return s.createNIC(id, "", ep, true, false)
}
// CreateNamedNIC creates a NIC with the provided id and link-layer endpoint,
// and a human-readable name.
-func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, name, linkEP, true, false)
+func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error {
+ return s.createNIC(id, name, ep, true, false)
}
// CreateNamedLoopbackNIC creates a NIC with the provided id and link-layer
// endpoint, and a human-readable name.
-func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, name, linkEP, true, true)
+func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error {
+ return s.createNIC(id, name, ep, true, true)
}
// CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint,
// but leave it disable. Stack.EnableNIC must be called before the link-layer
// endpoint starts delivering packets to it.
-func (s *Stack) CreateDisabledNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, "", linkEP, false, false)
+func (s *Stack) CreateDisabledNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
+ return s.createNIC(id, "", ep, false, false)
}
// CreateDisabledNamedNIC is a combination of CreateNamedNIC and
// CreateDisabledNIC.
-func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, name, linkEP, false, false)
+func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error {
+ return s.createNIC(id, name, ep, false, false)
}
// EnableNIC enables the given NIC so that the link-layer endpoint can start
@@ -679,9 +680,7 @@ func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error {
return tcpip.ErrUnknownNICID
}
- nic.attachLinkEndpoint()
-
- return nil
+ return nic.enable()
}
// CheckNIC checks if a NIC is usable.
@@ -696,14 +695,14 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool {
}
// NICSubnets returns a map of NICIDs to their associated subnets.
-func (s *Stack) NICSubnets() map[tcpip.NICID][]tcpip.Subnet {
+func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet {
s.mu.RLock()
defer s.mu.RUnlock()
nics := map[tcpip.NICID][]tcpip.Subnet{}
for id, nic := range s.nics {
- nics[id] = append(nics[id], nic.Subnets()...)
+ nics[id] = append(nics[id], nic.AddressRanges()...)
}
return nics
}
@@ -739,7 +738,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
nics[id] = NICInfo{
Name: nic.name,
LinkAddress: nic.linkEP.LinkAddress(),
- ProtocolAddresses: nic.Addresses(),
+ ProtocolAddresses: nic.PrimaryAddresses(),
Flags: flags,
MTU: nic.linkEP.MTU(),
Stats: nic.stats,
@@ -804,71 +803,79 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc
return nic.AddAddress(protocolAddress, peb)
}
-// AddSubnet adds a subnet range to the specified NIC.
-func (s *Stack) AddSubnet(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error {
+// AddAddressRange adds a range of addresses to the specified NIC. The range is
+// given by a subnet address, and all addresses contained in the subnet are
+// used except for the subnet address itself and the subnet's broadcast
+// address.
+func (s *Stack) AddAddressRange(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
if nic, ok := s.nics[id]; ok {
- nic.AddSubnet(protocol, subnet)
+ nic.AddAddressRange(protocol, subnet)
return nil
}
return tcpip.ErrUnknownNICID
}
-// RemoveSubnet removes the subnet range from the specified NIC.
-func (s *Stack) RemoveSubnet(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error {
+// RemoveAddressRange removes the range of addresses from the specified NIC.
+func (s *Stack) RemoveAddressRange(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
if nic, ok := s.nics[id]; ok {
- nic.RemoveSubnet(subnet)
+ nic.RemoveAddressRange(subnet)
return nil
}
return tcpip.ErrUnknownNICID
}
-// ContainsSubnet reports whether the specified NIC contains the specified
-// subnet.
-func (s *Stack) ContainsSubnet(id tcpip.NICID, subnet tcpip.Subnet) (bool, *tcpip.Error) {
+// RemoveAddress removes an existing network-layer address from the specified
+// NIC.
+func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
if nic, ok := s.nics[id]; ok {
- return nic.ContainsSubnet(subnet), nil
+ return nic.RemoveAddress(addr)
}
- return false, tcpip.ErrUnknownNICID
+ return tcpip.ErrUnknownNICID
}
-// RemoveAddress removes an existing network-layer address from the specified
-// NIC.
-func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
+// AllAddresses returns a map of NICIDs to their protocol addresses (primary
+// and non-primary).
+func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress {
s.mu.RLock()
defer s.mu.RUnlock()
- if nic, ok := s.nics[id]; ok {
- return nic.RemoveAddress(addr)
+ nics := make(map[tcpip.NICID][]tcpip.ProtocolAddress)
+ for id, nic := range s.nics {
+ nics[id] = nic.AllAddresses()
}
-
- return tcpip.ErrUnknownNICID
+ return nics
}
-// GetMainNICAddress returns the first primary address (and the subnet that
-// contains it) for the given NIC and protocol. Returns an arbitrary endpoint's
-// address if no primary addresses exist. Returns an error if the NIC doesn't
-// exist or has no endpoints.
+// GetMainNICAddress returns the first primary address and prefix for the given
+// NIC and protocol. Returns an error if the NIC doesn't exist and an empty
+// value if the NIC doesn't have a primary address for the given protocol.
func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
- if nic, ok := s.nics[id]; ok {
- return nic.getMainNICAddress(protocol)
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID
}
- return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID
+ for _, a := range nic.PrimaryAddresses() {
+ if a.Protocol == protocol {
+ return a.AddressWithPrefix, nil
+ }
+ }
+ return tcpip.AddressWithPrefix{}, nil
}
func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) {
@@ -1035,73 +1042,27 @@ func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.
// transport dispatcher. Received packets that match the provided id will be
// delivered to the given endpoint; specifying a nic is optional, but
// nic-specific IDs have precedence over global ones.
-func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
- if nicID == 0 {
- return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort)
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic == nil {
- return tcpip.ErrUnknownNICID
- }
-
- return nic.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort)
+func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+ return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort, bindToDevice)
}
// UnregisterTransportEndpoint removes the endpoint with the given id from the
// stack transport dispatcher.
-func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) {
- if nicID == 0 {
- s.demux.unregisterEndpoint(netProtos, protocol, id, ep)
- return
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic != nil {
- nic.demux.unregisterEndpoint(netProtos, protocol, id, ep)
- }
+func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
+ s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice)
}
// RegisterRawTransportEndpoint registers the given endpoint with the stack
// transport dispatcher. Received packets that match the provided transport
// protocol will be delivered to the given endpoint.
func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
- if nicID == 0 {
- return s.demux.registerRawEndpoint(netProto, transProto, ep)
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic == nil {
- return tcpip.ErrUnknownNICID
- }
-
- return nic.demux.registerRawEndpoint(netProto, transProto, ep)
+ return s.demux.registerRawEndpoint(netProto, transProto, ep)
}
// UnregisterRawTransportEndpoint removes the endpoint for the transport
// protocol from the stack transport dispatcher.
func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
- if nicID == 0 {
- s.demux.unregisterRawEndpoint(netProto, transProto, ep)
- return
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic != nil {
- nic.demux.unregisterRawEndpoint(netProto, transProto, ep)
- }
+ s.demux.unregisterRawEndpoint(netProto, transProto, ep)
}
// RegisterRestoredEndpoint records e as an endpoint that has been restored on
@@ -1215,3 +1176,49 @@ func (s *Stack) IPTables() iptables.IPTables {
func (s *Stack) SetIPTables(ipt iptables.IPTables) {
s.tables = ipt
}
+
+// ICMPLimit returns the maximum number of ICMP messages that can be sent
+// in one second.
+func (s *Stack) ICMPLimit() rate.Limit {
+ return s.icmpRateLimiter.Limit()
+}
+
+// SetICMPLimit sets the maximum number of ICMP messages that be sent
+// in one second.
+func (s *Stack) SetICMPLimit(newLimit rate.Limit) {
+ s.icmpRateLimiter.SetLimit(newLimit)
+}
+
+// ICMPBurst returns the maximum number of ICMP messages that can be sent
+// in a single burst.
+func (s *Stack) ICMPBurst() int {
+ return s.icmpRateLimiter.Burst()
+}
+
+// SetICMPBurst sets the maximum number of ICMP messages that can be sent
+// in a single burst.
+func (s *Stack) SetICMPBurst(burst int) {
+ s.icmpRateLimiter.SetBurst(burst)
+}
+
+// AllowICMPMessage returns true if we the rate limiter allows at least one
+// ICMP message to be sent at this instant.
+func (s *Stack) AllowICMPMessage() bool {
+ return s.icmpRateLimiter.Allow()
+}
+
+// PortSeed returns a 32 bit value that can be used as a seed value for port
+// picking.
+//
+// NOTE: The seed is generated once during stack initialization only.
+func (s *Stack) PortSeed() uint32 {
+ return s.portSeed
+}
+
+func generateRandUint32() uint32 {
+ b := make([]byte, 4)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+ return binary.LittleEndian.Uint32(b)
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 137c6183e..d2dede8a9 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -60,11 +60,11 @@ type fakeNetworkEndpoint struct {
prefixLen int
proto *fakeNetworkProtocol
dispatcher stack.TransportDispatcher
- linkEP stack.LinkEndpoint
+ ep stack.LinkEndpoint
}
func (f *fakeNetworkEndpoint) MTU() uint32 {
- return f.linkEP.MTU() - uint32(f.MaxHeaderLength())
+ return f.ep.MTU() - uint32(f.MaxHeaderLength())
}
func (f *fakeNetworkEndpoint) NICID() tcpip.NICID {
@@ -108,7 +108,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedV
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
- return f.linkEP.MaxHeaderLength() + fakeNetHeaderLen
+ return f.ep.MaxHeaderLength() + fakeNetHeaderLen
}
func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
@@ -116,7 +116,7 @@ func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProto
}
func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
- return f.linkEP.Capabilities()
+ return f.ep.Capabilities()
}
func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8, loop stack.PacketLooping) *tcpip.Error {
@@ -141,7 +141,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu
return nil
}
- return f.linkEP.WritePacket(r, gso, hdr, payload, fakeNetNumber)
+ return f.ep.WritePacket(r, gso, hdr, payload, fakeNetNumber)
}
func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
@@ -181,18 +181,22 @@ func (f *fakeNetworkProtocol) DefaultPrefixLen() int {
return fakeDefaultPrefixLen
}
+func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int {
+ return f.packetCount[int(intfAddr)%len(f.packetCount)]
+}
+
func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
}
-func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
return &fakeNetworkEndpoint{
nicid: nicid,
id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
prefixLen: addrWithPrefix.PrefixLen,
proto: f,
dispatcher: dispatcher,
- linkEP: linkEP,
+ ep: ep,
}, nil
}
@@ -218,12 +222,18 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error {
}
}
+func fakeNetFactory() stack.NetworkProtocol {
+ return &fakeNetworkProtocol{}
+}
+
func TestNetworkReceive(t *testing.T) {
// Create a stack with the fake network protocol, one nic, and two
// addresses attached to it: 1 & 2.
- id, linkEP := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -241,7 +251,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet with wrong address is not delivered.
buf[0] = 3
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 0 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
}
@@ -251,7 +261,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet is delivered to first endpoint.
buf[0] = 1
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -261,7 +271,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet is delivered to second endpoint.
buf[0] = 2
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -270,7 +280,7 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is not delivered if protocol number is wrong.
- linkEP.Inject(fakeNetNumber-1, buf.ToVectorisedView())
+ ep.Inject(fakeNetNumber-1, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -280,7 +290,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet that is too small is dropped.
buf.CapLength(2)
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -289,16 +299,75 @@ func TestNetworkReceive(t *testing.T) {
}
}
-func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View) {
+func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Error {
r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
- t.Fatal("FindRoute failed:", err)
+ return err
}
defer r.Release()
+ return send(r, payload)
+}
+func send(r stack.Route, payload buffer.View) *tcpip.Error {
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
- if err := r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123); err != nil {
- t.Error("WritePacket failed:", err)
+ return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123)
+}
+
+func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) {
+ t.Helper()
+ ep.Drain()
+ if err := sendTo(s, addr, payload); err != nil {
+ t.Error("sendTo failed:", err)
+ }
+ if got, want := ep.Drain(), 1; got != want {
+ t.Errorf("sendTo packet count: got = %d, want %d", got, want)
+ }
+}
+
+func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) {
+ t.Helper()
+ ep.Drain()
+ if err := send(r, payload); err != nil {
+ t.Error("send failed:", err)
+ }
+ if got, want := ep.Drain(), 1; got != want {
+ t.Errorf("send packet count: got = %d, want %d", got, want)
+ }
+}
+
+func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+ t.Helper()
+ if gotErr := send(r, payload); gotErr != wantErr {
+ t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr)
+ }
+}
+
+func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+ t.Helper()
+ if gotErr := sendTo(s, addr, payload); gotErr != wantErr {
+ t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr)
+ }
+}
+
+func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) {
+ t.Helper()
+ // testRecvInternal injects one packet, and we expect to receive it.
+ want := fakeNet.PacketCount(localAddrByte) + 1
+ testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want)
+}
+
+func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) {
+ t.Helper()
+ // testRecvInternal injects one packet, and we do NOT expect to receive it.
+ want := fakeNet.PacketCount(localAddrByte)
+ testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want)
+}
+
+func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) {
+ t.Helper()
+ ep.Inject(fakeNetNumber, buf.ToVectorisedView())
+ if got := fakeNet.PacketCount(localAddrByte); got != want {
+ t.Errorf("receive packet count: got = %d, want %d", got, want)
}
}
@@ -306,9 +375,11 @@ func TestNetworkSend(t *testing.T) {
// Create a stack with the fake network protocol, one nic, and one
// address: 1. The route table sends all packets through the only
// existing nic.
- id, linkEP := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("NewNIC failed:", err)
}
@@ -325,20 +396,19 @@ func TestNetworkSend(t *testing.T) {
}
// Make sure that the link-layer endpoint received the outbound packet.
- sendTo(t, s, "\x03", nil)
- if c := linkEP.Drain(); c != 1 {
- t.Errorf("packetCount = %d, want %d", c, 1)
- }
+ testSendTo(t, s, "\x03", ep, nil)
}
func TestNetworkSendMultiRoute(t *testing.T) {
// Create a stack with the fake network protocol, two nics, and two
// addresses per nic, the first nic has odd address, the second one has
// even addresses.
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id1, linkEP1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id1); err != nil {
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep1); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -350,8 +420,8 @@ func TestNetworkSendMultiRoute(t *testing.T) {
t.Fatal("AddAddress failed:", err)
}
- id2, linkEP2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, id2); err != nil {
+ ep2 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(2, ep2); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -382,18 +452,10 @@ func TestNetworkSendMultiRoute(t *testing.T) {
}
// Send a packet to an odd destination.
- sendTo(t, s, "\x05", nil)
-
- if c := linkEP1.Drain(); c != 1 {
- t.Errorf("packetCount = %d, want %d", c, 1)
- }
+ testSendTo(t, s, "\x05", ep1, nil)
// Send a packet to an even destination.
- sendTo(t, s, "\x06", nil)
-
- if c := linkEP2.Drain(); c != 1 {
- t.Errorf("packetCount = %d, want %d", c, 1)
- }
+ testSendTo(t, s, "\x06", ep2, nil)
}
func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) {
@@ -424,10 +486,12 @@ func TestRoutes(t *testing.T) {
// Create a stack with the fake network protocol, two nics, and two
// addresses per nic, the first nic has odd address, the second one has
// even addresses.
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id1, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id1); err != nil {
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep1); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -439,8 +503,8 @@ func TestRoutes(t *testing.T) {
t.Fatal("AddAddress failed:", err)
}
- id2, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, id2); err != nil {
+ ep2 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(2, ep2); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -498,58 +562,71 @@ func TestRoutes(t *testing.T) {
}
func TestAddressRemoval(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ const localAddrByte byte = 0x01
+ localAddr := tcpip.Address([]byte{localAddrByte})
+ remoteAddr := tcpip.Address("\x02")
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
t.Fatal("AddAddress failed:", err)
}
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
- // Write a packet, and check that it gets delivered.
- fakeNet.packetCount[1] = 0
- buf[0] = 1
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
+ // Send and receive packets, and verify they are received.
+ buf[0] = localAddrByte
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
- // Remove the address, then check that packet doesn't get delivered
- // anymore.
- if err := s.RemoveAddress(1, "\x01"); err != nil {
+ // Remove the address, then check that send/receive doesn't work anymore.
+ if err := s.RemoveAddress(1, localAddr); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
// Check that removing the same address fails.
- if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress {
+ if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress {
t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress)
}
}
-func TestDelayedRemovalDueToRoute(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+func TestAddressRemovalWithRouteHeld(t *testing.T) {
+ const localAddrByte byte = 0x01
+ localAddr := tcpip.Address([]byte{localAddrByte})
+ remoteAddr := tcpip.Address("\x02")
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
- t.Fatal("CreateNIC failed:", err)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
}
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+ buf := buffer.NewView(30)
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
t.Fatal("AddAddress failed:", err)
}
-
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
if err != nil {
@@ -558,58 +635,239 @@ func TestDelayedRemovalDueToRoute(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
- buf := buffer.NewView(30)
-
- // Write a packet, and check that it gets delivered.
- fakeNet.packetCount[1] = 0
- buf[0] = 1
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
-
- // Get a route, check that packet is still deliverable.
- r, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */)
+ r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
t.Fatal("FindRoute failed:", err)
}
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 2 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 2)
- }
+ // Send and receive packets, and verify they are received.
+ buf[0] = localAddrByte
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSend(t, r, ep, nil)
+ testSendTo(t, s, remoteAddr, ep, nil)
- // Remove the address, then check that packet is still deliverable
- // because the route is keeping the address alive.
- if err := s.RemoveAddress(1, "\x01"); err != nil {
+ // Remove the address, then check that send/receive doesn't work anymore.
+ if err := s.RemoveAddress(1, localAddr); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 3 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3)
- }
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+ testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState)
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
// Check that removing the same address fails.
- if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress {
+ if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress {
t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress)
}
+}
+
+func verifyAddress(t *testing.T, s *stack.Stack, nicid tcpip.NICID, addr tcpip.Address) {
+ t.Helper()
+ info, ok := s.NICInfo()[nicid]
+ if !ok {
+ t.Fatalf("NICInfo() failed to find nicid=%d", nicid)
+ }
+ if len(addr) == 0 {
+ // No address given, verify that there is no address assigned to the NIC.
+ for _, a := range info.ProtocolAddresses {
+ if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) {
+ t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{}))
+ }
+ }
+ return
+ }
+ // Address given, verify the address is assigned to the NIC and no other
+ // address is.
+ found := false
+ for _, a := range info.ProtocolAddresses {
+ if a.Protocol == fakeNetNumber {
+ if a.AddressWithPrefix.Address == addr {
+ found = true
+ } else {
+ t.Errorf("verify address: got = %s, want = %s", a.AddressWithPrefix.Address, addr)
+ }
+ }
+ }
+ if !found {
+ t.Errorf("verify address: couldn't find %s on the NIC", addr)
+ }
+}
+
+func TestEndpointExpiration(t *testing.T) {
+ const (
+ localAddrByte byte = 0x01
+ remoteAddr tcpip.Address = "\x03"
+ noAddr tcpip.Address = ""
+ nicid tcpip.NICID = 1
+ )
+ localAddr := tcpip.Address([]byte{localAddrByte})
+
+ for _, promiscuous := range []bool{true, false} {
+ for _, spoofing := range []bool{true, false} {
+ t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- // Release the route, then check that packet is not deliverable anymore.
- r.Release()
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 3 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3)
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, ep); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+ buf := buffer.NewView(30)
+ buf[0] = localAddrByte
+
+ if promiscuous {
+ if err := s.SetPromiscuousMode(nicid, true); err != nil {
+ t.Fatal("SetPromiscuousMode failed:", err)
+ }
+ }
+
+ if spoofing {
+ if err := s.SetSpoofing(nicid, true); err != nil {
+ t.Fatal("SetSpoofing failed:", err)
+ }
+ }
+
+ // 1. No Address yet, send should only work for spoofing, receive for
+ // promiscuous mode.
+ //-----------------------
+ verifyAddress(t, s, nicid, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+ }
+ if spoofing {
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, ep, nil)
+ } else {
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ }
+
+ // 2. Add Address, everything should work.
+ //-----------------------
+ if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, localAddr)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
+
+ // 3. Remove the address, send should only work for spoofing, receive
+ // for promiscuous mode.
+ //-----------------------
+ if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ t.Fatal("RemoveAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+ }
+ if spoofing {
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, ep, nil)
+ } else {
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ }
+
+ // 4. Add Address back, everything should work again.
+ //-----------------------
+ if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, localAddr)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
+
+ // 5. Take a reference to the endpoint by getting a route. Verify that
+ // we can still send/receive, including sending using the route.
+ //-----------------------
+ r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
+ testSend(t, r, ep, nil)
+
+ // 6. Remove the address. Send should only work for spoofing, receive
+ // for promiscuous mode.
+ //-----------------------
+ if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ t.Fatal("RemoveAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+ }
+ if spoofing {
+ testSend(t, r, ep, nil)
+ testSendTo(t, s, remoteAddr, ep, nil)
+ } else {
+ testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState)
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ }
+
+ // 7. Add Address back, everything should work again.
+ //-----------------------
+ if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, localAddr)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
+ testSend(t, r, ep, nil)
+
+ // 8. Remove the route, sendTo/recv should still work.
+ //-----------------------
+ r.Release()
+ verifyAddress(t, s, nicid, localAddr)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
+
+ // 9. Remove the address. Send should only work for spoofing, receive
+ // for promiscuous mode.
+ //-----------------------
+ if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ t.Fatal("RemoveAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+ }
+ if spoofing {
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, ep, nil)
+ } else {
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ }
+ })
+ }
}
}
func TestPromiscuousMode(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -627,22 +885,15 @@ func TestPromiscuousMode(t *testing.T) {
// Write a packet, and check that it doesn't get delivered as we don't
// have a matching endpoint.
- fakeNet.packetCount[1] = 0
- buf[0] = 1
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 0 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
- }
+ const localAddrByte byte = 0x01
+ buf[0] = localAddrByte
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
// Set promiscuous mode, then check that packet is delivered.
if err := s.SetPromiscuousMode(1, true); err != nil {
t.Fatal("SetPromiscuousMode failed:", err)
}
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
// Check that we can't get a route as there is no local address.
_, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */)
@@ -655,25 +906,24 @@ func TestPromiscuousMode(t *testing.T) {
if err := s.SetPromiscuousMode(1, false); err != nil {
t.Fatal("SetPromiscuousMode failed:", err)
}
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
}
-func TestAddressSpoofing(t *testing.T) {
- srcAddr := tcpip.Address("\x01")
- dstAddr := tcpip.Address("\x02")
+func TestSpoofingWithAddress(t *testing.T) {
+ localAddr := tcpip.Address("\x01")
+ nonExistentLocalAddr := tcpip.Address("\x02")
+ dstAddr := tcpip.Address("\x03")
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, dstAddr); err != nil {
+ if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
t.Fatal("AddAddress failed:", err)
}
@@ -687,7 +937,7 @@ func TestAddressSpoofing(t *testing.T) {
// With address spoofing disabled, FindRoute does not permit an address
// that was not added to the NIC to be used as the source.
- r, err := s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
if err == nil {
t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
}
@@ -697,23 +947,92 @@ func TestAddressSpoofing(t *testing.T) {
if err := s.SetSpoofing(1, true); err != nil {
t.Fatal("SetSpoofing failed:", err)
}
- r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+ if r.LocalAddress != nonExistentLocalAddr {
+ t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr)
+ }
+ if r.RemoteAddress != dstAddr {
+ t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr)
+ }
+ // Sending a packet works.
+ testSendTo(t, s, dstAddr, ep, nil)
+ testSend(t, r, ep, nil)
+
+ // FindRoute should also work with a local address that exists on the NIC.
+ r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
t.Fatal("FindRoute failed:", err)
}
- if r.LocalAddress != srcAddr {
- t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, srcAddr)
+ if r.LocalAddress != localAddr {
+ t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr)
}
if r.RemoteAddress != dstAddr {
- t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr)
+ t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr)
}
+ // Sending a packet using the route works.
+ testSend(t, r, ep, nil)
+}
+
+func TestSpoofingNoAddress(t *testing.T) {
+ nonExistentLocalAddr := tcpip.Address("\x01")
+ dstAddr := tcpip.Address("\x02")
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ // With address spoofing disabled, FindRoute does not permit an address
+ // that was not added to the NIC to be used as the source.
+ r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err == nil {
+ t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
+ }
+ // Sending a packet fails.
+ testFailingSendTo(t, s, dstAddr, ep, nil, tcpip.ErrNoRoute)
+
+ // With address spoofing enabled, FindRoute permits any address to be used
+ // as the source.
+ if err := s.SetSpoofing(1, true); err != nil {
+ t.Fatal("SetSpoofing failed:", err)
+ }
+ r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+ if r.LocalAddress != nonExistentLocalAddr {
+ t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr)
+ }
+ if r.RemoteAddress != dstAddr {
+ t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr)
+ }
+ // Sending a packet works.
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, ep, nil)
}
func TestBroadcastNeedsNoRoute(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
s.SetRouteTable([]tcpip.Route{})
@@ -781,10 +1100,12 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
{"IPv6 Unicast Not Link-Local 7", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
} {
t.Run(tc.name, func(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -835,12 +1156,14 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
}
}
-// Set the subnet, then check that packet is delivered.
-func TestSubnetAcceptsMatchingPacket(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+// Add a range of addresses, then check that a packet is delivered.
+func TestAddressRangeAcceptsMatchingPacket(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -856,29 +1179,59 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) {
buf := buffer.NewView(30)
- buf[0] = 1
- fakeNet.packetCount[1] = 0
+ const localAddrByte byte = 0x01
+ buf[0] = localAddrByte
subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0"))
if err != nil {
t.Fatal("NewSubnet failed:", err)
}
- if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
- t.Fatal("AddSubnet failed:", err)
+ if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil {
+ t.Fatal("AddAddressRange failed:", err)
}
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+}
+
+func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, subnet tcpip.Subnet, rangeExists bool) {
+ t.Helper()
+
+ // Loop over all addresses and check them.
+ numOfAddresses := 1 << uint(8-subnet.Prefix())
+ if numOfAddresses < 1 || numOfAddresses > 255 {
+ t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet)
+ }
+
+ addrBytes := []byte(subnet.ID())
+ for i := 0; i < numOfAddresses; i++ {
+ addr := tcpip.Address(addrBytes)
+ wantNicID := nicID
+ // The subnet and broadcast addresses are skipped.
+ if !rangeExists || addr == subnet.ID() || addr == subnet.Broadcast() {
+ wantNicID = 0
+ }
+ if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, addr); gotNicID != wantNicID {
+ t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, addr, gotNicID, wantNicID)
+ }
+ addrBytes[0]++
+ }
+
+ // Trying the next address should always fail since it is outside the range.
+ if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addrBytes)); gotNicID != 0 {
+ t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addrBytes), gotNicID, 0)
}
}
-// Set the subnet, then check that CheckLocalAddress returns the correct NIC.
+// Set a range of addresses, then remove it again, and check at each step that
+// CheckLocalAddress returns the correct NIC for each address or zero if not
+// existent.
func TestCheckLocalAddressForSubnet(t *testing.T) {
const nicID tcpip.NICID = 1
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -891,39 +1244,34 @@ func TestCheckLocalAddressForSubnet(t *testing.T) {
}
subnet, err := tcpip.NewSubnet(tcpip.Address("\xa0"), tcpip.AddressMask("\xf0"))
-
if err != nil {
t.Fatal("NewSubnet failed:", err)
}
- if err := s.AddSubnet(nicID, fakeNetNumber, subnet); err != nil {
- t.Fatal("AddSubnet failed:", err)
- }
- // Loop over all subnet addresses and check them.
- numOfAddresses := 1 << uint(8-subnet.Prefix())
- if numOfAddresses < 1 || numOfAddresses > 255 {
- t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet)
- }
- addr := []byte(subnet.ID())
- for i := 0; i < numOfAddresses; i++ {
- if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != nicID {
- t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, nicID)
- }
- addr[0]++
+ testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */)
+
+ if err := s.AddAddressRange(nicID, fakeNetNumber, subnet); err != nil {
+ t.Fatal("AddAddressRange failed:", err)
}
- // Trying the next address should fail since it is outside the subnet range.
- if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != 0 {
- t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, 0)
+ testNicForAddressRange(t, nicID, s, subnet, true /* rangeExists */)
+
+ if err := s.RemoveAddressRange(nicID, subnet); err != nil {
+ t.Fatal("RemoveAddressRange failed:", err)
}
+
+ testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */)
}
-// Set destination outside the subnet, then check it doesn't get delivered.
-func TestSubnetRejectsNonmatchingPacket(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+// Set a range of addresses, then send a packet to a destination outside the
+// range and then check it doesn't get delivered.
+func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -939,23 +1287,23 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) {
buf := buffer.NewView(30)
- buf[0] = 1
- fakeNet.packetCount[1] = 0
+ const localAddrByte byte = 0x01
+ buf[0] = localAddrByte
subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0"))
if err != nil {
t.Fatal("NewSubnet failed:", err)
}
- if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
- t.Fatal("AddSubnet failed:", err)
- }
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 0 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
+ if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil {
+ t.Fatal("AddAddressRange failed:", err)
}
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
}
func TestNetworkOptions(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, []string{}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{},
+ })
// Try an unsupported network protocol.
if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol {
@@ -994,44 +1342,53 @@ func TestNetworkOptions(t *testing.T) {
}
}
-func TestSubnetAddRemove(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.Subnet) bool {
+ ranges, ok := s.NICAddressRanges()[id]
+ if !ok {
+ return false
+ }
+ for _, r := range ranges {
+ if r == addrRange {
+ return true
+ }
+ }
+ return false
+}
+
+func TestAddresRangeAddRemove(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
addr := tcpip.Address("\x01\x01\x01\x01")
mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr)))
- subnet, err := tcpip.NewSubnet(addr, mask)
+ addrRange, err := tcpip.NewSubnet(addr, mask)
if err != nil {
t.Fatal("NewSubnet failed:", err)
}
- if contained, err := s.ContainsSubnet(1, subnet); err != nil {
- t.Fatal("ContainsSubnet failed:", err)
- } else if contained {
- t.Fatal("got s.ContainsSubnet(...) = true, want = false")
+ if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want {
+ t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want)
}
- if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
- t.Fatal("AddSubnet failed:", err)
+ if err := s.AddAddressRange(1, fakeNetNumber, addrRange); err != nil {
+ t.Fatal("AddAddressRange failed:", err)
}
- if contained, err := s.ContainsSubnet(1, subnet); err != nil {
- t.Fatal("ContainsSubnet failed:", err)
- } else if !contained {
- t.Fatal("got s.ContainsSubnet(...) = false, want = true")
+ if got, want := stackContainsAddressRange(s, 1, addrRange), true; got != want {
+ t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want)
}
- if err := s.RemoveSubnet(1, subnet); err != nil {
- t.Fatal("RemoveSubnet failed:", err)
+ if err := s.RemoveAddressRange(1, addrRange); err != nil {
+ t.Fatal("RemoveAddressRange failed:", err)
}
- if contained, err := s.ContainsSubnet(1, subnet); err != nil {
- t.Fatal("ContainsSubnet failed:", err)
- } else if contained {
- t.Fatal("got s.ContainsSubnet(...) = true, want = false")
+ if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want {
+ t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want)
}
}
@@ -1042,9 +1399,11 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
t.Run(fmt.Sprintf("canBe=%d", canBe), func(t *testing.T) {
for never := 0; never < 3; never++ {
t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
// Insert <canBe> primary and <never> never-primary addresses.
@@ -1082,20 +1441,20 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
// Check that GetMainNICAddress returns an address if at least
// one primary address was added. In that case make sure the
// address/prefixLen matches what we added.
+ gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
+ t.Fatal("GetMainNICAddress failed:", err)
+ }
if len(primaryAddrAdded) == 0 {
- // No primary addresses present, expect an error.
- if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got s.GetMainNICAddress(...) = %v, wanted = %s", err, tcpip.ErrNoLinkAddress)
+ // No primary addresses present.
+ if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr {
+ t.Fatalf("GetMainNICAddress: got addr = %s, want = %s", gotAddr, wantAddr)
}
} else {
- // At least one primary address was added, expect a valid
- // address and prefixLen.
- gotAddressWithPefix, err := s.GetMainNICAddress(1, fakeNetNumber)
- if err != nil {
- t.Fatal("GetMainNICAddress failed:", err)
- }
- if _, ok := primaryAddrAdded[gotAddressWithPefix]; !ok {
- t.Fatalf("GetMainNICAddress: got addressWithPrefix = %v, wanted any in {%v}", gotAddressWithPefix, primaryAddrAdded)
+ // At least one primary address was added, verify the returned
+ // address is in the list of primary addresses we added.
+ if _, ok := primaryAddrAdded[gotAddr]; !ok {
+ t.Fatalf("GetMainNICAddress: got = %s, want any in {%v}", gotAddr, primaryAddrAdded)
}
}
})
@@ -1107,9 +1466,11 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
}
func TestGetMainNICAddressAddRemove(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1134,19 +1495,25 @@ func TestGetMainNICAddressAddRemove(t *testing.T) {
}
// Check that we get the right initial address and prefix length.
- if gotAddressWithPrefix, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil {
+ gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
t.Fatal("GetMainNICAddress failed:", err)
- } else if gotAddressWithPrefix != protocolAddress.AddressWithPrefix {
- t.Fatalf("got GetMainNICAddress = %+v, want = %+v", gotAddressWithPrefix, protocolAddress.AddressWithPrefix)
+ }
+ if wantAddr := protocolAddress.AddressWithPrefix; gotAddr != wantAddr {
+ t.Fatalf("got s.GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr)
}
if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
- // Check that we get an error after removal.
- if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got s.GetMainNICAddress(...) = %v, want = %s", err, tcpip.ErrNoLinkAddress)
+ // Check that we get no address after removal.
+ gotAddr, err = s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
+ t.Fatal("GetMainNICAddress failed:", err)
+ }
+ if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr {
+ t.Fatalf("got GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr)
}
})
}
@@ -1161,8 +1528,10 @@ func (g *addressGenerator) next(addrLen int) tcpip.Address {
}
func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.ProtocolAddress) {
+ t.Helper()
+
if len(gotAddresses) != len(expectedAddresses) {
- t.Fatalf("got len(addresses) = %d, wanted = %d", len(gotAddresses), len(expectedAddresses))
+ t.Fatalf("got len(addresses) = %d, want = %d", len(gotAddresses), len(expectedAddresses))
}
sort.Slice(gotAddresses, func(i, j int) bool {
@@ -1182,9 +1551,11 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto
func TestAddAddress(t *testing.T) {
const nicid = 1
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, id); err != nil {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1201,15 +1572,17 @@ func TestAddAddress(t *testing.T) {
})
}
- gotAddresses := s.NICInfo()[nicid].ProtocolAddresses
+ gotAddresses := s.AllAddresses()[nicid]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestAddProtocolAddress(t *testing.T) {
const nicid = 1
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, id); err != nil {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1233,15 +1606,17 @@ func TestAddProtocolAddress(t *testing.T) {
}
}
- gotAddresses := s.NICInfo()[nicid].ProtocolAddresses
+ gotAddresses := s.AllAddresses()[nicid]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestAddAddressWithOptions(t *testing.T) {
const nicid = 1
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, id); err != nil {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1262,15 +1637,17 @@ func TestAddAddressWithOptions(t *testing.T) {
}
}
- gotAddresses := s.NICInfo()[nicid].ProtocolAddresses
+ gotAddresses := s.AllAddresses()[nicid]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestAddProtocolAddressWithOptions(t *testing.T) {
const nicid = 1
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, id); err != nil {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1297,15 +1674,17 @@ func TestAddProtocolAddressWithOptions(t *testing.T) {
}
}
- gotAddresses := s.NICInfo()[nicid].ProtocolAddresses
+ gotAddresses := s.AllAddresses()[nicid]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestNICStats(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id1, linkEP1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id1); err != nil {
- t.Fatal("CreateNIC failed:", err)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep1); err != nil {
+ t.Fatal("CreateNIC failed: ", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatal("AddAddress failed:", err)
@@ -1321,7 +1700,7 @@ func TestNICStats(t *testing.T) {
// Send a packet to address 1.
buf := buffer.NewView(30)
- linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep1.Inject(fakeNetNumber, buf.ToVectorisedView())
if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want {
t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want)
}
@@ -1332,10 +1711,12 @@ func TestNICStats(t *testing.T) {
payload := buffer.NewView(10)
// Write a packet out via the address for NIC 1
- sendTo(t, s, "\x01", payload)
- want := uint64(linkEP1.Drain())
+ if err := sendTo(s, "\x01", payload); err != nil {
+ t.Fatal("sendTo failed: ", err)
+ }
+ want := uint64(ep1.Drain())
if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want {
- t.Errorf("got Tx.Packets.Value() = %d, linkEP1.Drain() = %d", got, want)
+ t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want)
}
if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want {
@@ -1346,19 +1727,21 @@ func TestNICStats(t *testing.T) {
func TestNICForwarding(t *testing.T) {
// Create a stack with the fake network protocol, two NICs, each with
// an address.
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
s.SetForwarding(true)
- id1, linkEP1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id1); err != nil {
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep1); err != nil {
t.Fatal("CreateNIC #1 failed:", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatal("AddAddress #1 failed:", err)
}
- id2, linkEP2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, id2); err != nil {
+ ep2 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(2, ep2); err != nil {
t.Fatal("CreateNIC #2 failed:", err)
}
if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
@@ -1377,10 +1760,10 @@ func TestNICForwarding(t *testing.T) {
// Send a packet to address 3.
buf := buffer.NewView(30)
buf[0] = 3
- linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep1.Inject(fakeNetNumber, buf.ToVectorisedView())
select {
- case <-linkEP2.C:
+ case <-ep2.C:
default:
t.Fatal("Packet not forwarded")
}
@@ -1394,9 +1777,3 @@ func TestNICForwarding(t *testing.T) {
t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
}
}
-
-func init() {
- stack.RegisterNetworkProtocolFactory("fakeNet", func() stack.NetworkProtocol {
- return &fakeNetworkProtocol{}
- })
-}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index cf8a6d129..8c768c299 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -35,25 +35,109 @@ type protocolIDs struct {
type transportEndpoints struct {
// mu protects all fields of the transportEndpoints.
mu sync.RWMutex
- endpoints map[TransportEndpointID]TransportEndpoint
+ endpoints map[TransportEndpointID]*endpointsByNic
// rawEndpoints contains endpoints for raw sockets, which receive all
// traffic of a given protocol regardless of port.
rawEndpoints []RawTransportEndpoint
}
+type endpointsByNic struct {
+ mu sync.RWMutex
+ endpoints map[tcpip.NICID]*multiPortEndpoint
+ // seed is a random secret for a jenkins hash.
+ seed uint32
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
+ epsByNic.mu.RLock()
+
+ mpep, ok := epsByNic.endpoints[r.ref.nic.ID()]
+ if !ok {
+ if mpep, ok = epsByNic.endpoints[0]; !ok {
+ epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ return
+ }
+ }
+
+ // If this is a broadcast or multicast datagram, deliver the datagram to all
+ // endpoints bound to the right device.
+ if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) {
+ mpep.handlePacketAll(r, id, vv)
+ epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ return
+ }
+
+ // multiPortEndpoints are guaranteed to have at least one element.
+ selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, vv)
+ epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) {
+ epsByNic.mu.RLock()
+ defer epsByNic.mu.RUnlock()
+
+ mpep, ok := epsByNic.endpoints[n.ID()]
+ if !ok {
+ mpep, ok = epsByNic.endpoints[0]
+ }
+ if !ok {
+ return
+ }
+
+ // TODO(eyalsoha): Why don't we look at id to see if this packet needs to
+ // broadcast like we are doing with handlePacket above?
+
+ // multiPortEndpoints are guaranteed to have at least one element.
+ selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, vv)
+}
+
+// registerEndpoint returns true if it succeeds. It fails and returns
+// false if ep already has an element with the same key.
+func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+ epsByNic.mu.Lock()
+ defer epsByNic.mu.Unlock()
+
+ if multiPortEp, ok := epsByNic.endpoints[bindToDevice]; ok {
+ // There was already a bind.
+ return multiPortEp.singleRegisterEndpoint(t, reusePort)
+ }
+
+ // This is a new binding.
+ multiPortEp := &multiPortEndpoint{}
+ multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
+ multiPortEp.reuse = reusePort
+ epsByNic.endpoints[bindToDevice] = multiPortEp
+ return multiPortEp.singleRegisterEndpoint(t, reusePort)
+}
+
+// unregisterEndpoint returns true if endpointsByNic has to be unregistered.
+func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool {
+ epsByNic.mu.Lock()
+ defer epsByNic.mu.Unlock()
+ multiPortEp, ok := epsByNic.endpoints[bindToDevice]
+ if !ok {
+ return false
+ }
+ if multiPortEp.unregisterEndpoint(t) {
+ delete(epsByNic.endpoints, bindToDevice)
+ }
+ return len(epsByNic.endpoints) == 0
+}
+
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
-func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint) {
+func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
eps.mu.Lock()
defer eps.mu.Unlock()
- e, ok := eps.endpoints[id]
+ epsByNic, ok := eps.endpoints[id]
if !ok {
return
}
- if multiPortEp, ok := e.(*multiPortEndpoint); ok {
- if !multiPortEp.unregisterEndpoint(ep) {
- return
- }
+ if !epsByNic.unregisterEndpoint(bindToDevice, ep) {
+ return
}
delete(eps.endpoints, id)
}
@@ -75,7 +159,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
for netProto := range stack.networkProtocols {
for proto := range stack.transportProtocols {
d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{
- endpoints: make(map[TransportEndpointID]TransportEndpoint),
+ endpoints: make(map[TransportEndpointID]*endpointsByNic),
}
}
}
@@ -85,10 +169,10 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
// registerEndpoint registers the given endpoint with the dispatcher such that
// packets that match the endpoint ID are delivered to it.
-func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
for i, n := range netProtos {
- if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort); err != nil {
- d.unregisterEndpoint(netProtos[:i], protocol, id, ep)
+ if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil {
+ d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice)
return err
}
}
@@ -97,13 +181,14 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum
}
// multiPortEndpoint is a container for TransportEndpoints which are bound to
-// the same pair of address and port.
+// the same pair of address and port. endpointsArr always has at least one
+// element.
type multiPortEndpoint struct {
mu sync.RWMutex
endpointsArr []TransportEndpoint
endpointsMap map[TransportEndpoint]int
- // seed is a random secret for a jenkins hash.
- seed uint32
+ // reuse indicates if more than one endpoint is allowed.
+ reuse bool
}
// reciprocalScale scales a value into range [0, n).
@@ -117,9 +202,10 @@ func reciprocalScale(val, n uint32) uint32 {
// selectEndpoint calculates a hash of destination and source addresses and
// ports then uses it to select a socket. In this case, all packets from one
// address will be sent to same endpoint.
-func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEndpoint {
- ep.mu.RLock()
- defer ep.mu.RUnlock()
+func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint {
+ if len(mpep.endpointsArr) == 1 {
+ return mpep.endpointsArr[0]
+ }
payload := []byte{
byte(id.LocalPort),
@@ -128,51 +214,50 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd
byte(id.RemotePort >> 8),
}
- h := jenkins.Sum32(ep.seed)
+ h := jenkins.Sum32(seed)
h.Write(payload)
h.Write([]byte(id.LocalAddress))
h.Write([]byte(id.RemoteAddress))
hash := h.Sum32()
- idx := reciprocalScale(hash, uint32(len(ep.endpointsArr)))
- return ep.endpointsArr[idx]
+ idx := reciprocalScale(hash, uint32(len(mpep.endpointsArr)))
+ return mpep.endpointsArr[idx]
}
-// HandlePacket is called by the stack when new packets arrive to this transport
-// endpoint.
-func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
- // If this is a broadcast or multicast datagram, deliver the datagram to all
- // endpoints managed by ep.
- if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) {
- for i, endpoint := range ep.endpointsArr {
- // HandlePacket modifies vv, so each endpoint needs its own copy.
- if i == len(ep.endpointsArr)-1 {
- endpoint.HandlePacket(r, id, vv)
- break
- }
- vvCopy := buffer.NewView(vv.Size())
- copy(vvCopy, vv.ToView())
- endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView())
+func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
+ ep.mu.RLock()
+ for i, endpoint := range ep.endpointsArr {
+ // HandlePacket modifies vv, so each endpoint needs its own copy except for
+ // the final one.
+ if i == len(ep.endpointsArr)-1 {
+ endpoint.HandlePacket(r, id, vv)
+ break
}
- } else {
- ep.selectEndpoint(id).HandlePacket(r, id, vv)
+ vvCopy := buffer.NewView(vv.Size())
+ copy(vvCopy, vv.ToView())
+ endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView())
}
+ ep.mu.RUnlock() // Don't use defer for performance reasons.
}
-// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (ep *multiPortEndpoint) HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) {
- ep.selectEndpoint(id).HandleControlPacket(id, typ, extra, vv)
-}
-
-func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint) {
+// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
+// list. The list might be empty already.
+func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- // A new endpoint is added into endpointsArr and its index there is
- // saved in endpointsMap. This will allows to remove endpoint from
- // the array fast.
+ if len(ep.endpointsArr) > 0 {
+ // If it was previously bound, we need to check if we can bind again.
+ if !ep.reuse || !reusePort {
+ return tcpip.ErrPortInUse
+ }
+ }
+
+ // A new endpoint is added into endpointsArr and its index there is saved in
+ // endpointsMap. This will allow us to remove endpoint from the array fast.
ep.endpointsMap[t] = len(ep.endpointsArr)
ep.endpointsArr = append(ep.endpointsArr, t)
+ return nil
}
// unregisterEndpoint returns true if multiPortEndpoint has to be unregistered.
@@ -197,53 +282,41 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool {
return true
}
-func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
if id.RemotePort != 0 {
+ // TODO(eyalsoha): Why?
reusePort = false
}
eps, ok := d.protocol[protocolIDs{netProto, protocol}]
if !ok {
- return nil
+ return tcpip.ErrUnknownProtocol
}
eps.mu.Lock()
defer eps.mu.Unlock()
- var multiPortEp *multiPortEndpoint
- if _, ok := eps.endpoints[id]; ok {
- if !reusePort {
- return tcpip.ErrPortInUse
- }
- multiPortEp, ok = eps.endpoints[id].(*multiPortEndpoint)
- if !ok {
- return tcpip.ErrPortInUse
- }
+ if epsByNic, ok := eps.endpoints[id]; ok {
+ // There was already a binding.
+ return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
}
- if reusePort {
- if multiPortEp == nil {
- multiPortEp = &multiPortEndpoint{}
- multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
- multiPortEp.seed = rand.Uint32()
- eps.endpoints[id] = multiPortEp
- }
-
- multiPortEp.singleRegisterEndpoint(ep)
-
- return nil
+ // This is a new binding.
+ epsByNic := &endpointsByNic{
+ endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
+ seed: rand.Uint32(),
}
- eps.endpoints[id] = ep
+ eps.endpoints[id] = epsByNic
- return nil
+ return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
-func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) {
+func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
for _, n := range netProtos {
if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
- eps.unregisterEndpoint(id, ep)
+ eps.unregisterEndpoint(id, ep, bindToDevice)
}
}
}
@@ -273,7 +346,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// If the packet is a broadcast, then find all matching transport endpoints.
// Otherwise, try to find a single matching transport endpoint.
- destEps := make([]TransportEndpoint, 0, 1)
+ destEps := make([]*endpointsByNic, 0, 1)
eps.mu.RLock()
if protocol == header.UDPProtocolNumber && id.LocalAddress == header.IPv4Broadcast {
@@ -299,7 +372,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// Deliver the packet.
for _, ep := range destEps {
- ep.HandlePacket(r, id, vv)
+ ep.handlePacket(r, id, vv)
}
return true
@@ -331,7 +404,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr
// deliverControlPacket attempts to deliver the given control packet. Returns
// true if it found an endpoint, false otherwise.
-func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{net, trans}]
if !ok {
return false
@@ -348,12 +421,12 @@ func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber,
}
// Deliver the packet.
- ep.HandleControlPacket(id, typ, extra, vv)
+ ep.handleControlPacket(n, id, typ, extra, vv)
return true
}
-func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint {
+func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic {
// Try to find a match with the id as provided.
if ep, ok := eps.endpoints[id]; ok {
return ep
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
new file mode 100644
index 000000000..210233dc0
--- /dev/null
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -0,0 +1,352 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack_test
+
+import (
+ "math"
+ "math/rand"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+
+ stackAddr = "\x0a\x00\x00\x01"
+ stackPort = 1234
+ testPort = 4096
+)
+
+type testContext struct {
+ t *testing.T
+ linkEPs map[string]*channel.Endpoint
+ s *stack.Stack
+
+ ep tcpip.Endpoint
+ wq waiter.Queue
+}
+
+func (c *testContext) cleanup() {
+ if c.ep != nil {
+ c.ep.Close()
+ }
+}
+
+func (c *testContext) createV6Endpoint(v6only bool) {
+ var err *tcpip.Error
+ c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ var v tcpip.V6OnlyOption
+ if v6only {
+ v = 1
+ }
+ if err := c.ep.SetSockOpt(v); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+}
+
+// newDualTestContextMultiNic creates the testing context and also linkEpNames
+// named NICs.
+func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
+ linkEPs := make(map[string]*channel.Endpoint)
+ for i, linkEpName := range linkEpNames {
+ channelEP := channel.New(256, mtu, "")
+ nicid := tcpip.NICID(i + 1)
+ if err := s.CreateNamedNIC(nicid, linkEpName, channelEP); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+ linkEPs[linkEpName] = channelEP
+
+ if err := s.AddAddress(nicid, ipv4.ProtocolNumber, stackAddr); err != nil {
+ t.Fatalf("AddAddress IPv4 failed: %v", err)
+ }
+
+ if err := s.AddAddress(nicid, ipv6.ProtocolNumber, stackV6Addr); err != nil {
+ t.Fatalf("AddAddress IPv6 failed: %v", err)
+ }
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: 1,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: 1,
+ },
+ })
+
+ return &testContext{
+ t: t,
+ s: s,
+ linkEPs: linkEPs,
+ }
+}
+
+type headers struct {
+ srcPort uint16
+ dstPort uint16
+}
+
+func newPayload() []byte {
+ b := make([]byte, 30+rand.Intn(100))
+ for i := range b {
+ b[i] = byte(rand.Intn(256))
+ }
+ return b
+}
+
+func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) {
+ // Allocate a buffer for data and headers.
+ buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
+ copy(buf[len(buf)-len(payload):], payload)
+
+ // Initialize the IP header.
+ ip := header.IPv6(buf)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: 65,
+ SrcAddr: testV6Addr,
+ DstAddr: stackV6Addr,
+ })
+
+ // Initialize the UDP header.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.Encode(&header.UDPFields{
+ SrcPort: h.srcPort,
+ DstPort: h.dstPort,
+ Length: uint16(header.UDPMinimumSize + len(payload)),
+ })
+
+ // Calculate the UDP pseudo-header checksum.
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u)))
+
+ // Calculate the UDP checksum and set it.
+ xsum = header.Checksum(payload, xsum)
+ u.SetChecksum(^u.CalculateChecksum(xsum))
+
+ // Inject packet.
+ c.linkEPs[linkEpName].Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
+}
+
+func TestTransportDemuxerRegister(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ want *tcpip.Error
+ }{
+ {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol},
+ {"success", ipv4.ProtocolNumber, nil},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
+ if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want {
+ t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want)
+ }
+ })
+ }
+}
+
+// TestReuseBindToDevice injects varied packets on input devices and checks that
+// the distribution of packets received matches expectations.
+func TestDistribution(t *testing.T) {
+ type endpointSockopts struct {
+ reuse int
+ bindToDevice string
+ }
+ for _, test := range []struct {
+ name string
+ // endpoints will received the inject packets.
+ endpoints []endpointSockopts
+ // wantedDistribution is the wanted ratio of packets received on each
+ // endpoint for each NIC on which packets are injected.
+ wantedDistributions map[string][]float64
+ }{
+ {
+ "BindPortReuse",
+ // 5 endpoints that all have reuse set.
+ []endpointSockopts{
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ },
+ map[string][]float64{
+ // Injected packets on dev0 get distributed evenly.
+ "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2},
+ },
+ },
+ {
+ "BindToDevice",
+ // 3 endpoints with various bindings.
+ []endpointSockopts{
+ endpointSockopts{0, "dev0"},
+ endpointSockopts{0, "dev1"},
+ endpointSockopts{0, "dev2"},
+ },
+ map[string][]float64{
+ // Injected packets on dev0 go only to the endpoint bound to dev0.
+ "dev0": []float64{1, 0, 0},
+ // Injected packets on dev1 go only to the endpoint bound to dev1.
+ "dev1": []float64{0, 1, 0},
+ // Injected packets on dev2 go only to the endpoint bound to dev2.
+ "dev2": []float64{0, 0, 1},
+ },
+ },
+ {
+ "ReuseAndBindToDevice",
+ // 6 endpoints with various bindings.
+ []endpointSockopts{
+ endpointSockopts{1, "dev0"},
+ endpointSockopts{1, "dev0"},
+ endpointSockopts{1, "dev1"},
+ endpointSockopts{1, "dev1"},
+ endpointSockopts{1, "dev1"},
+ endpointSockopts{1, ""},
+ },
+ map[string][]float64{
+ // Injected packets on dev0 get distributed among endpoints bound to
+ // dev0.
+ "dev0": []float64{0.5, 0.5, 0, 0, 0, 0},
+ // Injected packets on dev1 get distributed among endpoints bound to
+ // dev1 or unbound.
+ "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
+ // Injected packets on dev999 go only to the unbound.
+ "dev999": []float64{0, 0, 0, 0, 0, 1},
+ },
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ for device, wantedDistribution := range test.wantedDistributions {
+ t.Run(device, func(t *testing.T) {
+ var devices []string
+ for d := range test.wantedDistributions {
+ devices = append(devices, d)
+ }
+ c := newDualTestContextMultiNic(t, defaultMTU, devices)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ eps := make(map[tcpip.Endpoint]int)
+
+ pollChannel := make(chan tcpip.Endpoint)
+ for i, endpoint := range test.endpoints {
+ // Try to receive the data.
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+
+ var err *tcpip.Error
+ ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ eps[ep] = i
+
+ go func(ep tcpip.Endpoint) {
+ for range ch {
+ pollChannel <- ep
+ }
+ }(ep)
+
+ defer ep.Close()
+ reusePortOption := tcpip.ReusePortOption(endpoint.reuse)
+ if err := ep.SetSockOpt(reusePortOption); err != nil {
+ c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err)
+ }
+ bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
+ if err := ep.SetSockOpt(bindToDeviceOption); err != nil {
+ c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil {
+ t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err)
+ }
+ }
+
+ npackets := 100000
+ nports := 10000
+ if got, want := len(test.endpoints), len(wantedDistribution); got != want {
+ t.Fatalf("got len(test.endpoints) = %d, want %d", got, want)
+ }
+ ports := make(map[uint16]tcpip.Endpoint)
+ stats := make(map[tcpip.Endpoint]int)
+ for i := 0; i < npackets; i++ {
+ // Send a packet.
+ port := uint16(i % nports)
+ payload := newPayload()
+ c.sendV6Packet(payload,
+ &headers{
+ srcPort: testPort + port,
+ dstPort: stackPort},
+ device)
+
+ var addr tcpip.FullAddress
+ ep := <-pollChannel
+ _, _, err := ep.Read(&addr)
+ if err != nil {
+ c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err)
+ }
+ stats[ep]++
+ if i < nports {
+ ports[uint16(i)] = ep
+ } else {
+ // Check that all packets from one client are handled by the same
+ // socket.
+ if want, got := ports[port], ep; want != got {
+ t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got])
+ }
+ }
+ }
+
+ // Check that a packet distribution is as expected.
+ for ep, i := range eps {
+ wantedRatio := wantedDistribution[i]
+ wantedRecv := wantedRatio * float64(npackets)
+ actualRecv := stats[ep]
+ actualRatio := float64(stats[ep]) / float64(npackets)
+ // The deviation is less than 10%.
+ if math.Abs(actualRatio-wantedRatio) > 0.05 {
+ t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets)
+ }
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 5335897f5..842a16277 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -65,13 +65,13 @@ func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.Contr
return buffer.View{}, tcpip.ControlMessages{}, nil
}
-func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
if len(f.route.RemoteAddress) == 0 {
return 0, nil, tcpip.ErrNoRoute
}
hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()))
- v, err := p.Get(p.Size())
+ v, err := p.FullPayload()
if err != nil {
return 0, nil, err
}
@@ -91,6 +91,11 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+// SetSockOptInt sets a socket option. Currently not supported.
+func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOpt, int) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
return -1, tcpip.ErrUnknownProtocolOption
@@ -122,7 +127,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Try to register so that we can start receiving packets.
f.id.RemoteAddress = addr.Addr
- err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false)
+ err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false /* reuse */, 0 /* bindToDevice */)
if err != nil {
return err
}
@@ -163,7 +168,8 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
fakeTransNumber,
stack.TransportEndpointID{LocalAddress: a.Addr},
f,
- false,
+ false, /* reuse */
+ 0, /* bindtoDevice */
); err != nil {
return err
}
@@ -251,7 +257,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp
return 0, 0, nil
}
-func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool {
+func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool {
return true
}
@@ -277,10 +283,17 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error {
}
}
+func fakeTransFactory() stack.TransportProtocol {
+ return &fakeTransportProtocol{}
+}
+
func TestTransportReceive(t *testing.T) {
- id, linkEP := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
- if err := s.CreateNIC(1, id); err != nil {
+ linkEP := channel.New(10, defaultMTU, "")
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
+ if err := s.CreateNIC(1, linkEP); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -340,9 +353,12 @@ func TestTransportReceive(t *testing.T) {
}
func TestTransportControlReceive(t *testing.T) {
- id, linkEP := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
- if err := s.CreateNIC(1, id); err != nil {
+ linkEP := channel.New(10, defaultMTU, "")
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
+ if err := s.CreateNIC(1, linkEP); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -408,9 +424,12 @@ func TestTransportControlReceive(t *testing.T) {
}
func TestTransportSend(t *testing.T) {
- id, _ := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
- if err := s.CreateNIC(1, id); err != nil {
+ linkEP := channel.New(10, defaultMTU, "")
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
+ if err := s.CreateNIC(1, linkEP); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -452,7 +471,10 @@ func TestTransportSend(t *testing.T) {
}
func TestTransportOptions(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
// Try an unsupported transport protocol.
if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol {
@@ -493,20 +515,23 @@ func TestTransportOptions(t *testing.T) {
}
func TestTransportForwarding(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
s.SetForwarding(true)
// TODO(b/123449044): Change this to a channel NIC.
- id1 := loopback.New()
- if err := s.CreateNIC(1, id1); err != nil {
+ ep1 := loopback.New()
+ if err := s.CreateNIC(1, ep1); err != nil {
t.Fatalf("CreateNIC #1 failed: %v", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatalf("AddAddress #1 failed: %v", err)
}
- id2, linkEP2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, id2); err != nil {
+ ep2 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(2, ep2); err != nil {
t.Fatalf("CreateNIC #2 failed: %v", err)
}
if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
@@ -545,7 +570,7 @@ func TestTransportForwarding(t *testing.T) {
req[0] = 1
req[1] = 3
req[2] = byte(fakeTransNumber)
- linkEP2.Inject(fakeNetNumber, req.ToVectorisedView())
+ ep2.Inject(fakeNetNumber, req.ToVectorisedView())
aep, _, err := ep.Accept()
if err != nil || aep == nil {
@@ -559,7 +584,7 @@ func TestTransportForwarding(t *testing.T) {
var p channel.PacketInfo
select {
- case p = <-linkEP2.C:
+ case p = <-ep2.C:
default:
t.Fatal("Response packet not forwarded")
}
@@ -571,9 +596,3 @@ func TestTransportForwarding(t *testing.T) {
t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src)
}
}
-
-func init() {
- stack.RegisterTransportProtocolFactory("fakeTrans", func() stack.TransportProtocol {
- return &fakeTransportProtocol{}
- })
-}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 043dd549b..faaa4a4e3 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -219,6 +219,15 @@ func (s *Subnet) Mask() AddressMask {
return s.mask
}
+// Broadcast returns the subnet's broadcast address.
+func (s *Subnet) Broadcast() Address {
+ addr := []byte(s.address)
+ for i := range addr {
+ addr[i] |= ^s.mask[i]
+ }
+ return Address(addr)
+}
+
// NICID is a number that uniquely identifies a NIC.
type NICID int32
@@ -252,31 +261,34 @@ type FullAddress struct {
Port uint16
}
-// Payload provides an interface around data that is being sent to an endpoint.
-// This allows the endpoint to request the amount of data it needs based on
-// internal buffers without exposing them. 'p.Get(p.Size())' reads all the data.
-type Payload interface {
- // Get returns a slice containing exactly 'min(size, p.Size())' bytes.
- Get(size int) ([]byte, *Error)
-
- // Size returns the payload size.
- Size() int
+// Payloader is an interface that provides data.
+//
+// This interface allows the endpoint to request the amount of data it needs
+// based on internal buffers without exposing them.
+type Payloader interface {
+ // FullPayload returns all available bytes.
+ FullPayload() ([]byte, *Error)
+
+ // Payload returns a slice containing at most size bytes.
+ Payload(size int) ([]byte, *Error)
}
-// SlicePayload implements Payload on top of slices for convenience.
+// SlicePayload implements Payloader for slices.
+//
+// This is typically used for tests.
type SlicePayload []byte
-// Get implements Payload.
-func (s SlicePayload) Get(size int) ([]byte, *Error) {
- if size > s.Size() {
- size = s.Size()
- }
- return s[:size], nil
+// FullPayload implements Payloader.FullPayload.
+func (s SlicePayload) FullPayload() ([]byte, *Error) {
+ return s, nil
}
-// Size implements Payload.
-func (s SlicePayload) Size() int {
- return len(s)
+// Payload implements Payloader.Payload.
+func (s SlicePayload) Payload(size int) ([]byte, *Error) {
+ if size > len(s) {
+ size = len(s)
+ }
+ return s[:size], nil
}
// A ControlMessages contains socket control messages for IP sockets.
@@ -329,7 +341,7 @@ type Endpoint interface {
// ErrNoLinkAddress and a notification channel is returned for the caller to
// block. Channel is closed once address resolution is complete (success or
// not). The channel is only non-nil in this case.
- Write(Payload, WriteOptions) (int64, <-chan struct{}, *Error)
+ Write(Payloader, WriteOptions) (int64, <-chan struct{}, *Error)
// Peek reads data without consuming it from the endpoint.
//
@@ -389,6 +401,10 @@ type Endpoint interface {
// SetSockOpt sets a socket option. opt should be one of the *Option types.
SetSockOpt(opt interface{}) *Error
+ // SetSockOptInt sets a socket option, for simple cases where a value
+ // has the int type.
+ SetSockOptInt(opt SockOpt, v int) *Error
+
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// *Option types.
GetSockOpt(opt interface{}) *Error
@@ -423,16 +439,33 @@ type WriteOptions struct {
// EndOfRecord has the same semantics as Linux's MSG_EOR.
EndOfRecord bool
+
+ // Atomic means that all data fetched from Payloader must be written to the
+ // endpoint. If Atomic is false, then data fetched from the Payloader may be
+ // discarded if available endpoint buffer space is unsufficient.
+ Atomic bool
}
// SockOpt represents socket options which values have the int type.
type SockOpt int
const (
- // ReceiveQueueSizeOption is used in GetSockOpt to specify that the number of
- // unread bytes in the input buffer should be returned.
+ // ReceiveQueueSizeOption is used in GetSockOptInt to specify that the
+ // number of unread bytes in the input buffer should be returned.
ReceiveQueueSizeOption SockOpt = iota
+ // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to
+ // specify the send buffer size option.
+ SendBufferSizeOption
+
+ // ReceiveBufferSizeOption is used by SetSockOptInt/GetSockOptInt to
+ // specify the receive buffer size option.
+ ReceiveBufferSizeOption
+
+ // SendQueueSizeOption is used in GetSockOptInt to specify that the
+ // number of unread bytes in the output buffer should be returned.
+ SendQueueSizeOption
+
// TODO(b/137664753): convert all int socket options to be handled via
// GetSockOptInt.
)
@@ -441,18 +474,6 @@ const (
// the endpoint should be cleared and returned.
type ErrorOption struct{}
-// SendBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the send
-// buffer size option.
-type SendBufferSizeOption int
-
-// ReceiveBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the
-// receive buffer size option.
-type ReceiveBufferSizeOption int
-
-// SendQueueSizeOption is used in GetSockOpt to specify that the number of
-// unread bytes in the output buffer should be returned.
-type SendQueueSizeOption int
-
// V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6
// socket is to be restricted to sending and receiving IPv6 packets only.
type V6OnlyOption int
@@ -474,6 +495,10 @@ type ReuseAddressOption int
// to be bound to an identical socket address.
type ReusePortOption int
+// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets
+// should bind only on a specific NIC.
+type BindToDeviceOption string
+
// QuickAckOption is stubbed out in SetSockOpt/GetSockOpt.
type QuickAckOption int
@@ -581,7 +606,7 @@ type Route struct {
}
// String implements the fmt.Stringer interface.
-func (r *Route) String() string {
+func (r Route) String() string {
var out strings.Builder
fmt.Fprintf(&out, "%s", r.Destination)
if len(r.Gateway) > 0 {
@@ -591,9 +616,6 @@ func (r *Route) String() string {
return out.String()
}
-// LinkEndpointID represents a data link layer endpoint.
-type LinkEndpointID uint64
-
// TransportProtocolNumber is the number of a transport protocol.
type TransportProtocolNumber uint32
@@ -720,6 +742,10 @@ type ICMPv4SentPacketStats struct {
// Dropped is the total number of ICMPv4 packets dropped due to link
// layer errors.
Dropped *StatCounter
+
+ // RateLimited is the total number of ICMPv6 packets dropped due to
+ // rate limit being exceeded.
+ RateLimited *StatCounter
}
// ICMPv4ReceivedPacketStats collects inbound ICMPv4-specific stats.
@@ -738,6 +764,10 @@ type ICMPv6SentPacketStats struct {
// Dropped is the total number of ICMPv6 packets dropped due to link
// layer errors.
Dropped *StatCounter
+
+ // RateLimited is the total number of ICMPv6 packets dropped due to
+ // rate limit being exceeded.
+ RateLimited *StatCounter
}
// ICMPv6ReceivedPacketStats collects inbound ICMPv6-specific stats.
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 451d3880e..a3a910d41 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -15,7 +15,6 @@
package icmp
import (
- "encoding/binary"
"sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -105,7 +104,7 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e, 0 /* bindToDevice */)
}
// Close the receive list and drain it.
@@ -205,7 +204,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
@@ -290,7 +289,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
}
}
- v, err := p.Get(p.Size())
+ v, err := p.FullPayload()
if err != nil {
return 0, nil, err
}
@@ -320,6 +319,11 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
+// SetSockOptInt sets a socket option. Currently not supported.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return nil
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
@@ -332,6 +336,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
}
e.rcvMu.Unlock()
return v, nil
+ case tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ v := e.sndBufSize
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ v := e.rcvBufSizeMax
+ e.rcvMu.Unlock()
+ return v, nil
+
}
return -1, tcpip.ErrUnknownProtocolOption
}
@@ -342,18 +358,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.ErrorOption:
return nil
- case *tcpip.SendBufferSizeOption:
- e.mu.Lock()
- *o = tcpip.SendBufferSizeOption(e.sndBufSize)
- e.mu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- e.rcvMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax)
- e.rcvMu.Unlock()
- return nil
-
case *tcpip.KeepaliveEnabledOption:
*o = 0
return nil
@@ -368,14 +372,13 @@ func send4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- // Set the ident to the user-specified port. Sequence number should
- // already be set by the user.
- binary.BigEndian.PutUint16(data[header.ICMPv4PayloadOffset:], ident)
-
hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength()))
icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
copy(icmpv4, data)
+ // Set the ident to the user-specified port. Sequence number should
+ // already be set by the user.
+ icmpv4.SetIdent(ident)
data = data[header.ICMPv4MinimumSize:]
// Linux performs these basic checks.
@@ -394,14 +397,13 @@ func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- // Set the ident. Sequence number is provided by the user.
- binary.BigEndian.PutUint16(data[header.ICMPv6MinimumSize:], ident)
-
- hdr := buffer.NewPrependable(header.ICMPv6EchoMinimumSize + int(r.MaxHeaderLength()))
+ hdr := buffer.NewPrependable(header.ICMPv6MinimumSize + int(r.MaxHeaderLength()))
- icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
+ icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
copy(icmpv6, data)
- data = data[header.ICMPv6EchoMinimumSize:]
+ // Set the ident. Sequence number is provided by the user.
+ icmpv6.SetIdent(ident)
+ data = data[header.ICMPv6MinimumSize:]
if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 {
return tcpip.ErrInvalidEndpointState
@@ -541,14 +543,14 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false /* reuse */, 0 /* bindToDevice */)
return id, err
}
// We need to find a port for the endpoint.
_, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false /* reuse */, 0 /* bindtodevice */)
switch err {
case nil:
return true, nil
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 7fdba5d56..bfb16f7c3 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -14,16 +14,14 @@
// Package icmp contains the implementation of the ICMP and IPv6-ICMP transport
// protocols for use in ping. To use it in the networking stack, this package
-// must be added to the project, and
-// activated on the stack by passing icmp.ProtocolName (or "icmp") and/or
-// icmp.ProtocolName6 (or "icmp6") as one of the transport protocols when
-// calling stack.New(). Then endpoints can be created by passing
+// must be added to the project, and activated on the stack by passing
+// icmp.NewProtocol4() and/or icmp.NewProtocol6() as one of the transport
+// protocols when calling stack.New(). Then endpoints can be created by passing
// icmp.ProtocolNumber or icmp.ProtocolNumber6 as the transport protocol number
// when calling Stack.NewEndpoint().
package icmp
import (
- "encoding/binary"
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -35,15 +33,9 @@ import (
)
const (
- // ProtocolName4 is the string representation of the icmp protocol name.
- ProtocolName4 = "icmp4"
-
// ProtocolNumber4 is the ICMP protocol number.
ProtocolNumber4 = header.ICMPv4ProtocolNumber
- // ProtocolName6 is the string representation of the icmp protocol name.
- ProtocolName6 = "icmp6"
-
// ProtocolNumber6 is the IPv6-ICMP protocol number.
ProtocolNumber6 = header.ICMPv6ProtocolNumber
)
@@ -92,7 +84,7 @@ func (p *protocol) MinimumPacketSize() int {
case ProtocolNumber4:
return header.ICMPv4MinimumSize
case ProtocolNumber6:
- return header.ICMPv6EchoMinimumSize
+ return header.ICMPv6MinimumSize
}
panic(fmt.Sprint("unknown protocol number: ", p.number))
}
@@ -101,16 +93,18 @@ func (p *protocol) MinimumPacketSize() int {
func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
switch p.number {
case ProtocolNumber4:
- return 0, binary.BigEndian.Uint16(v[header.ICMPv4PayloadOffset:]), nil
+ hdr := header.ICMPv4(v)
+ return 0, hdr.Ident(), nil
case ProtocolNumber6:
- return 0, binary.BigEndian.Uint16(v[header.ICMPv6MinimumSize:]), nil
+ hdr := header.ICMPv6(v)
+ return 0, hdr.Ident(), nil
}
panic(fmt.Sprint("unknown protocol number: ", p.number))
}
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool {
+func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool {
return true
}
@@ -124,12 +118,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-func init() {
- stack.RegisterTransportProtocolFactory(ProtocolName4, func() stack.TransportProtocol {
- return &protocol{ProtocolNumber4}
- })
+// NewProtocol4 returns an ICMPv4 transport protocol.
+func NewProtocol4() stack.TransportProtocol {
+ return &protocol{ProtocolNumber4}
+}
- stack.RegisterTransportProtocolFactory(ProtocolName6, func() stack.TransportProtocol {
- return &protocol{ProtocolNumber6}
- })
+// NewProtocol6 returns an ICMPv6 transport protocol.
+func NewProtocol6() stack.TransportProtocol {
+ return &protocol{ProtocolNumber6}
}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 13e17e2a6..a02731a5d 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -207,7 +207,7 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes
}
// Write implements tcpip.Endpoint.Write.
-func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
@@ -220,9 +220,8 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64
return 0, nil, tcpip.ErrInvalidEndpointState
}
- payloadBytes, err := payload.Get(payload.Size())
+ payloadBytes, err := p.FullPayload()
if err != nil {
- ep.mu.RUnlock()
return 0, nil, err
}
@@ -230,7 +229,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64
// destination address, route using that address.
if !ep.associated {
ip := header.IPv4(payloadBytes)
- if !ip.IsValid(payload.Size()) {
+ if !ip.IsValid(len(payloadBytes)) {
ep.mu.RUnlock()
return 0, nil, tcpip.ErrInvalidOptionValue
}
@@ -493,6 +492,11 @@ func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
@@ -505,6 +509,19 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
}
ep.rcvMu.Unlock()
return v, nil
+
+ case tcpip.SendBufferSizeOption:
+ ep.mu.Lock()
+ v := ep.sndBufSize
+ ep.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ ep.rcvMu.Lock()
+ v := ep.rcvBufSizeMax
+ ep.rcvMu.Unlock()
+ return v, nil
+
}
return -1, tcpip.ErrUnknownProtocolOption
@@ -516,18 +533,6 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.ErrorOption:
return nil
- case *tcpip.SendBufferSizeOption:
- ep.mu.Lock()
- *o = tcpip.SendBufferSizeOption(ep.sndBufSize)
- ep.mu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- ep.rcvMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(ep.rcvBufSizeMax)
- ep.rcvMu.Unlock()
- return nil
-
case *tcpip.KeepaliveEnabledOption:
*o = 0
return nil
diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go
index 783c21e6b..a2512d666 100644
--- a/pkg/tcpip/transport/raw/protocol.go
+++ b/pkg/tcpip/transport/raw/protocol.go
@@ -20,13 +20,10 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-type factory struct{}
+// EndpointFactory implements stack.UnassociatedEndpointFactory.
+type EndpointFactory struct{}
// NewUnassociatedRawEndpoint implements stack.UnassociatedEndpointFactory.
-func (factory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (EndpointFactory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */)
}
-
-func init() {
- stack.RegisterUnassociatedFactory(factory{})
-}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 1ee1a53f8..a42e1f4a2 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "tcp_segment_list",
@@ -47,6 +49,7 @@ go_library(
"//pkg/sleep",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
"//pkg/tcpip/iptables",
"//pkg/tcpip/seqnum",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index e9c5099ea..3ae4a5426 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -143,6 +143,15 @@ func decSynRcvdCount() {
synRcvdCount.Unlock()
}
+// synCookiesInUse() returns true if the synRcvdCount is greater than
+// SynRcvdCountThreshold.
+func synCookiesInUse() bool {
+ synRcvdCount.Lock()
+ v := synRcvdCount.value
+ synRcvdCount.Unlock()
+ return v >= SynRcvdCountThreshold
+}
+
// newListenContext creates a new listen context.
func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
@@ -233,7 +242,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n.initGSO()
// Register new endpoint so that packets are routed to it.
- if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort); err != nil {
+ if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort, n.bindToDevice); err != nil {
n.Close()
return nil, err
}
@@ -446,6 +455,27 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
return
}
+ if !synCookiesInUse() {
+ // Send a reset as this is an ACK for which there is no
+ // half open connections and we are not using cookies
+ // yet.
+ //
+ // The only time we should reach here when a connection
+ // was opened and closed really quickly and a delayed
+ // ACK was received from the sender.
+ replyWithReset(s)
+ return
+ }
+
+ // Since SYN cookies are in use this is potentially an ACK to a
+ // SYN-ACK we sent but don't have a half open connection state
+ // as cookies are being used to protect against a potential SYN
+ // flood. In such cases validate the cookie and if valid create
+ // a fully connected endpoint and deliver to the accept queue.
+ //
+ // If not, silently drop the ACK to avoid leaking information
+ // when under a potential syn flood attack.
+ //
// Validate the cookie.
data, ok := ctx.isCookieValid(s.id, s.ackNumber-1, s.sequenceNumber-1)
if !ok || int(data) >= len(mssTable) {
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 00d2ae524..21038a65a 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -720,13 +720,18 @@ func (e *endpoint) handleClose() *tcpip.Error {
return nil
}
-// resetConnectionLocked sends a RST segment and puts the endpoint in an error
-// state with the given error code. This method must only be called from the
-// protocol goroutine.
+// resetConnectionLocked puts the endpoint in an error state with the given
+// error code and sends a RST if and only if the error is not ErrConnectionReset
+// indicating that the connection is being reset due to receiving a RST. This
+// method must only be called from the protocol goroutine.
func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
- e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0)
+ // Only send a reset if the connection is being aborted for a reason
+ // other than receiving a reset.
e.state = StateError
e.hardError = err
+ if err != tcpip.ErrConnectionReset {
+ e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0)
+ }
}
// completeWorkerLocked is called by the worker goroutine when it's about to
@@ -806,7 +811,7 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
if e.keepalive.unacked >= e.keepalive.count {
e.keepalive.Unlock()
- return tcpip.ErrConnectionReset
+ return tcpip.ErrTimeout
}
// RFC1122 4.2.3.6: TCP keepalive is a dataless ACK with
@@ -1068,6 +1073,10 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
e.workMu.Lock()
if err := funcs[v].f(); err != nil {
e.mu.Lock()
+ // Ensure we release all endpoint registration and route
+ // references as the connection is now in an error
+ // state.
+ e.workerCleanup = true
e.resetConnectionLocked(err)
// Lock released below.
epilogue()
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index ac927569a..f9d5e0085 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -15,6 +15,7 @@
package tcp
import (
+ "encoding/binary"
"fmt"
"math"
"strings"
@@ -26,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
@@ -280,6 +282,9 @@ type endpoint struct {
// reusePort is set to true if SO_REUSEPORT is enabled.
reusePort bool
+ // bindToDevice is set to the NIC on which to bind or disabled if 0.
+ bindToDevice tcpip.NICID
+
// delay enables Nagle's algorithm.
//
// delay is a boolean (0 is false) and must be accessed atomically.
@@ -564,11 +569,11 @@ func (e *endpoint) Close() {
// in Listen() when trying to register.
if e.state == StateListen && e.isPortReserved {
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
e.isRegistered = false
}
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice)
e.isPortReserved = false
}
@@ -625,12 +630,12 @@ func (e *endpoint) cleanupLocked() {
e.workerCleanup = false
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
e.isRegistered = false
}
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice)
e.isPortReserved = false
}
@@ -806,7 +811,7 @@ func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
}
// Write writes data to the endpoint's peer.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// Linux completely ignores any address passed to sendto(2) for TCP sockets
// (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
// and opts.EndOfRecord are also ignored.
@@ -821,47 +826,52 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
return 0, nil, err
}
- e.sndBufMu.Unlock()
- e.mu.RUnlock()
-
- // Nothing to do if the buffer is empty.
- if p.Size() == 0 {
- return 0, nil, nil
+ // We can release locks while copying data.
+ //
+ // This is not possible if atomic is set, because we can't allow the
+ // available buffer space to be consumed by some other caller while we
+ // are copying data in.
+ if !opts.Atomic {
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
}
- // Copy in memory without holding sndBufMu so that worker goroutine can
- // make progress independent of this operation.
- v, perr := p.Get(avail)
- if perr != nil {
+ // Fetch data.
+ v, perr := p.Payload(avail)
+ if perr != nil || len(v) == 0 {
+ if opts.Atomic { // See above.
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+ }
+ // Note that perr may be nil if len(v) == 0.
return 0, nil, perr
}
- e.mu.RLock()
- e.sndBufMu.Lock()
+ if !opts.Atomic { // See above.
+ e.mu.RLock()
+ e.sndBufMu.Lock()
- // Because we released the lock before copying, check state again
- // to make sure the endpoint is still in a valid state for a
- // write.
- avail, err = e.isEndpointWritableLocked()
- if err != nil {
- e.sndBufMu.Unlock()
- e.mu.RUnlock()
- return 0, nil, err
- }
+ // Because we released the lock before copying, check state again
+ // to make sure the endpoint is still in a valid state for a write.
+ avail, err = e.isEndpointWritableLocked()
+ if err != nil {
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+ return 0, nil, err
+ }
- // Discard any excess data copied in due to avail being reduced due to a
- // simultaneous write call to the socket.
- if avail < len(v) {
- v = v[:avail]
+ // Discard any excess data copied in due to avail being reduced due
+ // to a simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
+ }
}
// Add data to the send queue.
- l := len(v)
s := newSegmentFromView(&e.route, e.id, v)
- e.sndBufUsed += l
- e.sndBufInQueue += seqnum.Size(l)
+ e.sndBufUsed += len(v)
+ e.sndBufInQueue += seqnum.Size(len(v))
e.sndQueue.PushBack(s)
-
e.sndBufMu.Unlock()
// Release the endpoint lock to prevent deadlocks due to lock
// order inversion when acquiring workMu.
@@ -875,7 +885,8 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
// Let the protocol goroutine do the work.
e.sndWaker.Assert()
}
- return int64(l), nil, nil
+
+ return int64(len(v)), nil, nil
}
// Peek reads data without consuming it from the endpoint.
@@ -946,62 +957,9 @@ func (e *endpoint) zeroReceiveWindow(scale uint8) bool {
return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0
}
-// SetSockOpt sets a socket option.
-func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- switch v := opt.(type) {
- case tcpip.DelayOption:
- if v == 0 {
- atomic.StoreUint32(&e.delay, 0)
-
- // Handle delayed data.
- e.sndWaker.Assert()
- } else {
- atomic.StoreUint32(&e.delay, 1)
- }
- return nil
-
- case tcpip.CorkOption:
- if v == 0 {
- atomic.StoreUint32(&e.cork, 0)
-
- // Handle the corked data.
- e.sndWaker.Assert()
- } else {
- atomic.StoreUint32(&e.cork, 1)
- }
- return nil
-
- case tcpip.ReuseAddressOption:
- e.mu.Lock()
- e.reuseAddr = v != 0
- e.mu.Unlock()
- return nil
-
- case tcpip.ReusePortOption:
- e.mu.Lock()
- e.reusePort = v != 0
- e.mu.Unlock()
- return nil
-
- case tcpip.QuickAckOption:
- if v == 0 {
- atomic.StoreUint32(&e.slowAck, 1)
- } else {
- atomic.StoreUint32(&e.slowAck, 0)
- }
- return nil
-
- case tcpip.MaxSegOption:
- userMSS := v
- if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
- return tcpip.ErrInvalidOptionValue
- }
- e.mu.Lock()
- e.userMSS = int(userMSS)
- e.mu.Unlock()
- e.notifyProtocolGoroutine(notifyMSSChanged)
- return nil
-
+// SetSockOptInt sets a socket option.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ switch opt {
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
// allowed.
@@ -1065,6 +1023,82 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.sndBufMu.Unlock()
return nil
+ default:
+ return nil
+ }
+}
+
+// SetSockOpt sets a socket option.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch v := opt.(type) {
+ case tcpip.DelayOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.delay, 0)
+
+ // Handle delayed data.
+ e.sndWaker.Assert()
+ } else {
+ atomic.StoreUint32(&e.delay, 1)
+ }
+ return nil
+
+ case tcpip.CorkOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.cork, 0)
+
+ // Handle the corked data.
+ e.sndWaker.Assert()
+ } else {
+ atomic.StoreUint32(&e.cork, 1)
+ }
+ return nil
+
+ case tcpip.ReuseAddressOption:
+ e.mu.Lock()
+ e.reuseAddr = v != 0
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.ReusePortOption:
+ e.mu.Lock()
+ e.reusePort = v != 0
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.BindToDeviceOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if v == "" {
+ e.bindToDevice = 0
+ return nil
+ }
+ for nicid, nic := range e.stack.NICInfo() {
+ if nic.Name == string(v) {
+ e.bindToDevice = nicid
+ return nil
+ }
+ }
+ return tcpip.ErrUnknownDevice
+
+ case tcpip.QuickAckOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.slowAck, 1)
+ } else {
+ atomic.StoreUint32(&e.slowAck, 0)
+ }
+ return nil
+
+ case tcpip.MaxSegOption:
+ userMSS := v
+ if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
+ return tcpip.ErrInvalidOptionValue
+ }
+ e.mu.Lock()
+ e.userMSS = int(userMSS)
+ e.mu.Unlock()
+ e.notifyProtocolGoroutine(notifyMSSChanged)
+ return nil
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.netProto != header.IPv6ProtocolNumber {
@@ -1176,6 +1210,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
+ case tcpip.SendBufferSizeOption:
+ e.sndBufMu.Lock()
+ v := e.sndBufSize
+ e.sndBufMu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvListMu.Lock()
+ v := e.rcvBufSize
+ e.rcvListMu.Unlock()
+ return v, nil
+
}
return -1, tcpip.ErrUnknownProtocolOption
}
@@ -1198,18 +1244,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = header.TCPDefaultMSS
return nil
- case *tcpip.SendBufferSizeOption:
- e.sndBufMu.Lock()
- *o = tcpip.SendBufferSizeOption(e.sndBufSize)
- e.sndBufMu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- e.rcvListMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSize)
- e.rcvListMu.Unlock()
- return nil
-
case *tcpip.DelayOption:
*o = 0
if v := atomic.LoadUint32(&e.delay); v != 0 {
@@ -1246,6 +1280,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.BindToDeviceOption:
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
+ *o = tcpip.BindToDeviceOption(nic.Name)
+ return nil
+ }
+ *o = ""
+ return nil
+
case *tcpip.QuickAckOption:
*o = 1
if v := atomic.LoadUint32(&e.slowAck); v != 0 {
@@ -1452,7 +1496,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
if e.id.LocalPort != 0 {
// The endpoint is bound to a port, attempt to register it.
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort, e.bindToDevice)
if err != nil {
return err
}
@@ -1462,17 +1506,32 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
// address/port for both local and remote (otherwise this
// endpoint would be trying to connect to itself).
sameAddr := e.id.LocalAddress == e.id.RemoteAddress
- if _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
+
+ // Calculate a port offset based on the destination IP/port and
+ // src IP to ensure that for a given tuple (srcIP, destIP,
+ // destPort) the offset used as a starting point is the same to
+ // ensure that we can cycle through the port space effectively.
+ h := jenkins.Sum32(e.stack.PortSeed())
+ h.Write([]byte(e.id.LocalAddress))
+ h.Write([]byte(e.id.RemoteAddress))
+ portBuf := make([]byte, 2)
+ binary.LittleEndian.PutUint16(portBuf, e.id.RemotePort)
+ h.Write(portBuf)
+ portOffset := h.Sum32()
+
+ if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) {
if sameAddr && p == e.id.RemotePort {
return false, nil
}
- if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false) {
+ // reusePort is false below because connect cannot reuse a port even if
+ // reusePort was set.
+ if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false /* reusePort */, e.bindToDevice) {
return false, nil
}
id := e.id
id.LocalPort = p
- switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) {
+ switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) {
case nil:
e.id = id
return true, nil
@@ -1490,7 +1549,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
// before Connect: in such a case we don't want to hold on to
// reservations anymore.
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.bindToDevice)
e.isPortReserved = false
}
@@ -1634,7 +1693,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) {
}
// Register the endpoint.
- if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort); err != nil {
+ if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort, e.bindToDevice); err != nil {
return err
}
@@ -1715,7 +1774,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
}
}
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort, e.bindToDevice)
if err != nil {
return err
}
@@ -1725,16 +1784,16 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
e.id.LocalPort = port
// Any failures beyond this point must remove the port registration.
- defer func() {
+ defer func(bindToDevice tcpip.NICID) {
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, bindToDevice)
e.isPortReserved = false
e.effectiveNetProtos = nil
e.id.LocalPort = 0
e.id.LocalAddress = ""
e.boundNICID = 0
}
- }()
+ }(e.bindToDevice)
// If an address is specified, we must ensure that it's one of our
// local addresses.
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index ee04dcfcc..d5d8ab96a 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -14,7 +14,7 @@
// Package tcp contains the implementation of the TCP transport protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing tcp.ProtocolName (or "tcp") as one of the
+// activated on the stack by passing tcp.NewProtocol() as one of the
// transport protocols when calling stack.New(). Then endpoints can be created
// by passing tcp.ProtocolNumber as the transport protocol number when calling
// Stack.NewEndpoint().
@@ -34,9 +34,6 @@ import (
)
const (
- // ProtocolName is the string representation of the tcp protocol name.
- ProtocolName = "tcp"
-
// ProtocolNumber is the tcp protocol number.
ProtocolNumber = header.TCPProtocolNumber
@@ -129,7 +126,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// a reset is sent in response to any incoming segment except another reset. In
// particular, SYNs addressed to a non-existent connection are rejected by this
// means."
-func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool {
+func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
s := newSegment(r, id, vv)
defer s.decRef()
@@ -254,13 +251,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
}
}
-func init() {
- stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
- return &protocol{
- sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize},
- recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize},
- congestionControl: ccReno,
- availableCongestionControl: []string{ccReno, ccCubic},
- }
- })
+// NewProtocol returns a TCP transport protocol.
+func NewProtocol() stack.TransportProtocol {
+ return &protocol{
+ sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize},
+ recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize},
+ congestionControl: ccReno,
+ availableCongestionControl: []string{ccReno, ccCubic},
+ }
}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 1f9b1e0ef..735edfe55 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -664,7 +664,14 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
segEnd = seg.sequenceNumber.Add(1)
// Transition to FIN-WAIT1 state since we're initiating an active close.
s.ep.mu.Lock()
- s.ep.state = StateFinWait1
+ switch s.ep.state {
+ case StateCloseWait:
+ // We've already received a FIN and are now sending our own. The
+ // sender is now awaiting a final ACK for this FIN.
+ s.ep.state = StateLastAck
+ default:
+ s.ep.state = StateFinWait1
+ }
s.ep.mu.Unlock()
} else {
// We're sending a non-FIN segment.
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
index 272bbcdbd..9fa97528b 100644
--- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
@@ -38,7 +38,7 @@ func TestFastRecovery(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -190,7 +190,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -232,7 +232,7 @@ func TestCongestionAvoidance(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -336,7 +336,7 @@ func TestCubicCongestionAvoidance(t *testing.T) {
enableCUBIC(t, c)
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -445,7 +445,7 @@ func TestRetransmit(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index f79b8ec5f..089826a88 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -84,7 +84,7 @@ func TestConnectIncrementActiveConnection(t *testing.T) {
stats := c.Stack().Stats()
want := stats.TCP.ActiveConnectionOpenings.Value() + 1
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want {
t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want)
}
@@ -97,7 +97,7 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) {
stats := c.Stack().Stats()
want := stats.TCP.FailedConnectionAttempts.Value()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
t.Errorf("got stats.TCP.FailedConnectionOpenings.Value() = %v, want = %v", got, want)
}
@@ -131,7 +131,7 @@ func TestTCPSegmentsSentIncrement(t *testing.T) {
stats := c.Stack().Stats()
// SYN and ACK
want := stats.TCP.SegmentsSent.Value() + 2
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.SegmentsSent.Value(); got != want {
t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want)
@@ -190,21 +190,122 @@ func TestTCPResetsSentIncrement(t *testing.T) {
}
}
+// TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates
+// a RST if an ACK is received on the listening socket for which there is no
+// active handshake in progress and we are not using SYN cookies.
+func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %v", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ c.EP.Close()
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ c.GetPacket()
+
+ // Now resend the same ACK, this ACK should generate a RST as there
+ // should be no endpoint in SYN-RCVD state and we are not using
+ // syn-cookies yet. The reason we send the same ACK is we need a valid
+ // cookie(IRS) generated by the netstack without which the ACK will be
+ // rejected.
+ c.SendPacket(nil, ackHeaders)
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck)))
+}
+
func TestTCPResetsReceivedIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
stats := c.Stack().Stats()
want := stats.TCP.ResetsReceived.Value() + 1
- ackNum := seqnum.Value(789)
+ iss := seqnum.Value(789)
rcvWnd := seqnum.Size(30000)
- c.CreateConnected(ackNum, rcvWnd, nil)
+ c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
- SeqNum: c.IRS.Add(2),
- AckNum: ackNum.Add(2),
+ SeqNum: iss.Add(1),
+ AckNum: c.IRS.Add(1),
RcvWnd: rcvWnd,
Flags: header.TCPFlagRst,
})
@@ -214,18 +315,43 @@ func TestTCPResetsReceivedIncrement(t *testing.T) {
}
}
+func TestTCPResetsDoNotGenerateResets(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ stats := c.Stack().Stats()
+ want := stats.TCP.ResetsReceived.Value() + 1
+ iss := seqnum.Value(789)
+ rcvWnd := seqnum.Size(30000)
+ c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ SeqNum: iss.Add(1),
+ AckNum: c.IRS.Add(1),
+ RcvWnd: rcvWnd,
+ Flags: header.TCPFlagRst,
+ })
+
+ if got := stats.TCP.ResetsReceived.Value(); got != want {
+ t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want)
+ }
+ c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond)
+}
+
func TestActiveHandshake(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
}
func TestNonBlockingClose(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
ep := c.EP
c.EP = nil
@@ -241,7 +367,7 @@ func TestConnectResetAfterClose(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
ep := c.EP
c.EP = nil
@@ -291,7 +417,7 @@ func TestSimpleReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -339,11 +465,71 @@ func TestSimpleReceive(t *testing.T) {
)
}
+func TestConnectBindToDevice(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ device string
+ want tcp.EndpointState
+ }{
+ {"RightDevice", "nic1", tcp.StateEstablished},
+ {"WrongDevice", "nic2", tcp.StateSynSent},
+ {"AnyDevice", "", tcp.StateEstablished},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+ bindToDevice := tcpip.BindToDeviceOption(test.device)
+ c.EP.SetSockOpt(bindToDevice)
+ // Start connection attempt.
+ waitEntry, _ := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Receive SYN packet.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ iss := seqnum.Value(789)
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: tcpHdr.DestinationPort(),
+ DstPort: tcpHdr.SourcePort(),
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: rcvWnd,
+ TCPOpts: nil,
+ })
+
+ c.GetPacket()
+ if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want {
+ t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+ })
+ }
+}
+
func TestOutOfOrderReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -431,8 +617,7 @@ func TestOutOfOrderFlood(t *testing.T) {
defer c.Cleanup()
// Create a new connection with initial window size of 10.
- opt := tcpip.ReceiveBufferSizeOption(10)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 10)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
@@ -505,7 +690,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -574,7 +759,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -659,7 +844,7 @@ func TestShutdownRead(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
@@ -678,8 +863,7 @@ func TestFullWindowReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- opt := tcpip.ReceiveBufferSizeOption(10)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 10)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -746,11 +930,9 @@ func TestNoWindowShrinking(t *testing.T) {
defer c.Cleanup()
// Start off with a window size of 10, then shrink it to 5.
- opt := tcpip.ReceiveBufferSizeOption(10)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 10)
- opt = 5
- if err := c.EP.SetSockOpt(opt); err != nil {
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil {
t.Fatalf("SetSockOpt failed: %v", err)
}
@@ -850,7 +1032,7 @@ func TestSimpleSend(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
@@ -891,7 +1073,7 @@ func TestZeroWindowSend(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 0, nil)
+ c.CreateConnected(789, 0, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
@@ -949,8 +1131,7 @@ func TestScaledWindowConnect(t *testing.T) {
defer c.Cleanup()
// Set the window size greater than the maximum non-scaled window.
- opt := tcpip.ReceiveBufferSizeOption(65535 * 3)
- c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, 65535*3, []byte{
header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
})
@@ -984,8 +1165,7 @@ func TestNonScaledWindowConnect(t *testing.T) {
defer c.Cleanup()
// Set the window size greater than the maximum non-scaled window.
- opt := tcpip.ReceiveBufferSizeOption(65535 * 3)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 65535*3)
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
@@ -1025,7 +1205,7 @@ func TestScaledWindowAccept(t *testing.T) {
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1098,7 +1278,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1167,8 +1347,7 @@ func TestZeroScaledWindowReceive(t *testing.T) {
// Set the window size such that a window scale of 4 will be used.
const wnd = 65535 * 10
const ws = uint32(4)
- opt := tcpip.ReceiveBufferSizeOption(wnd)
- c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{
header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
})
@@ -1273,7 +1452,7 @@ func TestSegmentMerging(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Prevent the endpoint from processing packets.
test.stop(c.EP)
@@ -1323,7 +1502,7 @@ func TestDelay(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
c.EP.SetSockOpt(tcpip.DelayOption(1))
@@ -1371,7 +1550,7 @@ func TestUndelay(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
c.EP.SetSockOpt(tcpip.DelayOption(1))
@@ -1453,7 +1632,7 @@ func TestMSSNotDelayed(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
@@ -1569,7 +1748,7 @@ func TestSendGreaterThanMTU(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
testBrokenUpWrite(t, c, maxPayload)
}
@@ -1578,7 +1757,7 @@ func TestActiveSendMSSLessThanMTU(t *testing.T) {
c := context.New(t, 65535)
defer c.Cleanup()
- c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
testBrokenUpWrite(t, c, maxPayload)
@@ -1601,7 +1780,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
// Set the buffer size to a deterministic size so that we can check the
// window scaling option.
const rcvBufferSize = 0x20000
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1745,7 +1924,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
// window scaling option.
const rcvBufferSize = 0x20000
const wndScale = 2
- if err := c.EP.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil {
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1847,7 +2026,7 @@ func TestReceiveOnResetConnection(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Send RST segment.
c.SendPacket(nil, &context.Headers{
@@ -1884,7 +2063,7 @@ func TestSendOnResetConnection(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Send RST segment.
c.SendPacket(nil, &context.Headers{
@@ -1909,7 +2088,7 @@ func TestFinImmediately(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
@@ -1952,7 +2131,7 @@ func TestFinRetransmit(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
@@ -2006,7 +2185,7 @@ func TestFinWithNoPendingData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and have it acknowledged.
view := buffer.NewView(10)
@@ -2077,7 +2256,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write enough segments to fill the congestion window before ACK'ing
// any of them.
@@ -2165,7 +2344,7 @@ func TestFinWithPendingData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and acknowledge it to get cwnd to 2.
view := buffer.NewView(10)
@@ -2251,7 +2430,7 @@ func TestFinWithPartialAck(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and acknowledge it to get cwnd to 2. Also send
// FIN from the test side.
@@ -2383,7 +2562,7 @@ func scaledSendWindow(t *testing.T, scale uint8) {
defer c.Cleanup()
maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize
- c.CreateConnectedWithRawOptions(789, 0, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 0, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
header.TCPOptionWS, 3, scale, header.TCPOptionNOP,
})
@@ -2433,7 +2612,7 @@ func TestScaledSendWindow(t *testing.T) {
func TestReceivedValidSegmentCountIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.ValidSegmentsReceived.Value() + 1
@@ -2454,7 +2633,7 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) {
func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.InvalidSegmentsReceived.Value() + 1
vv := c.BuildSegment(nil, &context.Headers{
@@ -2478,7 +2657,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
func TestReceivedIncorrectChecksumIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.ChecksumErrors.Value() + 1
vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{
@@ -2509,7 +2688,7 @@ func TestReceivedSegmentQueuing(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Send 200 segments.
data := []byte{1, 2, 3}
@@ -2555,7 +2734,7 @@ func TestReadAfterClosedState(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -2730,8 +2909,8 @@ func TestReusePort(t *testing.T) {
func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
t.Helper()
- var s tcpip.ReceiveBufferSizeOption
- if err := ep.GetSockOpt(&s); err != nil {
+ s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
@@ -2743,8 +2922,8 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
t.Helper()
- var s tcpip.SendBufferSizeOption
- if err := ep.GetSockOpt(&s); err != nil {
+ s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption)
+ if err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
@@ -2754,7 +2933,10 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
}
func TestDefaultBufferSizes(t *testing.T) {
- s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
// Check the default values.
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
@@ -2800,7 +2982,10 @@ func TestDefaultBufferSizes(t *testing.T) {
}
func TestMinMaxBufferSizes(t *testing.T) {
- s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
// Check the default values.
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
@@ -2819,37 +3004,96 @@ func TestMinMaxBufferSizes(t *testing.T) {
}
// Set values below the min.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(199)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkRecvBufferSize(t, ep, 200)
- if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(299)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkSendBufferSize(t, ep, 300)
// Set values above the max.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(1 + tcp.DefaultReceiveBufferSize*20)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20)
- if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(1 + tcp.DefaultSendBufferSize*30)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30)
}
+func TestBindToDeviceOption(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}})
+
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
+ }
+ defer ep.Close()
+
+ if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil {
+ t.Errorf("CreateNamedNIC failed: %v", err)
+ }
+
+ // Make an nameless NIC.
+ if err := s.CreateNIC(54321, loopback.New()); err != nil {
+ t.Errorf("CreateNIC failed: %v", err)
+ }
+
+ // strPtr is used instead of taking the address of string literals, which is
+ // a compiler error.
+ strPtr := func(s string) *string {
+ return &s
+ }
+
+ testActions := []struct {
+ name string
+ setBindToDevice *string
+ setBindToDeviceError *tcpip.Error
+ getBindToDevice tcpip.BindToDeviceOption
+ }{
+ {"GetDefaultValue", nil, nil, ""},
+ {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
+ {"BindToExistent", strPtr("my_device"), nil, "my_device"},
+ {"UnbindToDevice", strPtr(""), nil, ""},
+ }
+ for _, testAction := range testActions {
+ t.Run(testAction.name, func(t *testing.T) {
+ if testAction.setBindToDevice != nil {
+ bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
+ if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ }
+ }
+ bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
+ if ep.GetSockOpt(&bindToDevice) != nil {
+ t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ }
+ if got, want := bindToDevice, testAction.getBindToDevice; got != want {
+ t.Errorf("bindToDevice got %q, want %q", got, want)
+ }
+ })
+ }
+}
+
func makeStack() (*stack.Stack, *tcpip.Error) {
- s := stack.New([]string{
- ipv4.ProtocolName,
- ipv6.ProtocolName,
- }, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{
+ ipv4.NewProtocol(),
+ ipv6.NewProtocol(),
+ },
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
id := loopback.New()
if testing.Verbose() {
@@ -3105,7 +3349,7 @@ func TestPathMTUDiscovery(t *testing.T) {
// Create new connection with MSS of 1460.
const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize
- c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
@@ -3182,7 +3426,7 @@ func TestTCPEndpointProbe(t *testing.T) {
invoked <- struct{}{}
})
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
c.SendPacket(data, &context.Headers{
@@ -3356,7 +3600,7 @@ func TestKeepalive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond))
c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond))
@@ -3459,8 +3703,8 @@ func TestKeepalive(t *testing.T) {
),
)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset)
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
}
}
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 272481aa0..ef823e4ae 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -137,7 +137,10 @@ type Context struct {
// New allocates and initializes a test context containing a new
// stack and a link-layer endpoint.
func New(t *testing.T, mtu uint32) *Context {
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
// Allow minimum send/receive buffer sizes to be 1 during tests.
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize, 10 * tcp.DefaultSendBufferSize}); err != nil {
@@ -150,11 +153,19 @@ func New(t *testing.T, mtu uint32) *Context {
// Some of the congestion control tests send up to 640 packets, we so
// set the channel size to 1000.
- id, linkEP := channel.New(1000, mtu, "")
+ ep := channel.New(1000, mtu, "")
+ wep := stack.LinkEndpoint(ep)
+ if testing.Verbose() {
+ wep = sniffer.New(ep)
+ }
+ if err := s.CreateNamedNIC(1, "nic1", wep); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+ wep2 := stack.LinkEndpoint(channel.New(1000, mtu, ""))
if testing.Verbose() {
- id = sniffer.New(id)
+ wep2 = sniffer.New(channel.New(1000, mtu, ""))
}
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNamedNIC(2, "nic2", wep2); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -180,7 +191,7 @@ func New(t *testing.T, mtu uint32) *Context {
return &Context{
t: t,
s: s,
- linkEP: linkEP,
+ linkEP: ep,
WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)),
}
}
@@ -267,7 +278,7 @@ func (c *Context) GetPacketNonBlocking() []byte {
// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint.
func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) {
// Allocate a buffer data and headers.
- buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p1) + len(p2))
+ buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p2))
if len(buf) > maxTotalSize {
buf = buf[:maxTotalSize]
}
@@ -286,9 +297,9 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt
icmp := header.ICMPv4(buf[header.IPv4MinimumSize:])
icmp.SetType(typ)
icmp.SetCode(code)
-
- copy(icmp[header.ICMPv4PayloadOffset:], p1)
- copy(icmp[header.ICMPv4PayloadOffset+len(p1):], p2)
+ const icmpv4VariableHeaderOffset = 4
+ copy(icmp[icmpv4VariableHeaderOffset:], p1)
+ copy(icmp[header.ICMPv4PayloadOffset:], p2)
// Inject packet.
c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView())
@@ -511,7 +522,7 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) {
}
// CreateConnected creates a connected TCP endpoint.
-func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption) {
+func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int) {
c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil)
}
@@ -584,12 +595,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte)
c.Port = tcpHdr.SourcePort()
}
-// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
-// the specified option bytes as the Option field in the initial SYN packet.
-//
-// It also sets the receive buffer for the endpoint to the specified
-// value in epRcvBuf.
-func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) {
+// Create creates a TCP endpoint.
+func (c *Context) Create(epRcvBuf int) {
// Create TCP endpoint.
var err *tcpip.Error
c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
@@ -597,11 +604,20 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.
c.t.Fatalf("NewEndpoint failed: %v", err)
}
- if epRcvBuf != nil {
- if err := c.EP.SetSockOpt(*epRcvBuf); err != nil {
+ if epRcvBuf != -1 {
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, epRcvBuf); err != nil {
c.t.Fatalf("SetSockOpt failed failed: %v", err)
}
}
+}
+
+// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
+// the specified option bytes as the Option field in the initial SYN packet.
+//
+// It also sets the receive buffer for the endpoint to the specified
+// value in epRcvBuf.
+func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) {
+ c.Create(epRcvBuf)
c.Connect(iss, rcvWnd, options)
}
diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD
index 4bec48c0f..43fcc27f0 100644
--- a/pkg/tcpip/transport/tcpconntrack/BUILD
+++ b/pkg/tcpip/transport/tcpconntrack/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index ac2666f69..7a635ab8d 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "udp_packet_list",
@@ -50,6 +52,7 @@ go_test(
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/loopback",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index ac5905772..52f5af777 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -15,7 +15,6 @@
package udp
import (
- "math"
"sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -37,13 +36,17 @@ type udpPacket struct {
views [8]buffer.View `state:"nosave"`
}
-type endpointState int
+// EndpointState represents the state of a UDP endpoint.
+type EndpointState uint32
+// Endpoint states. Note that are represented in a netstack-specific manner and
+// may not be meaningful externally. Specifically, they need to be translated to
+// Linux's representation for these states if presented to userspace.
const (
- stateInitial endpointState = iota
- stateBound
- stateConnected
- stateClosed
+ StateInitial EndpointState = iota
+ StateBound
+ StateConnected
+ StateClosed
)
// endpoint represents a UDP endpoint. This struct serves as the interface
@@ -74,7 +77,7 @@ type endpoint struct {
mu sync.RWMutex `state:"nosave"`
sndBufSize int
id stack.TransportEndpointID
- state endpointState
+ state EndpointState
bindNICID tcpip.NICID
regNICID tcpip.NICID
route stack.Route `state:"manual"`
@@ -85,6 +88,7 @@ type endpoint struct {
multicastNICID tcpip.NICID
multicastLoop bool
reusePort bool
+ bindToDevice tcpip.NICID
broadcast bool
// shutdownFlags represent the current shutdown state of the endpoint.
@@ -140,9 +144,9 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
- case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ case StateBound, StateConnected:
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice)
}
for _, mem := range e.multicastMemberships {
@@ -163,7 +167,7 @@ func (e *endpoint) Close() {
e.route.Release()
// Update the state.
- e.state = stateClosed
+ e.state = StateClosed
e.mu.Unlock()
@@ -211,11 +215,11 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
// Returns true for retry if preparation should be retried.
func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
switch e.state {
- case stateInitial:
- case stateConnected:
+ case StateInitial:
+ case StateConnected:
return false, nil
- case stateBound:
+ case StateBound:
if to == nil {
return false, tcpip.ErrDestinationRequired
}
@@ -232,7 +236,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
- if e.state != stateInitial {
+ if e.state != StateInitial {
return true, nil
}
@@ -273,17 +277,12 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netPr
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
}
- if p.Size() > math.MaxUint16 {
- // Payload can't possibly fit in a packet.
- return 0, nil, tcpip.ErrMessageTooLong
- }
-
to := opts.To
e.mu.RLock()
@@ -322,7 +321,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
defer e.mu.Unlock()
// Recheck state after lock was re-acquired.
- if e.state != stateConnected {
+ if e.state != StateConnected {
return 0, nil, tcpip.ErrInvalidEndpointState
}
}
@@ -366,10 +365,14 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
}
}
- v, err := p.Get(p.Size())
+ v, err := p.FullPayload()
if err != nil {
return 0, nil, err
}
+ if len(v) > header.UDPMaximumPacketSize {
+ // Payload can't possibly fit in a packet.
+ return 0, nil, tcpip.ErrMessageTooLong
+ }
ttl := route.DefaultTTL()
if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) {
@@ -387,7 +390,12 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
-// SetSockOpt sets a socket option. Currently not supported.
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return nil
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
switch v := opt.(type) {
case tcpip.V6OnlyOption:
@@ -400,7 +408,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
defer e.mu.Unlock()
// We only allow this to be set when we're in the initial state.
- if e.state != stateInitial {
+ if e.state != StateInitial {
return tcpip.ErrInvalidEndpointState
}
@@ -544,6 +552,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.reusePort = v != 0
e.mu.Unlock()
+ case tcpip.BindToDeviceOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if v == "" {
+ e.bindToDevice = 0
+ return nil
+ }
+ for nicid, nic := range e.stack.NICInfo() {
+ if nic.Name == string(v) {
+ e.bindToDevice = nicid
+ return nil
+ }
+ }
+ return tcpip.ErrUnknownDevice
+
case tcpip.BroadcastOption:
e.mu.Lock()
e.broadcast = v != 0
@@ -566,7 +589,20 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
}
e.rcvMu.Unlock()
return v, nil
+
+ case tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ v := e.sndBufSize
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ v := e.rcvBufSizeMax
+ e.rcvMu.Unlock()
+ return v, nil
}
+
return -1, tcpip.ErrUnknownProtocolOption
}
@@ -576,18 +612,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.ErrorOption:
return nil
- case *tcpip.SendBufferSizeOption:
- e.mu.Lock()
- *o = tcpip.SendBufferSizeOption(e.sndBufSize)
- e.mu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- e.rcvMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax)
- e.rcvMu.Unlock()
- return nil
-
case *tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.netProto != header.IPv6ProtocolNumber {
@@ -638,6 +662,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.BindToDeviceOption:
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
+ *o = tcpip.BindToDeviceOption(nic.Name)
+ return nil
+ }
+ *o = tcpip.BindToDeviceOption("")
+ return nil
+
case *tcpip.KeepaliveEnabledOption:
*o = 0
return nil
@@ -726,7 +760,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- if e.state != stateConnected {
+ if e.state != StateConnected {
return nil
}
id := stack.TransportEndpointID{}
@@ -741,12 +775,16 @@ func (e *endpoint) Disconnect() *tcpip.Error {
if err != nil {
return err
}
- e.state = stateBound
+ e.state = StateBound
} else {
- e.state = stateInitial
+ if e.id.LocalPort != 0 {
+ // Release the ephemeral port.
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice)
+ }
+ e.state = StateInitial
}
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
e.id = id
e.route.Release()
e.route = stack.Route{}
@@ -772,8 +810,8 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
nicid := addr.NIC
var localPort uint16
switch e.state {
- case stateInitial:
- case stateBound, stateConnected:
+ case StateInitial:
+ case StateBound, StateConnected:
localPort = e.id.LocalPort
if e.bindNICID == 0 {
break
@@ -801,7 +839,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
RemoteAddress: r.RemoteAddress,
}
- if e.state == stateInitial {
+ if e.state == StateInitial {
id.LocalAddress = r.LocalAddress
}
@@ -823,7 +861,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Remove the old registration.
if e.id.LocalPort != 0 {
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
}
e.id = id
@@ -832,7 +870,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.regNICID = nicid
e.effectiveNetProtos = netProtos
- e.state = stateConnected
+ e.state = StateConnected
e.rcvMu.Lock()
e.rcvReady = true
@@ -854,7 +892,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// A socket in the bound state can still receive multicast messages,
// so we need to notify waiters on shutdown.
- if e.state != stateBound && e.state != stateConnected {
+ if e.state != StateBound && e.state != StateConnected {
return tcpip.ErrNotConnected
}
@@ -886,16 +924,16 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
if e.id.LocalPort == 0 {
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice)
if err != nil {
return id, err
}
id.LocalPort = port
}
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice)
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.bindToDevice)
}
return id, err
}
@@ -903,7 +941,7 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ
func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
- if e.state != stateInitial {
+ if e.state != StateInitial {
return tcpip.ErrInvalidEndpointState
}
@@ -946,7 +984,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
e.effectiveNetProtos = netProtos
// Mark endpoint as bound.
- e.state = stateBound
+ e.state = StateBound
e.rcvMu.Lock()
e.rcvReady = true
@@ -989,7 +1027,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- if e.state != stateConnected {
+ if e.state != StateConnected {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
@@ -1069,10 +1107,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
}
-// State implements socket.Socket.State.
+// State implements tcpip.Endpoint.State.
func (e *endpoint) State() uint32 {
- // TODO(b/112063468): Translate internal state to values returned by Linux.
- return 0
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return uint32(e.state)
}
func isBroadcastOrMulticast(a tcpip.Address) bool {
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 5cbb56120..be46e6d4e 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -77,7 +77,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
}
- if e.state != stateBound && e.state != stateConnected {
+ if e.state != StateBound && e.state != StateConnected {
return
}
@@ -92,7 +92,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
var err *tcpip.Error
- if e.state == stateConnected {
+ if e.state == StateConnected {
e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop)
if err != nil {
panic(err)
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index a874fc9fd..2d0bc5221 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -74,7 +74,7 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID {
// CreateEndpoint creates a connected UDP endpoint for the session request.
func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
ep := newEndpoint(r.stack, r.route.NetProto, queue)
- if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort); err != nil {
+ if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort, ep.bindToDevice); err != nil {
ep.Close()
return nil, err
}
@@ -84,7 +84,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
ep.dstPort = r.id.RemotePort
ep.regNICID = r.route.NICID()
- ep.state = stateConnected
+ ep.state = StateConnected
ep.rcvMu.Lock()
ep.rcvReady = true
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index f76e7fbe1..f5cc932dd 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -14,7 +14,7 @@
// Package udp contains the implementation of the UDP transport protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing udp.ProtocolName (or "udp") as one of the
+// activated on the stack by passing udp.NewProtocol() as one of the
// transport protocols when calling stack.New(). Then endpoints can be created
// by passing udp.ProtocolNumber as the transport protocol number when calling
// Stack.NewEndpoint().
@@ -30,9 +30,6 @@ import (
)
const (
- // ProtocolName is the string representation of the udp protocol name.
- ProtocolName = "udp"
-
// ProtocolNumber is the udp protocol number.
ProtocolNumber = header.UDPProtocolNumber
)
@@ -69,7 +66,106 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool {
+func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
+ // Get the header then trim it from the view.
+ hdr := header.UDP(vv.First())
+ if int(hdr.Length()) > vv.Size() {
+ // Malformed packet.
+ r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
+ return true
+ }
+ // TODO(b/129426613): only send an ICMP message if UDP checksum is valid.
+
+ // Only send ICMP error if the address is not a multicast/broadcast
+ // v4/v6 address or the source is not the unspecified address.
+ //
+ // See: point e) in https://tools.ietf.org/html/rfc4443#section-2.4
+ if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) || id.RemoteAddress == header.IPv6Any || id.RemoteAddress == header.IPv4Any {
+ return true
+ }
+
+ // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination
+ // Unreachable messages with code:
+ //
+ // 2 (Protocol Unreachable), when the designated transport protocol
+ // is not supported; or
+ //
+ // 3 (Port Unreachable), when the designated transport protocol
+ // (e.g., UDP) is unable to demultiplex the datagram but has no
+ // protocol mechanism to inform the sender.
+ switch len(id.LocalAddress) {
+ case header.IPv4AddressSize:
+ if !r.Stack().AllowICMPMessage() {
+ r.Stack().Stats().ICMP.V4PacketsSent.RateLimited.Increment()
+ return true
+ }
+ // As per RFC 1812 Section 4.3.2.3
+ //
+ // ICMP datagram SHOULD contain as much of the original
+ // datagram as possible without the length of the ICMP
+ // datagram exceeding 576 bytes
+ //
+ // NOTE: The above RFC referenced is different from the original
+ // recommendation in RFC 1122 where it mentioned that at least 8
+ // bytes of the payload must be included. Today linux and other
+ // systems implement the] RFC1812 definition and not the original
+ // RFC 1122 requirement.
+ mtu := int(r.MTU())
+ if mtu > header.IPv4MinimumProcessableDatagramSize {
+ mtu = header.IPv4MinimumProcessableDatagramSize
+ }
+ headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize
+ available := int(mtu) - headerLen
+ payloadLen := len(netHeader) + vv.Size()
+ if payloadLen > available {
+ payloadLen = available
+ }
+
+ payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader})
+ payload.Append(vv)
+ payload.CapLength(payloadLen)
+
+ hdr := buffer.NewPrependable(headerLen)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ pkt.SetType(header.ICMPv4DstUnreachable)
+ pkt.SetCode(header.ICMPv4PortUnreachable)
+ pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload))
+ r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv4ProtocolNumber, r.DefaultTTL())
+
+ case header.IPv6AddressSize:
+ if !r.Stack().AllowICMPMessage() {
+ r.Stack().Stats().ICMP.V6PacketsSent.RateLimited.Increment()
+ return true
+ }
+
+ // As per RFC 4443 section 2.4
+ //
+ // (c) Every ICMPv6 error message (type < 128) MUST include
+ // as much of the IPv6 offending (invoking) packet (the
+ // packet that caused the error) as possible without making
+ // the error message packet exceed the minimum IPv6 MTU
+ // [IPv6].
+ mtu := int(r.MTU())
+ if mtu > header.IPv6MinimumMTU {
+ mtu = header.IPv6MinimumMTU
+ }
+ headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize
+ available := int(mtu) - headerLen
+ payloadLen := len(netHeader) + vv.Size()
+ if payloadLen > available {
+ payloadLen = available
+ }
+ payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader})
+ payload.Append(vv)
+ payload.CapLength(payloadLen)
+
+ hdr := buffer.NewPrependable(headerLen)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6DstUnreachableMinimumSize))
+ pkt.SetType(header.ICMPv6DstUnreachable)
+ pkt.SetCode(header.ICMPv6PortUnreachable)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload))
+ r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv6ProtocolNumber, r.DefaultTTL())
+ }
return true
}
@@ -83,8 +179,7 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-func init() {
- stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
- return &protocol{}
- })
+// NewProtocol returns a UDP transport protocol.
+func NewProtocol() stack.TransportProtocol {
+ return &protocol{}
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 9da6edce2..5059ca22d 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -17,7 +17,6 @@ package udp_test
import (
"bytes"
"fmt"
- "math"
"math/rand"
"testing"
"time"
@@ -27,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
@@ -274,13 +274,17 @@ type testContext struct {
func newDualTestContext(t *testing.T, mtu uint32) *testContext {
t.Helper()
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ ep := channel.New(256, mtu, "")
+ wep := stack.LinkEndpoint(ep)
- id, linkEP := channel.New(256, mtu, "")
if testing.Verbose() {
- id = sniffer.New(id)
+ wep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, wep); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -306,7 +310,7 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext {
return &testContext{
t: t,
s: s,
- linkEP: linkEP,
+ linkEP: ep,
}
}
@@ -461,94 +465,70 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) {
}
func newPayload() []byte {
- b := make([]byte, 30+rand.Intn(100))
+ return newMinPayload(30)
+}
+
+func newMinPayload(minSize int) []byte {
+ b := make([]byte, minSize+rand.Intn(100))
for i := range b {
b[i] = byte(rand.Intn(256))
}
return b
}
-func TestBindPortReuse(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- var eps [5]tcpip.Endpoint
- reusePortOpt := tcpip.ReusePortOption(1)
-
- pollChannel := make(chan tcpip.Endpoint)
- for i := 0; i < len(eps); i++ {
- // Try to receive the data.
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
- defer close(ch)
+func TestBindToDeviceOption(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
- var err *tcpip.Error
- eps[i], err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- go func(ep tcpip.Endpoint) {
- for range ch {
- pollChannel <- ep
- }
- }(eps[i])
-
- defer eps[i].Close()
- if err := eps[i].SetSockOpt(reusePortOpt); err != nil {
- c.t.Fatalf("SetSockOpt failed failed: %v", err)
- }
- if err := eps[i].Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
- }
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
}
+ defer ep.Close()
- npackets := 100000
- nports := 10000
- ports := make(map[uint16]tcpip.Endpoint)
- stats := make(map[tcpip.Endpoint]int)
- for i := 0; i < npackets; i++ {
- // Send a packet.
- port := uint16(i % nports)
- payload := newPayload()
- c.injectV6Packet(payload, &header4Tuple{
- srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort + port},
- dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
- })
+ if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil {
+ t.Errorf("CreateNamedNIC failed: %v", err)
+ }
- var addr tcpip.FullAddress
- ep := <-pollChannel
- _, _, err := ep.Read(&addr)
- if err != nil {
- c.t.Fatalf("Read failed: %v", err)
- }
- stats[ep]++
- if i < nports {
- ports[uint16(i)] = ep
- } else {
- // Check that all packets from one client are handled
- // by the same socket.
- if ports[port] != ep {
- t.Fatalf("Port mismatch")
- }
- }
+ // Make an nameless NIC.
+ if err := s.CreateNIC(54321, loopback.New()); err != nil {
+ t.Errorf("CreateNIC failed: %v", err)
}
- if len(stats) != len(eps) {
- t.Fatalf("Only %d(expected %d) sockets received packets", len(stats), len(eps))
+ // strPtr is used instead of taking the address of string literals, which is
+ // a compiler error.
+ strPtr := func(s string) *string {
+ return &s
}
- // Check that a packet distribution is fair between sockets.
- for _, c := range stats {
- n := float64(npackets) / float64(len(eps))
- // The deviation is less than 10%.
- if math.Abs(float64(c)-n) > n/10 {
- t.Fatal(c, n)
- }
+ testActions := []struct {
+ name string
+ setBindToDevice *string
+ setBindToDeviceError *tcpip.Error
+ getBindToDevice tcpip.BindToDeviceOption
+ }{
+ {"GetDefaultValue", nil, nil, ""},
+ {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
+ {"BindToExistent", strPtr("my_device"), nil, "my_device"},
+ {"UnbindToDevice", strPtr(""), nil, ""},
+ }
+ for _, testAction := range testActions {
+ t.Run(testAction.name, func(t *testing.T) {
+ if testAction.setBindToDevice != nil {
+ bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
+ if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ }
+ }
+ bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
+ if ep.GetSockOpt(&bindToDevice) != nil {
+ t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ }
+ if got, want := bindToDevice, testAction.getBindToDevice; got != want {
+ t.Errorf("bindToDevice got %q, want %q", got, want)
+ }
+ })
}
}
@@ -1238,3 +1218,153 @@ func TestMulticastInterfaceOption(t *testing.T) {
})
}
}
+
+// TestV4UnknownDestination verifies that we generate an ICMPv4 Destination
+// Unreachable message when a udp datagram is received on ports for which there
+// is no bound udp socket.
+func TestV4UnknownDestination(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ testCases := []struct {
+ flow testFlow
+ icmpRequired bool
+ // largePayload if true, will result in a payload large enough
+ // so that the final generated IPv4 packet is larger than
+ // header.IPv4MinimumProcessableDatagramSize.
+ largePayload bool
+ }{
+ {unicastV4, true, false},
+ {unicastV4, true, true},
+ {multicastV4, false, false},
+ {multicastV4, false, true},
+ {broadcast, false, false},
+ {broadcast, false, true},
+ }
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) {
+ payload := newPayload()
+ if tc.largePayload {
+ payload = newMinPayload(576)
+ }
+ c.injectPacket(tc.flow, payload)
+ if !tc.icmpRequired {
+ select {
+ case p := <-c.linkEP.C:
+ t.Fatalf("unexpected packet received: %+v", p)
+ case <-time.After(1 * time.Second):
+ return
+ }
+ }
+
+ select {
+ case p := <-c.linkEP.C:
+ var pkt []byte
+ pkt = append(pkt, p.Header...)
+ pkt = append(pkt, p.Payload...)
+ if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
+ t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
+ }
+
+ hdr := header.IPv4(pkt)
+ checker.IPv4(t, hdr, checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
+
+ icmpPkt := header.ICMPv4(hdr.Payload())
+ payloadIPHeader := header.IPv4(icmpPkt.Payload())
+ wantLen := len(payload)
+ if tc.largePayload {
+ wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize
+ }
+
+ // In case of large payloads the IP packet may be truncated. Update
+ // the length field before retrieving the udp datagram payload.
+ payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize))
+
+ origDgram := header.UDP(payloadIPHeader.Payload())
+ if got, want := len(origDgram.Payload()), wantLen; got != want {
+ t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ }
+ if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
+ t.Fatalf("unexpected payload got: %d, want: %d", got, want)
+ }
+ case <-time.After(1 * time.Second):
+ t.Fatalf("packet wasn't written out")
+ }
+ })
+ }
+}
+
+// TestV6UnknownDestination verifies that we generate an ICMPv6 Destination
+// Unreachable message when a udp datagram is received on ports for which there
+// is no bound udp socket.
+func TestV6UnknownDestination(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ testCases := []struct {
+ flow testFlow
+ icmpRequired bool
+ // largePayload if true will result in a payload large enough to
+ // create an IPv6 packet > header.IPv6MinimumMTU bytes.
+ largePayload bool
+ }{
+ {unicastV6, true, false},
+ {unicastV6, true, true},
+ {multicastV6, false, false},
+ {multicastV6, false, true},
+ }
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) {
+ payload := newPayload()
+ if tc.largePayload {
+ payload = newMinPayload(1280)
+ }
+ c.injectPacket(tc.flow, payload)
+ if !tc.icmpRequired {
+ select {
+ case p := <-c.linkEP.C:
+ t.Fatalf("unexpected packet received: %+v", p)
+ case <-time.After(1 * time.Second):
+ return
+ }
+ }
+
+ select {
+ case p := <-c.linkEP.C:
+ var pkt []byte
+ pkt = append(pkt, p.Header...)
+ pkt = append(pkt, p.Payload...)
+ if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
+ t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
+ }
+
+ hdr := header.IPv6(pkt)
+ checker.IPv6(t, hdr, checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6DstUnreachable),
+ checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
+
+ icmpPkt := header.ICMPv6(hdr.Payload())
+ payloadIPHeader := header.IPv6(icmpPkt.Payload())
+ wantLen := len(payload)
+ if tc.largePayload {
+ wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
+ }
+ // In case of large payloads the IP packet may be truncated. Update
+ // the length field before retrieving the udp datagram payload.
+ payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
+
+ origDgram := header.UDP(payloadIPHeader.Payload())
+ if got, want := len(origDgram.Payload()), wantLen; got != want {
+ t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ }
+ if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
+ t.Fatalf("unexpected payload got: %v, want: %v", got, want)
+ }
+ case <-time.After(1 * time.Second):
+ t.Fatalf("packet wasn't written out")
+ }
+ })
+ }
+}
diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD
index 98d51cc69..6afdb29b7 100644
--- a/pkg/tmutex/BUILD
+++ b/pkg/tmutex/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD
index cbd92fc05..8f6f180e5 100644
--- a/pkg/unet/BUILD
+++ b/pkg/unet/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/urpc/BUILD b/pkg/urpc/BUILD
index b7f505a84..b6bbb0ea2 100644
--- a/pkg/urpc/BUILD
+++ b/pkg/urpc/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD
index 9173dfd0f..8dc88becb 100644
--- a/pkg/waiter/BUILD
+++ b/pkg/waiter/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "waiter_list",