summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.buildkite/pipeline.yaml10
-rw-r--r--g3doc/user_guide/install.md10
-rw-r--r--pkg/p9/p9.go11
-rw-r--r--pkg/ring0/BUILD1
-rw-r--r--pkg/ring0/defs.go3
-rw-r--r--pkg/ring0/gen_offsets/BUILD1
-rw-r--r--pkg/ring0/kernel_amd64.go12
-rw-r--r--pkg/ring0/kernel_arm64.go4
-rw-r--r--pkg/sentry/arch/BUILD3
-rw-r--r--pkg/sentry/arch/arch.go34
-rw-r--r--pkg/sentry/arch/arch_aarch64.go71
-rw-r--r--pkg/sentry/arch/arch_amd64.go13
-rw-r--r--pkg/sentry/arch/arch_arm64.go13
-rw-r--r--pkg/sentry/arch/arch_state_x86.go54
-rw-r--r--pkg/sentry/arch/arch_x86.go217
-rw-r--r--pkg/sentry/arch/arch_x86_impl.go3
-rw-r--r--pkg/sentry/arch/fpu/BUILD21
-rw-r--r--pkg/sentry/arch/fpu/fpu.go54
-rw-r--r--pkg/sentry/arch/fpu/fpu_amd64.go280
-rw-r--r--pkg/sentry/arch/fpu/fpu_amd64.s (renamed from pkg/sentry/arch/arch_amd64.s)0
-rw-r--r--pkg/sentry/arch/fpu/fpu_arm64.go63
-rw-r--r--pkg/sentry/arch/signal_amd64.go9
-rw-r--r--pkg/sentry/arch/signal_arm64.go7
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached.go21
-rw-r--r--pkg/sentry/fs/gofer/file.go27
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go20
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go20
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go14
-rw-r--r--pkg/sentry/fsimpl/overlay/filesystem.go21
-rw-r--r--pkg/sentry/fsimpl/overlay/overlay.go21
-rw-r--r--pkg/sentry/fsimpl/overlay/regular_file.go43
-rw-r--r--pkg/sentry/kernel/ptrace_amd64.go10
-rw-r--r--pkg/sentry/platform/kvm/BUILD2
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64.go6
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.go10
-rw-r--r--pkg/sentry/platform/kvm/kvm_amd64_test.go2
-rw-r--r--pkg/sentry/platform/kvm/kvm_test.go33
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go13
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64.go3
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go3
-rw-r--r--pkg/sentry/platform/ptrace/BUILD1
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_unsafe.go9
-rw-r--r--pkg/sentry/syscalls/linux/error.go26
-rw-r--r--pkg/syserror/syserror.go10
-rw-r--r--pkg/tcpip/checker/checker.go7
-rw-r--r--pkg/tcpip/header/ndp_options.go156
-rw-r--r--pkg/tcpip/header/ndp_test.go695
-rw-r--r--pkg/tcpip/header/ndpoptionidentifier_string.go32
-rw-r--r--pkg/tcpip/network/arp/arp.go21
-rw-r--r--pkg/tcpip/network/internal/ip/duplicate_address_detection.go139
-rw-r--r--pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go124
-rw-r--r--pkg/tcpip/network/ip_test.go4
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go121
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go60
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go4
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go155
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go16
-rw-r--r--pkg/tcpip/network/ipv6/mld_test.go18
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go14
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go72
-rw-r--r--pkg/tcpip/stack/ndp_test.go79
-rw-r--r--pkg/tcpip/stack/nic.go14
-rw-r--r--pkg/tcpip/stack/registration.go8
-rw-r--r--pkg/tcpip/stack/stack.go31
-rw-r--r--pkg/tcpip/stack/stack_test.go138
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go4
-rw-r--r--pkg/tcpip/tests/integration/BUILD2
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go261
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go148
-rw-r--r--pkg/tcpip/tests/integration/route_test.go132
-rw-r--r--pkg/tcpip/tests/utils/BUILD3
-rw-r--r--pkg/tcpip/tests/utils/utils.go60
-rw-r--r--pkg/tcpip/transport/tcp/accept.go25
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go4
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go2
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go507
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go4
-rw-r--r--pkg/test/dockerutil/network.go5
-rw-r--r--runsc/boot/controller.go2
-rw-r--r--runsc/boot/fs.go15
-rw-r--r--runsc/boot/fs_test.go3
-rw-r--r--runsc/boot/loader.go2
-rw-r--r--runsc/boot/loader_test.go6
-rw-r--r--runsc/boot/vfs.go2
-rw-r--r--runsc/cmd/do.go15
-rw-r--r--runsc/cmd/gofer.go18
-rw-r--r--runsc/config/config.go21
-rw-r--r--runsc/config/flags.go3
-rw-r--r--runsc/fsgofer/filter/config.go5
-rw-r--r--runsc/fsgofer/filter/filter.go6
-rw-r--r--runsc/fsgofer/fsgofer.go23
-rw-r--r--runsc/fsgofer/fsgofer_test.go26
-rw-r--r--runsc/specutils/specutils.go11
-rw-r--r--test/benchmarks/fs/BUILD2
-rw-r--r--test/benchmarks/fs/bazel_test.go23
-rw-r--r--test/benchmarks/fs/fio_test.go20
-rw-r--r--test/benchmarks/harness/BUILD1
-rw-r--r--test/benchmarks/harness/util.go34
-rw-r--r--test/packetimpact/runner/defs.bzl12
-rw-r--r--test/packetimpact/runner/dut.go22
-rw-r--r--test/packetimpact/tests/BUILD22
-rw-r--r--test/packetimpact/tests/tcp_outside_the_window_closing_test.go86
-rw-r--r--test/packetimpact/tests/tcp_outside_the_window_test.go72
-rw-r--r--test/packetimpact/tests/tcp_unacc_seq_ack_closing_test.go94
-rw-r--r--test/packetimpact/tests/tcp_unacc_seq_ack_test.go63
-rw-r--r--test/syscalls/linux/BUILD4
-rw-r--r--test/syscalls/linux/read.cc40
-rw-r--r--test/syscalls/linux/setgid.cc95
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc111
-rw-r--r--test/syscalls/linux/write.cc78
-rw-r--r--tools/go_marshal/gomarshal/BUILD1
-rw-r--r--tools/go_marshal/gomarshal/generator.go31
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_dynamic.go96
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_struct.go339
-rw-r--r--tools/go_marshal/test/BUILD5
-rw-r--r--tools/go_marshal/test/dynamic.go83
-rw-r--r--tools/go_marshal/test/marshal_test.go33
-rw-r--r--tools/go_marshal/test/test.go35
-rw-r--r--tools/go_stateify/main.go2
119 files changed, 3649 insertions, 2187 deletions
diff --git a/.buildkite/pipeline.yaml b/.buildkite/pipeline.yaml
index aa2fd1f47..3bc5041c0 100644
--- a/.buildkite/pipeline.yaml
+++ b/.buildkite/pipeline.yaml
@@ -186,10 +186,14 @@ steps:
# For fio, running with --test.benchtime=Xs scales the written/read
# bytes to several GB. This is not a problem for root/bind/volume mounts,
# but for tmpfs mounts, the size can grow to more memory than the machine
- # has availabe. Fix the runs to 10GB written/read for the benchmark.
+ # has availabe. Fix the runs to 1GB written/read for the benchmark.
- <<: *benchmarks
- label: ":floppy_disk: FIO benchmarks"
- command: make benchmark-platforms BENCHMARKS_SUITE=fio BENCHMARKS_TARGETS=test/benchmarks/fs:fio_test BENCHMARKS_OPTIONS=--test.benchtime=10000x
+ label: ":floppy_disk: FIO benchmarks (read/write)"
+ command: make benchmark-platforms BENCHMARKS_SUITE=fio BENCHMARKS_TARGETS=test/benchmarks/fs:fio_test BENCHMARKS_FILTER=Fio/operation\.[rw][er] BENCHMARKS_OPTIONS=--test.benchtime=1000x
+ # For rand(read|write) fio benchmarks, running 15s does not overwhelm the system for tmpfs mounts.
+ - <<: *benchmarks
+ label: ":cd: FIO benchmarks (randread/randwrite)"
+ command: make benchmark-platforms BENCHMARKS_SUITE=fio BENCHMARKS_TARGETS=test/benchmarks/fs:fio_test BENCHMARKS_FILTER=Fio/operation\.rand BENCHMARKS_OPTIONS=--test.benchtime=15s
- <<: *benchmarks
label: ":globe_with_meridians: HTTPD benchmarks"
command: make benchmark-platforms BENCHMARKS_FILTER="Continuous" BENCHMARKS_SUITE=httpd BENCHMARKS_TARGETS=test/benchmarks/network:httpd_test
diff --git a/g3doc/user_guide/install.md b/g3doc/user_guide/install.md
index ad0ab9923..bcfba0179 100644
--- a/g3doc/user_guide/install.md
+++ b/g3doc/user_guide/install.md
@@ -59,7 +59,7 @@ Next, the configure the key used to sign archives and the repository:
```bash
curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add -
-sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main"
+sudo add-apt-repository "deb [arch=amd64,arm64] https://storage.googleapis.com/gvisor/releases release main"
```
Now the runsc package can be installed:
@@ -96,7 +96,7 @@ You can use this link with the steps described in
For `apt` installation, use the `master` to configure the repository:
```bash
-sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases master main"
+sudo add-apt-repository "deb [arch=amd64,arm64] https://storage.googleapis.com/gvisor/releases master main"
```
### Nightly
@@ -118,7 +118,7 @@ Note that a release may not be available for every day.
For `apt` installation, use the `nightly` to configure the repository:
```bash
-sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases nightly main"
+sudo add-apt-repository "deb [arch=amd64,arm64] https://storage.googleapis.com/gvisor/releases nightly main"
```
### Latest release
@@ -133,7 +133,7 @@ You can use this link with the steps described in
For `apt` installation, use the `release` to configure the repository:
```bash
-sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main"
+sudo add-apt-repository "deb [arch=amd64,arm64] https://storage.googleapis.com/gvisor/releases release main"
```
### Specific release
@@ -152,7 +152,7 @@ For `apt` installation of a specific release, which may include point updates,
use the date of the release for repository, e.g. `${yyyymmdd}`.
```bash
-sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases yyyymmdd main"
+sudo add-apt-repository "deb [arch=amd64,arm64] https://storage.googleapis.com/gvisor/releases yyyymmdd main"
```
> Note: only newer releases may be available as `apt` repositories.
diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go
index 2235f8968..648cf4b49 100644
--- a/pkg/p9/p9.go
+++ b/pkg/p9/p9.go
@@ -151,9 +151,16 @@ const (
// Sticky is a mode bit indicating sticky directories.
Sticky FileMode = 01000
+ // SetGID is the set group ID bit.
+ SetGID FileMode = 02000
+
+ // SetUID is the set user ID bit.
+ SetUID FileMode = 04000
+
// permissionsMask is the mask to apply to FileModes for permissions. It
- // includes rwx bits for user, group and others, and sticky bit.
- permissionsMask FileMode = 01777
+ // includes rwx bits for user, group, and others, as well as the sticky
+ // bit, setuid bit, and setgid bit.
+ permissionsMask FileMode = 07777
)
// QIDType is the most significant byte of the FileMode word, to be used as the
diff --git a/pkg/ring0/BUILD b/pkg/ring0/BUILD
index d1b14efdb..885958456 100644
--- a/pkg/ring0/BUILD
+++ b/pkg/ring0/BUILD
@@ -80,6 +80,7 @@ go_library(
"//pkg/ring0/pagetables",
"//pkg/safecopy",
"//pkg/sentry/arch",
+ "//pkg/sentry/arch/fpu",
"//pkg/usermem",
],
)
diff --git a/pkg/ring0/defs.go b/pkg/ring0/defs.go
index e8ce608ba..b6e2012e8 100644
--- a/pkg/ring0/defs.go
+++ b/pkg/ring0/defs.go
@@ -17,6 +17,7 @@ package ring0
import (
"gvisor.dev/gvisor/pkg/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
)
// Kernel is a global kernel object.
@@ -96,7 +97,7 @@ type SwitchOpts struct {
// FloatingPointState is a byte pointer where floating point state is
// saved and restored.
- FloatingPointState arch.FloatingPointData
+ FloatingPointState *fpu.State
// PageTables are the application page tables.
PageTables *pagetables.PageTables
diff --git a/pkg/ring0/gen_offsets/BUILD b/pkg/ring0/gen_offsets/BUILD
index 15b93d61c..f421e1687 100644
--- a/pkg/ring0/gen_offsets/BUILD
+++ b/pkg/ring0/gen_offsets/BUILD
@@ -35,6 +35,7 @@ go_binary(
"//pkg/cpuid",
"//pkg/ring0/pagetables",
"//pkg/sentry/arch",
+ "//pkg/sentry/arch/fpu",
"//pkg/usermem",
],
)
diff --git a/pkg/ring0/kernel_amd64.go b/pkg/ring0/kernel_amd64.go
index e9e706716..33c259757 100644
--- a/pkg/ring0/kernel_amd64.go
+++ b/pkg/ring0/kernel_amd64.go
@@ -239,17 +239,17 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
regs.Ss = uint64(Udata) // Ditto.
// Perform the switch.
- swapgs() // GS will be swapped on return.
- WriteFS(uintptr(regs.Fs_base)) // escapes: no. Set application FS.
- WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS.
- LoadFloatingPoint(&switchOpts.FloatingPointState[0]) // escapes: no. Copy in floating point.
+ swapgs() // GS will be swapped on return.
+ WriteFS(uintptr(regs.Fs_base)) // escapes: no. Set application FS.
+ WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS.
+ LoadFloatingPoint(switchOpts.FloatingPointState.BytePointer()) // escapes: no. Copy in floating point.
if switchOpts.FullRestore {
vector = iret(c, regs, uintptr(userCR3))
} else {
vector = sysret(c, regs, uintptr(userCR3))
}
- SaveFloatingPoint(&switchOpts.FloatingPointState[0]) // escapes: no. Copy out floating point.
- WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS.
+ SaveFloatingPoint(switchOpts.FloatingPointState.BytePointer()) // escapes: no. Copy out floating point.
+ WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS.
return
}
diff --git a/pkg/ring0/kernel_arm64.go b/pkg/ring0/kernel_arm64.go
index c9a120952..7975e5f92 100644
--- a/pkg/ring0/kernel_arm64.go
+++ b/pkg/ring0/kernel_arm64.go
@@ -62,7 +62,7 @@ func IsCanonical(addr uint64) bool {
//go:nosplit
func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
storeAppASID(uintptr(switchOpts.UserASID))
- storeEl0Fpstate(&switchOpts.FloatingPointState[0])
+ storeEl0Fpstate(switchOpts.FloatingPointState.BytePointer())
if switchOpts.Flush {
FlushTlbByASID(uintptr(switchOpts.UserASID))
@@ -82,7 +82,7 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
fpDisableTrap = CPACREL1()
if fpDisableTrap != 0 {
- SaveFloatingPoint(&switchOpts.FloatingPointState[0])
+ SaveFloatingPoint(switchOpts.FloatingPointState.BytePointer())
}
vector = c.vecCode
diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD
index 85278b389..f660f1614 100644
--- a/pkg/sentry/arch/BUILD
+++ b/pkg/sentry/arch/BUILD
@@ -9,7 +9,6 @@ go_library(
"arch.go",
"arch_aarch64.go",
"arch_amd64.go",
- "arch_amd64.s",
"arch_arm64.go",
"arch_state_x86.go",
"arch_x86.go",
@@ -36,8 +35,8 @@ go_library(
"//pkg/log",
"//pkg/marshal",
"//pkg/marshal/primitive",
+ "//pkg/sentry/arch/fpu",
"//pkg/sentry/limits",
- "//pkg/sync",
"//pkg/syserror",
"//pkg/usermem",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go
index 3443b9e1b..921151137 100644
--- a/pkg/sentry/arch/arch.go
+++ b/pkg/sentry/arch/arch.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -50,12 +51,6 @@ func (a Arch) String() string {
}
}
-// FloatingPointData is a generic type, and will always be passed as a pointer.
-// We rely on the individual arch implementations to meet all the necessary
-// requirements. For example, on x86 the region must be 16-byte aligned and 512
-// bytes in size.
-type FloatingPointData []byte
-
// Context provides architecture-dependent information for a specific thread.
//
// NOTE(b/34169503): Currently we use uintptr here to refer to a generic native
@@ -187,7 +182,7 @@ type Context interface {
ClearSingleStep()
// FloatingPointData will be passed to underlying save routines.
- FloatingPointData() FloatingPointData
+ FloatingPointData() *fpu.State
// NewMmapLayout returns a layout for a new MM, where MinAddr for the
// returned layout must be no lower than min, and MaxAddr for the returned
@@ -221,16 +216,6 @@ type Context interface {
// number of bytes read.
PtraceSetRegs(src io.Reader) (int, error)
- // PtraceGetFPRegs implements ptrace(PTRACE_GETFPREGS) by writing the
- // floating-point registers represented by this Context to addr in dst and
- // returning the number of bytes written.
- PtraceGetFPRegs(dst io.Writer) (int, error)
-
- // PtraceSetFPRegs implements ptrace(PTRACE_SETFPREGS) by reading
- // floating-point registers from src into this Context and returning the
- // number of bytes read.
- PtraceSetFPRegs(src io.Reader) (int, error)
-
// PtraceGetRegSet implements ptrace(PTRACE_GETREGSET) by writing the
// register set given by architecture-defined value regset from this
// Context to dst and returning the number of bytes written, which must be
@@ -365,18 +350,3 @@ func (a SyscallArgument) SizeT() uint {
func (a SyscallArgument) ModeT() uint {
return uint(uint16(a.Value))
}
-
-// ErrFloatingPoint indicates a failed restore due to unusable floating point
-// state.
-type ErrFloatingPoint struct {
- // supported is the supported floating point state.
- supported uint64
-
- // saved is the saved floating point state.
- saved uint64
-}
-
-// Error returns a sensible description of the restore error.
-func (e ErrFloatingPoint) Error() string {
- return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supported, e.saved)
-}
diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go
index 6b81e9708..08789f517 100644
--- a/pkg/sentry/arch/arch_aarch64.go
+++ b/pkg/sentry/arch/arch_aarch64.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -40,65 +41,11 @@ type Registers struct {
const (
// SyscallWidth is the width of insturctions.
SyscallWidth = 4
-
- // fpsimdMagic is the magic number which is used in fpsimd_context.
- fpsimdMagic = 0x46508001
-
- // fpsimdContextSize is the size of fpsimd_context.
- fpsimdContextSize = 0x210
)
// ARMTrapFlag is the mask for the trap flag.
const ARMTrapFlag = uint64(1) << 21
-// aarch64FPState is aarch64 floating point state.
-type aarch64FPState []byte
-
-// initAarch64FPState sets up initial state.
-//
-// Related code in Linux kernel: fpsimd_flush_thread().
-// FPCR = FPCR_RM_RN (0x0 << 22).
-//
-// Currently, aarch64FPState is only a space of 0x210 length for fpstate.
-// The fp head is useless in sentry/ptrace/kvm.
-//
-func initAarch64FPState(data aarch64FPState) {
-}
-
-func newAarch64FPStateSlice() []byte {
- return alignedBytes(4096, 16)[:fpsimdContextSize]
-}
-
-// newAarch64FPState returns an initialized floating point state.
-//
-// The returned state is large enough to store all floating point state
-// supported by host, even if the app won't use much of it due to a restricted
-// FeatureSet.
-func newAarch64FPState() aarch64FPState {
- f := aarch64FPState(newAarch64FPStateSlice())
- initAarch64FPState(f)
- return f
-}
-
-// fork creates and returns an identical copy of the aarch64 floating point state.
-func (f aarch64FPState) fork() aarch64FPState {
- n := aarch64FPState(newAarch64FPStateSlice())
- copy(n, f)
- return n
-}
-
-// FloatingPointData returns the raw data pointer.
-func (f aarch64FPState) FloatingPointData() FloatingPointData {
- return ([]byte)(f)
-}
-
-// NewFloatingPointData returns a new floating point data blob.
-//
-// This is primarily for use in tests.
-func NewFloatingPointData() FloatingPointData {
- return ([]byte)(newAarch64FPState())
-}
-
// State contains the common architecture bits for aarch64 (the build tag of this
// file ensures it's only built on aarch64).
//
@@ -108,7 +55,7 @@ type State struct {
Regs Registers
// Our floating point state.
- aarch64FPState `state:"wait"`
+ fpState fpu.State `state:"wait"`
// FeatureSet is a pointer to the currently active feature set.
FeatureSet *cpuid.FeatureSet
@@ -162,10 +109,10 @@ func (s State) Proto() *rpb.Registers {
// Fork creates and returns an identical copy of the state.
func (s *State) Fork() State {
return State{
- Regs: s.Regs,
- aarch64FPState: s.aarch64FPState.fork(),
- FeatureSet: s.FeatureSet,
- OrigR0: s.OrigR0,
+ Regs: s.Regs,
+ fpState: s.fpState.Fork(),
+ FeatureSet: s.FeatureSet,
+ OrigR0: s.OrigR0,
}
}
@@ -318,10 +265,10 @@ func New(arch Arch, fs *cpuid.FeatureSet) Context {
case ARM64:
return &context64{
State{
- aarch64FPState: newAarch64FPState(),
- FeatureSet: fs,
+ fpState: fpu.NewState(),
+ FeatureSet: fs,
},
- []aarch64FPState(nil),
+ []fpu.State(nil),
}
}
panic(fmt.Sprintf("unknown architecture %v", arch))
diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go
index 15d8ddb40..2571be60f 100644
--- a/pkg/sentry/arch/arch_amd64.go
+++ b/pkg/sentry/arch/arch_amd64.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -105,7 +106,7 @@ const (
// +stateify savable
type context64 struct {
State
- sigFPState []x86FPState // fpstate to be restored on sigreturn.
+ sigFPState []fpu.State // fpstate to be restored on sigreturn.
}
// Arch implements Context.Arch.
@@ -113,14 +114,18 @@ func (c *context64) Arch() Arch {
return AMD64
}
-func (c *context64) copySigFPState() []x86FPState {
- var sigfps []x86FPState
+func (c *context64) copySigFPState() []fpu.State {
+ var sigfps []fpu.State
for _, s := range c.sigFPState {
- sigfps = append(sigfps, s.fork())
+ sigfps = append(sigfps, s.Fork())
}
return sigfps
}
+func (c *context64) FloatingPointData() *fpu.State {
+ return &c.State.fpState
+}
+
// Fork returns an exact copy of this context.
func (c *context64) Fork() Context {
return &context64{
diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go
index 0c61a3ff7..14ad9483b 100644
--- a/pkg/sentry/arch/arch_arm64.go
+++ b/pkg/sentry/arch/arch_arm64.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -79,7 +80,7 @@ const (
// +stateify savable
type context64 struct {
State
- sigFPState []aarch64FPState // fpstate to be restored on sigreturn.
+ sigFPState []fpu.State // fpstate to be restored on sigreturn.
}
// Arch implements Context.Arch.
@@ -87,10 +88,10 @@ func (c *context64) Arch() Arch {
return ARM64
}
-func (c *context64) copySigFPState() []aarch64FPState {
- var sigfps []aarch64FPState
+func (c *context64) copySigFPState() []fpu.State {
+ var sigfps []fpu.State
for _, s := range c.sigFPState {
- sigfps = append(sigfps, s.fork())
+ sigfps = append(sigfps, s.Fork())
}
return sigfps
}
@@ -286,3 +287,7 @@ func (c *context64) PtracePokeUser(addr, data uintptr) error {
// TODO(gvisor.dev/issue/1239): Full ptrace supporting for Arm64.
return nil
}
+
+func (c *context64) FloatingPointData() *fpu.State {
+ return &c.State.fpState
+}
diff --git a/pkg/sentry/arch/arch_state_x86.go b/pkg/sentry/arch/arch_state_x86.go
index 840e53d33..b2b94c304 100644
--- a/pkg/sentry/arch/arch_state_x86.go
+++ b/pkg/sentry/arch/arch_state_x86.go
@@ -16,59 +16,7 @@
package arch
-import (
- "gvisor.dev/gvisor/pkg/cpuid"
- "gvisor.dev/gvisor/pkg/usermem"
-)
-
-// XSTATE_BV does not exist if FXSAVE is used, but FXSAVE implicitly saves x87
-// and SSE state, so this is the equivalent XSTATE_BV value.
-const fxsaveBV uint64 = cpuid.XSAVEFeatureX87 | cpuid.XSAVEFeatureSSE
-
// afterLoadFPState is invoked by afterLoad.
func (s *State) afterLoadFPState() {
- old := s.x86FPState
-
- // Recreate the slice. This is done to ensure that it is aligned
- // appropriately in memory, and large enough to accommodate any new
- // state that may be saved by the new CPU. Even if extraneous new state
- // is saved, the state we care about is guaranteed to be a subset of
- // new state. Later optimizations can use less space when using a
- // smaller state component bitmap. Intel SDM Volume 1 Chapter 13 has
- // more info.
- s.x86FPState = newX86FPState()
-
- // x86FPState always contains all the FP state supported by the host.
- // We may have come from a newer machine that supports additional state
- // which we cannot restore.
- //
- // The x86 FP state areas are backwards compatible, so we can simply
- // truncate the additional floating point state.
- //
- // Applications should not depend on the truncated state because it
- // should relate only to features that were not exposed in the app
- // FeatureSet. However, because we do not *prevent* them from using
- // this state, we must verify here that there is no in-use state
- // (according to XSTATE_BV) which we do not support.
- if len(s.x86FPState) < len(old) {
- // What do we support?
- supportedBV := fxsaveBV
- if fs := cpuid.HostFeatureSet(); fs.UseXsave() {
- supportedBV = fs.ValidXCR0Mask()
- }
-
- // What was in use?
- savedBV := fxsaveBV
- if len(old) >= xstateBVOffset+8 {
- savedBV = usermem.ByteOrder.Uint64(old[xstateBVOffset:])
- }
-
- // Supported features must be a superset of saved features.
- if savedBV&^supportedBV != 0 {
- panic(ErrFloatingPoint{supported: supportedBV, saved: savedBV})
- }
- }
-
- // Copy to the new, aligned location.
- copy(s.x86FPState, old)
+ s.fpState.AfterLoad()
}
diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go
index 91edf0703..e8e52d3a8 100644
--- a/pkg/sentry/arch/arch_x86.go
+++ b/pkg/sentry/arch/arch_x86.go
@@ -24,10 +24,9 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto"
- "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// Registers represents the CPU registers for this architecture.
@@ -111,57 +110,6 @@ var (
X86TrapFlag uint64 = (1 << 8)
)
-// x86FPState is x86 floating point state.
-type x86FPState []byte
-
-// initX86FPState (defined in asm files) sets up initial state.
-func initX86FPState(data *byte, useXsave bool)
-
-func newX86FPStateSlice() []byte {
- size, align := cpuid.HostFeatureSet().ExtendedStateSize()
- capacity := size
- // Always use at least 4096 bytes.
- //
- // For the KVM platform, this state is a fixed 4096 bytes, so make sure
- // that the underlying array is at _least_ that size otherwise we will
- // corrupt random memory. This is not a pleasant thing to debug.
- if capacity < 4096 {
- capacity = 4096
- }
- return alignedBytes(capacity, align)[:size]
-}
-
-// newX86FPState returns an initialized floating point state.
-//
-// The returned state is large enough to store all floating point state
-// supported by host, even if the app won't use much of it due to a restricted
-// FeatureSet. Since they may still be able to see state not advertised by
-// CPUID we must ensure it does not contain any sentry state.
-func newX86FPState() x86FPState {
- f := x86FPState(newX86FPStateSlice())
- initX86FPState(&f.FloatingPointData()[0], cpuid.HostFeatureSet().UseXsave())
- return f
-}
-
-// fork creates and returns an identical copy of the x86 floating point state.
-func (f x86FPState) fork() x86FPState {
- n := x86FPState(newX86FPStateSlice())
- copy(n, f)
- return n
-}
-
-// FloatingPointData returns the raw data pointer.
-func (f x86FPState) FloatingPointData() FloatingPointData {
- return []byte(f)
-}
-
-// NewFloatingPointData returns a new floating point data blob.
-//
-// This is primarily for use in tests.
-func NewFloatingPointData() FloatingPointData {
- return (FloatingPointData)(newX86FPState())
-}
-
// Proto returns a protobuf representation of the system registers in State.
func (s State) Proto() *rpb.Registers {
regs := &rpb.AMD64Registers{
@@ -200,7 +148,7 @@ func (s State) Proto() *rpb.Registers {
func (s *State) Fork() State {
return State{
Regs: s.Regs,
- x86FPState: s.x86FPState.fork(),
+ fpState: s.fpState.Fork(),
FeatureSet: s.FeatureSet,
}
}
@@ -393,149 +341,6 @@ func isValidSegmentBase(reg uint64) bool {
return reg < uint64(maxAddr64)
}
-// ptraceFPRegsSize is the size in bytes of Linux's user_i387_struct, the type
-// manipulated by PTRACE_GETFPREGS and PTRACE_SETFPREGS on x86. Equivalently,
-// ptraceFPRegsSize is the size in bytes of the x86 FXSAVE area.
-const ptraceFPRegsSize = 512
-
-// PtraceGetFPRegs implements Context.PtraceGetFPRegs.
-func (s *State) PtraceGetFPRegs(dst io.Writer) (int, error) {
- return dst.Write(s.x86FPState[:ptraceFPRegsSize])
-}
-
-// PtraceSetFPRegs implements Context.PtraceSetFPRegs.
-func (s *State) PtraceSetFPRegs(src io.Reader) (int, error) {
- var f [ptraceFPRegsSize]byte
- n, err := io.ReadFull(src, f[:])
- if err != nil {
- return 0, err
- }
- // Force reserved bits in MXCSR to 0. This is consistent with Linux.
- sanitizeMXCSR(x86FPState(f[:]))
- // N.B. this only copies the beginning of the FP state, which
- // corresponds to the FXSAVE area.
- copy(s.x86FPState, f[:])
- return n, nil
-}
-
-const (
- // mxcsrOffset is the offset in bytes of the MXCSR field from the start of
- // the FXSAVE area. (Intel SDM Vol. 1, Table 10-2 "Format of an FXSAVE
- // Area")
- mxcsrOffset = 24
-
- // mxcsrMaskOffset is the offset in bytes of the MXCSR_MASK field from the
- // start of the FXSAVE area.
- mxcsrMaskOffset = 28
-)
-
-var (
- mxcsrMask uint32
- initMXCSRMask sync.Once
-)
-
-// sanitizeMXCSR coerces reserved bits in the MXCSR field of f to 0. ("FXRSTOR
-// generates a general-protection fault (#GP) in response to an attempt to set
-// any of the reserved bits of the MXCSR register." - Intel SDM Vol. 1, Section
-// 10.5.1.2 "SSE State")
-func sanitizeMXCSR(f x86FPState) {
- mxcsr := usermem.ByteOrder.Uint32(f[mxcsrOffset:])
- initMXCSRMask.Do(func() {
- temp := x86FPState(alignedBytes(uint(ptraceFPRegsSize), 16))
- initX86FPState(&temp.FloatingPointData()[0], false /* useXsave */)
- mxcsrMask = usermem.ByteOrder.Uint32(temp[mxcsrMaskOffset:])
- if mxcsrMask == 0 {
- // "If the value of the MXCSR_MASK field is 00000000H, then the
- // MXCSR_MASK value is the default value of 0000FFBFH." - Intel SDM
- // Vol. 1, Section 11.6.6 "Guidelines for Writing to the MXCSR
- // Register"
- mxcsrMask = 0xffbf
- }
- })
- mxcsr &= mxcsrMask
- usermem.ByteOrder.PutUint32(f[mxcsrOffset:], mxcsr)
-}
-
-const (
- // minXstateBytes is the minimum size in bytes of an x86 XSAVE area, equal
- // to the size of the XSAVE legacy area (512 bytes) plus the size of the
- // XSAVE header (64 bytes). Equivalently, minXstateBytes is GDB's
- // X86_XSTATE_SSE_SIZE.
- minXstateBytes = 512 + 64
-
- // userXstateXCR0Offset is the offset in bytes of the USER_XSTATE_XCR0_WORD
- // field in Linux's struct user_xstateregs, which is the type manipulated
- // by ptrace(PTRACE_GET/SETREGSET, NT_X86_XSTATE). Equivalently,
- // userXstateXCR0Offset is GDB's I386_LINUX_XSAVE_XCR0_OFFSET.
- userXstateXCR0Offset = 464
-
- // xstateBVOffset is the offset in bytes of the XSTATE_BV field in an x86
- // XSAVE area.
- xstateBVOffset = 512
-
- // xsaveHeaderZeroedOffset and xsaveHeaderZeroedBytes indicate parts of the
- // XSAVE header that we coerce to zero: "Bytes 15:8 of the XSAVE header is
- // a state-component bitmap called XCOMP_BV. ... Bytes 63:16 of the XSAVE
- // header are reserved." - Intel SDM Vol. 1, Section 13.4.2 "XSAVE Header".
- // Linux ignores XCOMP_BV, but it's able to recover from XRSTOR #GP
- // exceptions resulting from invalid values; we aren't. Linux also never
- // uses the compacted format when doing XSAVE and doesn't even define the
- // compaction extensions to XSAVE as a CPU feature, so for simplicity we
- // assume no one is using them.
- xsaveHeaderZeroedOffset = 512 + 8
- xsaveHeaderZeroedBytes = 64 - 8
-)
-
-func (s *State) ptraceGetXstateRegs(dst io.Writer, maxlen int) (int, error) {
- // N.B. s.x86FPState may contain more state than the application
- // expects. We only copy the subset that would be in their XSAVE area.
- ess, _ := s.FeatureSet.ExtendedStateSize()
- f := make([]byte, ess)
- copy(f, s.x86FPState)
- // "The XSAVE feature set does not use bytes 511:416; bytes 463:416 are
- // reserved." - Intel SDM Vol 1., Section 13.4.1 "Legacy Region of an XSAVE
- // Area". Linux uses the first 8 bytes of this area to store the OS XSTATE
- // mask. GDB relies on this: see
- // gdb/x86-linux-nat.c:x86_linux_read_description().
- usermem.ByteOrder.PutUint64(f[userXstateXCR0Offset:], s.FeatureSet.ValidXCR0Mask())
- if len(f) > maxlen {
- f = f[:maxlen]
- }
- return dst.Write(f)
-}
-
-func (s *State) ptraceSetXstateRegs(src io.Reader, maxlen int) (int, error) {
- // Allow users to pass an xstate register set smaller than ours (they can
- // mask bits out of XSTATE_BV), as long as it's at least minXstateBytes.
- // Also allow users to pass a register set larger than ours; anything after
- // their ExtendedStateSize will be ignored. (I think Linux technically
- // permits setting a register set smaller than minXstateBytes, but it has
- // the same silent truncation behavior in kernel/ptrace.c:ptrace_regset().)
- if maxlen < minXstateBytes {
- return 0, unix.EFAULT
- }
- ess, _ := s.FeatureSet.ExtendedStateSize()
- if maxlen > int(ess) {
- maxlen = int(ess)
- }
- f := make([]byte, maxlen)
- if _, err := io.ReadFull(src, f); err != nil {
- return 0, err
- }
- // Force reserved bits in MXCSR to 0. This is consistent with Linux.
- sanitizeMXCSR(x86FPState(f))
- // Users can't enable *more* XCR0 bits than what we, and the CPU, support.
- xstateBV := usermem.ByteOrder.Uint64(f[xstateBVOffset:])
- xstateBV &= s.FeatureSet.ValidXCR0Mask()
- usermem.ByteOrder.PutUint64(f[xstateBVOffset:], xstateBV)
- // Force XCOMP_BV and reserved bytes in the XSAVE header to 0.
- reserved := f[xsaveHeaderZeroedOffset : xsaveHeaderZeroedOffset+xsaveHeaderZeroedBytes]
- for i := range reserved {
- reserved[i] = 0
- }
- return copy(s.x86FPState, f), nil
-}
-
// Register sets defined in include/uapi/linux/elf.h.
const (
_NT_PRSTATUS = 1
@@ -552,12 +357,9 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int,
}
return s.PtraceGetRegs(dst)
case _NT_PRFPREG:
- if maxlen < ptraceFPRegsSize {
- return 0, syserror.EFAULT
- }
- return s.PtraceGetFPRegs(dst)
+ return s.fpState.PtraceGetFPRegs(dst, maxlen)
case _NT_X86_XSTATE:
- return s.ptraceGetXstateRegs(dst, maxlen)
+ return s.fpState.PtraceGetXstateRegs(dst, maxlen, s.FeatureSet)
default:
return 0, syserror.EINVAL
}
@@ -572,12 +374,9 @@ func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int,
}
return s.PtraceSetRegs(src)
case _NT_PRFPREG:
- if maxlen < ptraceFPRegsSize {
- return 0, syserror.EFAULT
- }
- return s.PtraceSetFPRegs(src)
+ return s.fpState.PtraceSetFPRegs(src, maxlen)
case _NT_X86_XSTATE:
- return s.ptraceSetXstateRegs(src, maxlen)
+ return s.fpState.PtraceSetXstateRegs(src, maxlen, s.FeatureSet)
default:
return 0, syserror.EINVAL
}
@@ -609,10 +408,10 @@ func New(arch Arch, fs *cpuid.FeatureSet) Context {
case AMD64:
return &context64{
State{
- x86FPState: newX86FPState(),
+ fpState: fpu.NewState(),
FeatureSet: fs,
},
- []x86FPState(nil),
+ []fpu.State(nil),
}
}
panic(fmt.Sprintf("unknown architecture %v", arch))
diff --git a/pkg/sentry/arch/arch_x86_impl.go b/pkg/sentry/arch/arch_x86_impl.go
index 0c73fcbfb..5d7b99bd9 100644
--- a/pkg/sentry/arch/arch_x86_impl.go
+++ b/pkg/sentry/arch/arch_x86_impl.go
@@ -18,6 +18,7 @@ package arch
import (
"gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
)
// State contains the common architecture bits for X86 (the build tag of this
@@ -29,7 +30,7 @@ type State struct {
Regs Registers
// Our floating point state.
- x86FPState `state:"wait"`
+ fpState fpu.State `state:"wait"`
// FeatureSet is a pointer to the currently active feature set.
FeatureSet *cpuid.FeatureSet
diff --git a/pkg/sentry/arch/fpu/BUILD b/pkg/sentry/arch/fpu/BUILD
new file mode 100644
index 000000000..0a5395267
--- /dev/null
+++ b/pkg/sentry/arch/fpu/BUILD
@@ -0,0 +1,21 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "fpu",
+ srcs = [
+ "fpu.go",
+ "fpu_amd64.go",
+ "fpu_amd64.s",
+ "fpu_arm64.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/cpuid",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/arch/fpu/fpu.go b/pkg/sentry/arch/fpu/fpu.go
new file mode 100644
index 000000000..867d309a3
--- /dev/null
+++ b/pkg/sentry/arch/fpu/fpu.go
@@ -0,0 +1,54 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fpu provides basic floating point helpers.
+package fpu
+
+import (
+ "fmt"
+ "reflect"
+)
+
+// State represents floating point state.
+//
+// This is a simple byte slice, but may have architecture-specific methods
+// attached to it.
+type State []byte
+
+// ErrLoadingState indicates a failed restore due to unusable floating point
+// state.
+type ErrLoadingState struct {
+ // supported is the supported floating point state.
+ supportedFeatures uint64
+
+ // saved is the saved floating point state.
+ savedFeatures uint64
+}
+
+// Error returns a sensible description of the restore error.
+func (e ErrLoadingState) Error() string {
+ return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supportedFeatures, e.savedFeatures)
+}
+
+// alignedBytes returns a slice of size bytes, aligned in memory to the given
+// alignment. This is used because we require certain structures to be aligned
+// in a specific way (for example, the X86 floating point data).
+func alignedBytes(size, alignment uint) []byte {
+ data := make([]byte, size+alignment-1)
+ offset := uint(reflect.ValueOf(data).Index(0).Addr().Pointer() % uintptr(alignment))
+ if offset == 0 {
+ return data[:size:size]
+ }
+ return data[alignment-offset:][:size:size]
+}
diff --git a/pkg/sentry/arch/fpu/fpu_amd64.go b/pkg/sentry/arch/fpu/fpu_amd64.go
new file mode 100644
index 000000000..3a62f51be
--- /dev/null
+++ b/pkg/sentry/arch/fpu/fpu_amd64.go
@@ -0,0 +1,280 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64 i386
+
+package fpu
+
+import (
+ "io"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// initX86FPState (defined in asm files) sets up initial state.
+func initX86FPState(data *byte, useXsave bool)
+
+func newX86FPStateSlice() State {
+ size, align := cpuid.HostFeatureSet().ExtendedStateSize()
+ capacity := size
+ // Always use at least 4096 bytes.
+ //
+ // For the KVM platform, this state is a fixed 4096 bytes, so make sure
+ // that the underlying array is at _least_ that size otherwise we will
+ // corrupt random memory. This is not a pleasant thing to debug.
+ if capacity < 4096 {
+ capacity = 4096
+ }
+ return alignedBytes(capacity, align)[:size]
+}
+
+// NewState returns an initialized floating point state.
+//
+// The returned state is large enough to store all floating point state
+// supported by host, even if the app won't use much of it due to a restricted
+// FeatureSet. Since they may still be able to see state not advertised by
+// CPUID we must ensure it does not contain any sentry state.
+func NewState() State {
+ f := newX86FPStateSlice()
+ initX86FPState(&f[0], cpuid.HostFeatureSet().UseXsave())
+ return f
+}
+
+// Fork creates and returns an identical copy of the x86 floating point state.
+func (s *State) Fork() State {
+ n := newX86FPStateSlice()
+ copy(n, *s)
+ return n
+}
+
+// ptraceFPRegsSize is the size in bytes of Linux's user_i387_struct, the type
+// manipulated by PTRACE_GETFPREGS and PTRACE_SETFPREGS on x86. Equivalently,
+// ptraceFPRegsSize is the size in bytes of the x86 FXSAVE area.
+const ptraceFPRegsSize = 512
+
+// PtraceGetFPRegs implements Context.PtraceGetFPRegs.
+func (s *State) PtraceGetFPRegs(dst io.Writer, maxlen int) (int, error) {
+ if maxlen < ptraceFPRegsSize {
+ return 0, syserror.EFAULT
+ }
+
+ return dst.Write((*s)[:ptraceFPRegsSize])
+}
+
+// PtraceSetFPRegs implements Context.PtraceSetFPRegs.
+func (s *State) PtraceSetFPRegs(src io.Reader, maxlen int) (int, error) {
+ if maxlen < ptraceFPRegsSize {
+ return 0, syserror.EFAULT
+ }
+
+ var f [ptraceFPRegsSize]byte
+ n, err := io.ReadFull(src, f[:])
+ if err != nil {
+ return 0, err
+ }
+ // Force reserved bits in MXCSR to 0. This is consistent with Linux.
+ sanitizeMXCSR(State(f[:]))
+ // N.B. this only copies the beginning of the FP state, which
+ // corresponds to the FXSAVE area.
+ copy(*s, f[:])
+ return n, nil
+}
+
+const (
+ // mxcsrOffset is the offset in bytes of the MXCSR field from the start of
+ // the FXSAVE area. (Intel SDM Vol. 1, Table 10-2 "Format of an FXSAVE
+ // Area")
+ mxcsrOffset = 24
+
+ // mxcsrMaskOffset is the offset in bytes of the MXCSR_MASK field from the
+ // start of the FXSAVE area.
+ mxcsrMaskOffset = 28
+)
+
+var (
+ mxcsrMask uint32
+ initMXCSRMask sync.Once
+)
+
+const (
+ // minXstateBytes is the minimum size in bytes of an x86 XSAVE area, equal
+ // to the size of the XSAVE legacy area (512 bytes) plus the size of the
+ // XSAVE header (64 bytes). Equivalently, minXstateBytes is GDB's
+ // X86_XSTATE_SSE_SIZE.
+ minXstateBytes = 512 + 64
+
+ // userXstateXCR0Offset is the offset in bytes of the USER_XSTATE_XCR0_WORD
+ // field in Linux's struct user_xstateregs, which is the type manipulated
+ // by ptrace(PTRACE_GET/SETREGSET, NT_X86_XSTATE). Equivalently,
+ // userXstateXCR0Offset is GDB's I386_LINUX_XSAVE_XCR0_OFFSET.
+ userXstateXCR0Offset = 464
+
+ // xstateBVOffset is the offset in bytes of the XSTATE_BV field in an x86
+ // XSAVE area.
+ xstateBVOffset = 512
+
+ // xsaveHeaderZeroedOffset and xsaveHeaderZeroedBytes indicate parts of the
+ // XSAVE header that we coerce to zero: "Bytes 15:8 of the XSAVE header is
+ // a state-component bitmap called XCOMP_BV. ... Bytes 63:16 of the XSAVE
+ // header are reserved." - Intel SDM Vol. 1, Section 13.4.2 "XSAVE Header".
+ // Linux ignores XCOMP_BV, but it's able to recover from XRSTOR #GP
+ // exceptions resulting from invalid values; we aren't. Linux also never
+ // uses the compacted format when doing XSAVE and doesn't even define the
+ // compaction extensions to XSAVE as a CPU feature, so for simplicity we
+ // assume no one is using them.
+ xsaveHeaderZeroedOffset = 512 + 8
+ xsaveHeaderZeroedBytes = 64 - 8
+)
+
+// sanitizeMXCSR coerces reserved bits in the MXCSR field of f to 0. ("FXRSTOR
+// generates a general-protection fault (#GP) in response to an attempt to set
+// any of the reserved bits of the MXCSR register." - Intel SDM Vol. 1, Section
+// 10.5.1.2 "SSE State")
+func sanitizeMXCSR(f State) {
+ mxcsr := usermem.ByteOrder.Uint32(f[mxcsrOffset:])
+ initMXCSRMask.Do(func() {
+ temp := State(alignedBytes(uint(ptraceFPRegsSize), 16))
+ initX86FPState(&temp[0], false /* useXsave */)
+ mxcsrMask = usermem.ByteOrder.Uint32(temp[mxcsrMaskOffset:])
+ if mxcsrMask == 0 {
+ // "If the value of the MXCSR_MASK field is 00000000H, then the
+ // MXCSR_MASK value is the default value of 0000FFBFH." - Intel SDM
+ // Vol. 1, Section 11.6.6 "Guidelines for Writing to the MXCSR
+ // Register"
+ mxcsrMask = 0xffbf
+ }
+ })
+ mxcsr &= mxcsrMask
+ usermem.ByteOrder.PutUint32(f[mxcsrOffset:], mxcsr)
+}
+
+// PtraceGetXstateRegs implements ptrace(PTRACE_GETREGS, NT_X86_XSTATE) by
+// writing the floating point registers from this state to dst and returning the
+// number of bytes written, which must be less than or equal to maxlen.
+func (s *State) PtraceGetXstateRegs(dst io.Writer, maxlen int, featureSet *cpuid.FeatureSet) (int, error) {
+ // N.B. s.x86FPState may contain more state than the application
+ // expects. We only copy the subset that would be in their XSAVE area.
+ ess, _ := featureSet.ExtendedStateSize()
+ f := make([]byte, ess)
+ copy(f, *s)
+ // "The XSAVE feature set does not use bytes 511:416; bytes 463:416 are
+ // reserved." - Intel SDM Vol 1., Section 13.4.1 "Legacy Region of an XSAVE
+ // Area". Linux uses the first 8 bytes of this area to store the OS XSTATE
+ // mask. GDB relies on this: see
+ // gdb/x86-linux-nat.c:x86_linux_read_description().
+ usermem.ByteOrder.PutUint64(f[userXstateXCR0Offset:], featureSet.ValidXCR0Mask())
+ if len(f) > maxlen {
+ f = f[:maxlen]
+ }
+ return dst.Write(f)
+}
+
+// PtraceSetXstateRegs implements ptrace(PTRACE_SETREGS, NT_X86_XSTATE) by
+// reading floating point registers from src and returning the number of bytes
+// read, which must be less than or equal to maxlen.
+func (s *State) PtraceSetXstateRegs(src io.Reader, maxlen int, featureSet *cpuid.FeatureSet) (int, error) {
+ // Allow users to pass an xstate register set smaller than ours (they can
+ // mask bits out of XSTATE_BV), as long as it's at least minXstateBytes.
+ // Also allow users to pass a register set larger than ours; anything after
+ // their ExtendedStateSize will be ignored. (I think Linux technically
+ // permits setting a register set smaller than minXstateBytes, but it has
+ // the same silent truncation behavior in kernel/ptrace.c:ptrace_regset().)
+ if maxlen < minXstateBytes {
+ return 0, unix.EFAULT
+ }
+ ess, _ := featureSet.ExtendedStateSize()
+ if maxlen > int(ess) {
+ maxlen = int(ess)
+ }
+ f := make([]byte, maxlen)
+ if _, err := io.ReadFull(src, f); err != nil {
+ return 0, err
+ }
+ // Force reserved bits in MXCSR to 0. This is consistent with Linux.
+ sanitizeMXCSR(State(f))
+ // Users can't enable *more* XCR0 bits than what we, and the CPU, support.
+ xstateBV := usermem.ByteOrder.Uint64(f[xstateBVOffset:])
+ xstateBV &= featureSet.ValidXCR0Mask()
+ usermem.ByteOrder.PutUint64(f[xstateBVOffset:], xstateBV)
+ // Force XCOMP_BV and reserved bytes in the XSAVE header to 0.
+ reserved := f[xsaveHeaderZeroedOffset : xsaveHeaderZeroedOffset+xsaveHeaderZeroedBytes]
+ for i := range reserved {
+ reserved[i] = 0
+ }
+ return copy(*s, f), nil
+}
+
+// BytePointer returns a pointer to the first byte of the state.
+//
+//go:nosplit
+func (s *State) BytePointer() *byte {
+ return &(*s)[0]
+}
+
+// XSTATE_BV does not exist if FXSAVE is used, but FXSAVE implicitly saves x87
+// and SSE state, so this is the equivalent XSTATE_BV value.
+const fxsaveBV uint64 = cpuid.XSAVEFeatureX87 | cpuid.XSAVEFeatureSSE
+
+// AfterLoad converts the loaded state to the format that compatible with the
+// current processor.
+func (s *State) AfterLoad() {
+ old := *s
+
+ // Recreate the slice. This is done to ensure that it is aligned
+ // appropriately in memory, and large enough to accommodate any new
+ // state that may be saved by the new CPU. Even if extraneous new state
+ // is saved, the state we care about is guaranteed to be a subset of
+ // new state. Later optimizations can use less space when using a
+ // smaller state component bitmap. Intel SDM Volume 1 Chapter 13 has
+ // more info.
+ *s = NewState()
+
+ // x86FPState always contains all the FP state supported by the host.
+ // We may have come from a newer machine that supports additional state
+ // which we cannot restore.
+ //
+ // The x86 FP state areas are backwards compatible, so we can simply
+ // truncate the additional floating point state.
+ //
+ // Applications should not depend on the truncated state because it
+ // should relate only to features that were not exposed in the app
+ // FeatureSet. However, because we do not *prevent* them from using
+ // this state, we must verify here that there is no in-use state
+ // (according to XSTATE_BV) which we do not support.
+ if len(*s) < len(old) {
+ // What do we support?
+ supportedBV := fxsaveBV
+ if fs := cpuid.HostFeatureSet(); fs.UseXsave() {
+ supportedBV = fs.ValidXCR0Mask()
+ }
+
+ // What was in use?
+ savedBV := fxsaveBV
+ if len(old) >= xstateBVOffset+8 {
+ savedBV = usermem.ByteOrder.Uint64(old[xstateBVOffset:])
+ }
+
+ // Supported features must be a superset of saved features.
+ if savedBV&^supportedBV != 0 {
+ panic(ErrLoadingState{supportedFeatures: supportedBV, savedFeatures: savedBV})
+ }
+ }
+
+ // Copy to the new, aligned location.
+ copy(*s, old)
+}
diff --git a/pkg/sentry/arch/arch_amd64.s b/pkg/sentry/arch/fpu/fpu_amd64.s
index 6c10336e7..6c10336e7 100644
--- a/pkg/sentry/arch/arch_amd64.s
+++ b/pkg/sentry/arch/fpu/fpu_amd64.s
diff --git a/pkg/sentry/arch/fpu/fpu_arm64.go b/pkg/sentry/arch/fpu/fpu_arm64.go
new file mode 100644
index 000000000..d2f62631d
--- /dev/null
+++ b/pkg/sentry/arch/fpu/fpu_arm64.go
@@ -0,0 +1,63 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package fpu
+
+const (
+ // fpsimdMagic is the magic number which is used in fpsimd_context.
+ fpsimdMagic = 0x46508001
+
+ // fpsimdContextSize is the size of fpsimd_context.
+ fpsimdContextSize = 0x210
+)
+
+// initAarch64FPState sets up initial state.
+//
+// Related code in Linux kernel: fpsimd_flush_thread().
+// FPCR = FPCR_RM_RN (0x0 << 22).
+//
+// Currently, aarch64FPState is only a space of 0x210 length for fpstate.
+// The fp head is useless in sentry/ptrace/kvm.
+//
+func initAarch64FPState(data *State) {
+}
+
+func newAarch64FPStateSlice() []byte {
+ return alignedBytes(4096, 16)[:fpsimdContextSize]
+}
+
+// NewState returns an initialized floating point state.
+//
+// The returned state is large enough to store all floating point state
+// supported by host, even if the app won't use much of it due to a restricted
+// FeatureSet.
+func NewState() State {
+ f := State(newAarch64FPStateSlice())
+ initAarch64FPState(&f)
+ return f
+}
+
+// Fork creates and returns an identical copy of the aarch64 floating point state.
+func (s *State) Fork() State {
+ n := State(newAarch64FPStateSlice())
+ copy(n, *s)
+ return n
+}
+
+// BytePointer returns a pointer to the first byte of the state.
+func (s *State) BytePointer() *byte {
+ return &(*s)[0]
+}
diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go
index e6557cab6..ee3743483 100644
--- a/pkg/sentry/arch/signal_amd64.go
+++ b/pkg/sentry/arch/signal_amd64.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/marshal/primitive"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -98,7 +99,7 @@ func (c *context64) NewSignalStack() NativeSignalStack {
const _FP_XSTATE_MAGIC2_SIZE = 4
func (c *context64) fpuFrameSize() (size int, useXsave bool) {
- size = len(c.x86FPState)
+ size = len(c.fpState)
if size > 512 {
// Make room for the magic cookie at the end of the xsave frame.
size += _FP_XSTATE_MAGIC2_SIZE
@@ -226,10 +227,10 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
c.Regs.Ss = userDS
// Save the thread's floating point state.
- c.sigFPState = append(c.sigFPState, c.x86FPState)
+ c.sigFPState = append(c.sigFPState, c.fpState)
// Signal handler gets a clean floating point state.
- c.x86FPState = newX86FPState()
+ c.fpState = fpu.NewState()
return nil
}
@@ -273,7 +274,7 @@ func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalSt
// Restore floating point state.
l := len(c.sigFPState)
if l > 0 {
- c.x86FPState = c.sigFPState[l-1]
+ c.fpState = c.sigFPState[l-1]
// NOTE(cl/133042258): State save requires that any slice
// elements from '[len:cap]' to be zero value.
c.sigFPState[l-1] = nil
diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go
index 4491008c2..53281dcba 100644
--- a/pkg/sentry/arch/signal_arm64.go
+++ b/pkg/sentry/arch/signal_arm64.go
@@ -20,6 +20,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -139,9 +140,9 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
c.Regs.Regs[30] = uint64(act.Restorer)
// Save the thread's floating point state.
- c.sigFPState = append(c.sigFPState, c.aarch64FPState)
+ c.sigFPState = append(c.sigFPState, c.fpState)
// Signal handler gets a clean floating point state.
- c.aarch64FPState = newAarch64FPState()
+ c.fpState = fpu.NewState()
return nil
}
@@ -166,7 +167,7 @@ func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalSt
// Restore floating point state.
l := len(c.sigFPState)
if l > 0 {
- c.aarch64FPState = c.sigFPState[l-1]
+ c.fpState = c.sigFPState[l-1]
// NOTE(cl/133042258): State save requires that any slice
// elements from '[len:cap]' to be zero value.
c.sigFPState[l-1] = nil
diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go
index 82eda3e43..0ed7aafa5 100644
--- a/pkg/sentry/fs/fsutil/inode_cached.go
+++ b/pkg/sentry/fs/fsutil/inode_cached.go
@@ -380,16 +380,17 @@ func (c *CachingInodeOperations) Allocate(ctx context.Context, offset, length in
return nil
}
-// WriteOut implements fs.InodeOperations.WriteOut.
-func (c *CachingInodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) error {
+// WriteDirtyPagesAndAttrs will write the dirty pages and attributes to the
+// gofer without calling Fsync on the remote file.
+func (c *CachingInodeOperations) WriteDirtyPagesAndAttrs(ctx context.Context, inode *fs.Inode) error {
c.attrMu.Lock()
+ defer c.attrMu.Unlock()
+ c.dataMu.Lock()
+ defer c.dataMu.Unlock()
// Write dirty pages back.
- c.dataMu.Lock()
err := SyncDirtyAll(ctx, &c.cache, &c.dirty, uint64(c.attr.Size), c.mfp.MemoryFile(), c.backingFile.WriteFromBlocksAt)
- c.dataMu.Unlock()
if err != nil {
- c.attrMu.Unlock()
return err
}
@@ -399,12 +400,18 @@ func (c *CachingInodeOperations) WriteOut(ctx context.Context, inode *fs.Inode)
// Write out cached attributes.
if err := c.backingFile.SetMaskedAttributes(ctx, c.dirtyAttr, c.attr, false); err != nil {
- c.attrMu.Unlock()
return err
}
c.dirtyAttr = fs.AttrMask{}
- c.attrMu.Unlock()
+ return nil
+}
+
+// WriteOut implements fs.InodeOperations.WriteOut.
+func (c *CachingInodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) error {
+ if err := c.WriteDirtyPagesAndAttrs(ctx, inode); err != nil {
+ return err
+ }
// Fsync the remote file.
return c.backingFile.Sync(ctx)
diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go
index 06d450ba6..8f5a87120 100644
--- a/pkg/sentry/fs/gofer/file.go
+++ b/pkg/sentry/fs/gofer/file.go
@@ -204,20 +204,8 @@ func (f *fileOperations) readdirAll(ctx context.Context) (map[string]fs.DentAttr
return entries, nil
}
-// maybeSync will call FSync on the file if either the cache policy or file
-// flags require it.
+// maybeSync will call FSync on the file if the file flags require it.
func (f *fileOperations) maybeSync(ctx context.Context, file *fs.File, offset, n int64) error {
- if n == 0 {
- // Nothing to sync.
- return nil
- }
-
- if f.inodeOperations.session().cachePolicy.writeThrough(file.Dirent.Inode) {
- // Call WriteOut directly, as some "writethrough" filesystems
- // do not support sync.
- return f.inodeOperations.cachingInodeOps.WriteOut(ctx, file.Dirent.Inode)
- }
-
flags := file.Flags()
var syncType fs.SyncType
switch {
@@ -254,6 +242,19 @@ func (f *fileOperations) Write(ctx context.Context, file *fs.File, src usermem.I
n, err = src.CopyInTo(ctx, f.handles.readWriterAt(ctx, offset))
}
+ if n == 0 {
+ // Nothing written. We are done.
+ return 0, err
+ }
+
+ // Write the dirty pages and attributes if cache policy tells us to.
+ if f.inodeOperations.session().cachePolicy.writeThrough(file.Dirent.Inode) {
+ if werr := f.inodeOperations.cachingInodeOps.WriteDirtyPagesAndAttrs(ctx, file.Dirent.Inode); werr != nil {
+ // Report no bytes written since the write faild.
+ return 0, werr
+ }
+ }
+
// We may need to sync the written bytes.
if syncErr := f.maybeSync(ctx, file, offset, n); syncErr != nil {
// Sync failed. Report 0 bytes written, since none of them are
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index c34451269..43c3c5a2d 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -783,7 +783,15 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
creds := rp.Credentials()
return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string, _ **[]*dentry) error {
- if _, err := parent.file.mkdir(ctx, name, (p9.FileMode)(opts.Mode), (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)); err != nil {
+ // If the parent is a setgid directory, use the parent's GID
+ // rather than the caller's and enable setgid.
+ kgid := creds.EffectiveKGID
+ mode := opts.Mode
+ if atomic.LoadUint32(&parent.mode)&linux.S_ISGID != 0 {
+ kgid = auth.KGID(atomic.LoadUint32(&parent.gid))
+ mode |= linux.S_ISGID
+ }
+ if _, err := parent.file.mkdir(ctx, name, p9.FileMode(mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid)); err != nil {
if !opts.ForSyntheticMountpoint || err == syserror.EEXIST {
return err
}
@@ -1145,7 +1153,15 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
name := rp.Component()
// We only want the access mode for creating the file.
createFlags := p9.OpenFlags(opts.Flags) & p9.OpenFlagsModeMask
- fdobj, openFile, createQID, _, err := dirfile.create(ctx, name, createFlags, (p9.FileMode)(opts.Mode), (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
+
+ // If the parent is a setgid directory, use the parent's GID rather
+ // than the caller's.
+ kgid := creds.EffectiveKGID
+ if atomic.LoadUint32(&d.mode)&linux.S_ISGID != 0 {
+ kgid = auth.KGID(atomic.LoadUint32(&d.gid))
+ }
+
+ fdobj, openFile, createQID, _, err := dirfile.create(ctx, name, createFlags, p9.FileMode(opts.Mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid))
if err != nil {
dirfile.close(ctx)
return nil, err
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 71569dc65..692da02c1 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -1102,10 +1102,26 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
d.metadataMu.Lock()
defer d.metadataMu.Unlock()
+
+ // As with Linux, if the UID, GID, or file size is changing, we have to
+ // clear permission bits. Note that when set, clearSGID causes
+ // permissions to be updated, but does not modify stat.Mask, as
+ // modification would cause an extra inotify flag to be set.
+ clearSGID := stat.Mask&linux.STATX_UID != 0 && stat.UID != atomic.LoadUint32(&d.uid) ||
+ stat.Mask&linux.STATX_GID != 0 && stat.GID != atomic.LoadUint32(&d.gid) ||
+ stat.Mask&linux.STATX_SIZE != 0
+ if clearSGID {
+ if stat.Mask&linux.STATX_MODE != 0 {
+ stat.Mode = uint16(vfs.ClearSUIDAndSGID(uint32(stat.Mode)))
+ } else {
+ stat.Mode = uint16(vfs.ClearSUIDAndSGID(atomic.LoadUint32(&d.mode)))
+ }
+ }
+
if !d.isSynthetic() {
if stat.Mask != 0 {
if err := d.file.setAttr(ctx, p9.SetAttrMask{
- Permissions: stat.Mask&linux.STATX_MODE != 0,
+ Permissions: stat.Mask&linux.STATX_MODE != 0 || clearSGID,
UID: stat.Mask&linux.STATX_UID != 0,
GID: stat.Mask&linux.STATX_GID != 0,
Size: stat.Mask&linux.STATX_SIZE != 0,
@@ -1140,7 +1156,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
return nil
}
}
- if stat.Mask&linux.STATX_MODE != 0 {
+ if stat.Mask&linux.STATX_MODE != 0 || clearSGID {
atomic.StoreUint32(&d.mode, d.fileType()|uint32(stat.Mode))
}
if stat.Mask&linux.STATX_UID != 0 {
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index 283b220bb..4f1ad0c88 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -266,6 +266,20 @@ func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off
return 0, offset, err
}
}
+
+ // As with Linux, writing clears the setuid and setgid bits.
+ if n > 0 {
+ oldMode := atomic.LoadUint32(&d.mode)
+ // If setuid or setgid were set, update d.mode and propagate
+ // changes to the host.
+ if newMode := vfs.ClearSUIDAndSGID(oldMode); newMode != oldMode {
+ atomic.StoreUint32(&d.mode, newMode)
+ if err := d.file.setAttr(ctx, p9.SetAttrMask{Permissions: true}, p9.SetAttr{Permissions: p9.FileMode(newMode)}); err != nil {
+ return 0, offset, err
+ }
+ }
+ }
+
return n, offset + n, nil
}
diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go
index 84e37f793..46c500427 100644
--- a/pkg/sentry/fsimpl/overlay/filesystem.go
+++ b/pkg/sentry/fsimpl/overlay/filesystem.go
@@ -689,13 +689,9 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
}
return err
}
- creds := rp.Credentials()
+
if err := vfsObj.SetStatAt(ctx, fs.creds, &pop, &vfs.SetStatOptions{
- Stat: linux.Statx{
- Mask: linux.STATX_UID | linux.STATX_GID,
- UID: uint32(creds.EffectiveKUID),
- GID: uint32(creds.EffectiveKGID),
- },
+ Stat: parent.newChildOwnerStat(opts.Mode, rp.Credentials()),
}); err != nil {
if cleanupErr := vfsObj.RmdirAt(ctx, fs.creds, &pop); cleanupErr != nil {
panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer directory after MkdirAt metadata update failure: %v", cleanupErr))
@@ -750,11 +746,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
}
creds := rp.Credentials()
if err := vfsObj.SetStatAt(ctx, fs.creds, &pop, &vfs.SetStatOptions{
- Stat: linux.Statx{
- Mask: linux.STATX_UID | linux.STATX_GID,
- UID: uint32(creds.EffectiveKUID),
- GID: uint32(creds.EffectiveKGID),
- },
+ Stat: parent.newChildOwnerStat(opts.Mode, creds),
}); err != nil {
if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil {
panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after MknodAt metadata update failure: %v", cleanupErr))
@@ -963,14 +955,11 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving
}
return nil, err
}
+
// Change the file's owner to the caller. We can't use upperFD.SetStat()
// because it will pick up creds from ctx.
if err := vfsObj.SetStatAt(ctx, fs.creds, &pop, &vfs.SetStatOptions{
- Stat: linux.Statx{
- Mask: linux.STATX_UID | linux.STATX_GID,
- UID: uint32(creds.EffectiveKUID),
- GID: uint32(creds.EffectiveKGID),
- },
+ Stat: parent.newChildOwnerStat(opts.Mode, creds),
}); err != nil {
if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil {
panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) metadata update failure: %v", cleanupErr))
diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go
index 58680bc80..454c20d4f 100644
--- a/pkg/sentry/fsimpl/overlay/overlay.go
+++ b/pkg/sentry/fsimpl/overlay/overlay.go
@@ -749,6 +749,27 @@ func (d *dentry) mayDelete(creds *auth.Credentials, child *dentry) error {
)
}
+// newChildOwnerStat returns a Statx for configuring the UID, GID, and mode of
+// children.
+func (d *dentry) newChildOwnerStat(mode linux.FileMode, creds *auth.Credentials) linux.Statx {
+ stat := linux.Statx{
+ Mask: uint32(linux.STATX_UID | linux.STATX_GID),
+ UID: uint32(creds.EffectiveKUID),
+ GID: uint32(creds.EffectiveKGID),
+ }
+ // Set GID and possibly the SGID bit if the parent is an SGID directory.
+ d.copyMu.RLock()
+ defer d.copyMu.RUnlock()
+ if atomic.LoadUint32(&d.mode)&linux.ModeSetGID == linux.ModeSetGID {
+ stat.GID = atomic.LoadUint32(&d.gid)
+ if stat.Mode&linux.ModeDirectory == linux.ModeDirectory {
+ stat.Mode = uint16(mode) | linux.ModeSetGID
+ stat.Mask |= linux.STATX_MODE
+ }
+ }
+ return stat
+}
+
// fileDescription is embedded by overlay implementations of
// vfs.FileDescriptionImpl.
//
diff --git a/pkg/sentry/fsimpl/overlay/regular_file.go b/pkg/sentry/fsimpl/overlay/regular_file.go
index 25c785fd4..d791c06db 100644
--- a/pkg/sentry/fsimpl/overlay/regular_file.go
+++ b/pkg/sentry/fsimpl/overlay/regular_file.go
@@ -205,6 +205,20 @@ func (fd *regularFileFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) e
if err := wrappedFD.SetStat(ctx, opts); err != nil {
return err
}
+
+ // Changing owners may clear one or both of the setuid and setgid bits,
+ // so we may have to update opts before setting d.mode.
+ if opts.Stat.Mask&(linux.STATX_UID|linux.STATX_GID) != 0 {
+ stat, err := wrappedFD.Stat(ctx, vfs.StatOptions{
+ Mask: linux.STATX_MODE,
+ })
+ if err != nil {
+ return err
+ }
+ opts.Stat.Mode = stat.Mode
+ opts.Stat.Mask |= linux.STATX_MODE
+ }
+
d.updateAfterSetStatLocked(&opts)
if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent)
@@ -295,7 +309,11 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
return 0, err
}
defer wrappedFD.DecRef(ctx)
- return wrappedFD.PWrite(ctx, src, offset, opts)
+ n, err := wrappedFD.PWrite(ctx, src, offset, opts)
+ if err != nil {
+ return n, err
+ }
+ return fd.updateSetUserGroupIDs(ctx, wrappedFD, n)
}
// Write implements vfs.FileDescriptionImpl.Write.
@@ -307,7 +325,28 @@ func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts
if err != nil {
return 0, err
}
- return wrappedFD.Write(ctx, src, opts)
+ n, err := wrappedFD.Write(ctx, src, opts)
+ if err != nil {
+ return n, err
+ }
+ return fd.updateSetUserGroupIDs(ctx, wrappedFD, n)
+}
+
+func (fd *regularFileFD) updateSetUserGroupIDs(ctx context.Context, wrappedFD *vfs.FileDescription, written int64) (int64, error) {
+ // Writing can clear the setuid and/or setgid bits. We only have to
+ // check this if something was written and one of those bits was set.
+ dentry := fd.dentry()
+ if written == 0 || atomic.LoadUint32(&dentry.mode)&(linux.S_ISUID|linux.S_ISGID) == 0 {
+ return written, nil
+ }
+ stat, err := wrappedFD.Stat(ctx, vfs.StatOptions{Mask: linux.STATX_MODE})
+ if err != nil {
+ return written, err
+ }
+ dentry.copyMu.Lock()
+ defer dentry.copyMu.Unlock()
+ atomic.StoreUint32(&dentry.mode, uint32(stat.Mode))
+ return written, nil
}
// Seek implements vfs.FileDescriptionImpl.Seek.
diff --git a/pkg/sentry/kernel/ptrace_amd64.go b/pkg/sentry/kernel/ptrace_amd64.go
index 609ad3941..7aea3dcd8 100644
--- a/pkg/sentry/kernel/ptrace_amd64.go
+++ b/pkg/sentry/kernel/ptrace_amd64.go
@@ -51,14 +51,15 @@ func (t *Task) ptraceArch(target *Task, req int64, addr, data usermem.Addr) erro
return err
case linux.PTRACE_GETFPREGS:
- _, err := target.Arch().PtraceGetFPRegs(&usermem.IOReadWriter{
+ s := target.Arch().FloatingPointData()
+ _, err := target.Arch().FloatingPointData().PtraceGetFPRegs(&usermem.IOReadWriter{
Ctx: t,
IO: t.MemoryManager(),
Addr: data,
Opts: usermem.IOOpts{
AddressSpaceActive: true,
},
- })
+ }, len(*s))
return err
case linux.PTRACE_SETREGS:
@@ -73,14 +74,15 @@ func (t *Task) ptraceArch(target *Task, req int64, addr, data usermem.Addr) erro
return err
case linux.PTRACE_SETFPREGS:
- _, err := target.Arch().PtraceSetFPRegs(&usermem.IOReadWriter{
+ s := target.Arch().FloatingPointData()
+ _, err := s.PtraceSetFPRegs(&usermem.IOReadWriter{
Ctx: t,
IO: t.MemoryManager(),
Addr: data,
Opts: usermem.IOOpts{
AddressSpaceActive: true,
},
- })
+ }, len(*s))
return err
default:
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index 4f9e781af..03a76eb9b 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -50,6 +50,7 @@ go_library(
"//pkg/safecopy",
"//pkg/seccomp",
"//pkg/sentry/arch",
+ "//pkg/sentry/arch/fpu",
"//pkg/sentry/memmap",
"//pkg/sentry/platform",
"//pkg/sentry/platform/interrupt",
@@ -78,6 +79,7 @@ go_test(
"//pkg/ring0",
"//pkg/ring0/pagetables",
"//pkg/sentry/arch",
+ "//pkg/sentry/arch/fpu",
"//pkg/sentry/platform",
"//pkg/sentry/platform/kvm/testutil",
"//pkg/sentry/time",
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.go b/pkg/sentry/platform/kvm/bluepill_amd64.go
index 308696efe..d761bbdee 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64.go
@@ -73,7 +73,7 @@ func (c *vCPU) KernelSyscall() {
// We only trigger a bluepill entry in the bluepill function, and can
// therefore be guaranteed that there is no floating point state to be
// loaded on resuming from halt. We only worry about saving on exit.
- ring0.SaveFloatingPoint(&c.floatingPointState[0]) // escapes: no.
+ ring0.SaveFloatingPoint(c.floatingPointState.BytePointer()) // escapes: no.
ring0.Halt()
ring0.WriteFS(uintptr(regs.Fs_base)) // escapes: no, reload host segment.
}
@@ -92,7 +92,7 @@ func (c *vCPU) KernelException(vector ring0.Vector) {
regs.Rip = 0
}
// See above.
- ring0.SaveFloatingPoint(&c.floatingPointState[0]) // escapes: no.
+ ring0.SaveFloatingPoint(c.floatingPointState.BytePointer()) // escapes: no.
ring0.Halt()
ring0.WriteFS(uintptr(regs.Fs_base)) // escapes: no; reload host segment.
}
@@ -124,5 +124,5 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) {
// Set the context pointer to the saved floating point state. This is
// where the guest data has been serialized, the kernel will restore
// from this new pointer value.
- context.Fpstate = uint64(uintptrValue(&c.floatingPointState[0]))
+ context.Fpstate = uint64(uintptrValue(c.floatingPointState.BytePointer()))
}
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go
index c317f1e99..578852c3f 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.go
@@ -92,7 +92,7 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) {
lazyVfp := c.GetLazyVFP()
if lazyVfp != 0 {
- fpsimd := fpsimdPtr(&c.floatingPointState[0])
+ fpsimd := fpsimdPtr(c.floatingPointState.BytePointer())
context.Fpsimd64.Fpsr = fpsimd.Fpsr
context.Fpsimd64.Fpcr = fpsimd.Fpcr
context.Fpsimd64.Vregs = fpsimd.Vregs
@@ -112,12 +112,12 @@ func (c *vCPU) KernelSyscall() {
fpDisableTrap := ring0.CPACREL1()
if fpDisableTrap != 0 {
- fpsimd := fpsimdPtr(&c.floatingPointState[0])
+ fpsimd := fpsimdPtr(c.floatingPointState.BytePointer())
fpcr := ring0.GetFPCR()
fpsr := ring0.GetFPSR()
fpsimd.Fpcr = uint32(fpcr)
fpsimd.Fpsr = uint32(fpsr)
- ring0.SaveVRegs(&c.floatingPointState[0])
+ ring0.SaveVRegs(c.floatingPointState.BytePointer())
}
ring0.Halt()
@@ -136,12 +136,12 @@ func (c *vCPU) KernelException(vector ring0.Vector) {
fpDisableTrap := ring0.CPACREL1()
if fpDisableTrap != 0 {
- fpsimd := fpsimdPtr(&c.floatingPointState[0])
+ fpsimd := fpsimdPtr(c.floatingPointState.BytePointer())
fpcr := ring0.GetFPCR()
fpsr := ring0.GetFPSR()
fpsimd.Fpcr = uint32(fpcr)
fpsimd.Fpsr = uint32(fpsr)
- ring0.SaveVRegs(&c.floatingPointState[0])
+ ring0.SaveVRegs(c.floatingPointState.BytePointer())
}
ring0.Halt()
diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.go b/pkg/sentry/platform/kvm/kvm_amd64_test.go
index 76fc594a0..e44e995a0 100644
--- a/pkg/sentry/platform/kvm/kvm_amd64_test.go
+++ b/pkg/sentry/platform/kvm/kvm_amd64_test.go
@@ -33,7 +33,7 @@ func TestSegments(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
FullRestore: true,
}, &si); err == platform.ErrContextInterrupt {
diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go
index 6243b9a04..5bce16dde 100644
--- a/pkg/sentry/platform/kvm/kvm_test.go
+++ b/pkg/sentry/platform/kvm/kvm_test.go
@@ -25,13 +25,14 @@ import (
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil"
ktime "gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/usermem"
)
-var dummyFPState = (*byte)(arch.NewFloatingPointData())
+var dummyFPState = fpu.NewState()
type testHarness interface {
Errorf(format string, args ...interface{})
@@ -159,7 +160,7 @@ func TestApplicationSyscall(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
FullRestore: true,
}, &si); err == platform.ErrContextInterrupt {
@@ -173,7 +174,7 @@ func TestApplicationSyscall(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
}, &si); err == platform.ErrContextInterrupt {
return true // Retry.
@@ -190,7 +191,7 @@ func TestApplicationFault(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
FullRestore: true,
}, &si); err == platform.ErrContextInterrupt {
@@ -205,7 +206,7 @@ func TestApplicationFault(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
}, &si); err == platform.ErrContextInterrupt {
return true // Retry.
@@ -223,7 +224,7 @@ func TestRegistersSyscall(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
}, &si); err == platform.ErrContextInterrupt {
continue // Retry.
@@ -246,7 +247,7 @@ func TestRegistersFault(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
FullRestore: true,
}, &si); err == platform.ErrContextInterrupt {
@@ -272,7 +273,7 @@ func TestBounce(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
}, &si); err != platform.ErrContextInterrupt {
t.Errorf("application partial restore: got %v, wanted %v", err, platform.ErrContextInterrupt)
@@ -287,7 +288,7 @@ func TestBounce(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
FullRestore: true,
}, &si); err != platform.ErrContextInterrupt {
@@ -319,7 +320,7 @@ func TestBounceStress(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
}, &si); err != platform.ErrContextInterrupt {
t.Errorf("application partial restore: got %v, wanted %v", err, platform.ErrContextInterrupt)
@@ -340,7 +341,7 @@ func TestInvalidate(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
}, &si); err == platform.ErrContextInterrupt {
continue // Retry.
@@ -355,7 +356,7 @@ func TestInvalidate(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
Flush: true,
}, &si); err == platform.ErrContextInterrupt {
@@ -379,7 +380,7 @@ func TestEmptyAddressSpace(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
}, &si); err == platform.ErrContextInterrupt {
return true // Retry.
@@ -393,7 +394,7 @@ func TestEmptyAddressSpace(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
FullRestore: true,
}, &si); err == platform.ErrContextInterrupt {
@@ -469,7 +470,7 @@ func BenchmarkApplicationSyscall(b *testing.B) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
}, &si); err == platform.ErrContextInterrupt {
a++
@@ -506,7 +507,7 @@ func BenchmarkWorldSwitchToUserRoundtrip(b *testing.B) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
}, &si); err == platform.ErrContextInterrupt {
a++
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index 916903881..3af96c7e5 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -27,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
ktime "gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/usermem"
@@ -70,7 +71,7 @@ type vCPUArchState struct {
// floatingPointState is the floating point state buffer used in guest
// to host transitions. See usage in bluepill_amd64.go.
- floatingPointState arch.FloatingPointData
+ floatingPointState fpu.State
}
const (
@@ -151,7 +152,7 @@ func (c *vCPU) initArchState() error {
// This will be saved prior to leaving the guest, and we restore from
// this always. We cannot use the pointer in the context alone because
// we don't know how large the area there is in reality.
- c.floatingPointState = arch.NewFloatingPointData()
+ c.floatingPointState = fpu.NewState()
// Set the time offset to the host native time.
return c.setSystemTime()
@@ -307,12 +308,12 @@ func loadByte(ptr *byte) byte {
// emulate instructions like xsave and xrstor.
//
//go:nosplit
-func prefaultFloatingPointState(data arch.FloatingPointData) {
- size := len(data)
+func prefaultFloatingPointState(data *fpu.State) {
+ size := len(*data)
for i := 0; i < size; i += usermem.PageSize {
- loadByte(&(data)[i])
+ loadByte(&(*data)[i])
}
- loadByte(&(data)[size-1])
+ loadByte(&(*data)[size-1])
}
// SwitchToUser unpacks architectural-details.
diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go
index 3d715e570..2edc9d1b2 100644
--- a/pkg/sentry/platform/kvm/machine_arm64.go
+++ b/pkg/sentry/platform/kvm/machine_arm64.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -32,7 +33,7 @@ type vCPUArchState struct {
// floatingPointState is the floating point state buffer used in guest
// to host transitions. See usage in bluepill_arm64.go.
- floatingPointState arch.FloatingPointData
+ floatingPointState fpu.State
}
const (
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index 059aa43d0..e7d5f3193 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -150,7 +151,7 @@ func (c *vCPU) initArchState() error {
c.PCIDs = pagetables.NewPCIDs(fixedKernelPCID+1, poolPCIDs)
}
- c.floatingPointState = arch.NewFloatingPointData()
+ c.floatingPointState = fpu.NewState()
return c.setSystemTime()
}
diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD
index fc43cc3c0..47efde6a2 100644
--- a/pkg/sentry/platform/ptrace/BUILD
+++ b/pkg/sentry/platform/ptrace/BUILD
@@ -30,6 +30,7 @@ go_library(
"//pkg/safecopy",
"//pkg/seccomp",
"//pkg/sentry/arch",
+ "//pkg/sentry/arch/fpu",
"//pkg/sentry/memmap",
"//pkg/sentry/platform",
"//pkg/sentry/platform/interrupt",
diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
index 6259350ec..01e73b019 100644
--- a/pkg/sentry/platform/ptrace/ptrace_unsafe.go
+++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
@@ -20,6 +20,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -62,9 +63,9 @@ func (t *thread) setRegs(regs *arch.Registers) error {
}
// getFPRegs gets the floating-point data via the GETREGSET ptrace unix.
-func (t *thread) getFPRegs(fpState arch.FloatingPointData, fpLen uint64, useXsave bool) error {
+func (t *thread) getFPRegs(fpState *fpu.State, fpLen uint64, useXsave bool) error {
iovec := unix.Iovec{
- Base: (*byte)(&fpState[0]),
+ Base: fpState.BytePointer(),
Len: fpLen,
}
_, _, errno := unix.RawSyscall6(
@@ -81,9 +82,9 @@ func (t *thread) getFPRegs(fpState arch.FloatingPointData, fpLen uint64, useXsav
}
// setFPRegs sets the floating-point data via the SETREGSET ptrace unix.
-func (t *thread) setFPRegs(fpState arch.FloatingPointData, fpLen uint64, useXsave bool) error {
+func (t *thread) setFPRegs(fpState *fpu.State, fpLen uint64, useXsave bool) error {
iovec := unix.Iovec{
- Base: (*byte)(&fpState[0]),
+ Base: fpState.BytePointer(),
Len: fpLen,
}
_, _, errno := unix.RawSyscall6(
diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go
index 5bd526b73..efec93f73 100644
--- a/pkg/sentry/syscalls/linux/error.go
+++ b/pkg/sentry/syscalls/linux/error.go
@@ -75,17 +75,25 @@ func handleIOError(ctx context.Context, partialResult bool, ioerr, intr error, o
// errors, we may consume the error and return only the partial read/write.
//
// Returns false if error is unknown.
-func handleIOErrorImpl(ctx context.Context, partialResult bool, err, intr error, op string) (bool, error) {
- switch err {
- case nil:
+func handleIOErrorImpl(ctx context.Context, partialResult bool, errOrig, intr error, op string) (bool, error) {
+ if errOrig == nil {
// Typical successful syscall.
return true, nil
+ }
+
+ // Translate error, if possible, to consolidate errors from other packages
+ // into a smaller set of errors from syserror package.
+ translatedErr := errOrig
+ if errno, ok := syserror.TranslateError(errOrig); ok {
+ translatedErr = errno
+ }
+ switch translatedErr {
case io.EOF:
// EOF is always consumed. If this is a partial read/write
// (result != 0), the application will see that, otherwise
// they will see 0.
return true, nil
- case syserror.ErrExceedsFileSizeLimit:
+ case syserror.EFBIG:
t := kernel.TaskFromContext(ctx)
if t == nil {
panic("I/O error should only occur from a context associated with a Task")
@@ -98,7 +106,7 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, err, intr error,
// Simultaneously send a SIGXFSZ per setrlimit(2).
t.SendSignal(kernel.SignalInfoNoInfo(linux.SIGXFSZ, t, t))
return true, syserror.EFBIG
- case syserror.ErrInterrupted:
+ case syserror.EINTR:
// The syscall was interrupted. Return nil if it completed
// partially, otherwise return the error code that the syscall
// needs (to indicate to the kernel what it should do).
@@ -110,10 +118,10 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, err, intr error,
if !partialResult {
// Typical syscall error.
- return true, err
+ return true, errOrig
}
- switch err {
+ switch translatedErr {
case syserror.EINTR:
// Syscall interrupted, but completed a partial
// read/write. Like ErrWouldBlock, since we have a
@@ -143,7 +151,7 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, err, intr error,
// For TCP sendfile connections, we may have a reset or timeout. But we
// should just return n as the result.
return true, nil
- case syserror.ErrWouldBlock:
+ case syserror.EWOULDBLOCK:
// Syscall would block, but completed a partial read/write.
// This case should only be returned by IssueIO for nonblocking
// files. Since we have a partial read/write, we consume
@@ -151,7 +159,7 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, err, intr error,
return true, nil
}
- switch err.(type) {
+ switch errOrig.(type) {
case syserror.SyscallRestartErrno:
// Identical to the EINTR case.
return true, nil
diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go
index 97de17afe..56b621357 100644
--- a/pkg/syserror/syserror.go
+++ b/pkg/syserror/syserror.go
@@ -130,17 +130,15 @@ func AddErrorUnwrapper(unwrap func(e error) (unix.Errno, bool)) {
// TranslateError translates errors to errnos, it will return false if
// the error was not registered.
func TranslateError(from error) (unix.Errno, bool) {
- err, ok := errorMap[from]
- if ok {
- return err, ok
+ if err, ok := errorMap[from]; ok {
+ return err, true
}
// Try to unwrap the error if we couldn't match an error
// exactly. This might mean that a package has its own
// error type.
for _, unwrap := range errorUnwrappers {
- err, ok := unwrap(from)
- if ok {
- return err, ok
+ if err, ok := unwrap(from); ok {
+ return err, true
}
}
return 0, false
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index fc622b246..fef065b05 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -1287,6 +1287,13 @@ func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption
} else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
}
+ case header.NDPNonceOption:
+ gotOpt, ok := opt.(header.NDPNonceOption)
+ if !ok {
+ t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
+ } else if diff := cmp.Diff(wantOpt.Nonce(), gotOpt.Nonce()); diff != "" {
+ t.Errorf("nonce mismatch (-want +got):\n%s", diff)
+ }
default:
t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt)
}
diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go
index 554242f0c..3d1bccd15 100644
--- a/pkg/tcpip/header/ndp_options.go
+++ b/pkg/tcpip/header/ndp_options.go
@@ -26,29 +26,33 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
-// NDPOptionIdentifier is an NDP option type identifier.
-type NDPOptionIdentifier uint8
+// ndpOptionIdentifier is an NDP option type identifier.
+type ndpOptionIdentifier uint8
const (
- // NDPSourceLinkLayerAddressOptionType is the type of the Source Link Layer
+ // ndpSourceLinkLayerAddressOptionType is the type of the Source Link Layer
// Address option, as per RFC 4861 section 4.6.1.
- NDPSourceLinkLayerAddressOptionType NDPOptionIdentifier = 1
+ ndpSourceLinkLayerAddressOptionType ndpOptionIdentifier = 1
- // NDPTargetLinkLayerAddressOptionType is the type of the Target Link Layer
+ // ndpTargetLinkLayerAddressOptionType is the type of the Target Link Layer
// Address option, as per RFC 4861 section 4.6.1.
- NDPTargetLinkLayerAddressOptionType NDPOptionIdentifier = 2
+ ndpTargetLinkLayerAddressOptionType ndpOptionIdentifier = 2
- // NDPPrefixInformationType is the type of the Prefix Information
+ // ndpPrefixInformationType is the type of the Prefix Information
// option, as per RFC 4861 section 4.6.2.
- NDPPrefixInformationType NDPOptionIdentifier = 3
+ ndpPrefixInformationType ndpOptionIdentifier = 3
- // NDPRecursiveDNSServerOptionType is the type of the Recursive DNS
+ // ndpNonceOptionType is the type of the Nonce option, as per
+ // RFC 3971 section 5.3.2.
+ ndpNonceOptionType ndpOptionIdentifier = 14
+
+ // ndpRecursiveDNSServerOptionType is the type of the Recursive DNS
// Server option, as per RFC 8106 section 5.1.
- NDPRecursiveDNSServerOptionType NDPOptionIdentifier = 25
+ ndpRecursiveDNSServerOptionType ndpOptionIdentifier = 25
- // NDPDNSSearchListOptionType is the type of the DNS Search List option,
+ // ndpDNSSearchListOptionType is the type of the DNS Search List option,
// as per RFC 8106 section 5.2.
- NDPDNSSearchListOptionType = 31
+ ndpDNSSearchListOptionType ndpOptionIdentifier = 31
)
const (
@@ -198,7 +202,7 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) {
// bytes for the whole option.
return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Type field: %w", io.ErrUnexpectedEOF)
}
- kind := NDPOptionIdentifier(temp)
+ kind := ndpOptionIdentifier(temp)
// Get the Length field.
length, err := i.opts.ReadByte()
@@ -225,13 +229,16 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) {
}
switch kind {
- case NDPSourceLinkLayerAddressOptionType:
+ case ndpSourceLinkLayerAddressOptionType:
return NDPSourceLinkLayerAddressOption(body), false, nil
- case NDPTargetLinkLayerAddressOptionType:
+ case ndpTargetLinkLayerAddressOptionType:
return NDPTargetLinkLayerAddressOption(body), false, nil
- case NDPPrefixInformationType:
+ case ndpNonceOptionType:
+ return NDPNonceOption(body), false, nil
+
+ case ndpPrefixInformationType:
// Make sure the length of a Prefix Information option
// body is ndpPrefixInformationLength, as per RFC 4861
// section 4.6.2.
@@ -241,7 +248,7 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) {
return NDPPrefixInformation(body), false, nil
- case NDPRecursiveDNSServerOptionType:
+ case ndpRecursiveDNSServerOptionType:
opt := NDPRecursiveDNSServer(body)
if err := opt.checkAddresses(); err != nil {
return nil, true, err
@@ -249,7 +256,7 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) {
return opt, false, nil
- case NDPDNSSearchListOptionType:
+ case ndpDNSSearchListOptionType:
opt := NDPDNSSearchList(body)
if err := opt.checkDomainNames(); err != nil {
return nil, true, err
@@ -316,7 +323,7 @@ func (b NDPOptions) Serialize(s NDPOptionsSerializer) int {
continue
}
- b[0] = byte(o.Type())
+ b[0] = byte(o.kind())
// We know this safe because paddedLength would have returned
// 0 if o had an invalid length (> 255 * lengthByteUnits).
@@ -341,11 +348,11 @@ func (b NDPOptions) Serialize(s NDPOptionsSerializer) int {
type NDPOption interface {
fmt.Stringer
- // Type returns the type of the receiver.
- Type() NDPOptionIdentifier
+ // kind returns the type of the receiver.
+ kind() ndpOptionIdentifier
- // Length returns the length of the body of the receiver, in bytes.
- Length() int
+ // length returns the length of the body of the receiver, in bytes.
+ length() int
// serializeInto serializes the receiver into the provided byte
// buffer.
@@ -365,7 +372,7 @@ type NDPOption interface {
// paddedLength returns the length of o, in bytes, with any padding bytes, if
// required.
func paddedLength(o NDPOption) int {
- l := o.Length()
+ l := o.length()
if l == 0 {
return 0
@@ -416,6 +423,37 @@ func (b NDPOptionsSerializer) Length() int {
return l
}
+// NDPNonceOption is the NDP Nonce Option as defined by RFC 3971 section 5.3.2.
+//
+// It is the first X bytes following the NDP option's Type and Length field
+// where X is the value in Length multiplied by lengthByteUnits - 2 bytes.
+type NDPNonceOption []byte
+
+// kind implements NDPOption.
+func (o NDPNonceOption) kind() ndpOptionIdentifier {
+ return ndpNonceOptionType
+}
+
+// length implements NDPOption.
+func (o NDPNonceOption) length() int {
+ return len(o)
+}
+
+// serializeInto implements NDPOption.
+func (o NDPNonceOption) serializeInto(b []byte) int {
+ return copy(b, o)
+}
+
+// String implements fmt.Stringer.
+func (o NDPNonceOption) String() string {
+ return fmt.Sprintf("%T(%x)", o, []byte(o))
+}
+
+// Nonce returns the nonce value this option holds.
+func (o NDPNonceOption) Nonce() []byte {
+ return []byte(o)
+}
+
// NDPSourceLinkLayerAddressOption is the NDP Source Link Layer Option
// as defined by RFC 4861 section 4.6.1.
//
@@ -423,22 +461,22 @@ func (b NDPOptionsSerializer) Length() int {
// where X is the value in Length multiplied by lengthByteUnits - 2 bytes.
type NDPSourceLinkLayerAddressOption tcpip.LinkAddress
-// Type implements NDPOption.Type.
-func (o NDPSourceLinkLayerAddressOption) Type() NDPOptionIdentifier {
- return NDPSourceLinkLayerAddressOptionType
+// kind implements NDPOption.
+func (o NDPSourceLinkLayerAddressOption) kind() ndpOptionIdentifier {
+ return ndpSourceLinkLayerAddressOptionType
}
-// Length implements NDPOption.Length.
-func (o NDPSourceLinkLayerAddressOption) Length() int {
+// length implements NDPOption.
+func (o NDPSourceLinkLayerAddressOption) length() int {
return len(o)
}
-// serializeInto implements NDPOption.serializeInto.
+// serializeInto implements NDPOption.
func (o NDPSourceLinkLayerAddressOption) serializeInto(b []byte) int {
return copy(b, o)
}
-// String implements fmt.Stringer.String.
+// String implements fmt.Stringer.
func (o NDPSourceLinkLayerAddressOption) String() string {
return fmt.Sprintf("%T(%s)", o, tcpip.LinkAddress(o))
}
@@ -463,22 +501,22 @@ func (o NDPSourceLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress {
// where X is the value in Length multiplied by lengthByteUnits - 2 bytes.
type NDPTargetLinkLayerAddressOption tcpip.LinkAddress
-// Type implements NDPOption.Type.
-func (o NDPTargetLinkLayerAddressOption) Type() NDPOptionIdentifier {
- return NDPTargetLinkLayerAddressOptionType
+// kind implements NDPOption.
+func (o NDPTargetLinkLayerAddressOption) kind() ndpOptionIdentifier {
+ return ndpTargetLinkLayerAddressOptionType
}
-// Length implements NDPOption.Length.
-func (o NDPTargetLinkLayerAddressOption) Length() int {
+// length implements NDPOption.
+func (o NDPTargetLinkLayerAddressOption) length() int {
return len(o)
}
-// serializeInto implements NDPOption.serializeInto.
+// serializeInto implements NDPOption.
func (o NDPTargetLinkLayerAddressOption) serializeInto(b []byte) int {
return copy(b, o)
}
-// String implements fmt.Stringer.String.
+// String implements fmt.Stringer.
func (o NDPTargetLinkLayerAddressOption) String() string {
return fmt.Sprintf("%T(%s)", o, tcpip.LinkAddress(o))
}
@@ -503,17 +541,17 @@ func (o NDPTargetLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress {
// ndpPrefixInformationLength bytes.
type NDPPrefixInformation []byte
-// Type implements NDPOption.Type.
-func (o NDPPrefixInformation) Type() NDPOptionIdentifier {
- return NDPPrefixInformationType
+// kind implements NDPOption.
+func (o NDPPrefixInformation) kind() ndpOptionIdentifier {
+ return ndpPrefixInformationType
}
-// Length implements NDPOption.Length.
-func (o NDPPrefixInformation) Length() int {
+// length implements NDPOption.
+func (o NDPPrefixInformation) length() int {
return ndpPrefixInformationLength
}
-// serializeInto implements NDPOption.serializeInto.
+// serializeInto implements NDPOption.
func (o NDPPrefixInformation) serializeInto(b []byte) int {
used := copy(b, o)
@@ -529,7 +567,7 @@ func (o NDPPrefixInformation) serializeInto(b []byte) int {
return used
}
-// String implements fmt.Stringer.String.
+// String implements fmt.Stringer.
func (o NDPPrefixInformation) String() string {
return fmt.Sprintf("%T(O=%t, A=%t, PL=%s, VL=%s, Prefix=%s)",
o,
@@ -627,17 +665,17 @@ type NDPRecursiveDNSServer []byte
// Type returns the type of an NDP Recursive DNS Server option.
//
-// Type implements NDPOption.Type.
-func (NDPRecursiveDNSServer) Type() NDPOptionIdentifier {
- return NDPRecursiveDNSServerOptionType
+// kind implements NDPOption.
+func (NDPRecursiveDNSServer) kind() ndpOptionIdentifier {
+ return ndpRecursiveDNSServerOptionType
}
-// Length implements NDPOption.Length.
-func (o NDPRecursiveDNSServer) Length() int {
+// length implements NDPOption.
+func (o NDPRecursiveDNSServer) length() int {
return len(o)
}
-// serializeInto implements NDPOption.serializeInto.
+// serializeInto implements NDPOption.
func (o NDPRecursiveDNSServer) serializeInto(b []byte) int {
used := copy(b, o)
@@ -649,7 +687,7 @@ func (o NDPRecursiveDNSServer) serializeInto(b []byte) int {
return used
}
-// String implements fmt.Stringer.String.
+// String implements fmt.Stringer.
func (o NDPRecursiveDNSServer) String() string {
lt := o.Lifetime()
addrs, err := o.Addresses()
@@ -722,17 +760,17 @@ func (o NDPRecursiveDNSServer) iterAddresses(fn func(tcpip.Address)) error {
// RFC 8106 section 5.2.
type NDPDNSSearchList []byte
-// Type implements NDPOption.Type.
-func (o NDPDNSSearchList) Type() NDPOptionIdentifier {
- return NDPDNSSearchListOptionType
+// kind implements NDPOption.
+func (o NDPDNSSearchList) kind() ndpOptionIdentifier {
+ return ndpDNSSearchListOptionType
}
-// Length implements NDPOption.Length.
-func (o NDPDNSSearchList) Length() int {
+// length implements NDPOption.
+func (o NDPDNSSearchList) length() int {
return len(o)
}
-// serializeInto implements NDPOption.serializeInto.
+// serializeInto implements NDPOption.
func (o NDPDNSSearchList) serializeInto(b []byte) int {
used := copy(b, o)
@@ -744,7 +782,7 @@ func (o NDPDNSSearchList) serializeInto(b []byte) int {
return used
}
-// String implements fmt.Stringer.String.
+// String implements fmt.Stringer.
func (o NDPDNSSearchList) String() string {
lt := o.Lifetime()
domainNames, err := o.DomainNames()
diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go
index dc4591253..d0a1a2492 100644
--- a/pkg/tcpip/header/ndp_test.go
+++ b/pkg/tcpip/header/ndp_test.go
@@ -16,6 +16,7 @@ package header
import (
"bytes"
+ "encoding/binary"
"errors"
"fmt"
"io"
@@ -192,90 +193,6 @@ func TestNDPSourceLinkLayerAddressOptionEthernetAddress(t *testing.T) {
}
}
-// TestNDPSourceLinkLayerAddressOptionSerialize tests serializing a
-// NDPSourceLinkLayerAddressOption.
-func TestNDPSourceLinkLayerAddressOptionSerialize(t *testing.T) {
- tests := []struct {
- name string
- buf []byte
- expectedBuf []byte
- addr tcpip.LinkAddress
- }{
- {
- "Ethernet",
- make([]byte, 8),
- []byte{1, 1, 1, 2, 3, 4, 5, 6},
- "\x01\x02\x03\x04\x05\x06",
- },
- {
- "Padding",
- []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
- []byte{1, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0},
- "\x01\x02\x03\x04\x05\x06\x07\x08",
- },
- {
- "Empty",
- nil,
- nil,
- "",
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- opts := NDPOptions(test.buf)
- serializer := NDPOptionsSerializer{
- NDPSourceLinkLayerAddressOption(test.addr),
- }
- if got, want := int(serializer.Length()), len(test.expectedBuf); got != want {
- t.Fatalf("got Length = %d, want = %d", got, want)
- }
- opts.Serialize(serializer)
- if !bytes.Equal(test.buf, test.expectedBuf) {
- t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf)
- }
-
- it, err := opts.Iter(true)
- if err != nil {
- t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
- }
-
- if len(test.expectedBuf) > 0 {
- next, done, err := it.Next()
- if err != nil {
- t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if done {
- t.Fatal("got Next = (_, true, _), want = (_, false, _)")
- }
- if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType {
- t.Fatalf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType)
- }
- sll := next.(NDPSourceLinkLayerAddressOption)
- if got, want := []byte(sll), test.expectedBuf[2:]; !bytes.Equal(got, want) {
- t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
- }
-
- if got, want := sll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want {
- t.Errorf("got sll.EthernetAddress = %s, want = %s", got, want)
- }
- }
-
- // Iterator should not return anything else.
- next, done, err := it.Next()
- if err != nil {
- t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if !done {
- t.Error("got Next = (_, false, _), want = (_, true, _)")
- }
- if next != nil {
- t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
- }
- })
- }
-}
-
// TestNDPTargetLinkLayerAddressOptionEthernetAddress tests getting the
// Ethernet address from an NDPTargetLinkLayerAddressOption.
func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) {
@@ -311,32 +228,309 @@ func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) {
}
}
-// TestNDPTargetLinkLayerAddressOptionSerialize tests serializing a
-// NDPTargetLinkLayerAddressOption.
-func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) {
+func TestOpts(t *testing.T) {
+ const optionHeaderLen = 2
+
+ checkNonce := func(expectedNonce []byte) func(*testing.T, NDPOption) {
+ return func(t *testing.T, opt NDPOption) {
+ if got := opt.kind(); got != ndpNonceOptionType {
+ t.Errorf("got kind() = %d, want = %d", got, ndpNonceOptionType)
+ }
+ nonce, ok := opt.(NDPNonceOption)
+ if !ok {
+ t.Fatalf("got nonce = %T, want = NDPNonceOption", opt)
+ }
+ if diff := cmp.Diff(expectedNonce, nonce.Nonce()); diff != "" {
+ t.Errorf("nonce mismatch (-want +got):\n%s", diff)
+ }
+ }
+ }
+
+ checkTLL := func(expectedAddr tcpip.LinkAddress) func(*testing.T, NDPOption) {
+ return func(t *testing.T, opt NDPOption) {
+ if got := opt.kind(); got != ndpTargetLinkLayerAddressOptionType {
+ t.Errorf("got kind() = %d, want = %d", got, ndpTargetLinkLayerAddressOptionType)
+ }
+ tll, ok := opt.(NDPTargetLinkLayerAddressOption)
+ if !ok {
+ t.Fatalf("got tll = %T, want = NDPTargetLinkLayerAddressOption", opt)
+ }
+ if got, want := tll.EthernetAddress(), expectedAddr; got != want {
+ t.Errorf("got tll.EthernetAddress = %s, want = %s", got, want)
+ }
+ }
+ }
+
+ checkSLL := func(expectedAddr tcpip.LinkAddress) func(*testing.T, NDPOption) {
+ return func(t *testing.T, opt NDPOption) {
+ if got := opt.kind(); got != ndpSourceLinkLayerAddressOptionType {
+ t.Errorf("got kind() = %d, want = %d", got, ndpSourceLinkLayerAddressOptionType)
+ }
+ sll, ok := opt.(NDPSourceLinkLayerAddressOption)
+ if !ok {
+ t.Fatalf("got sll = %T, want = NDPSourceLinkLayerAddressOption", opt)
+ }
+ if got, want := sll.EthernetAddress(), expectedAddr; got != want {
+ t.Errorf("got sll.EthernetAddress = %s, want = %s", got, want)
+ }
+ }
+ }
+
+ const validLifetimeSeconds = 16909060
+ const address = tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18")
+
+ expectedRDNSSBytes := [...]byte{
+ // Type, Length
+ 25, 3,
+
+ // Reserved
+ 0, 0,
+
+ // Lifetime
+ 1, 2, 4, 8,
+
+ // Address
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ }
+ binary.BigEndian.PutUint32(expectedRDNSSBytes[4:], validLifetimeSeconds)
+ if n := copy(expectedRDNSSBytes[8:], address); n != IPv6AddressSize {
+ t.Fatalf("got copy(...) = %d, want = %d", n, IPv6AddressSize)
+ }
+ // Update reserved fields to non zero values to make sure serializing sets
+ // them to zero.
+ rdnssBytes := expectedRDNSSBytes
+ rdnssBytes[1] = 1
+ rdnssBytes[2] = 2
+
+ const searchListPaddingBytes = 3
+ const domainName = "abc.abcd.e"
+ expectedSearchListBytes := [...]byte{
+ // Type, Length
+ 31, 3,
+
+ // Reserved
+ 0, 0,
+
+ // Lifetime
+ 1, 0, 0, 0,
+
+ // Domain names
+ 3, 'a', 'b', 'c',
+ 4, 'a', 'b', 'c', 'd',
+ 1, 'e',
+ 0,
+ 0, 0, 0, 0,
+ }
+ binary.BigEndian.PutUint32(expectedSearchListBytes[4:], validLifetimeSeconds)
+ // Update reserved fields to non zero values to make sure serializing sets
+ // them to zero.
+ searchListBytes := expectedSearchListBytes
+ searchListBytes[2] = 1
+ searchListBytes[3] = 2
+
+ const prefixLength = 43
+ const onLinkFlag = false
+ const slaacFlag = true
+ const preferredLifetimeSeconds = 84281096
+ const onLinkFlagBit = 7
+ const slaacFlagBit = 6
+ boolToByte := func(v bool) byte {
+ if v {
+ return 1
+ }
+ return 0
+ }
+ flags := boolToByte(onLinkFlag)<<onLinkFlagBit | boolToByte(slaacFlag)<<slaacFlagBit
+ expectedPrefixInformationBytes := [...]byte{
+ // Type, Length
+ 3, 4,
+
+ prefixLength, flags,
+
+ // Valid Lifetime
+ 1, 2, 3, 4,
+
+ // Preferred Lifetime
+ 5, 6, 7, 8,
+
+ // Reserved2
+ 0, 0, 0, 0,
+
+ // Address
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 21, 22, 23, 24,
+ }
+ binary.BigEndian.PutUint32(expectedPrefixInformationBytes[4:], validLifetimeSeconds)
+ binary.BigEndian.PutUint32(expectedPrefixInformationBytes[8:], preferredLifetimeSeconds)
+ if n := copy(expectedPrefixInformationBytes[16:], address); n != IPv6AddressSize {
+ t.Fatalf("got copy(...) = %d, want = %d", n, IPv6AddressSize)
+ }
+ // Update reserved fields to non zero values to make sure serializing sets
+ // them to zero.
+ prefixInformationBytes := expectedPrefixInformationBytes
+ prefixInformationBytes[3] |= (1 << slaacFlagBit) - 1
+ binary.BigEndian.PutUint32(prefixInformationBytes[12:], validLifetimeSeconds+1)
tests := []struct {
name string
buf []byte
+ opt NDPOption
expectedBuf []byte
- addr tcpip.LinkAddress
+ check func(*testing.T, NDPOption)
}{
{
- "Ethernet",
- make([]byte, 8),
- []byte{2, 1, 1, 2, 3, 4, 5, 6},
- "\x01\x02\x03\x04\x05\x06",
+ name: "Nonce",
+ buf: make([]byte, 8),
+ opt: NDPNonceOption([]byte{1, 2, 3, 4, 5, 6}),
+ expectedBuf: []byte{14, 1, 1, 2, 3, 4, 5, 6},
+ check: checkNonce([]byte{1, 2, 3, 4, 5, 6}),
+ },
+ {
+ name: "Nonce with padding",
+ buf: []byte{1, 1, 1, 1, 1, 1, 1, 1},
+ opt: NDPNonceOption([]byte{1, 2, 3, 4, 5}),
+ expectedBuf: []byte{14, 1, 1, 2, 3, 4, 5, 0},
+ check: checkNonce([]byte{1, 2, 3, 4, 5, 0}),
+ },
+
+ {
+ name: "TLL Ethernet",
+ buf: make([]byte, 8),
+ opt: NDPTargetLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06"),
+ expectedBuf: []byte{2, 1, 1, 2, 3, 4, 5, 6},
+ check: checkTLL("\x01\x02\x03\x04\x05\x06"),
+ },
+ {
+ name: "TLL Padding",
+ buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
+ opt: NDPTargetLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06\x07\x08"),
+ expectedBuf: []byte{2, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0},
+ check: checkTLL("\x01\x02\x03\x04\x05\x06"),
+ },
+ {
+ name: "TLL Empty",
+ buf: nil,
+ opt: NDPTargetLinkLayerAddressOption(""),
+ expectedBuf: nil,
},
+
{
- "Padding",
- []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
- []byte{2, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0},
- "\x01\x02\x03\x04\x05\x06\x07\x08",
+ name: "SLL Ethernet",
+ buf: make([]byte, 8),
+ opt: NDPSourceLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06"),
+ expectedBuf: []byte{1, 1, 1, 2, 3, 4, 5, 6},
+ check: checkSLL("\x01\x02\x03\x04\x05\x06"),
},
{
- "Empty",
- nil,
- nil,
- "",
+ name: "SLL Padding",
+ buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
+ opt: NDPSourceLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06\x07\x08"),
+ expectedBuf: []byte{1, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0},
+ check: checkSLL("\x01\x02\x03\x04\x05\x06"),
+ },
+ {
+ name: "SLL Empty",
+ buf: nil,
+ opt: NDPSourceLinkLayerAddressOption(""),
+ expectedBuf: nil,
+ },
+
+ {
+ name: "RDNSS",
+ buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
+ // NDPRecursiveDNSServer holds the option after the header bytes.
+ opt: NDPRecursiveDNSServer(rdnssBytes[optionHeaderLen:]),
+ expectedBuf: expectedRDNSSBytes[:],
+ check: func(t *testing.T, opt NDPOption) {
+ if got := opt.kind(); got != ndpRecursiveDNSServerOptionType {
+ t.Errorf("got kind() = %d, want = %d", got, ndpRecursiveDNSServerOptionType)
+ }
+ rdnss, ok := opt.(NDPRecursiveDNSServer)
+ if !ok {
+ t.Fatalf("got opt = %T, want = NDPRecursiveDNSServer", opt)
+ }
+ if got, want := rdnss.length(), len(expectedRDNSSBytes[optionHeaderLen:]); got != want {
+ t.Errorf("got length() = %d, want = %d", got, want)
+ }
+ if got, want := rdnss.Lifetime(), validLifetimeSeconds*time.Second; got != want {
+ t.Errorf("got Lifetime() = %s, want = %s", got, want)
+ }
+ if addrs, err := rdnss.Addresses(); err != nil {
+ t.Errorf("Addresses(): %s", err)
+ } else if diff := cmp.Diff([]tcpip.Address{address}, addrs); diff != "" {
+ t.Errorf("mismatched addresses (-want +got):\n%s", diff)
+ }
+ },
+ },
+
+ {
+ name: "Search list",
+ buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
+ opt: NDPDNSSearchList(searchListBytes[optionHeaderLen:]),
+ expectedBuf: expectedSearchListBytes[:],
+ check: func(t *testing.T, opt NDPOption) {
+ if got := opt.kind(); got != ndpDNSSearchListOptionType {
+ t.Errorf("got kind() = %d, want = %d", got, ndpDNSSearchListOptionType)
+ }
+
+ dnssl, ok := opt.(NDPDNSSearchList)
+ if !ok {
+ t.Fatalf("got opt = %T, want = NDPDNSSearchList", opt)
+ }
+ if got, want := dnssl.length(), len(expectedRDNSSBytes[optionHeaderLen:]); got != want {
+ t.Errorf("got length() = %d, want = %d", got, want)
+ }
+ if got, want := dnssl.Lifetime(), validLifetimeSeconds*time.Second; got != want {
+ t.Errorf("got Lifetime() = %s, want = %s", got, want)
+ }
+
+ if domainNames, err := dnssl.DomainNames(); err != nil {
+ t.Errorf("DomainNames(): %s", err)
+ } else if diff := cmp.Diff([]string{domainName}, domainNames); diff != "" {
+ t.Errorf("domain names mismatch (-want +got):\n%s", diff)
+ }
+ },
+ },
+
+ {
+ name: "Prefix Information",
+ buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
+ // NDPPrefixInformation holds the option after the header bytes.
+ opt: NDPPrefixInformation(prefixInformationBytes[optionHeaderLen:]),
+ expectedBuf: expectedPrefixInformationBytes[:],
+ check: func(t *testing.T, opt NDPOption) {
+ if got := opt.kind(); got != ndpPrefixInformationType {
+ t.Errorf("got kind() = %d, want = %d", got, ndpPrefixInformationType)
+ }
+
+ pi, ok := opt.(NDPPrefixInformation)
+ if !ok {
+ t.Fatalf("got opt = %T, want = NDPPrefixInformation", opt)
+ }
+
+ if got, want := pi.length(), len(expectedPrefixInformationBytes[optionHeaderLen:]); got != want {
+ t.Errorf("got length() = %d, want = %d", got, want)
+ }
+ if got := pi.PrefixLength(); got != prefixLength {
+ t.Errorf("got PrefixLength() = %d, want = %d", got, prefixLength)
+ }
+ if got := pi.OnLinkFlag(); got != onLinkFlag {
+ t.Errorf("got OnLinkFlag() = %t, want = %t", got, onLinkFlag)
+ }
+ if got := pi.AutonomousAddressConfigurationFlag(); got != slaacFlag {
+ t.Errorf("got AutonomousAddressConfigurationFlag() = %t, want = %t", got, slaacFlag)
+ }
+ if got, want := pi.ValidLifetime(), validLifetimeSeconds*time.Second; got != want {
+ t.Errorf("got ValidLifetime() = %s, want = %s", got, want)
+ }
+ if got, want := pi.PreferredLifetime(), preferredLifetimeSeconds*time.Second; got != want {
+ t.Errorf("got PreferredLifetime() = %s, want = %s", got, want)
+ }
+ if got := pi.Prefix(); got != address {
+ t.Errorf("got Prefix() = %s, want = %s", got, address)
+ }
+ },
},
}
@@ -344,230 +538,47 @@ func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
opts := NDPOptions(test.buf)
serializer := NDPOptionsSerializer{
- NDPTargetLinkLayerAddressOption(test.addr),
+ test.opt,
}
if got, want := int(serializer.Length()), len(test.expectedBuf); got != want {
- t.Fatalf("got Length = %d, want = %d", got, want)
+ t.Fatalf("got Length() = %d, want = %d", got, want)
}
opts.Serialize(serializer)
- if !bytes.Equal(test.buf, test.expectedBuf) {
- t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf)
+ if diff := cmp.Diff(test.expectedBuf, test.buf); diff != "" {
+ t.Fatalf("serialized buffer mismatch (-want +got):\n%s", diff)
}
it, err := opts.Iter(true)
if err != nil {
- t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ t.Fatalf("got Iter(true) = (_, %s), want = (_, nil)", err)
}
if len(test.expectedBuf) > 0 {
next, done, err := it.Next()
if err != nil {
- t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ t.Fatalf("got Next() = (_, _, %s), want = (_, _, nil)", err)
}
if done {
- t.Fatal("got Next = (_, true, _), want = (_, false, _)")
- }
- if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType {
- t.Fatalf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType)
- }
- tll := next.(NDPTargetLinkLayerAddressOption)
- if got, want := []byte(tll), test.expectedBuf[2:]; !bytes.Equal(got, want) {
- t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
- }
-
- if got, want := tll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want {
- t.Errorf("got tll.EthernetAddress = %s, want = %s", got, want)
+ t.Fatal("got Next() = (_, true, _), want = (_, false, _)")
}
+ test.check(t, next)
}
// Iterator should not return anything else.
next, done, err := it.Next()
if err != nil {
- t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ t.Errorf("got Next() = (_, _, %s), want = (_, _, nil)", err)
}
if !done {
- t.Error("got Next = (_, false, _), want = (_, true, _)")
+ t.Error("got Next() = (_, false, _), want = (_, true, _)")
}
if next != nil {
- t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
+ t.Errorf("got Next() = (%x, _, _), want = (nil, _, _)", next)
}
})
}
}
-// TestNDPPrefixInformationOption tests the field getters and serialization of a
-// NDPPrefixInformation.
-func TestNDPPrefixInformationOption(t *testing.T) {
- b := []byte{
- 43, 127,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 5, 5, 5, 5,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 17, 18, 19, 20,
- 21, 22, 23, 24,
- }
-
- targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
- opts := NDPOptions(targetBuf)
- serializer := NDPOptionsSerializer{
- NDPPrefixInformation(b),
- }
- opts.Serialize(serializer)
- expectedBuf := []byte{
- 3, 4, 43, 64,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 0, 0, 0, 0,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 17, 18, 19, 20,
- 21, 22, 23, 24,
- }
- if !bytes.Equal(targetBuf, expectedBuf) {
- t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expectedBuf)
- }
-
- it, err := opts.Iter(true)
- if err != nil {
- t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
- }
-
- next, done, err := it.Next()
- if err != nil {
- t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if done {
- t.Fatal("got Next = (_, true, _), want = (_, false, _)")
- }
- if got := next.Type(); got != NDPPrefixInformationType {
- t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType)
- }
-
- pi := next.(NDPPrefixInformation)
-
- if got := pi.Type(); got != 3 {
- t.Errorf("got Type = %d, want = 3", got)
- }
-
- if got := pi.Length(); got != 30 {
- t.Errorf("got Length = %d, want = 30", got)
- }
-
- if got := pi.PrefixLength(); got != 43 {
- t.Errorf("got PrefixLength = %d, want = 43", got)
- }
-
- if pi.OnLinkFlag() {
- t.Error("got OnLinkFlag = true, want = false")
- }
-
- if !pi.AutonomousAddressConfigurationFlag() {
- t.Error("got AutonomousAddressConfigurationFlag = false, want = true")
- }
-
- if got, want := pi.ValidLifetime(), 16909060*time.Second; got != want {
- t.Errorf("got ValidLifetime = %d, want = %d", got, want)
- }
-
- if got, want := pi.PreferredLifetime(), 84281096*time.Second; got != want {
- t.Errorf("got PreferredLifetime = %d, want = %d", got, want)
- }
-
- if got, want := pi.Prefix(), tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18"); got != want {
- t.Errorf("got Prefix = %s, want = %s", got, want)
- }
-
- // Iterator should not return anything else.
- next, done, err = it.Next()
- if err != nil {
- t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if !done {
- t.Error("got Next = (_, false, _), want = (_, true, _)")
- }
- if next != nil {
- t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
- }
-}
-
-func TestNDPRecursiveDNSServerOptionSerialize(t *testing.T) {
- b := []byte{
- 9, 8,
- 1, 2, 4, 8,
- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
- }
- targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
- expected := []byte{
- 25, 3, 0, 0,
- 1, 2, 4, 8,
- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
- }
- opts := NDPOptions(targetBuf)
- serializer := NDPOptionsSerializer{
- NDPRecursiveDNSServer(b),
- }
- if got, want := opts.Serialize(serializer), len(expected); got != want {
- t.Errorf("got Serialize = %d, want = %d", got, want)
- }
- if !bytes.Equal(targetBuf, expected) {
- t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected)
- }
-
- it, err := opts.Iter(true)
- if err != nil {
- t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
- }
-
- next, done, err := it.Next()
- if err != nil {
- t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if done {
- t.Fatal("got Next = (_, true, _), want = (_, false, _)")
- }
- if got := next.Type(); got != NDPRecursiveDNSServerOptionType {
- t.Errorf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType)
- }
-
- opt, ok := next.(NDPRecursiveDNSServer)
- if !ok {
- t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next)
- }
- if got := opt.Type(); got != 25 {
- t.Errorf("got Type = %d, want = 31", got)
- }
- if got := opt.Length(); got != 22 {
- t.Errorf("got Length = %d, want = 22", got)
- }
- if got, want := opt.Lifetime(), 16909320*time.Second; got != want {
- t.Errorf("got Lifetime = %s, want = %s", got, want)
- }
- want := []tcpip.Address{
- "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
- }
- addrs, err := opt.Addresses()
- if err != nil {
- t.Errorf("opt.Addresses() = %s", err)
- }
- if diff := cmp.Diff(addrs, want); diff != "" {
- t.Errorf("mismatched addresses (-want +got):\n%s", diff)
- }
-
- // Iterator should not return anything else.
- next, done, err = it.Next()
- if err != nil {
- t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if !done {
- t.Error("got Next = (_, false, _), want = (_, true, _)")
- }
- if next != nil {
- t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
- }
-}
-
func TestNDPRecursiveDNSServerOption(t *testing.T) {
tests := []struct {
name string
@@ -635,8 +646,8 @@ func TestNDPRecursiveDNSServerOption(t *testing.T) {
if done {
t.Fatal("got Next = (_, true, _), want = (_, false, _)")
}
- if got := next.Type(); got != NDPRecursiveDNSServerOptionType {
- t.Fatalf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType)
+ if got := next.kind(); got != ndpRecursiveDNSServerOptionType {
+ t.Fatalf("got Type = %d, want = %d", got, ndpRecursiveDNSServerOptionType)
}
opt, ok := next.(NDPRecursiveDNSServer)
@@ -1060,86 +1071,6 @@ func TestNDPSearchListOptionDomainNameLabelInvalidSymbols(t *testing.T) {
}
}
-func TestNDPDNSSearchListOptionSerialize(t *testing.T) {
- b := []byte{
- 9, 8,
- 1, 0, 0, 0,
- 3, 'a', 'b', 'c',
- 4, 'a', 'b', 'c', 'd',
- 1, 'e',
- 0,
- }
- targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
- expected := []byte{
- 31, 3, 0, 0,
- 1, 0, 0, 0,
- 3, 'a', 'b', 'c',
- 4, 'a', 'b', 'c', 'd',
- 1, 'e',
- 0,
- 0, 0, 0, 0,
- }
- opts := NDPOptions(targetBuf)
- serializer := NDPOptionsSerializer{
- NDPDNSSearchList(b),
- }
- if got, want := opts.Serialize(serializer), len(expected); got != want {
- t.Errorf("got Serialize = %d, want = %d", got, want)
- }
- if !bytes.Equal(targetBuf, expected) {
- t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected)
- }
-
- it, err := opts.Iter(true)
- if err != nil {
- t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
- }
-
- next, done, err := it.Next()
- if err != nil {
- t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if done {
- t.Fatal("got Next = (_, true, _), want = (_, false, _)")
- }
- if got := next.Type(); got != NDPDNSSearchListOptionType {
- t.Errorf("got Type = %d, want = %d", got, NDPDNSSearchListOptionType)
- }
-
- opt, ok := next.(NDPDNSSearchList)
- if !ok {
- t.Fatalf("next (type = %T) cannot be casted to an NDPDNSSearchList", next)
- }
- if got := opt.Type(); got != 31 {
- t.Errorf("got Type = %d, want = 31", got)
- }
- if got := opt.Length(); got != 22 {
- t.Errorf("got Length = %d, want = 22", got)
- }
- if got, want := opt.Lifetime(), 16777216*time.Second; got != want {
- t.Errorf("got Lifetime = %s, want = %s", got, want)
- }
- domainNames, err := opt.DomainNames()
- if err != nil {
- t.Errorf("opt.DomainNames() = %s", err)
- }
- if diff := cmp.Diff(domainNames, []string{"abc.abcd.e"}); diff != "" {
- t.Errorf("domain names mismatch (-want +got):\n%s", diff)
- }
-
- // Iterator should not return anything else.
- next, done, err = it.Next()
- if err != nil {
- t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if !done {
- t.Error("got Next = (_, false, _), want = (_, true, _)")
- }
- if next != nil {
- t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
- }
-}
-
// TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions
// the iterator was returned for is malformed.
func TestNDPOptionsIterCheck(t *testing.T) {
@@ -1472,8 +1403,8 @@ func TestNDPOptionsIter(t *testing.T) {
if got, want := []byte(next.(NDPSourceLinkLayerAddressOption)), buf[2:][:6]; !bytes.Equal(got, want) {
t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
}
- if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType {
- t.Errorf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType)
+ if got := next.kind(); got != ndpSourceLinkLayerAddressOptionType {
+ t.Errorf("got Type = %d, want = %d", got, ndpSourceLinkLayerAddressOptionType)
}
// Test the next (Target Link-Layer) option.
@@ -1487,8 +1418,8 @@ func TestNDPOptionsIter(t *testing.T) {
if got, want := []byte(next.(NDPTargetLinkLayerAddressOption)), buf[10:][:6]; !bytes.Equal(got, want) {
t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
}
- if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType {
- t.Errorf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType)
+ if got := next.kind(); got != ndpTargetLinkLayerAddressOptionType {
+ t.Errorf("got Type = %d, want = %d", got, ndpTargetLinkLayerAddressOptionType)
}
// Test the next (Prefix Information) option.
@@ -1503,8 +1434,8 @@ func TestNDPOptionsIter(t *testing.T) {
if got, want := next.(NDPPrefixInformation), buf[34:][:30]; !bytes.Equal(got, want) {
t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
}
- if got := next.Type(); got != NDPPrefixInformationType {
- t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType)
+ if got := next.kind(); got != ndpPrefixInformationType {
+ t.Errorf("got Type = %d, want = %d", got, ndpPrefixInformationType)
}
// Iterator should not return anything else.
diff --git a/pkg/tcpip/header/ndpoptionidentifier_string.go b/pkg/tcpip/header/ndpoptionidentifier_string.go
index 6fe9a336b..55ab1d7cf 100644
--- a/pkg/tcpip/header/ndpoptionidentifier_string.go
+++ b/pkg/tcpip/header/ndpoptionidentifier_string.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Code generated by "stringer -type NDPOptionIdentifier ."; DO NOT EDIT.
+// Code generated by "stringer -type ndpOptionIdentifier"; DO NOT EDIT.
package header
@@ -22,29 +22,37 @@ func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
- _ = x[NDPSourceLinkLayerAddressOptionType-1]
- _ = x[NDPTargetLinkLayerAddressOptionType-2]
- _ = x[NDPPrefixInformationType-3]
- _ = x[NDPRecursiveDNSServerOptionType-25]
+ _ = x[ndpSourceLinkLayerAddressOptionType-1]
+ _ = x[ndpTargetLinkLayerAddressOptionType-2]
+ _ = x[ndpPrefixInformationType-3]
+ _ = x[ndpNonceOptionType-14]
+ _ = x[ndpRecursiveDNSServerOptionType-25]
+ _ = x[ndpDNSSearchListOptionType-31]
}
const (
- _NDPOptionIdentifier_name_0 = "NDPSourceLinkLayerAddressOptionTypeNDPTargetLinkLayerAddressOptionTypeNDPPrefixInformationType"
- _NDPOptionIdentifier_name_1 = "NDPRecursiveDNSServerOptionType"
+ _ndpOptionIdentifier_name_0 = "ndpSourceLinkLayerAddressOptionTypendpTargetLinkLayerAddressOptionTypendpPrefixInformationType"
+ _ndpOptionIdentifier_name_1 = "ndpNonceOptionType"
+ _ndpOptionIdentifier_name_2 = "ndpRecursiveDNSServerOptionType"
+ _ndpOptionIdentifier_name_3 = "ndpDNSSearchListOptionType"
)
var (
- _NDPOptionIdentifier_index_0 = [...]uint8{0, 35, 70, 94}
+ _ndpOptionIdentifier_index_0 = [...]uint8{0, 35, 70, 94}
)
-func (i NDPOptionIdentifier) String() string {
+func (i ndpOptionIdentifier) String() string {
switch {
case 1 <= i && i <= 3:
i -= 1
- return _NDPOptionIdentifier_name_0[_NDPOptionIdentifier_index_0[i]:_NDPOptionIdentifier_index_0[i+1]]
+ return _ndpOptionIdentifier_name_0[_ndpOptionIdentifier_index_0[i]:_ndpOptionIdentifier_index_0[i+1]]
+ case i == 14:
+ return _ndpOptionIdentifier_name_1
case i == 25:
- return _NDPOptionIdentifier_name_1
+ return _ndpOptionIdentifier_name_2
+ case i == 31:
+ return _ndpOptionIdentifier_name_3
default:
- return "NDPOptionIdentifier(" + strconv.FormatInt(int64(i), 10) + ")"
+ return "ndpOptionIdentifier(" + strconv.FormatInt(int64(i), 10) + ")"
}
}
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index ae0461a6d..7ae38d684 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -38,6 +38,7 @@ const (
var _ stack.DuplicateAddressDetector = (*endpoint)(nil)
var _ stack.LinkAddressResolver = (*endpoint)(nil)
+var _ ip.DADProtocol = (*endpoint)(nil)
// ARP endpoints need to implement stack.NetworkEndpoint because the stack
// considers the layer above the link-layer a network layer; the only
@@ -82,7 +83,8 @@ func (*endpoint) DuplicateAddressProtocol() tcpip.NetworkProtocolNumber {
return header.IPv4ProtocolNumber
}
-func (e *endpoint) SendDADMessage(addr tcpip.Address) tcpip.Error {
+// SendDADMessage implements ip.DADProtocol.
+func (e *endpoint) SendDADMessage(addr tcpip.Address, _ []byte) tcpip.Error {
return e.sendARPRequest(header.IPv4Any, addr, header.EthernetBroadcastAddress)
}
@@ -284,9 +286,12 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran
e.mu.Lock()
e.mu.dad.Init(&e.mu, p.options.DADConfigs, ip.DADOptions{
- Clock: p.stack.Clock(),
- Protocol: e,
- NICID: nic.ID(),
+ Clock: p.stack.Clock(),
+ SecureRNG: p.stack.SecureRNG(),
+ // ARP does not support sending nonce values.
+ NonceSize: 0,
+ Protocol: e,
+ NICID: nic.ID(),
})
e.mu.Unlock()
@@ -305,8 +310,6 @@ func (*endpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error {
- nicID := e.nic.ID()
-
stats := e.stats.arp
if len(remoteLinkAddr) == 0 {
@@ -314,9 +317,9 @@ func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
}
if len(localAddr) == 0 {
- addr, ok := e.protocol.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber)
- if !ok {
- return &tcpip.ErrUnknownNICID{}
+ addr, err := e.nic.PrimaryAddress(header.IPv4ProtocolNumber)
+ if err != nil {
+ return err
}
if len(addr.Address) == 0 {
diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go
index 0053646ee..eed49f5d2 100644
--- a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go
+++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go
@@ -16,14 +16,27 @@
package ip
import (
+ "bytes"
"fmt"
+ "io"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+type extendRequest int
+
+const (
+ notRequested extendRequest = iota
+ requested
+ extended
+)
+
type dadState struct {
+ nonce []byte
+ extendRequest extendRequest
+
done *bool
timer tcpip.Timer
@@ -33,14 +46,17 @@ type dadState struct {
// DADProtocol is a protocol whose core state machine can be represented by DAD.
type DADProtocol interface {
// SendDADMessage attempts to send a DAD probe message.
- SendDADMessage(tcpip.Address) tcpip.Error
+ SendDADMessage(tcpip.Address, []byte) tcpip.Error
}
// DADOptions holds options for DAD.
type DADOptions struct {
- Clock tcpip.Clock
- Protocol DADProtocol
- NICID tcpip.NICID
+ Clock tcpip.Clock
+ SecureRNG io.Reader
+ NonceSize uint8
+ ExtendDADTransmits uint8
+ Protocol DADProtocol
+ NICID tcpip.NICID
}
// DAD performs duplicate address detection for addresses.
@@ -63,6 +79,10 @@ func (d *DAD) Init(protocolMU sync.Locker, configs stack.DADConfigurations, opts
panic("attempted to initialize DAD state twice")
}
+ if opts.NonceSize != 0 && opts.ExtendDADTransmits == 0 {
+ panic(fmt.Sprintf("given a non-zero value for NonceSize (%d) but zero for ExtendDADTransmits", opts.NonceSize))
+ }
+
*d = DAD{
opts: opts,
configs: configs,
@@ -96,10 +116,55 @@ func (d *DAD) CheckDuplicateAddressLocked(addr tcpip.Address, h stack.DADComplet
s = dadState{
done: &done,
timer: d.opts.Clock.AfterFunc(0, func() {
- var err tcpip.Error
dadDone := remaining == 0
+
+ nonce, earlyReturn := func() ([]byte, bool) {
+ d.protocolMU.Lock()
+ defer d.protocolMU.Unlock()
+
+ if done {
+ return nil, true
+ }
+
+ s, ok := d.addresses[addr]
+ if !ok {
+ panic(fmt.Sprintf("dad: timer fired but missing state for %s on NIC(%d)", addr, d.opts.NICID))
+ }
+
+ // As per RFC 7527 section 4
+ //
+ // If any probe is looped back within RetransTimer milliseconds
+ // after having sent DupAddrDetectTransmits NS(DAD) messages, the
+ // interface continues with another MAX_MULTICAST_SOLICIT number of
+ // NS(DAD) messages transmitted RetransTimer milliseconds apart.
+ if dadDone && s.extendRequest == requested {
+ dadDone = false
+ remaining = d.opts.ExtendDADTransmits
+ s.extendRequest = extended
+ }
+
+ if !dadDone && d.opts.NonceSize != 0 {
+ if s.nonce == nil {
+ s.nonce = make([]byte, d.opts.NonceSize)
+ }
+
+ if n, err := io.ReadFull(d.opts.SecureRNG, s.nonce); err != nil {
+ panic(fmt.Sprintf("SecureRNG.Read(...): %s", err))
+ } else if n != len(s.nonce) {
+ panic(fmt.Sprintf("expected to read %d bytes from secure RNG, only read %d bytes", len(s.nonce), n))
+ }
+ }
+
+ d.addresses[addr] = s
+ return s.nonce, false
+ }()
+ if earlyReturn {
+ return
+ }
+
+ var err tcpip.Error
if !dadDone {
- err = d.opts.Protocol.SendDADMessage(addr)
+ err = d.opts.Protocol.SendDADMessage(addr, nonce)
}
d.protocolMU.Lock()
@@ -142,6 +207,68 @@ func (d *DAD) CheckDuplicateAddressLocked(addr tcpip.Address, h stack.DADComplet
return ret
}
+// ExtendIfNonceEqualLockedDisposition enumerates the possible results from
+// ExtendIfNonceEqualLocked.
+type ExtendIfNonceEqualLockedDisposition int
+
+const (
+ // Extended indicates that the DAD process was extended.
+ Extended ExtendIfNonceEqualLockedDisposition = iota
+
+ // AlreadyExtended indicates that the DAD process was already extended.
+ AlreadyExtended
+
+ // NoDADStateFound indicates that DAD state was not found for the address.
+ NoDADStateFound
+
+ // NonceDisabled indicates that nonce values are not sent with DAD messages.
+ NonceDisabled
+
+ // NonceNotEqual indicates that the nonce value passed and the nonce in the
+ // last send DAD message are not equal.
+ NonceNotEqual
+)
+
+// ExtendIfNonceEqualLocked extends the DAD process if the provided nonce is the
+// same as the nonce sent in the last DAD message.
+//
+// Precondition: d.protocolMU must be locked.
+func (d *DAD) ExtendIfNonceEqualLocked(addr tcpip.Address, nonce []byte) ExtendIfNonceEqualLockedDisposition {
+ s, ok := d.addresses[addr]
+ if !ok {
+ return NoDADStateFound
+ }
+
+ if d.opts.NonceSize == 0 {
+ return NonceDisabled
+ }
+
+ if s.extendRequest != notRequested {
+ return AlreadyExtended
+ }
+
+ // As per RFC 7527 section 4
+ //
+ // If any probe is looped back within RetransTimer milliseconds after having
+ // sent DupAddrDetectTransmits NS(DAD) messages, the interface continues
+ // with another MAX_MULTICAST_SOLICIT number of NS(DAD) messages transmitted
+ // RetransTimer milliseconds apart.
+ //
+ // If a DAD message has already been sent and the nonce value we observed is
+ // the same as the nonce value we last sent, then we assume our probe was
+ // looped back and request an extension to the DAD process.
+ //
+ // Note, the first DAD message is sent asynchronously so we need to make sure
+ // that we sent a DAD message by checking if we have a nonce value set.
+ if s.nonce != nil && bytes.Equal(s.nonce, nonce) {
+ s.extendRequest = requested
+ d.addresses[addr] = s
+ return Extended
+ }
+
+ return NonceNotEqual
+}
+
// StopLocked stops a currently running DAD process.
//
// Precondition: d.protocolMU must be locked.
diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go
index e00aa4678..a22b712c6 100644
--- a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go
+++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go
@@ -15,6 +15,7 @@
package ip_test
import (
+ "bytes"
"testing"
"time"
@@ -32,8 +33,8 @@ type mockDADProtocol struct {
mu struct {
sync.Mutex
- dad ip.DAD
- sendCount map[tcpip.Address]int
+ dad ip.DAD
+ sentNonces map[tcpip.Address][][]byte
}
}
@@ -48,26 +49,30 @@ func (m *mockDADProtocol) init(t *testing.T, c stack.DADConfigurations, opts ip.
}
func (m *mockDADProtocol) initLocked() {
- m.mu.sendCount = make(map[tcpip.Address]int)
+ m.mu.sentNonces = make(map[tcpip.Address][][]byte)
}
-func (m *mockDADProtocol) SendDADMessage(addr tcpip.Address) tcpip.Error {
+func (m *mockDADProtocol) SendDADMessage(addr tcpip.Address, nonce []byte) tcpip.Error {
m.mu.Lock()
defer m.mu.Unlock()
- m.mu.sendCount[addr]++
+ m.mu.sentNonces[addr] = append(m.mu.sentNonces[addr], nonce)
return nil
}
func (m *mockDADProtocol) check(addrs []tcpip.Address) string {
- m.mu.Lock()
- defer m.mu.Unlock()
-
- sendCount := make(map[tcpip.Address]int)
+ sentNonces := make(map[tcpip.Address][][]byte)
for _, a := range addrs {
- sendCount[a]++
+ sentNonces[a] = append(sentNonces[a], nil)
}
- diff := cmp.Diff(sendCount, m.mu.sendCount)
+ return m.checkWithNonce(sentNonces)
+}
+
+func (m *mockDADProtocol) checkWithNonce(expectedSentNonces map[tcpip.Address][][]byte) string {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ diff := cmp.Diff(expectedSentNonces, m.mu.sentNonces)
m.initLocked()
return diff
}
@@ -84,6 +89,12 @@ func (m *mockDADProtocol) stop(addr tcpip.Address, reason stack.DADResult) {
m.mu.dad.StopLocked(addr, reason)
}
+func (m *mockDADProtocol) extendIfNonceEqual(addr tcpip.Address, nonce []byte) ip.ExtendIfNonceEqualLockedDisposition {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ return m.mu.dad.ExtendIfNonceEqualLocked(addr, nonce)
+}
+
func (m *mockDADProtocol) setConfigs(c stack.DADConfigurations) {
m.mu.Lock()
defer m.mu.Unlock()
@@ -277,3 +288,94 @@ func TestDADStop(t *testing.T) {
default:
}
}
+
+func TestNonce(t *testing.T) {
+ const (
+ nonceSize = 2
+
+ extendRequestAttempts = 2
+
+ dupAddrDetectTransmits = 2
+ extendTransmits = 5
+ )
+
+ var secureRNGBytes [nonceSize * (dupAddrDetectTransmits + extendTransmits)]byte
+ for i := range secureRNGBytes {
+ secureRNGBytes[i] = byte(i)
+ }
+
+ tests := []struct {
+ name string
+ mockedReceivedNonce []byte
+ expectedResults [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition
+ expectedTransmits int
+ }{
+ {
+ name: "not matching",
+ mockedReceivedNonce: []byte{0, 0},
+ expectedResults: [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition{ip.NonceNotEqual, ip.NonceNotEqual},
+ expectedTransmits: dupAddrDetectTransmits,
+ },
+ {
+ name: "matching nonce",
+ mockedReceivedNonce: secureRNGBytes[:nonceSize],
+ expectedResults: [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition{ip.Extended, ip.AlreadyExtended},
+ expectedTransmits: dupAddrDetectTransmits + extendTransmits,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ var dad mockDADProtocol
+ clock := faketime.NewManualClock()
+ dadConfigs := stack.DADConfigurations{
+ DupAddrDetectTransmits: dupAddrDetectTransmits,
+ RetransmitTimer: time.Second,
+ }
+
+ var secureRNG bytes.Reader
+ secureRNG.Reset(secureRNGBytes[:])
+ dad.init(t, dadConfigs, ip.DADOptions{
+ Clock: clock,
+ SecureRNG: &secureRNG,
+ NonceSize: nonceSize,
+ ExtendDADTransmits: extendTransmits,
+ })
+
+ ch := make(chan dadResult, 1)
+ if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting {
+ t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting)
+ }
+
+ clock.Advance(0)
+ for i, want := range test.expectedResults {
+ if got := dad.extendIfNonceEqual(addr1, test.mockedReceivedNonce); got != want {
+ t.Errorf("(i=%d) got dad.extendIfNonceEqual(%s, _) = %d, want = %d", i, addr1, got, want)
+ }
+ }
+
+ for i := 0; i < test.expectedTransmits; i++ {
+ if diff := dad.checkWithNonce(map[tcpip.Address][][]byte{
+ addr1: {
+ secureRNGBytes[nonceSize*i:][:nonceSize],
+ },
+ }); diff != "" {
+ t.Errorf("(i=%d) dad check mismatch (-want +got):\n%s", i, diff)
+ }
+
+ clock.Advance(dadConfigs.RetransmitTimer)
+ }
+
+ if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" {
+ t.Errorf("dad result mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have anymore updates.
+ select {
+ case r := <-ch:
+ t.Fatalf("unexpectedly got an extra DAD result; r = %#v", r)
+ default:
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index aee1652fa..a4edc69c7 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -335,6 +335,10 @@ func (*testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tc
return nil
}
+func (*testInterface) PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) {
+ return tcpip.AddressWithPrefix{}, nil
+}
+
func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool {
return false
}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 8a2140ebe..a1660e9a3 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -593,7 +593,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
// Check if the destination is owned by the stack.
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
- ep.handlePacket(pkt)
+ ep.handleValidatedPacket(h, pkt)
return nil
}
@@ -634,12 +634,25 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
- if !e.protocol.parse(pkt) {
+ h, ok := e.protocol.parseAndValidate(pkt)
+ if !ok {
stats.MalformedPacketsReceived.Increment()
return
}
if !e.nic.IsLoopback() {
+ if !e.protocol.options.AllowExternalLoopbackTraffic {
+ if header.IsV4LoopbackAddress(h.SourceAddress()) {
+ stats.InvalidSourceAddressesReceived.Increment()
+ return
+ }
+
+ if header.IsV4LoopbackAddress(h.DestinationAddress()) {
+ stats.InvalidDestinationAddressesReceived.Increment()
+ return
+ }
+ }
+
if e.protocol.stack.HandleLocal() {
addressEndpoint := e.AcquireAssignedAddress(header.IPv4(pkt.NetworkHeader().View()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
if addressEndpoint != nil {
@@ -662,62 +675,32 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
}
}
- e.handlePacket(pkt)
+ e.handleValidatedPacket(h, pkt)
}
+// handleLocalPacket is like HandlePacket except it does not perform the
+// prerouting iptables hook or check for loopback traffic that originated from
+// outside of the netstack (i.e. martian loopback packets).
func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum bool) {
stats := e.stats.ip
-
stats.PacketsReceived.Increment()
pkt = pkt.CloneToInbound()
- if e.protocol.parse(pkt) {
- pkt.RXTransportChecksumValidated = canSkipRXChecksum
- e.handlePacket(pkt)
+ pkt.RXTransportChecksumValidated = canSkipRXChecksum
+
+ h, ok := e.protocol.parseAndValidate(pkt)
+ if !ok {
+ stats.MalformedPacketsReceived.Increment()
return
}
- stats.MalformedPacketsReceived.Increment()
+ e.handleValidatedPacket(h, pkt)
}
-// handlePacket is like HandlePacket except it does not perform the prerouting
-// iptables hook.
-func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
+func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) {
pkt.NICID = e.nic.ID()
stats := e.stats
- h := header.IPv4(pkt.NetworkHeader().View())
- if !h.IsValid(pkt.Data().Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
- stats.ip.MalformedPacketsReceived.Increment()
- return
- }
-
- // There has been some confusion regarding verifying checksums. We need
- // just look for negative 0 (0xffff) as the checksum, as it's not possible to
- // get positive 0 (0) for the checksum. Some bad implementations could get it
- // when doing entry replacement in the early days of the Internet,
- // however the lore that one needs to check for both persists.
- //
- // RFC 1624 section 1 describes the source of this confusion as:
- // [the partial recalculation method described in RFC 1071] computes a
- // result for certain cases that differs from the one obtained from
- // scratch (one's complement of one's complement sum of the original
- // fields).
- //
- // However RFC 1624 section 5 clarifies that if using the verification method
- // "recommended by RFC 1071, it does not matter if an intermediate system
- // generated a -0 instead of +0".
- //
- // RFC1071 page 1 specifies the verification method as:
- // (3) To check a checksum, the 1's complement sum is computed over the
- // same set of octets, including the checksum field. If the result
- // is all 1 bits (-0 in 1's complement arithmetic), the check
- // succeeds.
- if h.CalculateChecksum() != 0xffff {
- stats.ip.MalformedPacketsReceived.Increment()
- return
- }
-
srcAddr := h.SourceAddress()
dstAddr := h.DestinationAddress()
@@ -1114,13 +1097,46 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
-// parse is like Parse but also attempts to parse the transport layer.
+// parseAndValidate parses the packet (including its transport layer header) and
+// returns the parsed IP header.
//
-// Returns true if the network header was successfully parsed.
-func (p *protocol) parse(pkt *stack.PacketBuffer) bool {
+// Returns true if the IP header was successfully parsed.
+func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv4, bool) {
transProtoNum, hasTransportHdr, ok := p.Parse(pkt)
if !ok {
- return false
+ return nil, false
+ }
+
+ h := header.IPv4(pkt.NetworkHeader().View())
+ // Do not include the link header's size when calculating the size of the IP
+ // packet.
+ if !h.IsValid(pkt.Size() - pkt.LinkHeader().View().Size()) {
+ return nil, false
+ }
+
+ // There has been some confusion regarding verifying checksums. We need
+ // just look for negative 0 (0xffff) as the checksum, as it's not possible to
+ // get positive 0 (0) for the checksum. Some bad implementations could get it
+ // when doing entry replacement in the early days of the Internet,
+ // however the lore that one needs to check for both persists.
+ //
+ // RFC 1624 section 1 describes the source of this confusion as:
+ // [the partial recalculation method described in RFC 1071] computes a
+ // result for certain cases that differs from the one obtained from
+ // scratch (one's complement of one's complement sum of the original
+ // fields).
+ //
+ // However RFC 1624 section 5 clarifies that if using the verification method
+ // "recommended by RFC 1071, it does not matter if an intermediate system
+ // generated a -0 instead of +0".
+ //
+ // RFC1071 page 1 specifies the verification method as:
+ // (3) To check a checksum, the 1's complement sum is computed over the
+ // same set of octets, including the checksum field. If the result
+ // is all 1 bits (-0 in 1's complement arithmetic), the check
+ // succeeds.
+ if h.CalculateChecksum() != 0xffff {
+ return nil, false
}
if hasTransportHdr {
@@ -1134,7 +1150,7 @@ func (p *protocol) parse(pkt *stack.PacketBuffer) bool {
}
}
- return true
+ return h, true
}
// Parse implements stack.NetworkProtocol.Parse.
@@ -1213,6 +1229,10 @@ func hashRoute(srcAddr, dstAddr tcpip.Address, protocol tcpip.TransportProtocolN
type Options struct {
// IGMP holds options for IGMP.
IGMP IGMPOptions
+
+ // AllowExternalLoopbackTraffic indicates that inbound loopback packets (i.e.
+ // martian loopback packets) should be accepted.
+ AllowExternalLoopbackTraffic bool
}
// NewProtocolWithOptions returns an IPv4 network protocol.
@@ -1599,9 +1619,8 @@ func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Opt
// TODO(https://gvisor.dev/issue/4586): This will need tweaking when we start
// really forwarding packets as we may need to get two addresses, for rx and
// tx interfaces. We will also have to take usage into account.
- prefixedAddress, ok := e.protocol.stack.GetMainNICAddress(e.nic.ID(), ProtocolNumber)
- localAddress := prefixedAddress.Address
- if !ok {
+ localAddress := e.MainAddress().Address
+ if len(localAddress) == 0 {
h := header.IPv4(pkt.NetworkHeader().View())
dstAddr := h.DestinationAddress()
if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(dstAddr) {
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 6344a3e09..2afa856dc 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -369,6 +369,18 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
return
}
+ var it header.NDPOptionIterator
+ {
+ var err error
+ it, err = ns.Options().Iter(false /* check */)
+ if err != nil {
+ // Options are not valid as per the wire format, silently drop the
+ // packet.
+ received.invalid.Increment()
+ return
+ }
+ }
+
if e.hasTentativeAddr(targetAddr) {
// If the target address is tentative and the source of the packet is a
// unicast (specified) address, then the source of the packet is
@@ -382,6 +394,22 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
// stack know so it can handle such a scenario and do nothing further with
// the NS.
if srcAddr == header.IPv6Any {
+ var nonce []byte
+ for {
+ opt, done, err := it.Next()
+ if err != nil {
+ received.invalid.Increment()
+ return
+ }
+ if done {
+ break
+ }
+ if n, ok := opt.(header.NDPNonceOption); ok {
+ nonce = n.Nonce()
+ break
+ }
+ }
+
// Since this is a DAD message we know the sender does not actually hold
// the target address so there is no "holder".
var holderLinkAddress tcpip.LinkAddress
@@ -397,7 +425,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
//
// TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate
// address is detected for an assigned address.
- switch err := e.dupTentativeAddrDetected(targetAddr, holderLinkAddress); err.(type) {
+ switch err := e.dupTentativeAddrDetected(targetAddr, holderLinkAddress, nonce); err.(type) {
case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState:
default:
panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err))
@@ -418,21 +446,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
return
}
- var sourceLinkAddr tcpip.LinkAddress
- {
- it, err := ns.Options().Iter(false /* check */)
- if err != nil {
- // Options are not valid as per the wire format, silently drop the
- // packet.
- received.invalid.Increment()
- return
- }
-
- sourceLinkAddr, ok = getSourceLinkAddr(it)
- if !ok {
- received.invalid.Increment()
- return
- }
+ sourceLinkAddr, ok := getSourceLinkAddr(it)
+ if !ok {
+ received.invalid.Increment()
+ return
}
// As per RFC 4861 section 4.3, the Source Link-Layer Address Option MUST
@@ -586,6 +603,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
e.dad.mu.Unlock()
if e.hasTentativeAddr(targetAddr) {
+ // We only send a nonce value in DAD messages to check for loopedback
+ // messages so we use the empty nonce value here.
+ var nonce []byte
+
// We just got an NA from a node that owns an address we are performing
// DAD on, implying the address is not unique. In this case we let the
// stack know so it can handle such a scenario and do nothing furthur with
@@ -602,7 +623,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
//
// TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate
// address is detected for an assigned address.
- switch err := e.dupTentativeAddrDetected(targetAddr, targetLinkAddr); err.(type) {
+ switch err := e.dupTentativeAddrDetected(targetAddr, targetLinkAddr, nonce); err.(type) {
case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState:
return
default:
@@ -899,13 +920,16 @@ func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
}
if len(localAddr) == 0 {
+ // Find an address that we can use as our source address.
addressEndpoint := e.AcquireOutgoingPrimaryAddress(remoteAddr, false /* allowExpired */)
if addressEndpoint == nil {
return &tcpip.ErrNetworkUnreachable{}
}
localAddr = addressEndpoint.AddressWithPrefix().Address
- } else if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, localAddr) == 0 {
+ addressEndpoint.DecRef()
+ } else if !e.checkLocalAddress(localAddr) {
+ // The provided local address is not assigned to us.
return &tcpip.ErrBadLocalAddress{}
}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index d4e63710c..47d713f88 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -155,6 +155,10 @@ func (t *testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber,
return nil
}
+func (*testInterface) PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) {
+ return tcpip.AddressWithPrefix{}, nil
+}
+
func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool {
return false
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 46b6cc41a..83e98bab9 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -348,7 +348,7 @@ func (e *endpoint) hasTentativeAddr(addr tcpip.Address) bool {
// dupTentativeAddrDetected removes the tentative address if it exists. If the
// address was generated via SLAAC, an attempt is made to generate a new
// address.
-func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr tcpip.LinkAddress) tcpip.Error {
+func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr tcpip.LinkAddress, nonce []byte) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -361,27 +361,48 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr t
return &tcpip.ErrInvalidEndpointState{}
}
- // If the address is a SLAAC address, do not invalidate its SLAAC prefix as an
- // attempt will be made to generate a new address for it.
- if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */, &stack.DADDupAddrDetected{HolderLinkAddress: holderLinkAddr}); err != nil {
- return err
- }
+ switch result := e.mu.ndp.dad.ExtendIfNonceEqualLocked(addr, nonce); result {
+ case ip.Extended:
+ // The nonce we got back was the same we sent so we know the message
+ // indicating a duplicate address was likely ours so do not consider
+ // the address duplicate here.
+ return nil
+ case ip.AlreadyExtended:
+ // See Extended.
+ //
+ // Our DAD message was looped back already.
+ return nil
+ case ip.NoDADStateFound:
+ panic(fmt.Sprintf("expected DAD state for tentative address %s", addr))
+ case ip.NonceDisabled:
+ // If nonce is disabled then we have no way to know if the packet was
+ // looped-back so we have to assume it indicates a duplicate address.
+ fallthrough
+ case ip.NonceNotEqual:
+ // If the address is a SLAAC address, do not invalidate its SLAAC prefix as an
+ // attempt will be made to generate a new address for it.
+ if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */, &stack.DADDupAddrDetected{HolderLinkAddress: holderLinkAddr}); err != nil {
+ return err
+ }
- prefix := addressEndpoint.Subnet()
+ prefix := addressEndpoint.Subnet()
- switch t := addressEndpoint.ConfigType(); t {
- case stack.AddressConfigStatic:
- case stack.AddressConfigSlaac:
- e.mu.ndp.regenerateSLAACAddr(prefix)
- case stack.AddressConfigSlaacTemp:
- // Do not reset the generation attempts counter for the prefix as the
- // temporary address is being regenerated in response to a DAD conflict.
- e.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */)
+ switch t := addressEndpoint.ConfigType(); t {
+ case stack.AddressConfigStatic:
+ case stack.AddressConfigSlaac:
+ e.mu.ndp.regenerateSLAACAddr(prefix)
+ case stack.AddressConfigSlaacTemp:
+ // Do not reset the generation attempts counter for the prefix as the
+ // temporary address is being regenerated in response to a DAD conflict.
+ e.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */)
+ default:
+ panic(fmt.Sprintf("unrecognized address config type = %d", t))
+ }
+
+ return nil
default:
- panic(fmt.Sprintf("unrecognized address config type = %d", t))
+ panic(fmt.Sprintf("unhandled result = %d", result))
}
-
- return nil
}
// transitionForwarding transitions the endpoint's forwarding status to
@@ -863,9 +884,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
dstAddr := h.DestinationAddress()
// Check if the destination is owned by the stack.
-
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
- ep.handlePacket(pkt)
+ ep.handleValidatedPacket(h, pkt)
return nil
}
@@ -904,12 +924,25 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
- if !e.protocol.parse(pkt) {
+ h, ok := e.protocol.parseAndValidate(pkt)
+ if !ok {
stats.MalformedPacketsReceived.Increment()
return
}
if !e.nic.IsLoopback() {
+ if !e.protocol.options.AllowExternalLoopbackTraffic {
+ if header.IsV6LoopbackAddress(h.SourceAddress()) {
+ stats.InvalidSourceAddressesReceived.Increment()
+ return
+ }
+
+ if header.IsV6LoopbackAddress(h.DestinationAddress()) {
+ stats.InvalidDestinationAddressesReceived.Increment()
+ return
+ }
+ }
+
if e.protocol.stack.HandleLocal() {
addressEndpoint := e.AcquireAssignedAddress(header.IPv6(pkt.NetworkHeader().View()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
if addressEndpoint != nil {
@@ -932,35 +965,31 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
}
}
- e.handlePacket(pkt)
+ e.handleValidatedPacket(h, pkt)
}
+// handleLocalPacket is like HandlePacket except it does not perform the
+// prerouting iptables hook or check for loopback traffic that originated from
+// outside of the netstack (i.e. martian loopback packets).
func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum bool) {
stats := e.stats.ip
-
stats.PacketsReceived.Increment()
pkt = pkt.CloneToInbound()
- if e.protocol.parse(pkt) {
- pkt.RXTransportChecksumValidated = canSkipRXChecksum
- e.handlePacket(pkt)
+ pkt.RXTransportChecksumValidated = canSkipRXChecksum
+
+ h, ok := e.protocol.parseAndValidate(pkt)
+ if !ok {
+ stats.MalformedPacketsReceived.Increment()
return
}
- stats.MalformedPacketsReceived.Increment()
+ e.handleValidatedPacket(h, pkt)
}
-// handlePacket is like HandlePacket except it does not perform the prerouting
-// iptables hook.
-func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
+func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) {
pkt.NICID = e.nic.ID()
stats := e.stats.ip
-
- h := header.IPv6(pkt.NetworkHeader().View())
- if !h.IsValid(pkt.Data().Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
- stats.MalformedPacketsReceived.Increment()
- return
- }
srcAddr := h.SourceAddress()
dstAddr := h.DestinationAddress()
@@ -1797,16 +1826,36 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran
dispatcher: dispatcher,
protocol: p,
}
+
+ // NDP options must be 8 octet aligned and the first 2 bytes are used for
+ // the type and length fields leaving 6 octets as the minimum size for a
+ // nonce option without padding.
+ const nonceSize = 6
+
+ // As per RFC 7527 section 4.1,
+ //
+ // If any probe is looped back within RetransTimer milliseconds after
+ // having sent DupAddrDetectTransmits NS(DAD) messages, the interface
+ // continues with another MAX_MULTICAST_SOLICIT number of NS(DAD)
+ // messages transmitted RetransTimer milliseconds apart.
+ //
+ // Value taken from RFC 4861 section 10.
+ const maxMulticastSolicit = 3
+ dadOptions := ip.DADOptions{
+ Clock: p.stack.Clock(),
+ SecureRNG: p.stack.SecureRNG(),
+ NonceSize: nonceSize,
+ ExtendDADTransmits: maxMulticastSolicit,
+ Protocol: &e.mu.ndp,
+ NICID: nic.ID(),
+ }
+
e.mu.Lock()
e.mu.addressableEndpointState.Init(e)
- e.mu.ndp.init(e)
+ e.mu.ndp.init(e, dadOptions)
e.mu.mld.init(e)
e.dad.mu.Lock()
- e.dad.mu.dad.Init(&e.dad.mu, p.options.DADConfigs, ip.DADOptions{
- Clock: p.stack.Clock(),
- Protocol: &e.mu.ndp,
- NICID: nic.ID(),
- })
+ e.dad.mu.dad.Init(&e.dad.mu, p.options.DADConfigs, dadOptions)
e.dad.mu.Unlock()
e.mu.Unlock()
@@ -1879,13 +1928,21 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
-// parse is like Parse but also attempts to parse the transport layer.
+// parseAndValidate parses the packet (including its transport layer header) and
+// returns the parsed IP header.
//
-// Returns true if the network header was successfully parsed.
-func (p *protocol) parse(pkt *stack.PacketBuffer) bool {
+// Returns true if the IP header was successfully parsed.
+func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv6, bool) {
transProtoNum, hasTransportHdr, ok := p.Parse(pkt)
if !ok {
- return false
+ return nil, false
+ }
+
+ h := header.IPv6(pkt.NetworkHeader().View())
+ // Do not include the link header's size when calculating the size of the IP
+ // packet.
+ if !h.IsValid(pkt.Size() - pkt.LinkHeader().View().Size()) {
+ return nil, false
}
if hasTransportHdr {
@@ -1899,7 +1956,7 @@ func (p *protocol) parse(pkt *stack.PacketBuffer) bool {
}
}
- return true
+ return h, true
}
// Parse implements stack.NetworkProtocol.Parse.
@@ -2013,6 +2070,10 @@ type Options struct {
// DADConfigs holds the default DAD configurations used by IPv6 endpoints.
DADConfigs stack.DADConfigurations
+
+ // AllowExternalLoopbackTraffic indicates that inbound loopback packets (i.e.
+ // martian loopback packets) should be accepted.
+ AllowExternalLoopbackTraffic bool
}
// NewProtocolWithOptions returns an IPv6 network protocol.
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 266a53e3b..81f5f23c3 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -343,6 +343,8 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
// TestAddIpv6Address tests adding IPv6 addresses.
func TestAddIpv6Address(t *testing.T) {
+ const nicID = 1
+
tests := []struct {
name string
addr tcpip.Address
@@ -367,18 +369,18 @@ func TestAddIpv6Address(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
- if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(1, ProtocolNumber, test.addr); err != nil {
- t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err)
+ if err := s.AddAddress(nicID, ProtocolNumber, test.addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, nil) = %s", nicID, ProtocolNumber, err)
}
- if addr, ok := s.GetMainNICAddress(1, header.IPv6ProtocolNumber); !ok {
- t.Fatalf("got stack.GetMainNICAddress(1, %d) = (_, false), want = (_, true)", header.IPv6ProtocolNumber)
+ if addr, err := s.GetMainNICAddress(nicID, ProtocolNumber); err != nil {
+ t.Fatalf("stack.GetMainNICAddress(%d, %d): %s", nicID, ProtocolNumber, err)
} else if addr.Address != test.addr {
- t.Fatalf("got stack.GetMainNICAddress(1_, %d) = (%s, true), want = (%s, true)", header.IPv6ProtocolNumber, addr.Address, test.addr)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, ProtocolNumber, addr.Address, test.addr)
}
})
}
diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go
index 9a425e50a..85a8f9944 100644
--- a/pkg/tcpip/network/ipv6/mld_test.go
+++ b/pkg/tcpip/network/ipv6/mld_test.go
@@ -15,6 +15,7 @@
package ipv6_test
import (
+ "bytes"
"testing"
"time"
@@ -119,11 +120,26 @@ func TestSendQueuedMLDReports(t *testing.T) {
},
}
+ nonce := [...]byte{
+ 1, 2, 3, 4, 5, 6,
+ }
+
+ const maxNSMessages = 2
+ secureRNGBytes := make([]byte, len(nonce)*maxNSMessages)
+ for b := secureRNGBytes[:]; len(b) > 0; b = b[len(nonce):] {
+ if n := copy(b, nonce[:]); n != len(nonce) {
+ t.Fatalf("got copy(...) = %d, want = %d", n, len(nonce))
+ }
+ }
+
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
dadResolutionTime := test.retransmitTimer * time.Duration(test.dadTransmits)
clock := faketime.NewManualClock()
+ var secureRNG bytes.Reader
+ secureRNG.Reset(secureRNGBytes[:])
s := stack.New(stack.Options{
+ SecureRNG: &secureRNG,
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
DADConfigs: stack.DADConfigurations{
DupAddrDetectTransmits: test.dadTransmits,
@@ -154,7 +170,7 @@ func TestSendQueuedMLDReports(t *testing.T) {
checker.TTL(header.NDPHopLimit),
checker.NDPNS(
checker.NDPNSTargetAddress(addr),
- checker.NDPNSOptions(nil),
+ checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonce[:])}),
))
}
}
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index d9b728878..536493f87 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -1789,18 +1789,14 @@ func (ndp *ndpState) stopSolicitingRouters() {
ndp.rtrSolicitTimer = timer{}
}
-func (ndp *ndpState) init(ep *endpoint) {
+func (ndp *ndpState) init(ep *endpoint, dadOptions ip.DADOptions) {
if ndp.defaultRouters != nil {
panic("attempted to initialize NDP state twice")
}
ndp.ep = ep
ndp.configs = ep.protocol.options.NDPConfigs
- ndp.dad.Init(&ndp.ep.mu, ep.protocol.options.DADConfigs, ip.DADOptions{
- Clock: ep.protocol.stack.Clock(),
- Protocol: ndp,
- NICID: ep.nic.ID(),
- })
+ ndp.dad.Init(&ndp.ep.mu, ep.protocol.options.DADConfigs, dadOptions)
ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState)
ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState)
ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState)
@@ -1811,9 +1807,11 @@ func (ndp *ndpState) init(ep *endpoint) {
}
}
-func (ndp *ndpState) SendDADMessage(addr tcpip.Address) tcpip.Error {
+func (ndp *ndpState) SendDADMessage(addr tcpip.Address, nonce []byte) tcpip.Error {
snmc := header.SolicitedNodeAddr(addr)
- return ndp.ep.sendNDPNS(header.IPv6Any, snmc, addr, header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* opts */)
+ return ndp.ep.sendNDPNS(header.IPv6Any, snmc, addr, header.EthernetAddressFromMulticastIPv6Address(snmc), header.NDPOptionsSerializer{
+ header.NDPNonceOption(nonce),
+ })
}
func (e *endpoint) sendNDPNS(srcAddr, dstAddr, targetAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, opts header.NDPOptionsSerializer) tcpip.Error {
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 6e850fd46..52b9a200c 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -15,6 +15,7 @@
package ipv6
import (
+ "bytes"
"context"
"strings"
"testing"
@@ -1264,8 +1265,21 @@ func TestCheckDuplicateAddress(t *testing.T) {
DupAddrDetectTransmits: 1,
RetransmitTimer: time.Second,
}
+
+ nonces := [...][]byte{
+ {1, 2, 3, 4, 5, 6},
+ {7, 8, 9, 10, 11, 12},
+ }
+
+ var secureRNGBytes []byte
+ for _, n := range nonces {
+ secureRNGBytes = append(secureRNGBytes, n...)
+ }
+ var secureRNG bytes.Reader
+ secureRNG.Reset(secureRNGBytes[:])
s := stack.New(stack.Options{
- Clock: clock,
+ SecureRNG: &secureRNG,
+ Clock: clock,
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{
DADConfigs: dadConfigs,
})},
@@ -1278,10 +1292,36 @@ func TestCheckDuplicateAddress(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- dadPacketsSent := 1
+ dadPacketsSent := 0
+ snmc := header.SolicitedNodeAddr(lladdr0)
+ remoteLinkAddr := header.EthernetAddressFromMulticastIPv6Address(snmc)
+ checkDADMsg := func() {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("expected %d-th DAD message", dadPacketsSent)
+ }
+
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Errorf("(i=%d) got p.Proto = %d, want = %d", dadPacketsSent, p.Proto, header.IPv6ProtocolNumber)
+ }
+
+ if p.Route.RemoteLinkAddress != remoteLinkAddr {
+ t.Errorf("(i=%d) got p.Route.RemoteLinkAddress = %s, want = %s", dadPacketsSent, p.Route.RemoteLinkAddress, remoteLinkAddr)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(snmc),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(lladdr0),
+ checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonces[dadPacketsSent])}),
+ ))
+ }
if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
}
+ checkDADMsg()
// Start DAD for the address we just added.
//
@@ -1297,6 +1337,7 @@ func TestCheckDuplicateAddress(t *testing.T) {
} else if res != stack.DADStarting {
t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, ProtocolNumber, lladdr0, res, stack.DADStarting)
}
+ checkDADMsg()
// Remove the address and make sure our DAD request was not stopped.
if err := s.RemoveAddress(nicID, lladdr0); err != nil {
@@ -1328,33 +1369,6 @@ func TestCheckDuplicateAddress(t *testing.T) {
default:
}
- snmc := header.SolicitedNodeAddr(lladdr0)
- remoteLinkAddr := header.EthernetAddressFromMulticastIPv6Address(snmc)
-
- for i := 0; i < dadPacketsSent; i++ {
- p, ok := e.Read()
- if !ok {
- t.Fatalf("expected %d-th DAD message", i)
- }
-
- if p.Proto != header.IPv6ProtocolNumber {
- t.Errorf("(i=%d) got p.Proto = %d, want = %d", i, p.Proto, header.IPv6ProtocolNumber)
- }
-
- if p.Route.RemoteLinkAddress != remoteLinkAddr {
- t.Errorf("(i=%d) got p.Route.RemoteLinkAddress = %s, want = %s", i, p.Route.RemoteLinkAddress, remoteLinkAddr)
- }
-
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(snmc),
- checker.TTL(header.NDPHopLimit),
- checker.NDPNS(
- checker.NDPNSTargetAddress(lladdr0),
- checker.NDPNSOptions(nil),
- ))
- }
-
// Should have no more packets.
if p, ok := e.Read(); ok {
t.Errorf("got unexpected packet = %#v", p)
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 47796a6ba..43e6d102c 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -15,6 +15,7 @@
package stack_test
import (
+ "bytes"
"context"
"encoding/binary"
"fmt"
@@ -29,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
@@ -358,6 +360,66 @@ func TestDADDisabled(t *testing.T) {
}
}
+func TestDADResolveLoopback(t *testing.T) {
+ const nicID = 1
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ }
+
+ dadConfigs := stack.DADConfigurations{
+ RetransmitTimer: time.Second,
+ DupAddrDetectTransmits: 1,
+ }
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ Clock: clock,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPDisp: &ndpDisp,
+ DADConfigs: dadConfigs,
+ })},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: addr1,
+ PrefixLen: defaultPrefixLen,
+ }
+ if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
+ t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
+ }
+
+ // Address should not be considered bound to the NIC yet (DAD ongoing).
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
+ }
+
+ // DAD should not resolve after the normal resolution time since our DAD
+ // message was looped back - we should extend our DAD process.
+ dadResolutionTime := time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer
+ clock.Advance(dadResolutionTime)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Error(err)
+ }
+
+ // Make sure the address does not resolve before the extended resolution time
+ // has passed.
+ const delta = time.Nanosecond
+ // DAD will send extra NS probes if an NS message is looped back.
+ const extraTransmits = 3
+ clock.Advance(dadResolutionTime*extraTransmits - delta)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Error(err)
+ }
+
+ // DAD should now resolve.
+ clock.Advance(delta)
+ if diff := checkDADEvent(<-ndpDisp.dadC, nicID, addr1, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
+ }
+}
+
// TestDADResolve tests that an address successfully resolves after performing
// DAD for various values of DupAddrDetectTransmits and RetransmitTimer.
// Included in the subtests is a test to make sure that an invalid
@@ -404,6 +466,16 @@ func TestDADResolve(t *testing.T) {
},
}
+ nonces := [][]byte{
+ {1, 2, 3, 4, 5, 6},
+ {7, 8, 9, 10, 11, 12},
+ }
+
+ var secureRNGBytes []byte
+ for _, n := range nonces {
+ secureRNGBytes = append(secureRNGBytes, n...)
+ }
+
for _, test := range tests {
test := test
@@ -419,7 +491,12 @@ func TestDADResolve(t *testing.T) {
headerLength: test.linkHeaderLen,
}
e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ var secureRNG bytes.Reader
+ secureRNG.Reset(secureRNGBytes)
+
s := stack.New(stack.Options{
+ SecureRNG: &secureRNG,
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPDisp: &ndpDisp,
DADConfigs: stack.DADConfigurations{
@@ -553,7 +630,7 @@ func TestDADResolve(t *testing.T) {
checker.TTL(header.NDPHopLimit),
checker.NDPNS(
checker.NDPNSTargetAddress(addr1),
- checker.NDPNSOptions(nil),
+ checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonces[i])}),
))
if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 62f7c880e..ca15c0691 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -568,23 +568,19 @@ func (n *nic) primaryAddresses() []tcpip.ProtocolAddress {
return addrs
}
-// primaryAddress returns the primary address associated with this NIC.
-//
-// primaryAddress will return the first non-deprecated address if such an
-// address exists. If no non-deprecated address exists, the first deprecated
-// address will be returned.
-func (n *nic) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix {
+// PrimaryAddress implements NetworkInterface.
+func (n *nic) PrimaryAddress(proto tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) {
ep, ok := n.networkEndpoints[proto]
if !ok {
- return tcpip.AddressWithPrefix{}
+ return tcpip.AddressWithPrefix{}, &tcpip.ErrUnknownProtocol{}
}
addressableEndpoint, ok := ep.(AddressableEndpoint)
if !ok {
- return tcpip.AddressWithPrefix{}
+ return tcpip.AddressWithPrefix{}, &tcpip.ErrNotSupported{}
}
- return addressableEndpoint.MainAddress()
+ return addressableEndpoint.MainAddress(), nil
}
// removeAddress removes an address from n.
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 85f0f471a..ff3a385e1 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -525,6 +525,14 @@ type NetworkInterface interface {
// assigned to it.
Spoofing() bool
+ // PrimaryAddress returns the primary address associated with the interface.
+ //
+ // PrimaryAddress will return the first non-deprecated address if such an
+ // address exists. If no non-deprecated addresses exist, the first deprecated
+ // address will be returned. If no deprecated addresses exist, the zero value
+ // will be returned.
+ PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error)
+
// CheckLocalAddress returns true if the address exists on the interface.
CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 53370c354..931a97ddc 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -23,6 +23,7 @@ import (
"bytes"
"encoding/binary"
"fmt"
+ "io"
mathrand "math/rand"
"sync/atomic"
"time"
@@ -445,6 +446,9 @@ type Stack struct {
// used when a random number is required.
randomGenerator *mathrand.Rand
+ // secureRNG is a cryptographically secure random number generator.
+ secureRNG io.Reader
+
// sendBufferSize holds the min/default/max send buffer sizes for
// endpoints other than TCP.
sendBufferSize tcpip.SendBufferSizeOption
@@ -528,6 +532,9 @@ type Options struct {
// IPTables are the initial iptables rules. If nil, iptables will allow
// all traffic.
IPTables *IPTables
+
+ // SecureRNG is a cryptographically secure random number generator.
+ SecureRNG io.Reader
}
// TransportEndpointInfo holds useful information about a transport endpoint
@@ -636,6 +643,10 @@ func New(opts Options) *Stack {
opts.NUDConfigs.resetInvalidFields()
+ if opts.SecureRNG == nil {
+ opts.SecureRNG = rand.Reader
+ }
+
s := &Stack{
transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
@@ -652,6 +663,7 @@ func New(opts Options) *Stack {
uniqueIDGenerator: opts.UniqueID,
nudDisp: opts.NUDDisp,
randomGenerator: mathrand.New(randSrc),
+ secureRNG: opts.SecureRNG,
sendBufferSize: tcpip.SendBufferSizeOption{
Min: MinBufferSize,
Default: DefaultBufferSize,
@@ -1211,20 +1223,19 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress {
}
// GetMainNICAddress returns the first non-deprecated primary address and prefix
-// for the given NIC and protocol. If no non-deprecated primary address exists,
-// a deprecated primary address and prefix will be returned. Returns false if
-// the NIC doesn't exist and an empty value if the NIC doesn't have a primary
-// address for the given protocol.
-func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, bool) {
+// for the given NIC and protocol. If no non-deprecated primary addresses exist,
+// a deprecated address will be returned. If no deprecated addresses exist, the
+// zero value will be returned.
+func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
nic, ok := s.nics[id]
if !ok {
- return tcpip.AddressWithPrefix{}, false
+ return tcpip.AddressWithPrefix{}, &tcpip.ErrUnknownNICID{}
}
- return nic.primaryAddress(protocol), true
+ return nic.PrimaryAddress(protocol)
}
func (s *Stack) getAddressEP(nic *nic, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) AssignableAddressEndpoint {
@@ -2047,6 +2058,12 @@ func (s *Stack) Rand() *mathrand.Rand {
return s.randomGenerator
}
+// SecureRNG returns the stack's cryptographically secure random number
+// generator.
+func (s *Stack) SecureRNG() io.Reader {
+ return s.secureRNG
+}
+
func generateRandUint32() uint32 {
b := make([]byte, 4)
if _, err := rand.Read(b); err != nil {
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 880219007..7ddf7a083 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -62,10 +62,10 @@ const (
)
func checkGetMainNICAddress(s *stack.Stack, nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber, want tcpip.AddressWithPrefix) error {
- if addr, ok := s.GetMainNICAddress(nicID, proto); !ok {
- return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, false), want = (_, true)", nicID, proto)
+ if addr, err := s.GetMainNICAddress(nicID, proto); err != nil {
+ return fmt.Errorf("stack.GetMainNICAddress(%d, %d): %s", nicID, proto, err)
} else if addr != want {
- return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, true), want = (%s, true)", nicID, proto, addr, want)
+ return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, proto, addr, want)
}
return nil
}
@@ -1854,6 +1854,8 @@ func TestNetworkOption(t *testing.T) {
}
func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
+ const nicID = 1
+
for _, addrLen := range []int{4, 16} {
t.Run(fmt.Sprintf("addrLen=%d", addrLen), func(t *testing.T) {
for canBe := 0; canBe < 3; canBe++ {
@@ -1864,8 +1866,8 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
+ if err := s.CreateNIC(nicID, ep); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
// Insert <canBe> primary and <never> never-primary addresses.
// Each one will add a network endpoint to the NIC.
@@ -1888,34 +1890,34 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
PrefixLen: addrLen * 8,
},
}
- if err := s.AddProtocolAddressWithOptions(1, protocolAddress, behavior); err != nil {
- t.Fatal("AddProtocolAddressWithOptions failed:", err)
+ if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil {
+ t.Fatalf("AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, protocolAddress, behavior, err)
}
// Remember the address/prefix.
primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{}
} else {
- if err := s.AddAddressWithOptions(1, fakeNetNumber, address, behavior); err != nil {
- t.Fatal("AddAddressWithOptions failed:", err)
+ if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil {
+ t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s:", nicID, fakeNetNumber, address, behavior, err)
}
}
}
// Check that GetMainNICAddress returns an address if at least
// one primary address was added. In that case make sure the
// address/prefixLen matches what we added.
- gotAddr, ok := s.GetMainNICAddress(1, fakeNetNumber)
- if !ok {
- t.Fatalf("got GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber)
+ gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber)
+ if err != nil {
+ t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
}
if len(primaryAddrAdded) == 0 {
// No primary addresses present.
if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr {
- t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, gotAddr, wantAddr)
+ t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, wantAddr)
}
} else {
// At least one primary address was added, verify the returned
// address is in the list of primary addresses we added.
if _, ok := primaryAddrAdded[gotAddr]; !ok {
- t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, gotAddr, primaryAddrAdded)
+ t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, primaryAddrAdded)
}
}
})
@@ -1926,6 +1928,45 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
}
}
+func TestGetMainNICAddressErrors(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ // Sanity check with a successful call.
+ if addr, err := s.GetMainNICAddress(nicID, ipv4.ProtocolNumber); err != nil {
+ t.Errorf("s.GetMainNICAddress(%d, %d): %s", nicID, ipv4.ProtocolNumber, err)
+ } else if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, ipv4.ProtocolNumber, addr, want)
+ }
+
+ const unknownNICID = nicID + 1
+ switch addr, err := s.GetMainNICAddress(unknownNICID, ipv4.ProtocolNumber); err.(type) {
+ case *tcpip.ErrUnknownNICID:
+ default:
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, %T), want = (_, tcpip.ErrUnknownNICID)", unknownNICID, ipv4.ProtocolNumber, addr, err)
+ }
+
+ // ARP is not an addressable network endpoint.
+ switch addr, err := s.GetMainNICAddress(nicID, arp.ProtocolNumber); err.(type) {
+ case *tcpip.ErrNotSupported:
+ default:
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, %T), want = (_, tcpip.ErrNotSupported)", nicID, arp.ProtocolNumber, addr, err)
+ }
+
+ const unknownProtocolNumber = 1234
+ switch addr, err := s.GetMainNICAddress(nicID, unknownProtocolNumber); err.(type) {
+ case *tcpip.ErrUnknownProtocol:
+ default:
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, %T), want = (_, tcpip.ErrUnknownProtocol)", nicID, unknownProtocolNumber, addr, err)
+ }
+}
+
func TestGetMainNICAddressAddRemove(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
@@ -2507,11 +2548,15 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
}
}
- // Check that we get no address after removal.
- if err := checkGetMainNICAddress(s, 1, fakeNetNumber, tcpip.AddressWithPrefix{}); err != nil {
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, expectedMainAddr); err != nil {
t.Fatal(err)
}
- if err := checkGetMainNICAddress(s, 1, header.IPv6ProtocolNumber, expectedMainAddr); err != nil {
+
+ // Disabling the NIC should remove the auto-generated address.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
t.Fatal(err)
}
})
@@ -2617,6 +2662,8 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
// TestNewPEB tests that a new PrimaryEndpointBehavior value (peb) is respected
// when an address's kind gets "promoted" to permanent from permanentExpired.
func TestNewPEBOnPromotionToPermanent(t *testing.T) {
+ const nicID = 1
+
pebs := []stack.PrimaryEndpointBehavior{
stack.NeverPrimaryEndpoint,
stack.CanBePrimaryEndpoint,
@@ -2630,8 +2677,8 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep1); err != nil {
- t.Fatal("CreateNIC failed:", err)
+ if err := s.CreateNIC(nicID, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
// Add a permanent address with initial
@@ -2639,20 +2686,21 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// NeverPrimaryEndpoint, the address should not
// be returned by a call to GetMainNICAddress;
// else, it should.
- if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", pi); err != nil {
- t.Fatal("AddAddressWithOptions failed:", err)
+ const address1 = tcpip.Address("\x01")
+ if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, pi); err != nil {
+ t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err)
}
- addr, ok := s.GetMainNICAddress(1, fakeNetNumber)
- if !ok {
- t.Fatalf("GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber)
+ addr, err := s.GetMainNICAddress(nicID, fakeNetNumber)
+ if err != nil {
+ t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
}
if pi == stack.NeverPrimaryEndpoint {
if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, addr, want)
+ t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr, want)
}
- } else if addr.Address != "\x01" {
- t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (1, true)", fakeNetNumber, addr.Address)
+ } else if addr.Address != address1 {
+ t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr.Address, address1)
}
{
@@ -2670,13 +2718,14 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// new peb is respected when an address gets
// "promoted" to permanent from a
// permanentExpired kind.
- r, err := s.FindRoute(1, "\x01", "\x02", fakeNetNumber, false)
+ const address2 = tcpip.Address("\x02")
+ r, err := s.FindRoute(nicID, address1, address2, fakeNetNumber, false)
if err != nil {
- t.Fatalf("FindRoute failed: %v", err)
+ t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, address1, address2, fakeNetNumber, err)
}
defer r.Release()
- if err := s.RemoveAddress(1, "\x01"); err != nil {
- t.Fatalf("RemoveAddress failed: %v", err)
+ if err := s.RemoveAddress(nicID, address1); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s): %s", nicID, address1, err)
}
//
@@ -2687,19 +2736,20 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// Add some other address with peb set to
// FirstPrimaryEndpoint.
- if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x03", stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddAddressWithOptions failed: %v", err)
+ const address3 = tcpip.Address("\x03")
+ if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint); err != nil {
+ t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint, err)
}
// Add back the address we removed earlier and
// make sure the new peb was respected.
// (The address should just be promoted now).
- if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", ps); err != nil {
- t.Fatalf("AddAddressWithOptions failed: %v", err)
+ if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, ps); err != nil {
+ t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err)
}
var primaryAddrs []tcpip.Address
- for _, pa := range s.NICInfo()[1].ProtocolAddresses {
+ for _, pa := range s.NICInfo()[nicID].ProtocolAddresses {
primaryAddrs = append(primaryAddrs, pa.AddressWithPrefix.Address)
}
var expectedList []tcpip.Address
@@ -2728,20 +2778,20 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// should be returned by a call to
// GetMainNICAddress; else, our original address
// should be returned.
- if err := s.RemoveAddress(1, "\x03"); err != nil {
- t.Fatalf("RemoveAddress failed: %v", err)
+ if err := s.RemoveAddress(nicID, address3); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s): %s", nicID, address3, err)
}
- addr, ok = s.GetMainNICAddress(1, fakeNetNumber)
- if !ok {
- t.Fatalf("got GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber)
+ addr, err = s.GetMainNICAddress(nicID, fakeNetNumber)
+ if err != nil {
+ t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
}
if ps == stack.NeverPrimaryEndpoint {
if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, addr, want)
+ t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr, want)
}
} else {
- if addr.Address != "\x01" {
- t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (1, true)", fakeNetNumber, addr.Address)
+ if addr.Address != address1 {
+ t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr.Address, address1)
}
}
})
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 10cbbe589..c1c6cbccd 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -33,8 +33,8 @@ import (
)
const (
- testSrcAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- testDstAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ testSrcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ testDstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
testSrcAddrV4 = "\x0a\x00\x00\x01"
testDstAddrV4 = "\x0a\x00\x00\x02"
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 58aabe547..3cc8c36f1 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -72,11 +72,13 @@ go_test(
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/tests/utils",
+ "//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index 80afc2825..6462e9d42 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -24,11 +24,13 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
@@ -502,3 +504,262 @@ func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) {
})
}
}
+
+func TestExternalLoopbackTraffic(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+
+ ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01")
+
+ numPackets = 1
+ )
+
+ loopbackSourcedICMPv4 := func(e *channel.Endpoint) {
+ utils.RxICMPv4EchoRequest(e, ipv4Loopback, utils.Ipv4Addr.Address)
+ }
+
+ loopbackSourcedICMPv6 := func(e *channel.Endpoint) {
+ utils.RxICMPv6EchoRequest(e, header.IPv6Loopback, utils.Ipv6Addr.Address)
+ }
+
+ loopbackDestinedICMPv4 := func(e *channel.Endpoint) {
+ utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, ipv4Loopback)
+ }
+
+ loopbackDestinedICMPv6 := func(e *channel.Endpoint) {
+ utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, header.IPv6Loopback)
+ }
+
+ invalidSrcAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter {
+ return s.InvalidSourceAddressesReceived
+ }
+
+ invalidDestAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter {
+ return s.InvalidDestinationAddressesReceived
+ }
+
+ tests := []struct {
+ name string
+ allowExternalLoopback bool
+ forwarding bool
+ rxICMP func(*channel.Endpoint)
+ invalidAddressStat func(tcpip.IPStats) *tcpip.StatCounter
+ shouldAccept bool
+ }{
+ {
+ name: "IPv4 external loopback sourced traffic without forwarding and drop external loopback disabled",
+ allowExternalLoopback: true,
+ forwarding: false,
+ rxICMP: loopbackSourcedICMPv4,
+ invalidAddressStat: invalidSrcAddrStat,
+ shouldAccept: true,
+ },
+ {
+ name: "IPv4 external loopback sourced traffic without forwarding and drop external loopback enabled",
+ allowExternalLoopback: false,
+ forwarding: false,
+ rxICMP: loopbackSourcedICMPv4,
+ invalidAddressStat: invalidSrcAddrStat,
+ shouldAccept: false,
+ },
+ {
+ name: "IPv4 external loopback sourced traffic with forwarding and drop external loopback disabled",
+ allowExternalLoopback: true,
+ forwarding: true,
+ rxICMP: loopbackSourcedICMPv4,
+ invalidAddressStat: invalidSrcAddrStat,
+ shouldAccept: true,
+ },
+ {
+ name: "IPv4 external loopback sourced traffic with forwarding and drop external loopback enabled",
+ allowExternalLoopback: false,
+ forwarding: true,
+ rxICMP: loopbackSourcedICMPv4,
+ invalidAddressStat: invalidSrcAddrStat,
+ shouldAccept: false,
+ },
+ {
+ name: "IPv4 external loopback destined traffic without forwarding and drop external loopback disabled",
+ allowExternalLoopback: true,
+ forwarding: false,
+ rxICMP: loopbackDestinedICMPv4,
+ invalidAddressStat: invalidDestAddrStat,
+ shouldAccept: false,
+ },
+ {
+ name: "IPv4 external loopback destined traffic without forwarding and drop external loopback enabled",
+ allowExternalLoopback: false,
+ forwarding: false,
+ rxICMP: loopbackDestinedICMPv4,
+ invalidAddressStat: invalidDestAddrStat,
+ shouldAccept: false,
+ },
+ {
+ name: "IPv4 external loopback destined traffic with forwarding and drop external loopback disabled",
+ allowExternalLoopback: true,
+ forwarding: true,
+ rxICMP: loopbackDestinedICMPv4,
+ invalidAddressStat: invalidDestAddrStat,
+ shouldAccept: true,
+ },
+ {
+ name: "IPv4 external loopback destined traffic with forwarding and drop external loopback enabled",
+ allowExternalLoopback: false,
+ forwarding: true,
+ rxICMP: loopbackDestinedICMPv4,
+ invalidAddressStat: invalidDestAddrStat,
+ shouldAccept: false,
+ },
+
+ {
+ name: "IPv6 external loopback sourced traffic without forwarding and drop external loopback disabled",
+ allowExternalLoopback: true,
+ forwarding: false,
+ rxICMP: loopbackSourcedICMPv6,
+ invalidAddressStat: invalidSrcAddrStat,
+ shouldAccept: true,
+ },
+ {
+ name: "IPv6 external loopback sourced traffic without forwarding and drop external loopback enabled",
+ allowExternalLoopback: false,
+ forwarding: false,
+ rxICMP: loopbackSourcedICMPv6,
+ invalidAddressStat: invalidSrcAddrStat,
+ shouldAccept: false,
+ },
+ {
+ name: "IPv6 external loopback sourced traffic with forwarding and drop external loopback disabled",
+ allowExternalLoopback: true,
+ forwarding: true,
+ rxICMP: loopbackSourcedICMPv6,
+ invalidAddressStat: invalidSrcAddrStat,
+ shouldAccept: true,
+ },
+ {
+ name: "IPv6 external loopback sourced traffic with forwarding and drop external loopback enabled",
+ allowExternalLoopback: false,
+ forwarding: true,
+ rxICMP: loopbackSourcedICMPv6,
+ invalidAddressStat: invalidSrcAddrStat,
+ shouldAccept: false,
+ },
+ {
+ name: "IPv6 external loopback destined traffic without forwarding and drop external loopback disabled",
+ allowExternalLoopback: true,
+ forwarding: false,
+ rxICMP: loopbackDestinedICMPv6,
+ invalidAddressStat: invalidDestAddrStat,
+ shouldAccept: false,
+ },
+ {
+ name: "IPv6 external loopback destined traffic without forwarding and drop external loopback enabled",
+ allowExternalLoopback: false,
+ forwarding: false,
+ rxICMP: loopbackDestinedICMPv6,
+ invalidAddressStat: invalidDestAddrStat,
+ shouldAccept: false,
+ },
+ {
+ name: "IPv6 external loopback destined traffic with forwarding and drop external loopback disabled",
+ allowExternalLoopback: true,
+ forwarding: true,
+ rxICMP: loopbackDestinedICMPv6,
+ invalidAddressStat: invalidDestAddrStat,
+ shouldAccept: true,
+ },
+ {
+ name: "IPv6 external loopback destined traffic with forwarding and drop external loopback enabled",
+ allowExternalLoopback: false,
+ forwarding: true,
+ rxICMP: loopbackDestinedICMPv6,
+ invalidAddressStat: invalidDestAddrStat,
+ shouldAccept: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocolWithOptions(ipv4.Options{
+ AllowExternalLoopbackTraffic: test.allowExternalLoopback,
+ }),
+ ipv6.NewProtocolWithOptions(ipv6.Options{
+ AllowExternalLoopbackTraffic: test.allowExternalLoopback,
+ }),
+ },
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
+ })
+ e := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID1, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ if err := s.AddAddressWithPrefix(nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr); err != nil {
+ t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr, err)
+ }
+ if err := s.AddAddressWithPrefix(nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr); err != nil {
+ t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr, err)
+ }
+
+ if err := s.CreateNIC(nicID2, loopback.New()); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+ if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, ipv4Loopback); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, ipv4Loopback, err)
+ }
+ if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, header.IPv6Loopback); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, header.IPv6Loopback, err)
+ }
+
+ if test.forwarding {
+ if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err)
+ }
+ if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err)
+ }
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID1,
+ },
+ tcpip.Route{
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID1,
+ },
+ tcpip.Route{
+ Destination: ipv4Loopback.WithPrefix().Subnet(),
+ NIC: nicID2,
+ },
+ tcpip.Route{
+ Destination: header.IPv6Loopback.WithPrefix().Subnet(),
+ NIC: nicID2,
+ },
+ })
+
+ stats := s.Stats().IP
+ invalidAddressStat := test.invalidAddressStat(stats)
+ deliveredPacketsStat := stats.PacketsDelivered
+ if got := invalidAddressStat.Value(); got != 0 {
+ t.Fatalf("got invalidAddressStat.Value() = %d, want = 0", got)
+ }
+ if got := deliveredPacketsStat.Value(); got != 0 {
+ t.Fatalf("got deliveredPacketsStat.Value() = %d, want = 0", got)
+ }
+ test.rxICMP(e)
+ var expectedInvalidPackets uint64
+ if !test.shouldAccept {
+ expectedInvalidPackets = numPackets
+ }
+ if got := invalidAddressStat.Value(); got != expectedInvalidPackets {
+ t.Fatalf("got invalidAddressStat.Value() = %d, want = %d", got, expectedInvalidPackets)
+ }
+ if got, want := deliveredPacketsStat.Value(), numPackets-expectedInvalidPackets; got != want {
+ t.Fatalf("got deliveredPacketsStat.Value() = %d, want = %d", got, want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 29266a4fc..77f4a88ec 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -45,82 +45,61 @@ const (
func TestPingMulticastBroadcast(t *testing.T) {
const nicID = 1
- rxIPv4ICMP := func(e *channel.Endpoint, dst tcpip.Address) {
- totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
- hdr := buffer.NewPrependable(totalLen)
- pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
- pkt.SetType(header.ICMPv4Echo)
- pkt.SetCode(0)
- pkt.SetChecksum(0)
- pkt.SetChecksum(^header.Checksum(pkt, 0))
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(totalLen),
- Protocol: uint8(icmp.ProtocolNumber4),
- TTL: ttl,
- SrcAddr: utils.RemoteIPv4Addr,
- DstAddr: dst,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- }
-
- rxIPv6ICMP := func(e *channel.Endpoint, dst tcpip.Address) {
- totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
- hdr := buffer.NewPrependable(totalLen)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
- pkt.SetType(header.ICMPv6EchoRequest)
- pkt.SetCode(0)
- pkt.SetChecksum(0)
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: utils.RemoteIPv6Addr,
- Dst: dst,
- }))
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- TransportProtocol: icmp.ProtocolNumber6,
- HopLimit: ttl,
- SrcAddr: utils.RemoteIPv6Addr,
- DstAddr: dst,
- })
-
- e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- }
-
tests := []struct {
- name string
- dstAddr tcpip.Address
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ rxICMP func(*channel.Endpoint, tcpip.Address, tcpip.Address)
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
+ expectedSrc tcpip.Address
}{
{
- name: "IPv4 unicast",
- dstAddr: utils.Ipv4Addr.Address,
+ name: "IPv4 unicast",
+ protoNum: header.IPv4ProtocolNumber,
+ dstAddr: utils.Ipv4Addr.Address,
+ srcAddr: utils.RemoteIPv4Addr,
+ rxICMP: utils.RxICMPv4EchoRequest,
+ expectedSrc: utils.Ipv4Addr.Address,
},
{
- name: "IPv4 directed broadcast",
- dstAddr: utils.Ipv4SubnetBcast,
+ name: "IPv4 directed broadcast",
+ protoNum: header.IPv4ProtocolNumber,
+ rxICMP: utils.RxICMPv4EchoRequest,
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: utils.Ipv4SubnetBcast,
+ expectedSrc: utils.Ipv4Addr.Address,
},
{
- name: "IPv4 broadcast",
- dstAddr: header.IPv4Broadcast,
+ name: "IPv4 broadcast",
+ protoNum: header.IPv4ProtocolNumber,
+ rxICMP: utils.RxICMPv4EchoRequest,
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: header.IPv4Broadcast,
+ expectedSrc: utils.Ipv4Addr.Address,
},
{
- name: "IPv4 all-systems multicast",
- dstAddr: header.IPv4AllSystems,
+ name: "IPv4 all-systems multicast",
+ protoNum: header.IPv4ProtocolNumber,
+ rxICMP: utils.RxICMPv4EchoRequest,
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: header.IPv4AllSystems,
+ expectedSrc: utils.Ipv4Addr.Address,
},
{
- name: "IPv6 unicast",
- dstAddr: utils.Ipv6Addr.Address,
+ name: "IPv6 unicast",
+ protoNum: header.IPv6ProtocolNumber,
+ rxICMP: utils.RxICMPv6EchoRequest,
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: utils.Ipv6Addr.Address,
+ expectedSrc: utils.Ipv6Addr.Address,
},
{
- name: "IPv6 all-nodes multicast",
- dstAddr: header.IPv6AllNodesMulticastAddress,
+ name: "IPv6 all-nodes multicast",
+ protoNum: header.IPv6ProtocolNumber,
+ rxICMP: utils.RxICMPv6EchoRequest,
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: header.IPv6AllNodesMulticastAddress,
+ expectedSrc: utils.Ipv6Addr.Address,
},
}
@@ -157,44 +136,29 @@ func TestPingMulticastBroadcast(t *testing.T) {
},
})
- var rxICMP func(*channel.Endpoint, tcpip.Address)
- var expectedSrc tcpip.Address
- var expectedDst tcpip.Address
- var protoNum tcpip.NetworkProtocolNumber
- switch l := len(test.dstAddr); l {
- case header.IPv4AddressSize:
- rxICMP = rxIPv4ICMP
- expectedSrc = utils.Ipv4Addr.Address
- expectedDst = utils.RemoteIPv4Addr
- protoNum = header.IPv4ProtocolNumber
- case header.IPv6AddressSize:
- rxICMP = rxIPv6ICMP
- expectedSrc = utils.Ipv6Addr.Address
- expectedDst = utils.RemoteIPv6Addr
- protoNum = header.IPv6ProtocolNumber
- default:
- t.Fatalf("got unexpected address length = %d bytes", l)
- }
-
- rxICMP(e, test.dstAddr)
+ test.rxICMP(e, test.srcAddr, test.dstAddr)
pkt, ok := e.Read()
if !ok {
t.Fatal("expected ICMP response")
}
- if pkt.Route.LocalAddress != expectedSrc {
- t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, expectedSrc)
+ if pkt.Route.LocalAddress != test.expectedSrc {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.expectedSrc)
}
- if pkt.Route.RemoteAddress != expectedDst {
- t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst)
+ // The destination of the response packet should be the source of the
+ // original packet.
+ if pkt.Route.RemoteAddress != test.srcAddr {
+ t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.srcAddr)
}
- src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
- if src != expectedSrc {
- t.Errorf("got pkt source = %s, want = %s", src, expectedSrc)
+ src, dst := s.NetworkProtocolInstance(test.protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
+ if src != test.expectedSrc {
+ t.Errorf("got pkt source = %s, want = %s", src, test.expectedSrc)
}
- if dst != expectedDst {
- t.Errorf("got pkt destination = %s, want = %s", dst, expectedDst)
+ // The destination of the response packet should be the source of the
+ // original packet.
+ if dst != test.srcAddr {
+ t.Errorf("got pkt destination = %s, want = %s", dst, test.srcAddr)
}
})
}
diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go
index 4455f6dd7..ed499179f 100644
--- a/pkg/tcpip/tests/integration/route_test.go
+++ b/pkg/tcpip/tests/integration/route_test.go
@@ -16,6 +16,7 @@ package route_test
import (
"bytes"
+ "fmt"
"testing"
"github.com/google/go-cmp/cmp"
@@ -161,78 +162,79 @@ func TestLocalPing(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
- HandleLocal: true,
- })
- e := test.linkEndpoint()
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
+ for _, allowExternalLoopback := range []bool{true, false} {
+ t.Run(fmt.Sprintf("AllowExternalLoopback=%t", allowExternalLoopback), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocolWithOptions(ipv4.Options{
+ AllowExternalLoopbackTraffic: allowExternalLoopback,
+ }),
+ ipv6.NewProtocolWithOptions(ipv6.Options{
+ AllowExternalLoopbackTraffic: allowExternalLoopback,
+ }),
+ },
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
+ HandleLocal: true,
+ })
+ e := test.linkEndpoint()
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
- if len(test.localAddr) != 0 {
- if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err)
- }
- }
+ if len(test.localAddr) != 0 {
+ if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err)
+ }
+ }
- var wq waiter.Queue
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq)
- if err != nil {
- t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err)
- }
- defer ep.Close()
-
- connAddr := tcpip.FullAddress{Addr: test.localAddr}
- {
- err := ep.Connect(connAddr)
- if diff := cmp.Diff(test.expectedConnectErr, err); diff != "" {
- t.Fatalf("unexpected error from ep.Connect(%#v), (-want, +got):\n%s", connAddr, diff)
- }
- }
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err)
+ }
+ defer ep.Close()
- if test.expectedConnectErr != nil {
- return
- }
+ connAddr := tcpip.FullAddress{Addr: test.localAddr}
+ if err := ep.Connect(connAddr); err != test.expectedConnectErr {
+ t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr)
+ }
- payload := test.icmpBuf(t)
- var r bytes.Reader
- r.Reset(payload)
- var wOpts tcpip.WriteOptions
- if n, err := ep.Write(&r, wOpts); err != nil {
- t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err)
- } else if n != int64(len(payload)) {
- t.Fatalf("got ep.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", payload, wOpts, n, len(payload))
- }
+ if test.expectedConnectErr != nil {
+ return
+ }
- // Wait for the endpoint to become readable.
- <-ch
+ var r bytes.Reader
+ payload := test.icmpBuf(t)
+ r.Reset(payload)
+ var wOpts tcpip.WriteOptions
+ if n, err := ep.Write(&r, wOpts); err != nil {
+ t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err)
+ } else if n != int64(len(payload)) {
+ t.Fatalf("got ep.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", payload, wOpts, n, len(payload))
+ }
- var buf bytes.Buffer
- opts := tcpip.ReadOptions{NeedRemoteAddr: true}
- res, err := ep.Read(&buf, opts)
- if err != nil {
- t.Fatalf("ep.Read(_, %#v): %s", opts, err)
- }
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: buf.Len(),
- Total: buf.Len(),
- RemoteAddr: tcpip.FullAddress{Addr: test.localAddr},
- }, res, checker.IgnoreCmpPath(
- "ControlMessages",
- "RemoteAddr.NIC",
- "RemoteAddr.Port",
- )); diff != "" {
- t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(buf.Bytes()[icmpDataOffset:], []byte(payload[icmpDataOffset:])); diff != "" {
- t.Errorf("received data mismatch (-want +got):\n%s", diff)
- }
+ // Wait for the endpoint to become readable.
+ <-ch
- test.checkLinkEndpoint(t, e)
+ var w bytes.Buffer
+ rr, err := ep.Read(&w, tcpip.ReadOptions{
+ NeedRemoteAddr: true,
+ })
+ if err != nil {
+ t.Fatalf("ep.Read(...): %s", err)
+ }
+ if diff := cmp.Diff(buffer.View(w.Bytes()[icmpDataOffset:]), payload[icmpDataOffset:]); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ }
+ if rr.RemoteAddr.Addr != test.localAddr {
+ t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr)
+ }
+
+ test.checkLinkEndpoint(t, e)
+ })
+ }
})
}
}
diff --git a/pkg/tcpip/tests/utils/BUILD b/pkg/tcpip/tests/utils/BUILD
index 433004148..a9699a367 100644
--- a/pkg/tcpip/tests/utils/BUILD
+++ b/pkg/tcpip/tests/utils/BUILD
@@ -8,12 +8,15 @@ go_library(
visibility = ["//pkg/tcpip/tests:__subpackages__"],
deps = [
"//pkg/tcpip",
+ "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
"//pkg/tcpip/link/ethernet",
"//pkg/tcpip/link/nested",
"//pkg/tcpip/link/pipe",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
],
)
diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go
index f414a2234..d1c9f3a94 100644
--- a/pkg/tcpip/tests/utils/utils.go
+++ b/pkg/tcpip/tests/utils/utils.go
@@ -20,13 +20,16 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/ethernet"
"gvisor.dev/gvisor/pkg/tcpip/link/nested"
"gvisor.dev/gvisor/pkg/tcpip/link/pipe"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
)
// Common NIC IDs used by tests.
@@ -45,6 +48,10 @@ const (
LinkAddr4 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
)
+const (
+ ttl = 255
+)
+
// Common IP addresses used by tests.
var (
Ipv4Addr = tcpip.AddressWithPrefix{
@@ -312,3 +319,56 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.
},
})
}
+
+// RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on
+// the provided endpoint.
+func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
+ totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ pkt.SetType(header.ICMPv4Echo)
+ pkt.SetCode(header.ICMPv4UnusedCode)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(^header.Checksum(pkt, 0))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(totalLen),
+ Protocol: uint8(icmp.ProtocolNumber4),
+ TTL: ttl,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+}
+
+// RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on
+// the provided endpoint.
+func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
+ totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ pkt.SetType(header.ICMPv6EchoRequest)
+ pkt.SetCode(header.ICMPv6UnusedCode)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: pkt,
+ Src: src,
+ Dst: dst,
+ }))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: header.ICMPv6MinimumSize,
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: ttl,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+}
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 3b574837c..0a2f3291c 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -20,6 +20,7 @@ import (
"fmt"
"hash"
"io"
+ "sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/rand"
@@ -390,7 +391,7 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
// endpoint has transitioned out of the listen state (acceptedChan is nil),
// the new endpoint is closed instead.
-func (e *endpoint) deliverAccepted(n *endpoint) {
+func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) {
e.mu.Lock()
e.pendingAccepted.Add(1)
e.mu.Unlock()
@@ -405,6 +406,9 @@ func (e *endpoint) deliverAccepted(n *endpoint) {
}
select {
case e.acceptedChan <- n:
+ if !withSynCookie {
+ atomic.AddInt32(&e.synRcvdCount, -1)
+ }
e.acceptMu.Unlock()
e.waiterQueue.Notify(waiter.EventIn)
return
@@ -476,7 +480,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
- e.synRcvdCount--
+ atomic.AddInt32(&e.synRcvdCount, -1)
return err
}
@@ -486,18 +490,13 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
ctx.cleanupFailedHandshake(h)
- e.mu.Lock()
- e.synRcvdCount--
- e.mu.Unlock()
+ atomic.AddInt32(&e.synRcvdCount, -1)
return
}
ctx.cleanupCompletedHandshake(h)
- e.mu.Lock()
- e.synRcvdCount--
- e.mu.Unlock()
h.ep.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
- e.deliverAccepted(h.ep)
+ e.deliverAccepted(h.ep, false /*withSynCookie*/)
}() // S/R-SAFE: synRcvdCount is the barrier.
return nil
@@ -505,17 +504,17 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
func (e *endpoint) incSynRcvdCount() bool {
e.acceptMu.Lock()
- canInc := e.synRcvdCount < cap(e.acceptedChan)
+ canInc := int(atomic.LoadInt32(&e.synRcvdCount)) < cap(e.acceptedChan)
e.acceptMu.Unlock()
if canInc {
- e.synRcvdCount++
+ atomic.AddInt32(&e.synRcvdCount, 1)
}
return canInc
}
func (e *endpoint) acceptQueueIsFull() bool {
e.acceptMu.Lock()
- full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan)
+ full := len(e.acceptedChan)+int(atomic.LoadInt32(&e.synRcvdCount)) >= cap(e.acceptedChan)
e.acceptMu.Unlock()
return full
}
@@ -737,7 +736,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
// Start the protocol goroutine.
n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
- go e.deliverAccepted(n)
+ go e.deliverAccepted(n, true /*withSynCookie*/)
return nil
default:
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 129f36d11..43d344350 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -532,8 +532,8 @@ type endpoint struct {
segmentQueue segmentQueue `state:"wait"`
// synRcvdCount is the number of connections for this endpoint that are
- // in SYN-RCVD state.
- synRcvdCount int
+ // in SYN-RCVD state; this is only accessed atomically.
+ synRcvdCount int32
// userMSS if non-zero is the MSS value explicitly set by the user
// for this endpoint using the TCP_MAXSEG setsockopt.
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index a5c82b8fa..bc6793fc6 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -260,7 +260,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
case StateEstablished:
r.ep.setEndpointState(StateCloseWait)
case StateFinWait1:
- if s.flagIsSet(header.TCPFlagAck) {
+ if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt {
// FIN-ACK, transition to TIME-WAIT.
r.ep.setEndpointState(StateTimeWait)
} else {
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index fd499a47b..6c86ae1ae 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -165,7 +165,7 @@ func TestConnectIncrementActiveConnection(t *testing.T) {
stats := c.Stack().Stats()
want := stats.TCP.ActiveConnectionOpenings.Value() + 1
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want {
t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %d, want = %d", got, want)
}
@@ -178,7 +178,7 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) {
stats := c.Stack().Stats()
want := stats.TCP.FailedConnectionAttempts.Value()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want)
}
@@ -239,7 +239,7 @@ func TestTCPSegmentsSentIncrement(t *testing.T) {
stats := c.Stack().Stats()
// SYN and ACK
want := stats.TCP.SegmentsSent.Value() + 2
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.SegmentsSent.Value(); got != want {
t.Errorf("got stats.TCP.SegmentsSent.Value() = %d, want = %d", got, want)
@@ -269,7 +269,7 @@ func TestTCPResetsSentIncrement(t *testing.T) {
}
// Send a SYN request.
- iss := seqnum.Value(789)
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -318,7 +318,7 @@ func TestTCPResetsSentNoICMP(t *testing.T) {
// Send a SYN request for a closed port. This should elicit an RST
// but NOT an ICMPv4 DstUnreachable packet.
- iss := seqnum.Value(789)
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -362,7 +362,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
}
// Send a SYN request.
- iss := seqnum.Value(789)
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -459,7 +459,7 @@ func TestTCPResetsReceivedIncrement(t *testing.T) {
stats := c.Stack().Stats()
want := stats.TCP.ResetsReceived.Value() + 1
- iss := seqnum.Value(789)
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
rcvWnd := seqnum.Size(30000)
c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
@@ -483,7 +483,7 @@ func TestTCPResetsDoNotGenerateResets(t *testing.T) {
stats := c.Stack().Stats()
want := stats.TCP.ResetsReceived.Value() + 1
- iss := seqnum.Value(789)
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
rcvWnd := seqnum.Size(30000)
c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
@@ -506,14 +506,14 @@ func TestActiveHandshake(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
}
func TestNonBlockingClose(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
ep := c.EP
c.EP = nil
@@ -537,18 +537,19 @@ func TestConnectResetAfterClose(t *testing.T) {
t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
ep := c.EP
c.EP = nil
// Close the endpoint, make sure we get a FIN segment, then acknowledge
// to complete closure of sender, but don't send our own FIN.
ep.Close()
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -556,7 +557,7 @@ func TestConnectResetAfterClose(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(2),
RcvWnd: 30000,
})
@@ -570,7 +571,7 @@ func TestConnectResetAfterClose(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(2),
RcvWnd: 30000,
})
@@ -612,7 +613,7 @@ func TestCurrentConnectedIncrement(t *testing.T) {
t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
ep := c.EP
c.EP = nil
@@ -625,12 +626,12 @@ func TestCurrentConnectedIncrement(t *testing.T) {
}
ep.Close()
-
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -638,7 +639,7 @@ func TestCurrentConnectedIncrement(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(2),
RcvWnd: 30000,
})
@@ -655,7 +656,7 @@ func TestCurrentConnectedIncrement(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(2),
RcvWnd: 30000,
})
@@ -666,7 +667,7 @@ func TestCurrentConnectedIncrement(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+2),
- checker.TCPAckNum(791),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -690,7 +691,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
ep := c.EP
c.EP = nil
@@ -699,11 +700,12 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
}
// Send a FIN for ESTABLISHED --> CLOSED-WAIT
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagFin | header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -713,7 +715,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(791),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -734,7 +736,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(791),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -753,7 +755,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 791,
+ SeqNum: iss.Add(1),
AckNum: c.IRS.Add(2),
RcvWnd: 30000,
})
@@ -764,7 +766,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 792,
+ SeqNum: iss.Add(2),
AckNum: c.IRS.Add(2),
RcvWnd: 30000,
})
@@ -804,7 +806,7 @@ func TestSimpleReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -813,11 +815,12 @@ func TestSimpleReceive(t *testing.T) {
ept := endpointTester{c.EP}
data := []byte{1, 2, 3}
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -840,7 +843,7 @@ func TestSimpleReceive(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(790+len(data))),
+ checker.TCPAckNum(uint32(iss)+uint32(len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1366,12 +1369,13 @@ func TestTOSV4(t *testing.T) {
// Check that data is received.
b := c.GetPacket()
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, b,
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790), // Acknum is initial sequence number + 1
+ checker.TCPAckNum(uint32(iss)), // Acknum is initial sequence number + 1
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
checker.TOS(tos, 0),
@@ -1414,12 +1418,13 @@ func TestTrafficClassV6(t *testing.T) {
// Check that data is received.
b := c.GetV6Packet()
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv6(t, b,
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
checker.TOS(tos, 0),
@@ -1472,7 +1477,7 @@ func TestConnectBindToDevice(t *testing.T) {
tcpHdr := header.TCP(header.IPv4(b).Payload())
c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
- iss := seqnum.Value(789)
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
rcvWnd := seqnum.Size(30000)
c.SendPacket(nil, &context.Headers{
SrcPort: tcpHdr.DestinationPort(),
@@ -1537,7 +1542,7 @@ func TestSynSent(t *testing.T) {
if test.reset {
// Send a packet with a proper ACK and a RST flag to cause the socket
// to error and close out.
- iss := seqnum.Value(789)
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
rcvWnd := seqnum.Size(30000)
c.SendPacket(nil, &context.Headers{
SrcPort: tcpHdr.DestinationPort(),
@@ -1582,7 +1587,7 @@ func TestOutOfOrderReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -1593,11 +1598,12 @@ func TestOutOfOrderReceive(t *testing.T) {
// Send second half of data first, with seqnum 3 ahead of expected.
data := []byte{1, 2, 3, 4, 5, 6}
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
c.SendPacket(data[3:], &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 793,
+ SeqNum: iss.Add(3),
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -1607,7 +1613,7 @@ func TestOutOfOrderReceive(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1621,7 +1627,7 @@ func TestOutOfOrderReceive(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -1639,7 +1645,7 @@ func TestOutOfOrderReceive(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(790+len(data))),
+ checker.TCPAckNum(uint32(iss)+uint32(len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1650,19 +1656,20 @@ func TestOutOfOrderFlood(t *testing.T) {
defer c.Cleanup()
rcvBufSz := math.MaxUint16
- c.CreateConnected(789, 30000, rcvBufSz)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBufSz)
ept := endpointTester{c.EP}
ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
// Send 100 packets before the actual one that is expected.
data := []byte{1, 2, 3, 4, 5, 6}
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
for i := 0; i < 100; i++ {
c.SendPacket(data[3:], &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 796,
+ SeqNum: iss.Add(6),
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -1671,19 +1678,19 @@ func TestOutOfOrderFlood(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck),
),
)
}
- // Send packet with seqnum 793. It must be discarded because the
+ // Send packet with seqnum as initial + 3. It must be discarded because the
// out-of-order buffer was filled by the previous packets.
c.SendPacket(data[3:], &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 793,
+ SeqNum: iss.Add(3),
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -1692,27 +1699,27 @@ func TestOutOfOrderFlood(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck),
),
)
- // Now send the expected packet, seqnum 790.
+ // Now send the expected packet with initial sequence number.
c.SendPacket(data[:3], &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
- // Check that only packet 790 is acknowledged.
+ // Check that only packet with initial sequence number is acknowledged.
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(793),
+ checker.TCPAckNum(uint32(iss)+3),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1722,7 +1729,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -1732,11 +1739,12 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
data := []byte{1, 2, 3}
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -1753,7 +1761,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(790+len(data))),
+ checker.TCPAckNum(uint32(iss)+uint32(len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1780,7 +1788,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790 + len(data)),
+ SeqNum: iss.Add(seqnum.Size(len(data))),
AckNum: c.IRS.Add(seqnum.Size(2)),
RcvWnd: 30000,
})
@@ -1790,7 +1798,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -1800,11 +1808,12 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
data := []byte{1, 2, 3}
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -1821,7 +1830,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(790+len(data))),
+ checker.TCPAckNum(uint32(iss)+uint32(len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1866,7 +1875,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790 + len(data)),
+ SeqNum: iss.Add(seqnum.Size(len(data))),
AckNum: c.IRS.Add(seqnum.Size(2)),
RcvWnd: 30000,
})
@@ -1876,7 +1885,7 @@ func TestShutdownRead(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
ept := endpointTester{c.EP}
ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
@@ -1897,7 +1906,7 @@ func TestFullWindowReceive(t *testing.T) {
defer c.Cleanup()
const rcvBufSz = 10
- c.CreateConnected(789, 30000, rcvBufSz)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBufSz)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -1913,11 +1922,12 @@ func TestFullWindowReceive(t *testing.T) {
for i := range data {
data[i] = byte(i % 255)
}
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -1934,7 +1944,7 @@ func TestFullWindowReceive(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(790+len(data))),
+ checker.TCPAckNum(uint32(iss)+uint32(len(data))),
checker.TCPFlags(header.TCPFlagAck),
checker.TCPWindow(0),
),
@@ -1956,7 +1966,7 @@ func TestFullWindowReceive(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(790+len(data))),
+ checker.TCPAckNum(uint32(iss)+uint32(len(data))),
checker.TCPFlags(header.TCPFlagAck),
checker.TCPWindow(10),
),
@@ -1996,8 +2006,7 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) {
}
payload := generateRandomPayload(t, payloadSize)
payloadLen := seqnum.Size(len(payload))
- iss := seqnum.Value(789)
- seqNum := iss.Add(1)
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
// Send payload to the endpoint and return the advertised receive window
// from the endpoint.
@@ -2005,12 +2014,12 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) {
c.SendPacket(payload, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
- SeqNum: seqNum,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
Flags: header.TCPFlagAck,
RcvWnd: 30000,
})
- seqNum = seqNum.Add(payloadLen)
+ iss = iss.Add(payloadLen)
pkt := c.GetPacket()
return uint32(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.RcvdWindowScale
@@ -2054,9 +2063,8 @@ func TestNoWindowShrinking(t *testing.T) {
// the right edge of the window does not shrink.
// NOTE: Netstack doubles the value specified here.
rcvBufSize := 65536
- iss := seqnum.Value(789)
// Enable window scaling with a scale of zero from our end.
- c.CreateConnectedWithRawOptions(iss, 30000, rcvBufSize, []byte{
+ c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, rcvBufSize, []byte{
header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
})
@@ -2069,13 +2077,13 @@ func TestNoWindowShrinking(t *testing.T) {
// Send a 1 byte payload so that we can record the current receive window.
// Send a payload of half the size of rcvBufSize.
- seqNum := iss.Add(1)
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
payload := []byte{1}
c.SendPacket(payload, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqNum,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -2092,20 +2100,20 @@ func TestNoWindowShrinking(t *testing.T) {
t.Fatalf("got data: %v, want: %v", got, want)
}
- seqNum = seqNum.Add(1)
// Verify that the ACK does not shrink the window.
pkt := c.GetPacket()
+ iss = iss.Add(1)
checker.IPv4(t, pkt,
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(seqNum)),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck),
),
)
// Stash the initial window.
initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale
- initialLastAcceptableSeq := seqNum.Add(seqnum.Size(initialWnd))
+ initialLastAcceptableSeq := iss.Add(seqnum.Size(initialWnd))
// Now shrink the receive buffer to half its original size.
if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufSize/2); err != nil {
t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err)
@@ -2117,11 +2125,11 @@ func TestNoWindowShrinking(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqNum,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
- seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2))
+ iss = iss.Add(seqnum.Size(rcvBufSize / 2))
// Verify that the ACK does not shrink the window.
pkt = c.GetPacket()
@@ -2129,12 +2137,12 @@ func TestNoWindowShrinking(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(seqNum)),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck),
),
)
newWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale
- newLastAcceptableSeq := seqNum.Add(seqnum.Size(newWnd))
+ newLastAcceptableSeq := iss.Add(seqnum.Size(newWnd))
if newLastAcceptableSeq.LessThan(initialLastAcceptableSeq) {
t.Fatalf("receive window shrunk unexpectedly got: %d, want >= %d", newLastAcceptableSeq, initialLastAcceptableSeq)
}
@@ -2145,17 +2153,17 @@ func TestNoWindowShrinking(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqNum,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
- seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2))
+ iss = iss.Add(seqnum.Size(rcvBufSize / 2))
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(seqNum)),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck),
checker.TCPWindow(0),
),
@@ -2173,7 +2181,7 @@ func TestNoWindowShrinking(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(seqNum)),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck),
checker.TCPWindow(uint16(rcvBufSize/2)>>c.RcvdWindowScale),
),
@@ -2184,7 +2192,7 @@ func TestSimpleSend(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
var r bytes.Reader
@@ -2195,12 +2203,13 @@ func TestSimpleSend(t *testing.T) {
// Check that data is received.
b := c.GetPacket()
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, b,
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2214,7 +2223,7 @@ func TestSimpleSend(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1 + seqnum.Size(len(data))),
RcvWnd: 30000,
})
@@ -2224,7 +2233,7 @@ func TestZeroWindowSend(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789 /* iss */, 0 /* rcvWnd */, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 0 /* rcvWnd */, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
var r bytes.Reader
@@ -2235,12 +2244,13 @@ func TestZeroWindowSend(t *testing.T) {
// Check if we got a zero-window probe.
b := c.GetPacket()
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, b,
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2250,7 +2260,7 @@ func TestZeroWindowSend(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -2262,7 +2272,7 @@ func TestZeroWindowSend(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2276,7 +2286,7 @@ func TestZeroWindowSend(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1 + seqnum.Size(len(data))),
RcvWnd: 30000,
})
@@ -2289,7 +2299,7 @@ func TestScaledWindowConnect(t *testing.T) {
defer c.Cleanup()
// Set the window size greater than the maximum non-scaled window.
- c.CreateConnectedWithRawOptions(789, 30000, 65535*3, []byte{
+ c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, 65535*3, []byte{
header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
})
@@ -2303,12 +2313,13 @@ func TestScaledWindowConnect(t *testing.T) {
// Check that data is received, and that advertised window is 0x5fff,
// that is, that it is scaled.
b := c.GetPacket()
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, b,
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPWindow(0x5fff),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
@@ -2322,7 +2333,7 @@ func TestNonScaledWindowConnect(t *testing.T) {
defer c.Cleanup()
// Set the window size greater than the maximum non-scaled window.
- c.CreateConnected(789, 30000, 65535*3)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, 65535*3)
data := []byte{1, 2, 3}
var r bytes.Reader
@@ -2334,12 +2345,13 @@ func TestNonScaledWindowConnect(t *testing.T) {
// Check that data is received, and that advertised window is 0xffff,
// that is, that it's not scaled.
b := c.GetPacket()
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, b,
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPWindow(0xffff),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
@@ -2407,12 +2419,13 @@ func TestScaledWindowAccept(t *testing.T) {
// Check that data is received, and that advertised window is 0x5fff,
// that is, that it is scaled.
b := c.GetPacket()
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, b,
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPWindow(0x5fff),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
@@ -2480,12 +2493,13 @@ func TestNonScaledWindowAccept(t *testing.T) {
// Check that data is received, and that advertised window is 0xffff,
// that is, that it's not scaled.
b := c.GetPacket()
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, b,
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPWindow(0xffff),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
@@ -2502,7 +2516,7 @@ func TestZeroScaledWindowReceive(t *testing.T) {
// Set the buffer size such that a window scale of 5 will be used.
const bufSz = 65535 * 10
const ws = uint32(5)
- c.CreateConnectedWithRawOptions(789, 30000, bufSz, []byte{
+ c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, bufSz, []byte{
header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
})
@@ -2510,13 +2524,14 @@ func TestZeroScaledWindowReceive(t *testing.T) {
remain := 0
sent := 0
data := make([]byte, 50000)
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
// Keep writing till the window drops below len(data).
for {
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790 + sent),
+ SeqNum: iss.Add(seqnum.Size(sent)),
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -2527,7 +2542,7 @@ func TestZeroScaledWindowReceive(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPAckNum(uint32(iss)+uint32(sent)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -2545,7 +2560,7 @@ func TestZeroScaledWindowReceive(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790 + sent),
+ SeqNum: iss.Add(seqnum.Size(sent)),
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -2556,7 +2571,7 @@ func TestZeroScaledWindowReceive(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPAckNum(uint32(iss)+uint32(sent)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -2594,7 +2609,7 @@ func TestZeroScaledWindowReceive(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPAckNum(uint32(iss)+uint32(sent)),
checker.TCPWindowGreaterThanEq(uint16(defaultMTU>>ws)),
checker.TCPFlags(header.TCPFlagAck),
),
@@ -2632,7 +2647,7 @@ func TestSegmentMerging(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
// Send tcp.InitialCwnd number of segments to fill up
// InitialWindow but don't ACK. That should prevent
@@ -2657,6 +2672,7 @@ func TestSegmentMerging(t *testing.T) {
}
// Check that we get tcp.InitialCwnd packets.
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
for i := 0; i < tcp.InitialCwnd; i++ {
b := c.GetPacket()
checker.IPv4(t, b,
@@ -2664,7 +2680,7 @@ func TestSegmentMerging(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2675,7 +2691,7 @@ func TestSegmentMerging(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1 + 10), // 10 for the 10 bytes of payload.
RcvWnd: 30000,
})
@@ -2687,7 +2703,7 @@ func TestSegmentMerging(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+11),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2701,7 +2717,7 @@ func TestSegmentMerging(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(11 + seqnum.Size(len(allData))),
RcvWnd: 30000,
})
@@ -2713,7 +2729,7 @@ func TestDelay(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
c.EP.SocketOptions().SetDelayOption(true)
@@ -2728,6 +2744,7 @@ func TestDelay(t *testing.T) {
}
seq := c.IRS.Add(1)
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
for _, want := range [][]byte{allData[:1], allData[1:]} {
// Check that data is received.
b := c.GetPacket()
@@ -2736,7 +2753,7 @@ func TestDelay(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(seq)),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2751,7 +2768,7 @@ func TestDelay(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seq,
RcvWnd: 30000,
})
@@ -2762,7 +2779,7 @@ func TestUndelay(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
c.EP.SocketOptions().SetDelayOption(true)
@@ -2776,7 +2793,7 @@ func TestUndelay(t *testing.T) {
}
seq := c.IRS.Add(1)
-
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
// Check that data is received.
first := c.GetPacket()
checker.IPv4(t, first,
@@ -2784,7 +2801,7 @@ func TestUndelay(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(seq)),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2807,7 +2824,7 @@ func TestUndelay(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(seq)),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2823,7 +2840,7 @@ func TestUndelay(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seq,
RcvWnd: 30000,
})
@@ -2845,7 +2862,7 @@ func TestMSSNotDelayed(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
+ c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
@@ -2861,7 +2878,7 @@ func TestMSSNotDelayed(t *testing.T) {
}
seq := c.IRS.Add(1)
-
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
for i, data := range allData {
// Check that data is received.
packet := c.GetPacket()
@@ -2870,7 +2887,7 @@ func TestMSSNotDelayed(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(seq)),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2887,7 +2904,7 @@ func TestMSSNotDelayed(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seq,
RcvWnd: 30000,
})
@@ -2912,6 +2929,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
// Check that data is received in chunks.
bytesReceived := 0
numPackets := 0
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
for bytesReceived != dataLen {
b := c.GetPacket()
numPackets++
@@ -2921,7 +2939,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1+uint32(bytesReceived)),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2945,7 +2963,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)),
RcvWnd: 30000,
TCPOpts: options,
@@ -2961,7 +2979,7 @@ func TestSendGreaterThanMTU(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
testBrokenUpWrite(t, c, maxPayload)
}
@@ -3001,7 +3019,7 @@ func TestActiveSendMSSLessThanMTU(t *testing.T) {
c := context.New(t, 65535)
defer c.Cleanup()
- c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
+ c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
testBrokenUpWrite(t, c, maxPayload)
@@ -3210,7 +3228,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
)
// Send SYN-ACK.
- iss := seqnum.Value(789)
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: tcpHdr.DestinationPort(),
DstPort: tcpHdr.SourcePort(),
@@ -3272,14 +3290,15 @@ func TestReceiveOnResetConnection(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
// Send RST segment.
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagRst,
- SeqNum: 790,
+ SeqNum: iss,
RcvWnd: 30000,
})
@@ -3328,14 +3347,15 @@ func TestSendOnResetConnection(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
// Send RST segment.
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagRst,
- SeqNum: 790,
+ SeqNum: iss,
RcvWnd: 30000,
})
@@ -3363,7 +3383,7 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
- c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
@@ -3426,7 +3446,7 @@ func TestMaxRTO(t *testing.T) {
t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
- c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
var r bytes.Reader
r.Reset(make([]byte, 1))
@@ -3469,7 +3489,7 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
// Disabling PMTU discovery causes all packets sent from this socket to
// have DF=0. This needs to be done because the IPv4 ID uniqueness
@@ -3518,19 +3538,20 @@ func TestFinImmediately(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
t.Fatalf("Shutdown failed: %s", err)
}
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3540,7 +3561,7 @@ func TestFinImmediately(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(2),
RcvWnd: 30000,
})
@@ -3551,7 +3572,7 @@ func TestFinImmediately(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+2),
- checker.TCPAckNum(791),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3561,19 +3582,20 @@ func TestFinRetransmit(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
t.Fatalf("Shutdown failed: %s", err)
}
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3584,7 +3606,7 @@ func TestFinRetransmit(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3594,7 +3616,7 @@ func TestFinRetransmit(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(2),
RcvWnd: 30000,
})
@@ -3605,7 +3627,7 @@ func TestFinRetransmit(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+2),
- checker.TCPAckNum(791),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3615,7 +3637,7 @@ func TestFinWithNoPendingData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
// Write something out, and have it acknowledged.
view := make([]byte, 10)
@@ -3626,12 +3648,13 @@ func TestFinWithNoPendingData(t *testing.T) {
}
next := uint32(c.IRS) + 1
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -3641,7 +3664,7 @@ func TestFinWithNoPendingData(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seqnum.Value(next),
RcvWnd: 30000,
})
@@ -3656,7 +3679,7 @@ func TestFinWithNoPendingData(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3667,7 +3690,7 @@ func TestFinWithNoPendingData(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seqnum.Value(next),
RcvWnd: 30000,
})
@@ -3678,7 +3701,7 @@ func TestFinWithNoPendingData(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(791),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3688,7 +3711,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
// Write enough segments to fill the congestion window before ACK'ing
// any of them.
@@ -3702,13 +3725,14 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
}
next := uint32(c.IRS) + 1
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
for i := tcp.InitialCwnd; i > 0; i-- {
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -3727,7 +3751,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -3737,7 +3761,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seqnum.Value(next),
RcvWnd: 30000,
})
@@ -3747,7 +3771,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3758,7 +3782,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seqnum.Value(next),
RcvWnd: 30000,
})
@@ -3768,7 +3792,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(791),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3778,7 +3802,7 @@ func TestFinWithPendingData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
// Write something out, and acknowledge it to get cwnd to 2.
view := make([]byte, 10)
@@ -3789,12 +3813,13 @@ func TestFinWithPendingData(t *testing.T) {
}
next := uint32(c.IRS) + 1
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -3804,7 +3829,7 @@ func TestFinWithPendingData(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seqnum.Value(next),
RcvWnd: 30000,
})
@@ -3820,7 +3845,7 @@ func TestFinWithPendingData(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -3836,7 +3861,7 @@ func TestFinWithPendingData(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3847,7 +3872,7 @@ func TestFinWithPendingData(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seqnum.Value(next),
RcvWnd: 30000,
})
@@ -3857,7 +3882,7 @@ func TestFinWithPendingData(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(791),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3867,7 +3892,7 @@ func TestFinWithPartialAck(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
// Write something out, and acknowledge it to get cwnd to 2. Also send
// FIN from the test side.
@@ -3879,12 +3904,13 @@ func TestFinWithPartialAck(t *testing.T) {
}
next := uint32(c.IRS) + 1
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -3894,7 +3920,7 @@ func TestFinWithPartialAck(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seqnum.Value(next),
RcvWnd: 30000,
})
@@ -3905,7 +3931,7 @@ func TestFinWithPartialAck(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(791),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -3921,7 +3947,7 @@ func TestFinWithPartialAck(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(791),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -3937,7 +3963,7 @@ func TestFinWithPartialAck(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(791),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3948,7 +3974,7 @@ func TestFinWithPartialAck(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 791,
+ SeqNum: iss.Add(1),
AckNum: seqnum.Value(next - 1),
RcvWnd: 30000,
})
@@ -3961,7 +3987,7 @@ func TestFinWithPartialAck(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 791,
+ SeqNum: iss.Add(1),
AckNum: seqnum.Value(next),
RcvWnd: 30000,
})
@@ -4002,17 +4028,18 @@ func scaledSendWindow(t *testing.T, scale uint8) {
defer c.Cleanup()
maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize
- c.CreateConnectedWithRawOptions(789, 0, -1 /* epRcvBuf */, []byte{
+ c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 0, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
header.TCPOptionWS, 3, scale, header.TCPOptionNOP,
})
// Open up the window with a scaled value.
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 1,
})
@@ -4031,7 +4058,7 @@ func scaledSendWindow(t *testing.T, scale uint8) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -4041,7 +4068,7 @@ func scaledSendWindow(t *testing.T, scale uint8) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagRst,
- SeqNum: 790,
+ SeqNum: iss,
})
}
@@ -4054,15 +4081,16 @@ func TestScaledSendWindow(t *testing.T) {
func TestReceivedValidSegmentCountIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.ValidSegmentsReceived.Value() + 1
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790),
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -4083,14 +4111,15 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) {
func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.InvalidSegmentsReceived.Value() + 1
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
vv := c.BuildSegment(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790),
+ SeqNum: seqnum.Value(iss),
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -4110,14 +4139,15 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
func TestReceivedIncorrectChecksumIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.ChecksumErrors.Value() + 1
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790),
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -4144,23 +4174,24 @@ func TestReceivedSegmentQueuing(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
// Send 200 segments.
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
data := []byte{1, 2, 3}
for i := 0; i < 200; i++ {
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790 + i*len(data)),
+ SeqNum: iss.Add(seqnum.Size(i * len(data))),
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
}
// Receive ACKs for all segments.
- last := seqnum.Value(790 + 200*len(data))
+ last := iss.Add(seqnum.Size(200 * len(data)))
for {
b := c.GetPacket()
checker.IPv4(t, b,
@@ -4198,7 +4229,7 @@ func TestReadAfterClosedState(t *testing.T) {
t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -4212,12 +4243,13 @@ func TestReadAfterClosedState(t *testing.T) {
t.Fatalf("Shutdown failed: %s", err)
}
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -4232,7 +4264,7 @@ func TestReadAfterClosedState(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(2),
RcvWnd: 30000,
})
@@ -4242,7 +4274,7 @@ func TestReadAfterClosedState(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+2),
- checker.TCPAckNum(uint32(791+len(data))),
+ checker.TCPAckNum(uint32(iss)+uint32(len(data))+1),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -4815,7 +4847,7 @@ func TestPathMTUDiscovery(t *testing.T) {
// Create new connection with MSS of 1460.
const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize
- c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
+ c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
@@ -4833,6 +4865,7 @@ func TestPathMTUDiscovery(t *testing.T) {
receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte {
var ret []byte
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
for i, size := range sizes {
p := c.GetPacket()
if i == which {
@@ -4843,7 +4876,7 @@ func TestPathMTUDiscovery(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(seqNum),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -4893,14 +4926,15 @@ func TestTCPEndpointProbe(t *testing.T) {
invoked <- struct{}{}
})
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -5027,7 +5061,7 @@ func TestEndpointSetCongestionControl(t *testing.T) {
}
if connected {
- c.Connect(789 /* iss */, 32768 /* rcvWnd */, nil)
+ c.Connect(context.TestInitialSequenceNumber, 32768 /* rcvWnd */, nil)
}
if err := c.EP.SetSockOpt(&tc.cc); err != tc.err {
@@ -5067,7 +5101,7 @@ func TestKeepalive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
const keepAliveIdle = 100 * time.Millisecond
const keepAliveInterval = 3 * time.Second
@@ -5087,13 +5121,14 @@ func TestKeepalive(t *testing.T) {
// 5 unacked keepalives are sent. ACK each one, and check that the
// connection stays alive after 5.
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
for i := 0; i < 10; i++ {
b := c.GetPacket()
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)),
- checker.TCPAckNum(uint32(790)),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -5103,7 +5138,7 @@ func TestKeepalive(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS,
RcvWnd: 30000,
})
@@ -5128,7 +5163,7 @@ func TestKeepalive(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -5140,7 +5175,7 @@ func TestKeepalive(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh),
),
)
@@ -5153,7 +5188,7 @@ func TestKeepalive(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seqnum.Value(next),
RcvWnd: 30000,
})
@@ -5166,7 +5201,7 @@ func TestKeepalive(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(next-1)),
- checker.TCPAckNum(uint32(790)),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -5184,7 +5219,7 @@ func TestKeepalive(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seqnum.Value(next),
RcvWnd: 30000,
})
@@ -5215,7 +5250,7 @@ func TestKeepalive(t *testing.T) {
func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
t.Helper()
// Send a SYN request.
- irs = seqnum.Value(789)
+ irs = seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: srcPort,
DstPort: context.StackPort,
@@ -5260,7 +5295,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki
func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
t.Helper()
// Send a SYN request.
- irs = seqnum.Value(789)
+ irs = seqnum.Value(context.TestInitialSequenceNumber)
c.SendV6Packet(nil, &context.Headers{
SrcPort: srcPort,
DstPort: context.StackPort,
@@ -5340,7 +5375,7 @@ func TestListenBacklogFull(t *testing.T) {
SrcPort: context.TestPort + uint16(lastPortOffset),
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
- SeqNum: seqnum.Value(789),
+ SeqNum: seqnum.Value(context.TestInitialSequenceNumber),
RcvWnd: 30000,
})
c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
@@ -5491,7 +5526,7 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) {
t.Fatalf("Listen failed: %s", err)
}
- irs := seqnum.Value(789)
+ irs := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacketWithAddrs(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -5591,7 +5626,7 @@ func TestListenNoAcceptNonUnicastV6(t *testing.T) {
t.Fatalf("Listen failed: %s", err)
}
- irs := seqnum.Value(789)
+ irs := seqnum.Value(context.TestInitialSequenceNumber)
c.SendV6PacketWithAddrs(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -5645,7 +5680,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
// Send two SYN's the first one should get a SYN-ACK, the
// second one should not get any response and is dropped as
// the synRcvd count will be equal to backlog.
- irs := seqnum.Value(789)
+ irs := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -5758,7 +5793,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
time.Sleep(50 * time.Millisecond)
// Send a SYN request.
- irs := seqnum.Value(789)
+ irs := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
// pick a different src port for new SYN.
SrcPort: context.TestPort + 1,
@@ -5824,7 +5859,7 @@ func TestSYNRetransmit(t *testing.T) {
// Send the same SYN packet multiple times. We should still get a valid SYN-ACK
// reply.
- irs := seqnum.Value(789)
+ irs := seqnum.Value(context.TestInitialSequenceNumber)
for i := 0; i < 5; i++ {
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
@@ -5867,7 +5902,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
}
// Send a SYN to get a SYN-ACK. This should put the ep into SYN-RCVD state
- irs := seqnum.Value(789)
+ irs := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -6051,7 +6086,7 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
SrcPort: srcPort + 2,
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
- SeqNum: seqnum.Value(789),
+ SeqNum: seqnum.Value(context.TestInitialSequenceNumber),
RcvWnd: 30000,
})
@@ -6501,7 +6536,7 @@ func TestTCPLingerTimeout(t *testing.T) {
c := context.New(t, 1500 /* mtu */)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
testCases := []struct {
name string
@@ -6552,7 +6587,7 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) {
}
// Send a SYN request.
- iss := seqnum.Value(789)
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -6671,7 +6706,7 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) {
}
// Send a SYN request.
- iss := seqnum.Value(789)
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -6778,7 +6813,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
}
// Send a SYN request.
- iss := seqnum.Value(789)
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -6852,7 +6887,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
// Send a SYN request w/ sequence number lower than
// the highest sequence number sent. We just reuse
// the same number.
- iss = seqnum.Value(789)
+ iss = seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -6872,7 +6907,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
// Send a SYN request w/ sequence number higher than
// the highest sequence number sent.
- iss = seqnum.Value(792)
+ iss = iss.Add(3)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -6942,7 +6977,7 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
}
// Send a SYN request.
- iss := seqnum.Value(789)
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -7091,7 +7126,7 @@ func TestTCPCloseWithData(t *testing.T) {
}
// Send a SYN request.
- iss := seqnum.Value(789)
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -7242,7 +7277,7 @@ func TestTCPUserTimeout(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
@@ -7268,12 +7303,13 @@ func TestTCPUserTimeout(t *testing.T) {
}
next := uint32(c.IRS) + 1
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -7299,7 +7335,7 @@ func TestTCPUserTimeout(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seqnum.Value(next),
RcvWnd: 30000,
})
@@ -7328,7 +7364,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value()
@@ -7362,11 +7398,12 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
// Now receive 1 keepalives, but don't ACK it.
b := c.GetPacket()
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)),
- checker.TCPAckNum(uint32(790)),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -7383,7 +7420,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seqnum.Value(c.IRS + 1),
RcvWnd: 30000,
})
@@ -7413,20 +7450,20 @@ func TestIncreaseWindowOnRead(t *testing.T) {
defer c.Cleanup()
const rcvBuf = 65535 * 10
- c.CreateConnected(789, 30000, rcvBuf)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBuf)
// Write chunks of ~30000 bytes. It's important that two
// payloads make it equal or longer than MSS.
remain := rcvBuf * 2
sent := 0
data := make([]byte, defaultMTU/2)
-
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
for remain > len(data) {
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790 + sent),
+ SeqNum: iss.Add(seqnum.Size(sent)),
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -7438,7 +7475,7 @@ func TestIncreaseWindowOnRead(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPAckNum(uint32(iss)+uint32(sent)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -7469,7 +7506,7 @@ func TestIncreaseWindowOnRead(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPAckNum(uint32(iss)+uint32(sent)),
checker.TCPWindow(uint16(0xffff)),
checker.TCPFlags(header.TCPFlagAck),
),
@@ -7483,20 +7520,20 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
defer c.Cleanup()
const rcvBuf = 65535 * 10
- c.CreateConnected(789, 30000, rcvBuf)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBuf)
// Write chunks of ~30000 bytes. It's important that two
// payloads make it equal or longer than MSS.
remain := rcvBuf
sent := 0
data := make([]byte, defaultMTU/2)
-
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
for remain > len(data) {
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790 + sent),
+ SeqNum: iss.Add(seqnum.Size(sent)),
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -7507,7 +7544,7 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPAckNum(uint32(iss)+uint32(sent)),
checker.TCPWindowLessThanEq(0xffff),
checker.TCPFlags(header.TCPFlagAck),
),
@@ -7523,7 +7560,7 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPAckNum(uint32(iss)+uint32(sent)),
checker.TCPWindow(uint16(0xffff)),
checker.TCPFlags(header.TCPFlagAck),
),
@@ -7664,16 +7701,16 @@ func TestResetDuringClose(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- iss := seqnum.Value(789)
- c.CreateConnected(iss, 30000, -1 /* epRecvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRecvBuf */)
// Send some data to make sure there is some unread
// data to trigger a reset on c.Close.
irs := c.IRS
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: iss.Add(1),
+ SeqNum: iss,
AckNum: irs.Add(1),
RcvWnd: 30000,
})
@@ -7683,7 +7720,7 @@ func TestResetDuringClose(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
checker.TCPSeqNum(uint32(irs.Add(1))),
- checker.TCPAckNum(uint32(iss.Add(5)))))
+ checker.TCPAckNum(uint32(iss)+4)))
// Close in a separate goroutine so that we can trigger
// a race with the RST we send below. This should not
@@ -7698,7 +7735,7 @@ func TestResetDuringClose(t *testing.T) {
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
- SeqNum: iss.Add(5),
+ SeqNum: iss.Add(4),
AckNum: c.IRS.Add(5),
RcvWnd: 30000,
Flags: header.TCPFlagRst,
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 5d81dbb94..c8126b51b 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -45,8 +45,8 @@ import (
// represents the remote endpoint.
const (
v4MappedAddrPrefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff"
- stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ stackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ testV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
stackV4MappedAddr = v4MappedAddrPrefix + stackAddr
testV4MappedAddr = v4MappedAddrPrefix + testAddr
multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr
diff --git a/pkg/test/dockerutil/network.go b/pkg/test/dockerutil/network.go
index 047091e75..dbe17fa5e 100644
--- a/pkg/test/dockerutil/network.go
+++ b/pkg/test/dockerutil/network.go
@@ -102,11 +102,8 @@ func (n *Network) Inspect(ctx context.Context) (types.NetworkResource, error) {
return n.client.NetworkInspect(ctx, n.id, types.NetworkInspectOptions{Verbose: true})
}
-// Cleanup cleans up the docker network and all the containers attached to it.
+// Cleanup cleans up the docker network.
func (n *Network) Cleanup(ctx context.Context) error {
- for _, c := range n.containers {
- c.CleanUp(ctx)
- }
n.containers = nil
return n.client.NetworkRemove(ctx, n.id)
diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go
index 1cd5fba5c..1ae76d7d7 100644
--- a/runsc/boot/controller.go
+++ b/runsc/boot/controller.go
@@ -400,7 +400,7 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error {
// Set up the restore environment.
ctx := k.SupervisorContext()
- mntr := newContainerMounter(cm.l.root.spec, cm.l.root.goferFDs, cm.l.k, cm.l.mountHints)
+ mntr := newContainerMounter(cm.l.root.spec, cm.l.root.goferFDs, cm.l.k, cm.l.mountHints, kernel.VFS2Enabled)
if kernel.VFS2Enabled {
ctx, err = mntr.configureRestore(ctx, cm.l.root.conf)
if err != nil {
diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go
index 77f632bb9..32adde643 100644
--- a/runsc/boot/fs.go
+++ b/runsc/boot/fs.go
@@ -103,14 +103,14 @@ func addOverlay(ctx context.Context, conf *config.Config, lower *fs.Inode, name
// compileMounts returns the supported mounts from the mount spec, adding any
// mandatory mounts that are required by the OCI specification.
-func compileMounts(spec *specs.Spec) []specs.Mount {
+func compileMounts(spec *specs.Spec, vfs2Enabled bool) []specs.Mount {
// Keep track of whether proc and sys were mounted.
var procMounted, sysMounted, devMounted, devptsMounted bool
var mounts []specs.Mount
// Mount all submounts from the spec.
for _, m := range spec.Mounts {
- if !specutils.IsSupportedDevMount(m) {
+ if !vfs2Enabled && !specutils.IsVFS1SupportedDevMount(m) {
log.Warningf("ignoring dev mount at %q", m.Destination)
continue
}
@@ -572,10 +572,10 @@ type containerMounter struct {
hints *podMountHints
}
-func newContainerMounter(spec *specs.Spec, goferFDs []*fd.FD, k *kernel.Kernel, hints *podMountHints) *containerMounter {
+func newContainerMounter(spec *specs.Spec, goferFDs []*fd.FD, k *kernel.Kernel, hints *podMountHints, vfs2Enabled bool) *containerMounter {
return &containerMounter{
root: spec.Root,
- mounts: compileMounts(spec),
+ mounts: compileMounts(spec, vfs2Enabled),
fds: fdDispenser{fds: goferFDs},
k: k,
hints: hints,
@@ -792,7 +792,7 @@ func (c *containerMounter) getMountNameAndOptions(conf *config.Config, m specs.M
case bind:
fd := c.fds.remove()
fsName = gofervfs2.Name
- opts = p9MountData(fd, c.getMountAccessType(m), conf.VFS2)
+ opts = p9MountData(fd, c.getMountAccessType(conf, m), conf.VFS2)
// If configured, add overlay to all writable mounts.
useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly
@@ -802,12 +802,11 @@ func (c *containerMounter) getMountNameAndOptions(conf *config.Config, m specs.M
return fsName, opts, useOverlay, nil
}
-func (c *containerMounter) getMountAccessType(mount specs.Mount) config.FileAccessType {
+func (c *containerMounter) getMountAccessType(conf *config.Config, mount specs.Mount) config.FileAccessType {
if hint := c.hints.findMount(mount); hint != nil {
return hint.fileAccessType()
}
- // Non-root bind mounts are always shared if no hints were provided.
- return config.FileAccessShared
+ return conf.FileAccessMounts
}
// mountSubmount mounts volumes inside the container's root. Because mounts may
diff --git a/runsc/boot/fs_test.go b/runsc/boot/fs_test.go
index e986231e5..b4f12d034 100644
--- a/runsc/boot/fs_test.go
+++ b/runsc/boot/fs_test.go
@@ -243,7 +243,8 @@ func TestGetMountAccessType(t *testing.T) {
t.Fatalf("newPodMountHints failed: %v", err)
}
mounter := containerMounter{hints: podHints}
- if got := mounter.getMountAccessType(specs.Mount{Source: source}); got != tst.want {
+ conf := &config.Config{FileAccessMounts: config.FileAccessShared}
+ if got := mounter.getMountAccessType(conf, specs.Mount{Source: source}); got != tst.want {
t.Errorf("getMountAccessType(), want: %v, got: %v", tst.want, got)
}
})
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index 5afce232d..774621970 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -752,7 +752,7 @@ func (l *Loader) createContainerProcess(root bool, cid string, info *containerIn
// Setup the child container file system.
l.startGoferMonitor(cid, info.goferFDs)
- mntr := newContainerMounter(info.spec, info.goferFDs, l.k, l.mountHints)
+ mntr := newContainerMounter(info.spec, info.goferFDs, l.k, l.mountHints, kernel.VFS2Enabled)
if root {
if err := mntr.processHints(info.conf, info.procArgs.Credentials); err != nil {
return nil, nil, nil, err
diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go
index 3121ca6eb..8b39bc59a 100644
--- a/runsc/boot/loader_test.go
+++ b/runsc/boot/loader_test.go
@@ -439,7 +439,7 @@ func TestCreateMountNamespace(t *testing.T) {
}
defer cleanup()
- mntr := newContainerMounter(&tc.spec, []*fd.FD{fd.New(sandEnd)}, nil, &podMountHints{})
+ mntr := newContainerMounter(&tc.spec, []*fd.FD{fd.New(sandEnd)}, nil, &podMountHints{}, false /* vfs2Enabled */)
mns, err := mntr.createMountNamespace(ctx, conf)
if err != nil {
t.Fatalf("failed to create mount namespace: %v", err)
@@ -479,7 +479,7 @@ func TestCreateMountNamespaceVFS2(t *testing.T) {
defer l.Destroy()
defer loaderCleanup()
- mntr := newContainerMounter(l.root.spec, l.root.goferFDs, l.k, l.mountHints)
+ mntr := newContainerMounter(l.root.spec, l.root.goferFDs, l.k, l.mountHints, true /* vfs2Enabled */)
if err := mntr.processHints(l.root.conf, l.root.procArgs.Credentials); err != nil {
t.Fatalf("failed process hints: %v", err)
}
@@ -702,7 +702,7 @@ func TestRestoreEnvironment(t *testing.T) {
for _, ioFD := range tc.ioFDs {
ioFDs = append(ioFDs, fd.New(ioFD))
}
- mntr := newContainerMounter(tc.spec, ioFDs, nil, &podMountHints{})
+ mntr := newContainerMounter(tc.spec, ioFDs, nil, &podMountHints{}, false /* vfs2Enabled */)
actualRenv, err := mntr.createRestoreEnvironment(conf)
if !tc.errorExpected && err != nil {
t.Fatalf("could not create restore environment for test:%s", tc.name)
diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go
index 3fd28e516..9b3dacf46 100644
--- a/runsc/boot/vfs.go
+++ b/runsc/boot/vfs.go
@@ -494,7 +494,7 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo
// but unlikely to be correct in this context.
return "", nil, false, fmt.Errorf("9P mount requires a connection FD")
}
- data = p9MountData(m.fd, c.getMountAccessType(m.Mount), true /* vfs2 */)
+ data = p9MountData(m.fd, c.getMountAccessType(conf, m.Mount), true /* vfs2 */)
iopts = gofer.InternalFilesystemOptions{
UniqueID: m.Destination,
}
diff --git a/runsc/cmd/do.go b/runsc/cmd/do.go
index 22c1dfeb8..455c57692 100644
--- a/runsc/cmd/do.go
+++ b/runsc/cmd/do.go
@@ -42,10 +42,11 @@ var errNoDefaultInterface = errors.New("no default interface found")
// Do implements subcommands.Command for the "do" command. It sets up a simple
// sandbox and executes the command inside it. See Usage() for more details.
type Do struct {
- root string
- cwd string
- ip string
- quiet bool
+ root string
+ cwd string
+ ip string
+ quiet bool
+ overlay bool
}
// Name implements subcommands.Command.Name.
@@ -76,6 +77,7 @@ func (c *Do) SetFlags(f *flag.FlagSet) {
f.StringVar(&c.cwd, "cwd", ".", "path to the current directory, defaults to the current directory")
f.StringVar(&c.ip, "ip", "192.168.10.2", "IPv4 address for the sandbox")
f.BoolVar(&c.quiet, "quiet", false, "suppress runsc messages to stdout. Application output is still sent to stdout and stderr")
+ f.BoolVar(&c.overlay, "force-overlay", true, "use an overlay. WARNING: disabling gives the command write access to the host")
}
// Execute implements subcommands.Command.Execute.
@@ -100,9 +102,8 @@ func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) su
return Errorf("Error to retrieve hostname: %v", err)
}
- // Map the entire host file system, but make it readonly with a writable
- // overlay on top (ignore --overlay option).
- conf.Overlay = true
+ // Map the entire host file system, optionally using an overlay.
+ conf.Overlay = c.overlay
absRoot, err := resolvePath(c.root)
if err != nil {
return Errorf("Error resolving root: %v", err)
diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go
index 639b2219c..4cb0164dd 100644
--- a/runsc/cmd/gofer.go
+++ b/runsc/cmd/gofer.go
@@ -165,8 +165,8 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
// Start with root mount, then add any other additional mount as needed.
ats := make([]p9.Attacher, 0, len(spec.Mounts)+1)
ap, err := fsgofer.NewAttachPoint("/", fsgofer.Config{
- ROMount: spec.Root.Readonly || conf.Overlay,
- EnableXattr: conf.Verity,
+ ROMount: spec.Root.Readonly || conf.Overlay,
+ EnableVerityXattr: conf.Verity,
})
if err != nil {
Fatalf("creating attach point: %v", err)
@@ -178,9 +178,9 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
for _, m := range spec.Mounts {
if specutils.Is9PMount(m) {
cfg := fsgofer.Config{
- ROMount: isReadonlyMount(m.Options) || conf.Overlay,
- HostUDS: conf.FSGoferHostUDS,
- EnableXattr: conf.Verity,
+ ROMount: isReadonlyMount(m.Options) || conf.Overlay,
+ HostUDS: conf.FSGoferHostUDS,
+ EnableVerityXattr: conf.Verity,
}
ap, err := fsgofer.NewAttachPoint(m.Destination, cfg)
if err != nil {
@@ -203,6 +203,10 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
filter.InstallUDSFilters()
}
+ if conf.Verity {
+ filter.InstallXattrFilters()
+ }
+
if err := filter.Install(); err != nil {
Fatalf("installing seccomp filters: %v", err)
}
@@ -346,7 +350,7 @@ func setupRootFS(spec *specs.Spec, conf *config.Config) error {
// creates directories as needed.
func setupMounts(conf *config.Config, mounts []specs.Mount, root string) error {
for _, m := range mounts {
- if m.Type != "bind" || !specutils.IsSupportedDevMount(m) {
+ if m.Type != "bind" || !specutils.IsVFS1SupportedDevMount(m) {
continue
}
@@ -386,7 +390,7 @@ func setupMounts(conf *config.Config, mounts []specs.Mount, root string) error {
func resolveMounts(conf *config.Config, mounts []specs.Mount, root string) ([]specs.Mount, error) {
cleanMounts := make([]specs.Mount, 0, len(mounts))
for _, m := range mounts {
- if m.Type != "bind" || !specutils.IsSupportedDevMount(m) {
+ if m.Type != "bind" || !specutils.IsVFS1SupportedDevMount(m) {
cleanMounts = append(cleanMounts, m)
continue
}
diff --git a/runsc/config/config.go b/runsc/config/config.go
index 34ef48825..1e5858837 100644
--- a/runsc/config/config.go
+++ b/runsc/config/config.go
@@ -58,9 +58,12 @@ type Config struct {
// DebugLogFormat is the log format for debug.
DebugLogFormat string `flag:"debug-log-format"`
- // FileAccess indicates how the filesystem is accessed.
+ // FileAccess indicates how the root filesystem is accessed.
FileAccess FileAccessType `flag:"file-access"`
+ // FileAccessMounts indicates how non-root volumes are accessed.
+ FileAccessMounts FileAccessType `flag:"file-access-mounts"`
+
// Overlay is whether to wrap the root filesystem in an overlay.
Overlay bool `flag:"overlay"`
@@ -197,13 +200,19 @@ func (c *Config) validate() error {
type FileAccessType int
const (
- // FileAccessExclusive is the same as FileAccessShared, but enables
- // extra caching for improved performance. It should only be used if
- // the sandbox has exclusive access to the filesystem.
+ // FileAccessExclusive gives the sandbox exclusive access over files and
+ // directories in the filesystem. No external modifications are permitted and
+ // can lead to undefined behavior.
+ //
+ // Exclusive filesystem access enables more aggressive caching and offers
+ // significantly better performance. This is the default mode for the root
+ // volume.
FileAccessExclusive FileAccessType = iota
- // FileAccessShared sends IO requests to a Gofer process that validates the
- // requests and forwards them to the host.
+ // FileAccessShared is used for volumes that can have external changes. It
+ // requires revalidation on every filesystem access to detect external
+ // changes, and reduces the amount of caching that can be done. This is the
+ // default mode for non-root volumes.
FileAccessShared
)
diff --git a/runsc/config/flags.go b/runsc/config/flags.go
index adbee506c..1d996c841 100644
--- a/runsc/config/flags.go
+++ b/runsc/config/flags.go
@@ -67,7 +67,8 @@ func RegisterFlags() {
flag.Bool("oci-seccomp", false, "Enables loading OCI seccomp filters inside the sandbox.")
// Flags that control sandbox runtime behavior: FS related.
- flag.Var(fileAccessTypePtr(FileAccessExclusive), "file-access", "specifies which filesystem to use for the root mount: exclusive (default), shared. Volume mounts are always shared.")
+ flag.Var(fileAccessTypePtr(FileAccessExclusive), "file-access", "specifies which filesystem validation to use for the root mount: exclusive (default), shared.")
+ flag.Var(fileAccessTypePtr(FileAccessShared), "file-access-mounts", "specifies which filesystem validation to use for volumes other than the root mount: shared (default), exclusive.")
flag.Bool("overlay", false, "wrap filesystem mounts with writable overlay. All modifications are stored in memory inside the sandbox.")
flag.Bool("verity", false, "specifies whether a verity file system will be mounted.")
flag.Bool("overlayfs-stale-read", true, "assume root mount is an overlay filesystem")
diff --git a/runsc/fsgofer/filter/config.go b/runsc/fsgofer/filter/config.go
index fd72414ce..246b7ed3c 100644
--- a/runsc/fsgofer/filter/config.go
+++ b/runsc/fsgofer/filter/config.go
@@ -247,3 +247,8 @@ var udsSyscalls = seccomp.SyscallRules{
},
},
}
+
+var xattrSyscalls = seccomp.SyscallRules{
+ unix.SYS_FGETXATTR: {},
+ unix.SYS_FSETXATTR: {},
+}
diff --git a/runsc/fsgofer/filter/filter.go b/runsc/fsgofer/filter/filter.go
index 289886720..6c67ee288 100644
--- a/runsc/fsgofer/filter/filter.go
+++ b/runsc/fsgofer/filter/filter.go
@@ -36,3 +36,9 @@ func InstallUDSFilters() {
// Add additional filters required for connecting to the host's sockets.
allowedSyscalls.Merge(udsSyscalls)
}
+
+// InstallXattrFilters extends the allowed syscalls to include xattr calls that
+// are necessary for Verity enabled file systems.
+func InstallXattrFilters() {
+ allowedSyscalls.Merge(xattrSyscalls)
+}
diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go
index 1e80a634d..e04ddda47 100644
--- a/runsc/fsgofer/fsgofer.go
+++ b/runsc/fsgofer/fsgofer.go
@@ -48,6 +48,14 @@ const (
allowedOpenFlags = unix.O_TRUNC
)
+// verityXattrs are the extended attributes used by verity file system.
+var verityXattrs = map[string]struct{}{
+ "user.merkle.offset": struct{}{},
+ "user.merkle.size": struct{}{},
+ "user.merkle.childrenOffset": struct{}{},
+ "user.merkle.childrenSize": struct{}{},
+}
+
// join is equivalent to path.Join() but skips path.Clean() which is expensive.
func join(parent, child string) string {
if child == "." || child == ".." {
@@ -67,8 +75,9 @@ type Config struct {
// HostUDS signals whether the gofer can mount a host's UDS.
HostUDS bool
- // enableXattr allows Get/SetXattr for the mounted file systems.
- EnableXattr bool
+ // EnableVerityXattr allows access to extended attributes used by the
+ // verity file system.
+ EnableVerityXattr bool
}
type attachPoint struct {
@@ -799,7 +808,10 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error {
}
func (l *localFile) GetXattr(name string, size uint64) (string, error) {
- if !l.attachPoint.conf.EnableXattr {
+ if !l.attachPoint.conf.EnableVerityXattr {
+ return "", unix.EOPNOTSUPP
+ }
+ if _, ok := verityXattrs[name]; !ok {
return "", unix.EOPNOTSUPP
}
buffer := make([]byte, size)
@@ -810,7 +822,10 @@ func (l *localFile) GetXattr(name string, size uint64) (string, error) {
}
func (l *localFile) SetXattr(name string, value string, flags uint32) error {
- if !l.attachPoint.conf.EnableXattr {
+ if !l.attachPoint.conf.EnableVerityXattr {
+ return unix.EOPNOTSUPP
+ }
+ if _, ok := verityXattrs[name]; !ok {
return unix.EOPNOTSUPP
}
return unix.Fsetxattr(l.file.FD(), name, []byte(value), int(flags))
diff --git a/runsc/fsgofer/fsgofer_test.go b/runsc/fsgofer/fsgofer_test.go
index a5f09f88f..d7e141476 100644
--- a/runsc/fsgofer/fsgofer_test.go
+++ b/runsc/fsgofer/fsgofer_test.go
@@ -579,20 +579,24 @@ func SetGetXattr(l *localFile, name string, value string) error {
return nil
}
+func TestSetGetDisabledXattr(t *testing.T) {
+ runCustom(t, []uint32{unix.S_IFREG}, rwConfs, func(t *testing.T, s state) {
+ name := "user.merkle.offset"
+ value := "tmp"
+ err := SetGetXattr(s.file, name, value)
+ if err == nil {
+ t.Fatalf("%v: SetGetXattr should have failed", s)
+ }
+ })
+}
+
func TestSetGetXattr(t *testing.T) {
- xattrConfs := []Config{{ROMount: false, EnableXattr: false}, {ROMount: false, EnableXattr: true}}
- runCustom(t, []uint32{unix.S_IFREG}, xattrConfs, func(t *testing.T, s state) {
- name := "user.test"
+ runCustom(t, []uint32{unix.S_IFREG}, []Config{{ROMount: false, EnableVerityXattr: true}}, func(t *testing.T, s state) {
+ name := "user.merkle.offset"
value := "tmp"
err := SetGetXattr(s.file, name, value)
- if s.conf.EnableXattr {
- if err != nil {
- t.Fatalf("%v: SetGetXattr failed, err: %v", s, err)
- }
- } else {
- if err == nil {
- t.Fatalf("%v: SetGetXattr should have failed", s)
- }
+ if err != nil {
+ t.Fatalf("%v: SetGetXattr failed, err: %v", s, err)
}
})
}
diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go
index 5ba38bfe4..45856fd58 100644
--- a/runsc/specutils/specutils.go
+++ b/runsc/specutils/specutils.go
@@ -334,14 +334,13 @@ func capsFromNames(names []string, skipSet map[linux.Capability]struct{}) (auth.
// Is9PMount returns true if the given mount can be mounted as an external gofer.
func Is9PMount(m specs.Mount) bool {
- return m.Type == "bind" && m.Source != "" && IsSupportedDevMount(m)
+ return m.Type == "bind" && m.Source != "" && IsVFS1SupportedDevMount(m)
}
-// IsSupportedDevMount returns true if the mount is a supported /dev mount.
-// Only mount that does not conflict with runsc default /dev mount is
-// supported.
-func IsSupportedDevMount(m specs.Mount) bool {
- // These are devices exist inside sentry. See pkg/sentry/fs/dev/dev.go
+// IsVFS1SupportedDevMount returns true if m.Destination does not specify a
+// path that is hardcoded by VFS1's implementation of /dev.
+func IsVFS1SupportedDevMount(m specs.Mount) bool {
+ // See pkg/sentry/fs/dev/dev.go.
var existingDevices = []string{
"/dev/fd", "/dev/stdin", "/dev/stdout", "/dev/stderr",
"/dev/null", "/dev/zero", "/dev/full", "/dev/random",
diff --git a/test/benchmarks/fs/BUILD b/test/benchmarks/fs/BUILD
index c94caab60..dc82e63b2 100644
--- a/test/benchmarks/fs/BUILD
+++ b/test/benchmarks/fs/BUILD
@@ -8,6 +8,7 @@ benchmark_test(
srcs = ["bazel_test.go"],
visibility = ["//:sandbox"],
deps = [
+ "//pkg/cleanup",
"//pkg/test/dockerutil",
"//test/benchmarks/harness",
"//test/benchmarks/tools",
@@ -21,6 +22,7 @@ benchmark_test(
srcs = ["fio_test.go"],
visibility = ["//:sandbox"],
deps = [
+ "//pkg/cleanup",
"//pkg/test/dockerutil",
"//test/benchmarks/harness",
"//test/benchmarks/tools",
diff --git a/test/benchmarks/fs/bazel_test.go b/test/benchmarks/fs/bazel_test.go
index 7ced963f6..797b1952d 100644
--- a/test/benchmarks/fs/bazel_test.go
+++ b/test/benchmarks/fs/bazel_test.go
@@ -20,6 +20,7 @@ import (
"strings"
"testing"
+ "gvisor.dev/gvisor/pkg/cleanup"
"gvisor.dev/gvisor/pkg/test/dockerutil"
"gvisor.dev/gvisor/test/benchmarks/harness"
"gvisor.dev/gvisor/test/benchmarks/tools"
@@ -28,8 +29,8 @@ import (
// Dimensions here are clean/dirty cache (do or don't drop caches)
// and if the mount on which we are compiling is a tmpfs/bind mount.
type benchmark struct {
- clearCache bool // clearCache drops caches before running.
- fstype string // type of filesystem to use.
+ clearCache bool // clearCache drops caches before running.
+ fstype harness.FileSystemType // type of filesystem to use.
}
// Note: CleanCache versions of this test require running with root permissions.
@@ -48,12 +49,12 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) {
// Get a machine from the Harness on which to run.
machine, err := harness.GetMachine()
if err != nil {
- b.Fatalf("failed to get machine: %v", err)
+ b.Fatalf("Failed to get machine: %v", err)
}
defer machine.CleanUp()
benchmarks := make([]benchmark, 0, 6)
- for _, filesys := range []string{harness.BindFS, harness.TmpFS, harness.RootFS} {
+ for _, filesys := range []harness.FileSystemType{harness.BindFS, harness.TmpFS, harness.RootFS} {
benchmarks = append(benchmarks, benchmark{
clearCache: true,
fstype: filesys,
@@ -75,7 +76,7 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) {
filesystem := tools.Parameter{
Name: "filesystem",
- Value: bm.fstype,
+ Value: string(bm.fstype),
}
name, err := tools.ParametersToName(pageCache, filesystem)
if err != nil {
@@ -86,13 +87,14 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) {
// Grab a container.
ctx := context.Background()
container := machine.GetContainer(ctx, b)
- defer container.CleanUp(ctx)
-
- mts, prefix, cleanup, err := harness.MakeMount(machine, bm.fstype)
+ cu := cleanup.Make(func() {
+ container.CleanUp(ctx)
+ })
+ defer cu.Clean()
+ mts, prefix, err := harness.MakeMount(machine, bm.fstype, &cu)
if err != nil {
b.Fatalf("Failed to make mount: %v", err)
}
- defer cleanup()
runOpts := dockerutil.RunOpts{
Image: image,
@@ -104,8 +106,9 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) {
b.Fatalf("run failed with: %v", err)
}
+ cpCmd := fmt.Sprintf("mkdir -p %s && cp -r %s %s/.", prefix, workdir, prefix)
if out, err := container.Exec(ctx, dockerutil.ExecOpts{},
- "cp", "-rf", workdir, prefix+"/."); err != nil {
+ "/bin/sh", "-c", cpCmd); err != nil {
b.Fatalf("failed to copy directory: %v (%s)", err, out)
}
diff --git a/test/benchmarks/fs/fio_test.go b/test/benchmarks/fs/fio_test.go
index f783a2b33..1482466f4 100644
--- a/test/benchmarks/fs/fio_test.go
+++ b/test/benchmarks/fs/fio_test.go
@@ -21,6 +21,7 @@ import (
"strings"
"testing"
+ "gvisor.dev/gvisor/pkg/cleanup"
"gvisor.dev/gvisor/pkg/test/dockerutil"
"gvisor.dev/gvisor/test/benchmarks/harness"
"gvisor.dev/gvisor/test/benchmarks/tools"
@@ -69,7 +70,7 @@ func BenchmarkFio(b *testing.B) {
}
defer machine.CleanUp()
- for _, fsType := range []string{harness.BindFS, harness.TmpFS, harness.RootFS} {
+ for _, fsType := range []harness.FileSystemType{harness.BindFS, harness.TmpFS, harness.RootFS} {
for _, tc := range testCases {
operation := tools.Parameter{
Name: "operation",
@@ -81,7 +82,7 @@ func BenchmarkFio(b *testing.B) {
}
filesystem := tools.Parameter{
Name: "filesystem",
- Value: fsType,
+ Value: string(fsType),
}
name, err := tools.ParametersToName(operation, blockSize, filesystem)
if err != nil {
@@ -90,15 +91,18 @@ func BenchmarkFio(b *testing.B) {
b.Run(name, func(b *testing.B) {
b.StopTimer()
tc.Size = b.N
+
ctx := context.Background()
container := machine.GetContainer(ctx, b)
- defer container.CleanUp(ctx)
+ cu := cleanup.Make(func() {
+ container.CleanUp(ctx)
+ })
+ defer cu.Clean()
- mnts, outdir, mountCleanup, err := harness.MakeMount(machine, fsType)
+ mnts, outdir, err := harness.MakeMount(machine, fsType, &cu)
if err != nil {
b.Fatalf("failed to make mount: %v", err)
}
- defer mountCleanup()
// Start the container with the mount.
if err := container.Spawn(
@@ -112,6 +116,11 @@ func BenchmarkFio(b *testing.B) {
b.Fatalf("failed to start fio container with: %v", err)
}
+ if out, err := container.Exec(ctx, dockerutil.ExecOpts{},
+ "mkdir", "-p", outdir); err != nil {
+ b.Fatalf("failed to copy directory: %v (%s)", err, out)
+ }
+
// Directory and filename inside container where fio will read/write.
outfile := filepath.Join(outdir, "test.txt")
@@ -130,7 +139,6 @@ func BenchmarkFio(b *testing.B) {
}
cmd := tc.MakeCmd(outfile)
-
if err := harness.DropCaches(machine); err != nil {
b.Fatalf("failed to drop caches: %v", err)
}
diff --git a/test/benchmarks/harness/BUILD b/test/benchmarks/harness/BUILD
index 116610938..367316661 100644
--- a/test/benchmarks/harness/BUILD
+++ b/test/benchmarks/harness/BUILD
@@ -12,6 +12,7 @@ go_library(
],
visibility = ["//:sandbox"],
deps = [
+ "//pkg/cleanup",
"//pkg/test/dockerutil",
"//pkg/test/testutil",
"@com_github_docker_docker//api/types/mount:go_default_library",
diff --git a/test/benchmarks/harness/util.go b/test/benchmarks/harness/util.go
index 36abe1069..f7e569751 100644
--- a/test/benchmarks/harness/util.go
+++ b/test/benchmarks/harness/util.go
@@ -22,6 +22,7 @@ import (
"testing"
"github.com/docker/docker/api/types/mount"
+ "gvisor.dev/gvisor/pkg/cleanup"
"gvisor.dev/gvisor/pkg/test/dockerutil"
"gvisor.dev/gvisor/pkg/test/testutil"
)
@@ -58,52 +59,55 @@ func DebugLog(b *testing.B, msg string, args ...interface{}) {
}
}
+// FileSystemType represents a type container mount.
+type FileSystemType string
+
const (
// BindFS indicates a bind mount should be created.
- BindFS = "bindfs"
+ BindFS FileSystemType = "bindfs"
// TmpFS indicates a tmpfs mount should be created.
- TmpFS = "tmpfs"
+ TmpFS FileSystemType = "tmpfs"
// RootFS indicates no mount should be created and the root mount should be used.
- RootFS = "rootfs"
+ RootFS FileSystemType = "rootfs"
)
// MakeMount makes a mount and cleanup based on the requested type. Bind
// and volume mounts are backed by a temp directory made with mktemp.
// tmpfs mounts require no such backing and are just made.
// rootfs mounts do not make a mount, but instead return a target direectory at root.
-// It is up to the caller to call the returned cleanup.
-func MakeMount(machine Machine, fsType string) ([]mount.Mount, string, func(), error) {
+// It is up to the caller to call Clean on the passed *cleanup.Cleanup
+func MakeMount(machine Machine, fsType FileSystemType, cu *cleanup.Cleanup) ([]mount.Mount, string, error) {
mounts := make([]mount.Mount, 0, 1)
+ target := "/data"
switch fsType {
case BindFS:
dir, err := machine.RunCommand("mktemp", "-d")
if err != nil {
- return mounts, "", func() {}, fmt.Errorf("failed to create tempdir: %v", err)
+ return mounts, "", fmt.Errorf("failed to create tempdir: %v", err)
}
dir = strings.TrimSuffix(dir, "\n")
-
+ cu.Add(func() {
+ machine.RunCommand("rm", "-rf", dir)
+ })
out, err := machine.RunCommand("chmod", "777", dir)
if err != nil {
- machine.RunCommand("rm", "-rf", dir)
- return mounts, "", func() {}, fmt.Errorf("failed modify directory: %v %s", err, out)
+ return mounts, "", fmt.Errorf("failed modify directory: %v %s", err, out)
}
- target := "/data"
mounts = append(mounts, mount.Mount{
Target: target,
Source: dir,
Type: mount.TypeBind,
})
- return mounts, target, func() { machine.RunCommand("rm", "-rf", dir) }, nil
+ return mounts, target, nil
case RootFS:
- return mounts, "/", func() {}, nil
+ return mounts, target, nil
case TmpFS:
- target := "/data"
mounts = append(mounts, mount.Mount{
Target: target,
Type: mount.TypeTmpfs,
})
- return mounts, target, func() {}, nil
+ return mounts, target, nil
default:
- return mounts, "", func() {}, fmt.Errorf("illegal mount type not supported: %v", fsType)
+ return mounts, "", fmt.Errorf("illegal mount type not supported: %v", fsType)
}
}
diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl
index 567f64c41..34e83ec49 100644
--- a/test/packetimpact/runner/defs.bzl
+++ b/test/packetimpact/runner/defs.bzl
@@ -203,11 +203,6 @@ ALL_TESTS = [
name = "tcp_outside_the_window",
),
PacketimpactTestInfo(
- name = "tcp_outside_the_window_closing",
- # TODO(b/181625316): Fix netstack then merge into tcp_outside_the_window.
- expect_netstack_failure = True,
- ),
- PacketimpactTestInfo(
name = "tcp_noaccept_close_rst",
),
PacketimpactTestInfo(
@@ -217,11 +212,6 @@ ALL_TESTS = [
name = "tcp_unacc_seq_ack",
),
PacketimpactTestInfo(
- name = "tcp_unacc_seq_ack_closing",
- # TODO(b/181625316): Fix netstack then merge into tcp_unacc_seq_ack.
- expect_netstack_failure = True,
- ),
- PacketimpactTestInfo(
name = "tcp_paws_mechanism",
# TODO(b/156682000): Fix netstack then remove the line below.
expect_netstack_failure = True,
@@ -289,8 +279,6 @@ ALL_TESTS = [
),
PacketimpactTestInfo(
name = "tcp_fin_retransmission",
- # TODO(b/181625316): Fix netstack then remove the line below.
- expect_netstack_failure = True,
),
]
diff --git a/test/packetimpact/runner/dut.go b/test/packetimpact/runner/dut.go
index 1064ca976..b271bd47e 100644
--- a/test/packetimpact/runner/dut.go
+++ b/test/packetimpact/runner/dut.go
@@ -137,7 +137,7 @@ func setUpDUT(ctx context.Context, t *testing.T, id int, mkDevice func(*dockerut
dn := dn
t.Cleanup(func() {
if err := dn.Cleanup(ctx); err != nil {
- t.Errorf("unable to cleanup container %s: %s", dn.Name, err)
+ t.Errorf("failed to cleanup network %s: %s", dn.Name, err)
}
})
// Sanity check.
@@ -151,13 +151,15 @@ func setUpDUT(ctx context.Context, t *testing.T, id int, mkDevice func(*dockerut
info.testNet = testNet
// Create the Docker container for the DUT.
- var dut DUT
+ makeContainer := dockerutil.MakeContainer
if native {
- dut = mkDevice(dockerutil.MakeNativeContainer(ctx, logger(fmt.Sprintf("dut-%d", id))))
- } else {
- dut = mkDevice(dockerutil.MakeContainer(ctx, logger(fmt.Sprintf("dut-%d", id))))
+ makeContainer = dockerutil.MakeNativeContainer
}
- info.dut = dut
+ dutContainer := makeContainer(ctx, logger(fmt.Sprintf("dut-%d", id)))
+ t.Cleanup(func() {
+ dutContainer.CleanUp(ctx)
+ })
+ info.dut = mkDevice(dutContainer)
runOpts := dockerutil.RunOpts{
Image: "packetimpact",
@@ -168,7 +170,7 @@ func setUpDUT(ctx context.Context, t *testing.T, id int, mkDevice func(*dockerut
}
ipv4PrefixLength, _ := testNet.Subnet.Mask.Size()
- remoteIPv6, remoteMAC, dutDeviceID, dutTestNetDev, err := dut.Prepare(ctx, t, runOpts, ctrlNet, testNet)
+ remoteIPv6, remoteMAC, dutDeviceID, dutTestNetDev, err := info.dut.Prepare(ctx, t, runOpts, ctrlNet, testNet)
if err != nil {
return dutInfo{}, err
}
@@ -183,7 +185,7 @@ func setUpDUT(ctx context.Context, t *testing.T, id int, mkDevice func(*dockerut
POSIXServerIP: AddressInSubnet(DUTAddr, *ctrlNet.Subnet),
POSIXServerPort: CtrlPort,
}
- info.uname, err = dut.Uname(ctx)
+ info.uname, err = info.dut.Uname(ctx)
if err != nil {
return dutInfo{}, fmt.Errorf("failed to get uname information on DUT: %w", err)
}
@@ -231,6 +233,9 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co
// Create the Docker container for the testbench.
testbenchContainer := dockerutil.MakeNativeContainer(ctx, logger("testbench"))
+ t.Cleanup(func() {
+ testbenchContainer.CleanUp(ctx)
+ })
runOpts := dockerutil.RunOpts{
Image: "packetimpact",
@@ -598,7 +603,6 @@ func createDockerNetwork(ctx context.Context, n *dockerutil.Network) error {
func StartContainer(ctx context.Context, runOpts dockerutil.RunOpts, c *dockerutil.Container, containerAddr net.IP, ns []*dockerutil.Network, sysctls map[string]string, cmd ...string) error {
conf, hostconf, netconf := c.ConfigsFrom(runOpts, cmd...)
_ = netconf
- hostconf.AutoRemove = true
hostconf.Sysctls = map[string]string{"net.ipv6.conf.all.disable_ipv6": "0"}
for k, v := range sysctls {
hostconf.Sysctls[k] = v
diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD
index d5cb0ae06..c0deb33e5 100644
--- a/test/packetimpact/tests/BUILD
+++ b/test/packetimpact/tests/BUILD
@@ -124,17 +124,6 @@ packetimpact_testbench(
)
packetimpact_testbench(
- name = "tcp_outside_the_window_closing",
- srcs = ["tcp_outside_the_window_closing_test.go"],
- deps = [
- "//pkg/tcpip/header",
- "//pkg/tcpip/seqnum",
- "//test/packetimpact/testbench",
- "@org_golang_x_sys//unix:go_default_library",
- ],
-)
-
-packetimpact_testbench(
name = "tcp_noaccept_close_rst",
srcs = ["tcp_noaccept_close_rst_test.go"],
deps = [
@@ -166,17 +155,6 @@ packetimpact_testbench(
)
packetimpact_testbench(
- name = "tcp_unacc_seq_ack_closing",
- srcs = ["tcp_unacc_seq_ack_closing_test.go"],
- deps = [
- "//pkg/tcpip/header",
- "//pkg/tcpip/seqnum",
- "//test/packetimpact/testbench",
- "@org_golang_x_sys//unix:go_default_library",
- ],
-)
-
-packetimpact_testbench(
name = "tcp_paws_mechanism",
srcs = ["tcp_paws_mechanism_test.go"],
deps = [
diff --git a/test/packetimpact/tests/tcp_outside_the_window_closing_test.go b/test/packetimpact/tests/tcp_outside_the_window_closing_test.go
deleted file mode 100644
index 1097746c7..000000000
--- a/test/packetimpact/tests/tcp_outside_the_window_closing_test.go
+++ /dev/null
@@ -1,86 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp_outside_the_window_closing_test
-
-import (
- "flag"
- "fmt"
- "testing"
- "time"
-
- "golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
- "gvisor.dev/gvisor/test/packetimpact/testbench"
-)
-
-func init() {
- testbench.Initialize(flag.CommandLine)
-}
-
-// TestAckOTWSeqInClosing tests that the DUT should send an ACK with
-// the right ACK number when receiving a packet with OTW Seq number
-// in CLOSING state. https://tools.ietf.org/html/rfc793#page-69
-func TestAckOTWSeqInClosing(t *testing.T) {
- for seqNumOffset := seqnum.Size(0); seqNumOffset < 3; seqNumOffset++ {
- for _, tt := range []struct {
- description string
- flags header.TCPFlags
- payloads testbench.Layers
- }{
- {"SYN", header.TCPFlagSyn, nil},
- {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil},
- {"ACK", header.TCPFlagAck, nil},
- {"FINACK", header.TCPFlagFin | header.TCPFlagAck, nil},
- {"Data", header.TCPFlagAck, []testbench.Layer{&testbench.Payload{Bytes: []byte("abc123")}}},
- } {
- t.Run(fmt.Sprintf("%s%d", tt.description, seqNumOffset), func(t *testing.T) {
- dut := testbench.NewDUT(t)
- listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
- defer dut.Close(t, listenFD)
- conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
- defer conn.Close(t)
- conn.Connect(t)
- acceptFD, _ := dut.Accept(t, listenFD)
- defer dut.Close(t, acceptFD)
-
- dut.Shutdown(t, acceptFD, unix.SHUT_WR)
-
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil {
- t.Fatalf("expected FINACK from DUT, but got none: %s", err)
- }
-
- // Do not ack the FIN from DUT so that the TCP state on DUT is CLOSING instead of CLOSED.
- seqNumForTheirFIN := testbench.Uint32(uint32(*conn.RemoteSeqNum(t)) - 1)
- conn.Send(t, testbench.TCP{AckNum: seqNumForTheirFIN, Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)})
-
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil {
- t.Errorf("expected an ACK to our FIN, but got none: %s", err)
- }
-
- windowSize := seqnum.Size(*conn.SynAck(t).WindowSize) + seqNumOffset
- conn.SendFrameStateless(t, conn.CreateFrame(t, testbench.Layers{&testbench.TCP{
- SeqNum: testbench.Uint32(uint32(conn.LocalSeqNum(t).Add(windowSize))),
- AckNum: seqNumForTheirFIN,
- Flags: testbench.TCPFlags(tt.flags),
- }}, tt.payloads...))
-
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil {
- t.Errorf("expected an ACK but got none: %s", err)
- }
- })
- }
- }
-}
diff --git a/test/packetimpact/tests/tcp_outside_the_window_test.go b/test/packetimpact/tests/tcp_outside_the_window_test.go
index 7cd7ff703..0523887d9 100644
--- a/test/packetimpact/tests/tcp_outside_the_window_test.go
+++ b/test/packetimpact/tests/tcp_outside_the_window_test.go
@@ -108,3 +108,75 @@ func TestTCPOutsideTheWindow(t *testing.T) {
})
}
}
+
+// TestAckOTWSeqInClosing tests that the DUT should send an ACK with
+// the right ACK number when receiving a packet with OTW Seq number
+// in CLOSING state. https://tools.ietf.org/html/rfc793#page-69
+func TestAckOTWSeqInClosing(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ flags header.TCPFlags
+ payloads testbench.Layers
+ seqNumOffset seqnum.Size
+ expectACK bool
+ }{
+ {"SYN", header.TCPFlagSyn, nil, 0, true},
+ {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 0, true},
+ {"ACK", header.TCPFlagAck, nil, 0, false},
+ {"FINACK", header.TCPFlagFin | header.TCPFlagAck, nil, 0, false},
+ {"Data", header.TCPFlagAck, []testbench.Layer{&testbench.Payload{Bytes: []byte("Sample Data")}}, 0, false},
+
+ {"SYN", header.TCPFlagSyn, nil, 1, true},
+ {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 1, true},
+ {"ACK", header.TCPFlagAck, nil, 1, true},
+ {"FINACK", header.TCPFlagFin | header.TCPFlagAck, nil, 1, true},
+ {"Data", header.TCPFlagAck, []testbench.Layer{&testbench.Payload{Bytes: []byte("Sample Data")}}, 1, true},
+
+ {"SYN", header.TCPFlagSyn, nil, 2, true},
+ {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 2, true},
+ {"ACK", header.TCPFlagAck, nil, 2, true},
+ {"FINACK", header.TCPFlagFin | header.TCPFlagAck, nil, 2, true},
+ {"Data", header.TCPFlagAck, []testbench.Layer{&testbench.Payload{Bytes: []byte("Sample Data")}}, 2, true},
+ } {
+ t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFD)
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+ conn.Connect(t)
+ acceptFD, _ := dut.Accept(t, listenFD)
+ defer dut.Close(t, acceptFD)
+
+ dut.Shutdown(t, acceptFD, unix.SHUT_WR)
+
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil {
+ t.Fatalf("expected FINACK from DUT, but got none: %s", err)
+ }
+
+ // Do not ack the FIN from DUT so that the TCP state on DUT is CLOSING instead of CLOSED.
+ seqNumForTheirFIN := testbench.Uint32(uint32(*conn.RemoteSeqNum(t)) - 1)
+ conn.Send(t, testbench.TCP{AckNum: seqNumForTheirFIN, Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)})
+
+ gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second)
+ if err != nil {
+ t.Fatalf("expected an ACK to our FIN, but got none: %s", err)
+ }
+
+ windowSize := seqnum.Size(*gotTCP.WindowSize) + tt.seqNumOffset
+ conn.SendFrameStateless(t, conn.CreateFrame(t, testbench.Layers{&testbench.TCP{
+ SeqNum: testbench.Uint32(uint32(conn.LocalSeqNum(t).Add(windowSize))),
+ AckNum: seqNumForTheirFIN,
+ Flags: testbench.TCPFlags(tt.flags),
+ }}, tt.payloads...))
+
+ gotACK, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second)
+ if tt.expectACK && err != nil {
+ t.Errorf("expected an ACK but got none: %s", err)
+ }
+ if !tt.expectACK && gotACK != nil {
+ t.Errorf("expected no ACK but got one: %s", gotACK)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/tests/tcp_unacc_seq_ack_closing_test.go b/test/packetimpact/tests/tcp_unacc_seq_ack_closing_test.go
deleted file mode 100644
index a208210ac..000000000
--- a/test/packetimpact/tests/tcp_unacc_seq_ack_closing_test.go
+++ /dev/null
@@ -1,94 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp_unacc_seq_ack_closing_test
-
-import (
- "flag"
- "fmt"
- "testing"
- "time"
-
- "golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
- "gvisor.dev/gvisor/test/packetimpact/testbench"
-)
-
-func init() {
- testbench.Initialize(flag.CommandLine)
-}
-
-func TestSimultaneousCloseUnaccSeqAck(t *testing.T) {
- for _, tt := range []struct {
- description string
- makeTestingTCP func(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset, windowSize seqnum.Size) testbench.TCP
- seqNumOffset seqnum.Size
- expectAck bool
- }{
- {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 0, expectAck: true},
- {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 1, expectAck: true},
- {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 2, expectAck: true},
- {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 0, expectAck: false},
- {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 1, expectAck: true},
- {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 2, expectAck: true},
- } {
- t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
- dut := testbench.NewDUT(t)
- listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/)
- defer dut.Close(t, listenFD)
- conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
- defer conn.Close(t)
-
- conn.Connect(t)
- acceptFD, _ := dut.Accept(t, listenFD)
-
- // Trigger active close.
- dut.Shutdown(t, acceptFD, unix.SHUT_WR)
-
- gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second)
- if err != nil {
- t.Fatalf("expected a FIN: %s", err)
- }
- // Do not ack the FIN from DUT so that we get to CLOSING.
- seqNumForTheirFIN := testbench.Uint32(uint32(*conn.RemoteSeqNum(t)) - 1)
- conn.Send(t, testbench.TCP{AckNum: seqNumForTheirFIN, Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)})
-
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil {
- t.Errorf("expected an ACK to our FIN, but got none: %s", err)
- }
-
- sampleData := []byte("Sample Data")
- samplePayload := &testbench.Payload{Bytes: sampleData}
-
- origSeq := uint32(*conn.LocalSeqNum(t))
- // Send a segment with OTW Seq / unacc ACK.
- tcp := tt.makeTestingTCP(t, &conn, tt.seqNumOffset, seqnum.Size(*gotTCP.WindowSize))
- if tt.description == "OTWSeq" {
- // If we generate an OTW Seq segment, make sure we don't acknowledge their FIN so that
- // we stay in CLOSING.
- tcp.AckNum = seqNumForTheirFIN
- }
- conn.Send(t, tcp, samplePayload)
-
- got, err := conn.Expect(t, testbench.TCP{AckNum: testbench.Uint32(origSeq), Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second)
- if tt.expectAck && err != nil {
- t.Errorf("expected an ack in CLOSING state, but got none: %s", err)
- }
- if !tt.expectAck && got != nil {
- t.Errorf("expected no ack in CLOSING state, but got one: %s", got)
- }
- })
- }
-}
diff --git a/test/packetimpact/tests/tcp_unacc_seq_ack_test.go b/test/packetimpact/tests/tcp_unacc_seq_ack_test.go
index ce0a26171..389bfc629 100644
--- a/test/packetimpact/tests/tcp_unacc_seq_ack_test.go
+++ b/test/packetimpact/tests/tcp_unacc_seq_ack_test.go
@@ -209,3 +209,66 @@ func TestActiveCloseUnaccpSeqAck(t *testing.T) {
})
}
}
+
+func TestSimultaneousCloseUnaccSeqAck(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ makeTestingTCP func(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset, windowSize seqnum.Size) testbench.TCP
+ seqNumOffset seqnum.Size
+ expectAck bool
+ }{
+ {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 0, expectAck: false},
+ {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 1, expectAck: true},
+ {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 2, expectAck: true},
+ {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 0, expectAck: false},
+ {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 1, expectAck: true},
+ {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 2, expectAck: true},
+ } {
+ t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/)
+ defer dut.Close(t, listenFD)
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ conn.Connect(t)
+ acceptFD, _ := dut.Accept(t, listenFD)
+
+ // Trigger active close.
+ dut.Shutdown(t, acceptFD, unix.SHUT_WR)
+
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil {
+ t.Fatalf("expected a FIN: %s", err)
+ }
+ // Do not ack the FIN from DUT so that we get to CLOSING.
+ seqNumForTheirFIN := testbench.Uint32(uint32(*conn.RemoteSeqNum(t)) - 1)
+ conn.Send(t, testbench.TCP{AckNum: seqNumForTheirFIN, Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)})
+
+ gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second)
+ if err != nil {
+ t.Errorf("expected an ACK to our FIN, but got none: %s", err)
+ }
+
+ sampleData := []byte("Sample Data")
+ samplePayload := &testbench.Payload{Bytes: sampleData}
+
+ origSeq := uint32(*conn.LocalSeqNum(t))
+ // Send a segment with OTW Seq / unacc ACK.
+ tcp := tt.makeTestingTCP(t, &conn, tt.seqNumOffset, seqnum.Size(*gotTCP.WindowSize))
+ if tt.description == "OTWSeq" {
+ // If we generate an OTW Seq segment, make sure we don't acknowledge their FIN so that
+ // we stay in CLOSING.
+ tcp.AckNum = seqNumForTheirFIN
+ }
+ conn.Send(t, tcp, samplePayload)
+
+ got, err := conn.Expect(t, testbench.TCP{AckNum: testbench.Uint32(origSeq), Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second)
+ if tt.expectAck && err != nil {
+ t.Errorf("expected an ack in CLOSING state, but got none: %s", err)
+ }
+ if !tt.expectAck && got != nil {
+ t.Errorf("expected no ack in CLOSING state, but got one: %s", got)
+ }
+ })
+ }
+}
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 4509b5e55..043ada583 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -1922,7 +1922,9 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/base:core_headers",
gtest,
+ "//test/util:cleanup",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
@@ -2162,6 +2164,7 @@ cc_binary(
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings",
gtest,
],
@@ -3990,6 +3993,7 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:cleanup",
+ "@com_google_absl//absl/base:core_headers",
gtest,
"//test/util:temp_path",
"//test/util:test_main",
diff --git a/test/syscalls/linux/read.cc b/test/syscalls/linux/read.cc
index 98d5e432d..087262535 100644
--- a/test/syscalls/linux/read.cc
+++ b/test/syscalls/linux/read.cc
@@ -13,11 +13,14 @@
// limitations under the License.
#include <fcntl.h>
+#include <sys/mman.h>
#include <unistd.h>
#include <vector>
#include "gtest/gtest.h"
+#include "absl/base/macros.h"
+#include "test/util/cleanup.h"
#include "test/util/file_descriptor.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
@@ -121,6 +124,43 @@ TEST_F(ReadTest, ReadWithOpath) {
EXPECT_THAT(ReadFd(fd.get(), buf.data(), 1), SyscallFailsWithErrno(EBADF));
}
+// Test that partial writes that hit SIGSEGV are correctly handled and return
+// partial write.
+TEST_F(ReadTest, PartialReadSIGSEGV) {
+ // Allocate 2 pages and remove permission from the second.
+ const size_t size = 2 * kPageSize;
+ void* addr =
+ mmap(0, size, PROT_WRITE | PROT_READ, MAP_ANONYMOUS | MAP_PRIVATE, 0, 0);
+ ASSERT_NE(addr, MAP_FAILED);
+ auto cleanup = Cleanup(
+ [addr, size] { EXPECT_THAT(munmap(addr, size), SyscallSucceeds()); });
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(name_.c_str(), O_RDWR, 0666));
+ for (size_t i = 0; i < 2; i++) {
+ EXPECT_THAT(pwrite(fd.get(), addr, size, 0),
+ SyscallSucceedsWithValue(size));
+ }
+
+ void* badAddr = reinterpret_cast<char*>(addr) + kPageSize;
+ ASSERT_THAT(mprotect(badAddr, kPageSize, PROT_NONE), SyscallSucceeds());
+
+ // Attempt to read to both pages. Create a non-contiguous iovec pair to
+ // ensure operation is done in 2 steps.
+ struct iovec iov[] = {
+ {
+ .iov_base = addr,
+ .iov_len = kPageSize,
+ },
+ {
+ .iov_base = addr,
+ .iov_len = size,
+ },
+ };
+ EXPECT_THAT(preadv(fd.get(), iov, ABSL_ARRAYSIZE(iov), 0),
+ SyscallSucceedsWithValue(size));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/setgid.cc b/test/syscalls/linux/setgid.cc
index cd030b094..98f8f3dfe 100644
--- a/test/syscalls/linux/setgid.cc
+++ b/test/syscalls/linux/setgid.cc
@@ -17,6 +17,7 @@
#include <unistd.h>
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "test/util/capability_util.h"
#include "test/util/cleanup.h"
#include "test/util/fs_util.h"
@@ -24,6 +25,11 @@
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
+ABSL_FLAG(std::vector<std::string>, groups, std::vector<std::string>({}),
+ "groups the test can use");
+
+constexpr gid_t kNobody = 65534;
+
namespace gvisor {
namespace testing {
@@ -46,6 +52,18 @@ PosixErrorOr<Cleanup> Setegid(gid_t egid) {
// Returns a pair of groups that the user is a member of.
PosixErrorOr<std::pair<gid_t, gid_t>> Groups() {
+ // Were we explicitly passed GIDs?
+ std::vector<std::string> flagged_groups = absl::GetFlag(FLAGS_groups);
+ if (flagged_groups.size() >= 2) {
+ int group1;
+ int group2;
+ if (!absl::SimpleAtoi(flagged_groups[0], &group1) ||
+ !absl::SimpleAtoi(flagged_groups[1], &group2)) {
+ return PosixError(EINVAL, "failed converting group flags to ints");
+ }
+ return std::pair<gid_t, gid_t>(group1, group2);
+ }
+
// See whether the user is a member of at least 2 groups.
std::vector<gid_t> groups(64);
for (; groups.size() <= NGROUPS_MAX; groups.resize(groups.size() * 2)) {
@@ -58,26 +76,47 @@ PosixErrorOr<std::pair<gid_t, gid_t>> Groups() {
return PosixError(errno, absl::StrFormat("getgroups(%d, %p)",
groups.size(), groups.data()));
}
- if (ngroups >= 2) {
- return std::pair<gid_t, gid_t>(groups[0], groups[1]);
+
+ if (ngroups < 2) {
+ // There aren't enough groups.
+ break;
+ }
+
+ // TODO(b/181878080): Read /proc/sys/fs/overflowgid once it is supported in
+ // gVisor.
+ if (groups[0] == kNobody || groups[1] == kNobody) {
+ // These groups aren't mapped into our user namespace, so we can't use
+ // them.
+ break;
}
- // There aren't enough groups.
- break;
+ return std::pair<gid_t, gid_t>(groups[0], groups[1]);
}
- // If we're root in the root user namespace, we can set our GID to whatever we
- // want. Try that before giving up.
- constexpr gid_t kGID1 = 1111;
- constexpr gid_t kGID2 = 2222;
- auto cleanup1 = Setegid(kGID1);
+ // If we're running in gVisor and are root in the root user namespace, we can
+ // set our GID to whatever we want. Try that before giving up.
+ //
+ // This won't work in native tests, as despite having CAP_SETGID, the gofer
+ // process will be sandboxed and unable to change file GIDs.
+ if (!IsRunningOnGvisor()) {
+ return PosixError(EPERM, "no valid groups for native testing");
+ }
+ PosixErrorOr<bool> capable = HaveCapability(CAP_SETGID);
+ if (!capable.ok()) {
+ return capable.error();
+ }
+ if (!capable.ValueOrDie()) {
+ return PosixError(EPERM, "missing CAP_SETGID");
+ }
+ gid_t gid = getegid();
+ auto cleanup1 = Setegid(gid);
if (!cleanup1.ok()) {
return cleanup1.error();
}
- auto cleanup2 = Setegid(kGID2);
+ auto cleanup2 = Setegid(kNobody);
if (!cleanup2.ok()) {
return cleanup2.error();
}
- return std::pair<gid_t, gid_t>(kGID1, kGID2);
+ return std::pair<gid_t, gid_t>(gid, kNobody);
}
class SetgidDirTest : public ::testing::Test {
@@ -85,17 +124,21 @@ class SetgidDirTest : public ::testing::Test {
void SetUp() override {
original_gid_ = getegid();
- // TODO(b/175325250): Enable when setgid directories are supported.
SKIP_IF(IsRunningWithVFS1());
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETGID)));
+ // If we can't find two usable groups, we're in an unsupporting environment.
+ // Skip the test.
+ PosixErrorOr<std::pair<gid_t, gid_t>> groups = Groups();
+ SKIP_IF(!groups.ok());
+ groups_ = groups.ValueOrDie();
+
+ auto cleanup = Setegid(groups_.first);
temp_dir_ = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0777 /* mode */));
- groups_ = ASSERT_NO_ERRNO_AND_VALUE(Groups());
}
void TearDown() override {
- ASSERT_THAT(setegid(original_gid_), SyscallSucceeds());
+ EXPECT_THAT(setegid(original_gid_), SyscallSucceeds());
}
void MkdirAsGid(gid_t gid, const std::string& path, mode_t mode) {
@@ -131,7 +174,7 @@ TEST_F(SetgidDirTest, Control) {
ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.first, g1owned, 0777));
// Set group to G2, create a file in g1owned, and confirm that G2 owns it.
- ASSERT_THAT(setegid(groups_.second), SyscallSucceeds());
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(Setegid(groups_.second));
FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
Open(JoinPath(g1owned, "g2owned").c_str(), O_CREAT | O_RDWR, 0777));
struct stat stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(fd));
@@ -146,7 +189,7 @@ TEST_F(SetgidDirTest, CreateFile) {
ASSERT_THAT(chmod(g1owned.c_str(), kDirmodeSgid), SyscallSucceeds());
// Set group to G2, create a file, and confirm that G1 owns it.
- ASSERT_THAT(setegid(groups_.second), SyscallSucceeds());
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(Setegid(groups_.second));
FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
Open(JoinPath(g1owned, "g2created").c_str(), O_CREAT | O_RDWR, 0666));
struct stat stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(fd));
@@ -194,7 +237,7 @@ TEST_F(SetgidDirTest, OldFile) {
ASSERT_THAT(chmod(g1owned.c_str(), kDirmodeNoSgid), SyscallSucceeds());
// Set group to G2, create a file, confirm that G2 owns it.
- ASSERT_THAT(setegid(groups_.second), SyscallSucceeds());
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(Setegid(groups_.second));
FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
Open(JoinPath(g1owned, "g2created").c_str(), O_CREAT | O_RDWR, 0666));
struct stat stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(fd));
@@ -217,7 +260,7 @@ TEST_F(SetgidDirTest, OldDir) {
ASSERT_THAT(chmod(g1owned.c_str(), kDirmodeNoSgid), SyscallSucceeds());
// Set group to G2, create a directory, confirm that G2 owns it.
- ASSERT_THAT(setegid(groups_.second), SyscallSucceeds());
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(Setegid(groups_.second));
auto g2created = JoinPath(g1owned, "g2created");
ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.second, g2created, 0666));
struct stat stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(g2created));
@@ -306,6 +349,10 @@ class FileModeTest : public ::testing::TestWithParam<FileModeTestcase> {};
TEST_P(FileModeTest, WriteToFile) {
SKIP_IF(IsRunningWithVFS1());
+ PosixErrorOr<std::pair<gid_t, gid_t>> groups = Groups();
+ SKIP_IF(!groups.ok());
+
+ auto cleanup = Setegid(groups.ValueOrDie().first);
auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0777 /* mode */));
auto path = JoinPath(temp_dir.path(), GetParam().name);
@@ -329,26 +376,28 @@ TEST_P(FileModeTest, WriteToFile) {
TEST_P(FileModeTest, TruncateFile) {
SKIP_IF(IsRunningWithVFS1());
+ PosixErrorOr<std::pair<gid_t, gid_t>> groups = Groups();
+ SKIP_IF(!groups.ok());
+
+ auto cleanup = Setegid(groups.ValueOrDie().first);
auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0777 /* mode */));
auto path = JoinPath(temp_dir.path(), GetParam().name);
FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(path.c_str(), O_CREAT | O_RDWR, 0666));
- ASSERT_THAT(fchmod(fd.get(), GetParam().mode), SyscallSucceeds());
- struct stat stats;
- ASSERT_THAT(fstat(fd.get(), &stats), SyscallSucceeds());
- EXPECT_EQ(stats.st_mode & kDirmodeMask, GetParam().mode);
// Write something to the file, as truncating an empty file is a no-op.
constexpr char c = 'M';
ASSERT_THAT(write(fd.get(), &c, sizeof(c)),
SyscallSucceedsWithValue(sizeof(c)));
+ ASSERT_THAT(fchmod(fd.get(), GetParam().mode), SyscallSucceeds());
// For security reasons, truncating the file clears the SUID bit, and clears
// the SGID bit when the group executable bit is unset (which is not a true
// SGID binary).
ASSERT_THAT(ftruncate(fd.get(), 0), SyscallSucceeds());
+ struct stat stats;
ASSERT_THAT(fstat(fd.get(), &stats), SyscallSucceeds());
EXPECT_EQ(stats.st_mode & kDirmodeMask, GetParam().result_mode);
}
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index 54b45b075..597b5bcb1 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -490,7 +490,11 @@ void TestListenWhileConnect(const TestParam& param,
TestAddress const& connector = param.connector;
constexpr int kBacklog = 2;
- constexpr int kClients = kBacklog + 1;
+ // Linux completes one more connection than the listen backlog argument.
+ // To ensure that there is at least one client connection that stays in
+ // connecting state, keep 2 more client connections than the listen backlog.
+ // gVisor differs in this behavior though, gvisor.dev/issue/3153.
+ constexpr int kClients = kBacklog + 2;
// Create the listening socket.
FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
@@ -527,7 +531,7 @@ void TestListenWhileConnect(const TestParam& param,
for (auto& client : clients) {
constexpr int kTimeout = 10000;
- struct pollfd pfd = {
+ pollfd pfd = {
.fd = client.get(),
.events = POLLIN,
};
@@ -543,6 +547,10 @@ void TestListenWhileConnect(const TestParam& param,
ASSERT_THAT(read(client.get(), &c, sizeof(c)),
AnyOf(SyscallFailsWithErrno(ECONNRESET),
SyscallFailsWithErrno(ECONNREFUSED)));
+ // The last client connection would be in connecting (SYN_SENT) state.
+ if (client.get() == clients[kClients - 1].get()) {
+ ASSERT_EQ(errno, ECONNREFUSED) << strerror(errno);
+ }
}
}
@@ -598,7 +606,7 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) {
connector.addr_len);
if (ret != 0) {
EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
- struct pollfd pfd = {
+ pollfd pfd = {
.fd = conn_fd.get(),
.events = POLLOUT,
};
@@ -623,6 +631,95 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) {
}
}
+// Test if the stack completes atmost listen backlog number of client
+// connections. It exercises the path of the stack that enqueues completed
+// connections to accept queue vs new incoming SYNs.
+TEST_P(SocketInetLoopbackTest, TCPConnectBacklog_NoRandomSave) {
+ const auto& param = GetParam();
+ const TestAddress& listener = param.listener;
+ const TestAddress& connector = param.connector;
+
+ constexpr int kBacklog = 1;
+ // Keep the number of client connections more than the listen backlog.
+ // Linux completes one more connection than the listen backlog argument.
+ // gVisor differs in this behavior though, gvisor.dev/issue/3153.
+ int kClients = kBacklog + 2;
+ if (IsRunningOnGvisor()) {
+ kClients--;
+ }
+
+ // Run the following test for few iterations to test race between accept queue
+ // getting filled with incoming SYNs.
+ for (int num = 0; num < 10; num++) {
+ FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
+
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(
+ getsockname(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+
+ std::vector<FileDescriptor> clients;
+ // Issue multiple non-blocking client connects.
+ for (int i = 0; i < kClients; i++) {
+ FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len);
+ if (ret != 0) {
+ EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
+ }
+ clients.push_back(std::move(client));
+ }
+
+ // Now that client connects are issued, wait for the accept queue to get
+ // filled and ensure no new client connection is completed.
+ for (int i = 0; i < kClients; i++) {
+ pollfd pfd = {
+ .fd = clients[i].get(),
+ .events = POLLOUT,
+ };
+ if (i < kClients - 1) {
+ // Poll for client side connection completions with a large timeout.
+ // We cannot poll on the listener side without calling accept as poll
+ // stays level triggered with non-zero accept queue length.
+ //
+ // Client side poll would not guarantee that the completed connection
+ // has been enqueued in to the acccept queue, but the fact that the
+ // listener ACKd the SYN, means that it cannot complete any new incoming
+ // SYNs when it has already ACKd for > backlog number of SYNs.
+ ASSERT_THAT(poll(&pfd, 1, 10000), SyscallSucceedsWithValue(1))
+ << "num=" << num << " i=" << i << " kClients=" << kClients;
+ ASSERT_EQ(pfd.revents, POLLOUT) << "num=" << num << " i=" << i;
+ } else {
+ // Now that we expect accept queue filled up, ensure that the last
+ // client connection never completes with a smaller poll timeout.
+ ASSERT_THAT(poll(&pfd, 1, 1000), SyscallSucceedsWithValue(0))
+ << "num=" << num << " i=" << i;
+ }
+
+ ASSERT_THAT(close(clients[i].release()), SyscallSucceedsWithValue(0))
+ << "num=" << num << " i=" << i;
+ }
+ clients.clear();
+ // We close the listening side and open a new listener. We could instead
+ // drain the accept queue by calling accept() and reuse the listener, but
+ // that is racy as the retransmitted SYNs could get ACKd as we make room in
+ // the accept queue.
+ ASSERT_THAT(close(listen_fd.release()), SyscallSucceedsWithValue(0));
+ }
+}
+
// TCPFinWait2Test creates a pair of connected sockets then closes one end to
// trigger FIN_WAIT2 state for the closed endpoint. Then it binds the same local
// IP/port on a new socket and tries to connect. The connect should fail w/
@@ -937,7 +1034,7 @@ void setupTimeWaitClose(const TestAddress* listener,
ASSERT_THAT(shutdown(active_closefd.get(), SHUT_WR), SyscallSucceeds());
{
constexpr int kTimeout = 10000;
- struct pollfd pfd = {
+ pollfd pfd = {
.fd = passive_closefd.get(),
.events = POLLIN,
};
@@ -948,7 +1045,7 @@ void setupTimeWaitClose(const TestAddress* listener,
{
constexpr int kTimeout = 10000;
constexpr int16_t want_events = POLLHUP;
- struct pollfd pfd = {
+ pollfd pfd = {
.fd = active_closefd.get(),
.events = want_events,
};
@@ -1181,7 +1278,7 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) {
// Wait for accept_fd to process the RST.
constexpr int kTimeout = 10000;
- struct pollfd pfd = {
+ pollfd pfd = {
.fd = accept_fd.get(),
.events = POLLIN,
};
@@ -1705,7 +1802,7 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort_NoRandomSave) {
SyscallSucceedsWithValue(sizeof(i)));
}
- struct pollfd pollfds[kThreadCount];
+ pollfd pollfds[kThreadCount];
for (int i = 0; i < kThreadCount; i++) {
pollfds[i].fd = listener_fds[i].get();
pollfds[i].events = POLLIN;
diff --git a/test/syscalls/linux/write.cc b/test/syscalls/linux/write.cc
index 740992d0a..3373ba72b 100644
--- a/test/syscalls/linux/write.cc
+++ b/test/syscalls/linux/write.cc
@@ -15,6 +15,7 @@
#include <errno.h>
#include <fcntl.h>
#include <signal.h>
+#include <sys/mman.h>
#include <sys/resource.h>
#include <sys/stat.h>
#include <sys/types.h>
@@ -23,6 +24,7 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "absl/base/macros.h"
#include "test/util/cleanup.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
@@ -256,6 +258,82 @@ TEST_F(WriteTest, PwriteWithOpath) {
SyscallFailsWithErrno(EBADF));
}
+// Test that partial writes that hit SIGSEGV are correctly handled and return
+// partial write.
+TEST_F(WriteTest, PartialWriteSIGSEGV) {
+ // Allocate 2 pages and remove permission from the second.
+ const size_t size = 2 * kPageSize;
+ void* addr = mmap(0, size, PROT_READ, MAP_ANONYMOUS | MAP_PRIVATE, 0, 0);
+ ASSERT_NE(addr, MAP_FAILED);
+ auto cleanup = Cleanup(
+ [addr, size] { EXPECT_THAT(munmap(addr, size), SyscallSucceeds()); });
+
+ void* badAddr = reinterpret_cast<char*>(addr) + kPageSize;
+ ASSERT_THAT(mprotect(badAddr, kPageSize, PROT_NONE), SyscallSucceeds());
+
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path().c_str(), O_WRONLY));
+
+ // Attempt to write both pages to the file. Create a non-contiguous iovec pair
+ // to ensure operation is done in 2 steps.
+ struct iovec iov[] = {
+ {
+ .iov_base = addr,
+ .iov_len = kPageSize,
+ },
+ {
+ .iov_base = addr,
+ .iov_len = size,
+ },
+ };
+ // Write should succeed for the first iovec and half of the second (=2 pages).
+ EXPECT_THAT(pwritev(fd.get(), iov, ABSL_ARRAYSIZE(iov), 0),
+ SyscallSucceedsWithValue(2 * kPageSize));
+}
+
+// Test that partial writes that hit SIGBUS are correctly handled and return
+// partial write.
+TEST_F(WriteTest, PartialWriteSIGBUS) {
+ SKIP_IF(getenv("GVISOR_GOFER_UNCACHED")); // Can't mmap from uncached files.
+
+ TempPath mapfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd_map =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(mapfile.path().c_str(), O_RDWR));
+
+ // Let the first page be read to force a partial write.
+ ASSERT_THAT(ftruncate(fd_map.get(), kPageSize), SyscallSucceeds());
+
+ // Map 2 pages, one of which is not allocated in the backing file. Reading
+ // from it will trigger a SIGBUS.
+ const size_t size = 2 * kPageSize;
+ void* addr =
+ mmap(NULL, size, PROT_READ, MAP_FILE | MAP_PRIVATE, fd_map.get(), 0);
+ ASSERT_NE(addr, MAP_FAILED);
+ auto cleanup = Cleanup(
+ [addr, size] { EXPECT_THAT(munmap(addr, size), SyscallSucceeds()); });
+
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path().c_str(), O_WRONLY));
+
+ // Attempt to write both pages to the file. Create a non-contiguous iovec pair
+ // to ensure operation is done in 2 steps.
+ struct iovec iov[] = {
+ {
+ .iov_base = addr,
+ .iov_len = kPageSize,
+ },
+ {
+ .iov_base = addr,
+ .iov_len = size,
+ },
+ };
+ // Write should succeed for the first iovec and half of the second (=2 pages).
+ ASSERT_THAT(pwritev(fd.get(), iov, ABSL_ARRAYSIZE(iov), 0),
+ SyscallSucceedsWithValue(2 * kPageSize));
+}
+
} // namespace
} // namespace testing
diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD
index 44cb33ae4..c2747d94c 100644
--- a/tools/go_marshal/gomarshal/BUILD
+++ b/tools/go_marshal/gomarshal/BUILD
@@ -8,6 +8,7 @@ go_library(
"generator.go",
"generator_interfaces.go",
"generator_interfaces_array_newtype.go",
+ "generator_interfaces_dynamic.go",
"generator_interfaces_primitive_newtype.go",
"generator_interfaces_struct.go",
"generator_tests.go",
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go
index 634abd1af..39394d2a7 100644
--- a/tools/go_marshal/gomarshal/generator.go
+++ b/tools/go_marshal/gomarshal/generator.go
@@ -126,6 +126,12 @@ func (g *Generator) writeHeader() error {
b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n")
// Emit build tags.
+ b.emit("// If there are issues with build tag aggregation, see\n")
+ b.emit("// tools/go_marshal/gomarshal/generator.go:writeHeader(). The build tags here\n")
+ b.emit("// come from the input set of files used to generate this file. This input set\n")
+ b.emit("// is filtered based on pre-defined file suffixes related to build tags, see \n")
+ b.emit("// tools/defs.bzl:calculate_sets().\n\n")
+
if t := tags.Aggregate(g.inputs); len(t) > 0 {
b.emit(strings.Join(t.Lines(), "\n"))
b.emit("\n\n")
@@ -381,36 +387,29 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp
func (g *Generator) generateOne(t *marshallableType, fset *token.FileSet) *interfaceGenerator {
i := newInterfaceGenerator(t.spec, t.recv, fset)
+ if t.dynamic {
+ if t.slice != nil {
+ abortAt(fset.Position(t.slice.comment.Slash), "Slice API is not supported for dynamic types because it assumes that each slice element is statically sized.")
+ }
+ // No validation needed, assume the user knows what they are doing.
+ i.emitMarshallableForDynamicType()
+ return i
+ }
switch ty := t.spec.Type.(type) {
case *ast.StructType:
- if t.dynamic {
- // Don't validate because this type is dynamically sized and probably
- // contains some funky slices which the validation does not allow.
- i.emitMarshallableForStruct(ty, t.dynamic)
- if t.slice != nil {
- abortAt(fset.Position(t.slice.comment.Slash), "Slice API is not supported for dynamic types because it assumes that each slice element is statically sized.")
- }
- break
- }
i.validateStruct(t.spec, ty)
- i.emitMarshallableForStruct(ty, t.dynamic)
+ i.emitMarshallableForStruct(ty)
if t.slice != nil {
i.emitMarshallableSliceForStruct(ty, t.slice)
}
case *ast.Ident:
i.validatePrimitiveNewtype(ty)
- if t.dynamic {
- abortAt(fset.Position(t.slice.comment.Slash), "Primitive type marked as '+marshal dynamic', but primitive types can not be dynamic.")
- }
i.emitMarshallableForPrimitiveNewtype(ty)
if t.slice != nil {
i.emitMarshallableSliceForPrimitiveNewtype(ty, t.slice)
}
case *ast.ArrayType:
i.validateArrayNewtype(t.spec.Name, ty)
- if t.dynamic {
- abortAt(fset.Position(t.slice.comment.Slash), "Marking array types as `dynamic` is currently not supported.")
- }
// After validate, we can safely call arrayLen.
i.emitMarshallableForArrayNewtype(t.spec.Name, ty, ty.Elt.(*ast.Ident))
if t.slice != nil {
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_dynamic.go b/tools/go_marshal/gomarshal/generator_interfaces_dynamic.go
new file mode 100644
index 000000000..b1a8622cd
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_interfaces_dynamic.go
@@ -0,0 +1,96 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gomarshal
+
+func (g *interfaceGenerator) emitMarshallableForDynamicType() {
+ // The user writes their own MarshalBytes, UnmarshalBytes and SizeBytes for
+ // dynamic types. Generate the rest using these definitions.
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Type %s is dynamic so it might have slice/string headers. Hence, it is not packed.\n", g.typeName())
+ g.emit("return false\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName())
+ g.emit("%s.MarshalBytes(dst)\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fallback to UnmarshalBytes.\n", g.typeName())
+ g.emit("%s.UnmarshalBytes(src)\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n")
+ g.emit("//go:nosplit\n")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("buf := cc.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r)
+ g.emit("%s.MarshalBytes(buf) // escapes: fallback.\n", g.r)
+ g.emit("return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.emit("//go:nosplit\n")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("return %s.CopyOutN(cc, addr, %s.SizeBytes())\n", g.r, g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
+ g.emit("//go:nosplit\n")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("buf := cc.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r)
+ g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n")
+ g.emit("// Unmarshal unconditionally. If we had a short copy-in, this results in a\n")
+ g.emit("// partially unmarshalled struct.\n")
+ g.emit("%s.UnmarshalBytes(buf) // escapes: fallback.\n", g.r)
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
+ g.recordUsedImport("io")
+ g.emit("func (%s *%s) WriteTo(writer io.Writer) (int64, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r)
+ g.emit("%s.MarshalBytes(buf)\n", g.r)
+ g.emit("length, err := writer.Write(buf)\n")
+ g.emit("return int64(length), err\n")
+ })
+ g.emit("}\n\n")
+}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go
index f98e41ed7..5f6306b8f 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces_struct.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go
@@ -69,11 +69,7 @@ func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType
})
}
-func (g *interfaceGenerator) isStructPacked(st *ast.StructType, isDynamic bool) bool {
- if isDynamic {
- // Dynamic types are not packed because a slice header might be present.
- return false
- }
+func (g *interfaceGenerator) isStructPacked(st *ast.StructType) bool {
packed := true
forEachStructField(st, func(f *ast.Field) {
if f.Tag != nil {
@@ -89,17 +85,165 @@ func (g *interfaceGenerator) isStructPacked(st *ast.StructType, isDynamic bool)
return packed
}
-func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType, isDynamic bool) {
- thisPacked := g.isStructPacked(st, isDynamic)
+func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
+ thisPacked := g.isStructPacked(st)
- // Dynamic types are supposed to manually implement SizeBytes, MarshalBytes
- // and UnmarshalBytes. The rest of the methos are autogenerated and depend on
- // the implementation of these three.
- if !isDynamic {
- g.emitSizeBytesForStruct(st)
- g.emitMarshalBytesForStruct(st)
- g.emitUnmarshalBytesForStruct(st)
- }
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ primitiveSize := 0
+ var dynamicSizeTerms []string
+
+ forEachStructField(st, fieldDispatcher{
+ primitive: func(_, t *ast.Ident) {
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ primitiveSize += size
+ } else {
+ g.recordUsedMarshallable(t.Name)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name))
+ }
+ },
+ selector: func(_, tX, tSel *ast.Ident) {
+ tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name)
+ g.recordUsedImport(tX.Name)
+ g.recordUsedMarshallable(tName)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName))
+ },
+ array: func(_ *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
+ lenExpr := g.arrayLenExpr(a)
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%d*%s", size, lenExpr))
+ } else {
+ g.recordUsedMarshallable(t.Name)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%s", t.Name, lenExpr))
+ }
+ },
+ }.dispatch)
+ g.emit("return %d", primitiveSize)
+ if len(dynamicSizeTerms) > 0 {
+ g.incIndent()
+ }
+ {
+ for _, d := range dynamicSizeTerms {
+ g.emitNoIndent(" +\n")
+ g.emit(d)
+ }
+ }
+ if len(dynamicSizeTerms) > 0 {
+ g.decIndent()
+ }
+ })
+ g.emit("\n}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ forEachStructField(st, fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("dst", len)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can reference here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name)
+ }
+ return
+ }
+ g.marshalScalar(g.fieldAccessor(n), t.Name, "dst")
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", tX.Name, tSel.Name)
+ g.emit("dst = dst[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name)
+ return
+ }
+ g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
+ },
+ array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
+ lenExpr := g.arrayLenExpr(a)
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)*%s] ~= [%s]%s{0}\n", t.Name, lenExpr, lenExpr, t.Name)
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ g.emit("dst = dst[%d*(%s):]\n", size, lenExpr)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can reference here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("dst = dst[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr)
+ }
+ return
+ }
+
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
+ g.inIndent(func() {
+ g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst")
+ })
+ g.emit("}\n")
+ },
+ }.dispatch)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ forEachStructField(st, fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("src", len)
+ } else {
+ // We don't have an instance of the dynamic type we can
+ // reference here (since the version in this struct is
+ // anonymous). Use a typed nil pointer to call
+ // SizeBytes() instead.
+ g.shiftDynamic("src", fmt.Sprintf("(*%s)(nil)", t.Name))
+ g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name))
+ }
+ return
+ }
+ g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src")
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: %s ~= src[:sizeof(%s.%s)]\n", g.fieldAccessor(n), tX.Name, tSel.Name)
+ g.emit("src = src[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name)
+ g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s.%s)(nil)", tX.Name, tSel.Name))
+ return
+ }
+ g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
+ },
+ array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
+ lenExpr := g.arrayLenExpr(a)
+ if n.Name == "_" {
+ g.emit("// Padding: ~ copy([%s]%s(%s), src[:sizeof(%s)*%s])\n", lenExpr, t.Name, g.fieldAccessor(n), t.Name, lenExpr)
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ g.emit("src = src[%d*(%s):]\n", size, lenExpr)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can referece here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("src = src[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr)
+ }
+ return
+ }
+
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
+ g.inIndent(func() {
+ g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src")
+ })
+ g.emit("}\n")
+ },
+ }.dispatch)
+ })
+ g.emit("}\n\n")
g.emit("// Packed implements marshal.Marshallable.Packed.\n")
g.emit("//go:nosplit\n")
@@ -284,171 +428,8 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType, isDyn
g.emit("}\n\n")
}
-func (g *interfaceGenerator) emitSizeBytesForStruct(st *ast.StructType) {
- g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
- g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
- g.inIndent(func() {
- primitiveSize := 0
- var dynamicSizeTerms []string
-
- forEachStructField(st, fieldDispatcher{
- primitive: func(_, t *ast.Ident) {
- if size, dynamic := g.scalarSize(t); !dynamic {
- primitiveSize += size
- } else {
- g.recordUsedMarshallable(t.Name)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name))
- }
- },
- selector: func(_, tX, tSel *ast.Ident) {
- tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name)
- g.recordUsedImport(tX.Name)
- g.recordUsedMarshallable(tName)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName))
- },
- array: func(_ *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
- lenExpr := g.arrayLenExpr(a)
- if size, dynamic := g.scalarSize(t); !dynamic {
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%d*%s", size, lenExpr))
- } else {
- g.recordUsedMarshallable(t.Name)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%s", t.Name, lenExpr))
- }
- },
- }.dispatch)
- g.emit("return %d", primitiveSize)
- if len(dynamicSizeTerms) > 0 {
- g.incIndent()
- }
- {
- for _, d := range dynamicSizeTerms {
- g.emitNoIndent(" +\n")
- g.emit(d)
- }
- }
- if len(dynamicSizeTerms) > 0 {
- g.decIndent()
- }
- })
- g.emit("\n}\n\n")
-}
-
-func (g *interfaceGenerator) emitMarshalBytesForStruct(st *ast.StructType) {
- g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
- g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- forEachStructField(st, fieldDispatcher{
- primitive: func(n, t *ast.Ident) {
- if n.Name == "_" {
- g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("dst", len)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can reference here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name)
- }
- return
- }
- g.marshalScalar(g.fieldAccessor(n), t.Name, "dst")
- },
- selector: func(n, tX, tSel *ast.Ident) {
- if n.Name == "_" {
- g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", tX.Name, tSel.Name)
- g.emit("dst = dst[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name)
- return
- }
- g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
- },
- array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
- lenExpr := g.arrayLenExpr(a)
- if n.Name == "_" {
- g.emit("// Padding: dst[:sizeof(%s)*%s] ~= [%s]%s{0}\n", t.Name, lenExpr, lenExpr, t.Name)
- if size, dynamic := g.scalarSize(t); !dynamic {
- g.emit("dst = dst[%d*(%s):]\n", size, lenExpr)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can reference here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("dst = dst[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr)
- }
- return
- }
-
- g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
- g.inIndent(func() {
- g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst")
- })
- g.emit("}\n")
- },
- }.dispatch)
- })
- g.emit("}\n\n")
-}
-
-func (g *interfaceGenerator) emitUnmarshalBytesForStruct(st *ast.StructType) {
- g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
- g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- forEachStructField(st, fieldDispatcher{
- primitive: func(n, t *ast.Ident) {
- if n.Name == "_" {
- g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("src", len)
- } else {
- // We don't have an instance of the dynamic type we can
- // reference here (since the version in this struct is
- // anonymous). Use a typed nil pointer to call
- // SizeBytes() instead.
- g.shiftDynamic("src", fmt.Sprintf("(*%s)(nil)", t.Name))
- g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name))
- }
- return
- }
- g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src")
- },
- selector: func(n, tX, tSel *ast.Ident) {
- if n.Name == "_" {
- g.emit("// Padding: %s ~= src[:sizeof(%s.%s)]\n", g.fieldAccessor(n), tX.Name, tSel.Name)
- g.emit("src = src[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name)
- g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s.%s)(nil)", tX.Name, tSel.Name))
- return
- }
- g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
- },
- array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
- lenExpr := g.arrayLenExpr(a)
- if n.Name == "_" {
- g.emit("// Padding: ~ copy([%s]%s(%s), src[:sizeof(%s)*%s])\n", lenExpr, t.Name, g.fieldAccessor(n), t.Name, lenExpr)
- if size, dynamic := g.scalarSize(t); !dynamic {
- g.emit("src = src[%d*(%s):]\n", size, lenExpr)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can referece here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("src = src[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr)
- }
- return
- }
-
- g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
- g.inIndent(func() {
- g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src")
- })
- g.emit("}\n")
- },
- }.dispatch)
- })
- g.emit("}\n\n")
-}
-
func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, slice *sliceAPI) {
- thisPacked := g.isStructPacked(st, false /* isDynamic */)
+ thisPacked := g.isStructPacked(st)
if slice.inner {
abortAt(g.f.Position(slice.comment.Slash), fmt.Sprintf("The ':inner' argument to '+marshal slice:%s:inner' is only applicable to newtypes on primitives. Remove it from this struct declaration.", slice.ident))
diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD
index cb2d4e6e3..5bceacd32 100644
--- a/tools/go_marshal/test/BUILD
+++ b/tools/go_marshal/test/BUILD
@@ -23,7 +23,10 @@ go_test(
go_library(
name = "test",
testonly = 1,
- srcs = ["test.go"],
+ srcs = [
+ "dynamic.go",
+ "test.go",
+ ],
marshal = True,
visibility = ["//tools/go_marshal/test:__subpackages__"],
deps = [
diff --git a/tools/go_marshal/test/dynamic.go b/tools/go_marshal/test/dynamic.go
new file mode 100644
index 000000000..9a812efe9
--- /dev/null
+++ b/tools/go_marshal/test/dynamic.go
@@ -0,0 +1,83 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package test
+
+import "gvisor.dev/gvisor/pkg/marshal/primitive"
+
+// Type12Dynamic is a dynamically sized struct which depends on the
+// autogenerator to generate some Marshallable methods for it.
+//
+// +marshal dynamic
+type Type12Dynamic struct {
+ X primitive.Int64
+ Y []primitive.Int64
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (t *Type12Dynamic) SizeBytes() int {
+ return (len(t.Y) * 8) + t.X.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (t *Type12Dynamic) MarshalBytes(dst []byte) {
+ t.X.MarshalBytes(dst)
+ dst = dst[t.X.SizeBytes():]
+ for i, x := range t.Y {
+ x.MarshalBytes(dst[i*8 : (i+1)*8])
+ }
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (t *Type12Dynamic) UnmarshalBytes(src []byte) {
+ t.X.UnmarshalBytes(src)
+ if t.Y != nil {
+ t.Y = t.Y[:0]
+ }
+ for i := t.X.SizeBytes(); i < len(src); i += 8 {
+ var x primitive.Int64
+ x.UnmarshalBytes(src[i:])
+ t.Y = append(t.Y, x)
+ }
+}
+
+// Type13Dynamic is a dynamically sized struct which depends on the
+// autogenerator to generate some Marshallable methods for it.
+//
+// It represents a string in memory which is preceded by a uint32 indicating
+// the string size.
+//
+// +marshal dynamic
+type Type13Dynamic string
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (t *Type13Dynamic) SizeBytes() int {
+ return (*primitive.Uint32)(nil).SizeBytes() + len(*t)
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (t *Type13Dynamic) MarshalBytes(dst []byte) {
+ strLen := primitive.Uint32(len(*t))
+ strLen.MarshalBytes(dst)
+ dst = dst[strLen.SizeBytes():]
+ copy(dst[:strLen], *t)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (t *Type13Dynamic) UnmarshalBytes(src []byte) {
+ var strLen primitive.Uint32
+ strLen.UnmarshalBytes(src)
+ src = src[strLen.SizeBytes():]
+ *t = Type13Dynamic(src[:strLen])
+}
diff --git a/tools/go_marshal/test/marshal_test.go b/tools/go_marshal/test/marshal_test.go
index b0091dc64..733689c79 100644
--- a/tools/go_marshal/test/marshal_test.go
+++ b/tools/go_marshal/test/marshal_test.go
@@ -515,20 +515,39 @@ func TestLimitedSliceMarshalling(t *testing.T) {
}
}
-func TestDynamicType(t *testing.T) {
+func TestDynamicTypeStruct(t *testing.T) {
t12 := test.Type12Dynamic{
X: 32,
Y: []primitive.Int64{5, 6, 7},
}
+ var cc mockCopyContext
+ cc.setLimit(t12.SizeBytes())
- var m marshal.Marshallable
- m = &t12 // Ensure that all methods were generated.
- b := make([]byte, m.SizeBytes())
- m.MarshalBytes(b)
+ if _, err := t12.CopyOut(&cc, usermem.Addr(0)); err != nil {
+ t.Fatalf("cc.CopyOut faile: %v", err)
+ }
- var res test.Type12Dynamic
- res.UnmarshalBytes(b)
+ res := test.Type12Dynamic{
+ Y: make([]primitive.Int64, len(t12.Y)),
+ }
+ res.CopyIn(&cc, usermem.Addr(0))
if !reflect.DeepEqual(t12, res) {
t.Errorf("dynamic type is not same after marshalling and unmarshalling: before = %+v, after = %+v", t12, res)
}
}
+
+func TestDynamicTypeIdentifier(t *testing.T) {
+ s := test.Type13Dynamic("go_marshal")
+ var cc mockCopyContext
+ cc.setLimit(s.SizeBytes())
+
+ if _, err := s.CopyOut(&cc, usermem.Addr(0)); err != nil {
+ t.Fatalf("cc.CopyOut faile: %v", err)
+ }
+
+ res := test.Type13Dynamic(make([]byte, len(s)))
+ res.CopyIn(&cc, usermem.Addr(0))
+ if res != s {
+ t.Errorf("dynamic type is not same after marshalling and unmarshalling: before = %s, after = %s", s, res)
+ }
+}
diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go
index b8eb989d9..e7e3ed74a 100644
--- a/tools/go_marshal/test/test.go
+++ b/tools/go_marshal/test/test.go
@@ -16,8 +16,6 @@
package test
import (
- "gvisor.dev/gvisor/pkg/marshal/primitive"
-
// We're intentionally using a package name alias here even though it's not
// necessary to test the code generator's ability to handle package aliases.
ex "gvisor.dev/gvisor/tools/go_marshal/test/external"
@@ -200,36 +198,3 @@ type Type11 struct {
ex.External
y int64
}
-
-// Type12Dynamic is a dynamically sized struct which depends on the autogenerator
-// to generate some Marshallable methods for it.
-//
-// +marshal dynamic
-type Type12Dynamic struct {
- X primitive.Int64
- Y []primitive.Int64
-}
-
-// SizeBytes implements marshal.Marshallable.SizeBytes.
-func (t *Type12Dynamic) SizeBytes() int {
- return (len(t.Y) * 8) + t.X.SizeBytes()
-}
-
-// MarshalBytes implements marshal.Marshallable.MarshalBytes.
-func (t *Type12Dynamic) MarshalBytes(dst []byte) {
- t.X.MarshalBytes(dst)
- dst = dst[t.X.SizeBytes():]
- for i, x := range t.Y {
- x.MarshalBytes(dst[i*8 : (i+1)*8])
- }
-}
-
-// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
-func (t *Type12Dynamic) UnmarshalBytes(src []byte) {
- t.X.UnmarshalBytes(src)
- for i := t.X.SizeBytes(); i < len(src); i += 8 {
- var x primitive.Int64
- x.UnmarshalBytes(src[i:])
- t.Y = append(t.Y, x)
- }
-}
diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go
index e1de12e25..93022f504 100644
--- a/tools/go_stateify/main.go
+++ b/tools/go_stateify/main.go
@@ -403,6 +403,7 @@ func main() {
// on this specific behavior, but the ability to specify slots
// allows a manual implementation to be order-dependent.
if generateSaverLoader {
+ fmt.Fprintf(outputFile, "// +checklocksignore\n")
fmt.Fprintf(outputFile, "func (%s *%s) StateSave(stateSinkObject %sSink) {\n", recv, ts.Name.Name, statePrefix)
fmt.Fprintf(outputFile, " %s.beforeSave()\n", recv)
scanFields(x, "", scanFunctions{zerovalue: emitZeroCheck})
@@ -425,6 +426,7 @@ func main() {
//
// N.B. See the comment above for the save method.
if generateSaverLoader {
+ fmt.Fprintf(outputFile, "// +checklocksignore\n")
fmt.Fprintf(outputFile, "func (%s *%s) StateLoad(stateSourceObject %sSource) {\n", recv, ts.Name.Name, statePrefix)
scanFields(x, "", scanFunctions{normal: emitLoad, wait: emitLoadWait})
scanFields(x, "", scanFunctions{value: emitLoadValue})