diff options
119 files changed, 3649 insertions, 2187 deletions
diff --git a/.buildkite/pipeline.yaml b/.buildkite/pipeline.yaml index aa2fd1f47..3bc5041c0 100644 --- a/.buildkite/pipeline.yaml +++ b/.buildkite/pipeline.yaml @@ -186,10 +186,14 @@ steps: # For fio, running with --test.benchtime=Xs scales the written/read # bytes to several GB. This is not a problem for root/bind/volume mounts, # but for tmpfs mounts, the size can grow to more memory than the machine - # has availabe. Fix the runs to 10GB written/read for the benchmark. + # has availabe. Fix the runs to 1GB written/read for the benchmark. - <<: *benchmarks - label: ":floppy_disk: FIO benchmarks" - command: make benchmark-platforms BENCHMARKS_SUITE=fio BENCHMARKS_TARGETS=test/benchmarks/fs:fio_test BENCHMARKS_OPTIONS=--test.benchtime=10000x + label: ":floppy_disk: FIO benchmarks (read/write)" + command: make benchmark-platforms BENCHMARKS_SUITE=fio BENCHMARKS_TARGETS=test/benchmarks/fs:fio_test BENCHMARKS_FILTER=Fio/operation\.[rw][er] BENCHMARKS_OPTIONS=--test.benchtime=1000x + # For rand(read|write) fio benchmarks, running 15s does not overwhelm the system for tmpfs mounts. + - <<: *benchmarks + label: ":cd: FIO benchmarks (randread/randwrite)" + command: make benchmark-platforms BENCHMARKS_SUITE=fio BENCHMARKS_TARGETS=test/benchmarks/fs:fio_test BENCHMARKS_FILTER=Fio/operation\.rand BENCHMARKS_OPTIONS=--test.benchtime=15s - <<: *benchmarks label: ":globe_with_meridians: HTTPD benchmarks" command: make benchmark-platforms BENCHMARKS_FILTER="Continuous" BENCHMARKS_SUITE=httpd BENCHMARKS_TARGETS=test/benchmarks/network:httpd_test diff --git a/g3doc/user_guide/install.md b/g3doc/user_guide/install.md index ad0ab9923..bcfba0179 100644 --- a/g3doc/user_guide/install.md +++ b/g3doc/user_guide/install.md @@ -59,7 +59,7 @@ Next, the configure the key used to sign archives and the repository: ```bash curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add - -sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main" +sudo add-apt-repository "deb [arch=amd64,arm64] https://storage.googleapis.com/gvisor/releases release main" ``` Now the runsc package can be installed: @@ -96,7 +96,7 @@ You can use this link with the steps described in For `apt` installation, use the `master` to configure the repository: ```bash -sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases master main" +sudo add-apt-repository "deb [arch=amd64,arm64] https://storage.googleapis.com/gvisor/releases master main" ``` ### Nightly @@ -118,7 +118,7 @@ Note that a release may not be available for every day. For `apt` installation, use the `nightly` to configure the repository: ```bash -sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases nightly main" +sudo add-apt-repository "deb [arch=amd64,arm64] https://storage.googleapis.com/gvisor/releases nightly main" ``` ### Latest release @@ -133,7 +133,7 @@ You can use this link with the steps described in For `apt` installation, use the `release` to configure the repository: ```bash -sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main" +sudo add-apt-repository "deb [arch=amd64,arm64] https://storage.googleapis.com/gvisor/releases release main" ``` ### Specific release @@ -152,7 +152,7 @@ For `apt` installation of a specific release, which may include point updates, use the date of the release for repository, e.g. `${yyyymmdd}`. ```bash -sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases yyyymmdd main" +sudo add-apt-repository "deb [arch=amd64,arm64] https://storage.googleapis.com/gvisor/releases yyyymmdd main" ``` > Note: only newer releases may be available as `apt` repositories. diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go index 2235f8968..648cf4b49 100644 --- a/pkg/p9/p9.go +++ b/pkg/p9/p9.go @@ -151,9 +151,16 @@ const ( // Sticky is a mode bit indicating sticky directories. Sticky FileMode = 01000 + // SetGID is the set group ID bit. + SetGID FileMode = 02000 + + // SetUID is the set user ID bit. + SetUID FileMode = 04000 + // permissionsMask is the mask to apply to FileModes for permissions. It - // includes rwx bits for user, group and others, and sticky bit. - permissionsMask FileMode = 01777 + // includes rwx bits for user, group, and others, as well as the sticky + // bit, setuid bit, and setgid bit. + permissionsMask FileMode = 07777 ) // QIDType is the most significant byte of the FileMode word, to be used as the diff --git a/pkg/ring0/BUILD b/pkg/ring0/BUILD index d1b14efdb..885958456 100644 --- a/pkg/ring0/BUILD +++ b/pkg/ring0/BUILD @@ -80,6 +80,7 @@ go_library( "//pkg/ring0/pagetables", "//pkg/safecopy", "//pkg/sentry/arch", + "//pkg/sentry/arch/fpu", "//pkg/usermem", ], ) diff --git a/pkg/ring0/defs.go b/pkg/ring0/defs.go index e8ce608ba..b6e2012e8 100644 --- a/pkg/ring0/defs.go +++ b/pkg/ring0/defs.go @@ -17,6 +17,7 @@ package ring0 import ( "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" ) // Kernel is a global kernel object. @@ -96,7 +97,7 @@ type SwitchOpts struct { // FloatingPointState is a byte pointer where floating point state is // saved and restored. - FloatingPointState arch.FloatingPointData + FloatingPointState *fpu.State // PageTables are the application page tables. PageTables *pagetables.PageTables diff --git a/pkg/ring0/gen_offsets/BUILD b/pkg/ring0/gen_offsets/BUILD index 15b93d61c..f421e1687 100644 --- a/pkg/ring0/gen_offsets/BUILD +++ b/pkg/ring0/gen_offsets/BUILD @@ -35,6 +35,7 @@ go_binary( "//pkg/cpuid", "//pkg/ring0/pagetables", "//pkg/sentry/arch", + "//pkg/sentry/arch/fpu", "//pkg/usermem", ], ) diff --git a/pkg/ring0/kernel_amd64.go b/pkg/ring0/kernel_amd64.go index e9e706716..33c259757 100644 --- a/pkg/ring0/kernel_amd64.go +++ b/pkg/ring0/kernel_amd64.go @@ -239,17 +239,17 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { regs.Ss = uint64(Udata) // Ditto. // Perform the switch. - swapgs() // GS will be swapped on return. - WriteFS(uintptr(regs.Fs_base)) // escapes: no. Set application FS. - WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS. - LoadFloatingPoint(&switchOpts.FloatingPointState[0]) // escapes: no. Copy in floating point. + swapgs() // GS will be swapped on return. + WriteFS(uintptr(regs.Fs_base)) // escapes: no. Set application FS. + WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS. + LoadFloatingPoint(switchOpts.FloatingPointState.BytePointer()) // escapes: no. Copy in floating point. if switchOpts.FullRestore { vector = iret(c, regs, uintptr(userCR3)) } else { vector = sysret(c, regs, uintptr(userCR3)) } - SaveFloatingPoint(&switchOpts.FloatingPointState[0]) // escapes: no. Copy out floating point. - WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS. + SaveFloatingPoint(switchOpts.FloatingPointState.BytePointer()) // escapes: no. Copy out floating point. + WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS. return } diff --git a/pkg/ring0/kernel_arm64.go b/pkg/ring0/kernel_arm64.go index c9a120952..7975e5f92 100644 --- a/pkg/ring0/kernel_arm64.go +++ b/pkg/ring0/kernel_arm64.go @@ -62,7 +62,7 @@ func IsCanonical(addr uint64) bool { //go:nosplit func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { storeAppASID(uintptr(switchOpts.UserASID)) - storeEl0Fpstate(&switchOpts.FloatingPointState[0]) + storeEl0Fpstate(switchOpts.FloatingPointState.BytePointer()) if switchOpts.Flush { FlushTlbByASID(uintptr(switchOpts.UserASID)) @@ -82,7 +82,7 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { fpDisableTrap = CPACREL1() if fpDisableTrap != 0 { - SaveFloatingPoint(&switchOpts.FloatingPointState[0]) + SaveFloatingPoint(switchOpts.FloatingPointState.BytePointer()) } vector = c.vecCode diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD index 85278b389..f660f1614 100644 --- a/pkg/sentry/arch/BUILD +++ b/pkg/sentry/arch/BUILD @@ -9,7 +9,6 @@ go_library( "arch.go", "arch_aarch64.go", "arch_amd64.go", - "arch_amd64.s", "arch_arm64.go", "arch_state_x86.go", "arch_x86.go", @@ -36,8 +35,8 @@ go_library( "//pkg/log", "//pkg/marshal", "//pkg/marshal/primitive", + "//pkg/sentry/arch/fpu", "//pkg/sentry/limits", - "//pkg/sync", "//pkg/syserror", "//pkg/usermem", "@org_golang_x_sys//unix:go_default_library", diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go index 3443b9e1b..921151137 100644 --- a/pkg/sentry/arch/arch.go +++ b/pkg/sentry/arch/arch.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/usermem" ) @@ -50,12 +51,6 @@ func (a Arch) String() string { } } -// FloatingPointData is a generic type, and will always be passed as a pointer. -// We rely on the individual arch implementations to meet all the necessary -// requirements. For example, on x86 the region must be 16-byte aligned and 512 -// bytes in size. -type FloatingPointData []byte - // Context provides architecture-dependent information for a specific thread. // // NOTE(b/34169503): Currently we use uintptr here to refer to a generic native @@ -187,7 +182,7 @@ type Context interface { ClearSingleStep() // FloatingPointData will be passed to underlying save routines. - FloatingPointData() FloatingPointData + FloatingPointData() *fpu.State // NewMmapLayout returns a layout for a new MM, where MinAddr for the // returned layout must be no lower than min, and MaxAddr for the returned @@ -221,16 +216,6 @@ type Context interface { // number of bytes read. PtraceSetRegs(src io.Reader) (int, error) - // PtraceGetFPRegs implements ptrace(PTRACE_GETFPREGS) by writing the - // floating-point registers represented by this Context to addr in dst and - // returning the number of bytes written. - PtraceGetFPRegs(dst io.Writer) (int, error) - - // PtraceSetFPRegs implements ptrace(PTRACE_SETFPREGS) by reading - // floating-point registers from src into this Context and returning the - // number of bytes read. - PtraceSetFPRegs(src io.Reader) (int, error) - // PtraceGetRegSet implements ptrace(PTRACE_GETREGSET) by writing the // register set given by architecture-defined value regset from this // Context to dst and returning the number of bytes written, which must be @@ -365,18 +350,3 @@ func (a SyscallArgument) SizeT() uint { func (a SyscallArgument) ModeT() uint { return uint(uint16(a.Value)) } - -// ErrFloatingPoint indicates a failed restore due to unusable floating point -// state. -type ErrFloatingPoint struct { - // supported is the supported floating point state. - supported uint64 - - // saved is the saved floating point state. - saved uint64 -} - -// Error returns a sensible description of the restore error. -func (e ErrFloatingPoint) Error() string { - return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supported, e.saved) -} diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go index 6b81e9708..08789f517 100644 --- a/pkg/sentry/arch/arch_aarch64.go +++ b/pkg/sentry/arch/arch_aarch64.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" "gvisor.dev/gvisor/pkg/syserror" ) @@ -40,65 +41,11 @@ type Registers struct { const ( // SyscallWidth is the width of insturctions. SyscallWidth = 4 - - // fpsimdMagic is the magic number which is used in fpsimd_context. - fpsimdMagic = 0x46508001 - - // fpsimdContextSize is the size of fpsimd_context. - fpsimdContextSize = 0x210 ) // ARMTrapFlag is the mask for the trap flag. const ARMTrapFlag = uint64(1) << 21 -// aarch64FPState is aarch64 floating point state. -type aarch64FPState []byte - -// initAarch64FPState sets up initial state. -// -// Related code in Linux kernel: fpsimd_flush_thread(). -// FPCR = FPCR_RM_RN (0x0 << 22). -// -// Currently, aarch64FPState is only a space of 0x210 length for fpstate. -// The fp head is useless in sentry/ptrace/kvm. -// -func initAarch64FPState(data aarch64FPState) { -} - -func newAarch64FPStateSlice() []byte { - return alignedBytes(4096, 16)[:fpsimdContextSize] -} - -// newAarch64FPState returns an initialized floating point state. -// -// The returned state is large enough to store all floating point state -// supported by host, even if the app won't use much of it due to a restricted -// FeatureSet. -func newAarch64FPState() aarch64FPState { - f := aarch64FPState(newAarch64FPStateSlice()) - initAarch64FPState(f) - return f -} - -// fork creates and returns an identical copy of the aarch64 floating point state. -func (f aarch64FPState) fork() aarch64FPState { - n := aarch64FPState(newAarch64FPStateSlice()) - copy(n, f) - return n -} - -// FloatingPointData returns the raw data pointer. -func (f aarch64FPState) FloatingPointData() FloatingPointData { - return ([]byte)(f) -} - -// NewFloatingPointData returns a new floating point data blob. -// -// This is primarily for use in tests. -func NewFloatingPointData() FloatingPointData { - return ([]byte)(newAarch64FPState()) -} - // State contains the common architecture bits for aarch64 (the build tag of this // file ensures it's only built on aarch64). // @@ -108,7 +55,7 @@ type State struct { Regs Registers // Our floating point state. - aarch64FPState `state:"wait"` + fpState fpu.State `state:"wait"` // FeatureSet is a pointer to the currently active feature set. FeatureSet *cpuid.FeatureSet @@ -162,10 +109,10 @@ func (s State) Proto() *rpb.Registers { // Fork creates and returns an identical copy of the state. func (s *State) Fork() State { return State{ - Regs: s.Regs, - aarch64FPState: s.aarch64FPState.fork(), - FeatureSet: s.FeatureSet, - OrigR0: s.OrigR0, + Regs: s.Regs, + fpState: s.fpState.Fork(), + FeatureSet: s.FeatureSet, + OrigR0: s.OrigR0, } } @@ -318,10 +265,10 @@ func New(arch Arch, fs *cpuid.FeatureSet) Context { case ARM64: return &context64{ State{ - aarch64FPState: newAarch64FPState(), - FeatureSet: fs, + fpState: fpu.NewState(), + FeatureSet: fs, }, - []aarch64FPState(nil), + []fpu.State(nil), } } panic(fmt.Sprintf("unknown architecture %v", arch)) diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go index 15d8ddb40..2571be60f 100644 --- a/pkg/sentry/arch/arch_amd64.go +++ b/pkg/sentry/arch/arch_amd64.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/usermem" ) @@ -105,7 +106,7 @@ const ( // +stateify savable type context64 struct { State - sigFPState []x86FPState // fpstate to be restored on sigreturn. + sigFPState []fpu.State // fpstate to be restored on sigreturn. } // Arch implements Context.Arch. @@ -113,14 +114,18 @@ func (c *context64) Arch() Arch { return AMD64 } -func (c *context64) copySigFPState() []x86FPState { - var sigfps []x86FPState +func (c *context64) copySigFPState() []fpu.State { + var sigfps []fpu.State for _, s := range c.sigFPState { - sigfps = append(sigfps, s.fork()) + sigfps = append(sigfps, s.Fork()) } return sigfps } +func (c *context64) FloatingPointData() *fpu.State { + return &c.State.fpState +} + // Fork returns an exact copy of this context. func (c *context64) Fork() Context { return &context64{ diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go index 0c61a3ff7..14ad9483b 100644 --- a/pkg/sentry/arch/arch_arm64.go +++ b/pkg/sentry/arch/arch_arm64.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/usermem" ) @@ -79,7 +80,7 @@ const ( // +stateify savable type context64 struct { State - sigFPState []aarch64FPState // fpstate to be restored on sigreturn. + sigFPState []fpu.State // fpstate to be restored on sigreturn. } // Arch implements Context.Arch. @@ -87,10 +88,10 @@ func (c *context64) Arch() Arch { return ARM64 } -func (c *context64) copySigFPState() []aarch64FPState { - var sigfps []aarch64FPState +func (c *context64) copySigFPState() []fpu.State { + var sigfps []fpu.State for _, s := range c.sigFPState { - sigfps = append(sigfps, s.fork()) + sigfps = append(sigfps, s.Fork()) } return sigfps } @@ -286,3 +287,7 @@ func (c *context64) PtracePokeUser(addr, data uintptr) error { // TODO(gvisor.dev/issue/1239): Full ptrace supporting for Arm64. return nil } + +func (c *context64) FloatingPointData() *fpu.State { + return &c.State.fpState +} diff --git a/pkg/sentry/arch/arch_state_x86.go b/pkg/sentry/arch/arch_state_x86.go index 840e53d33..b2b94c304 100644 --- a/pkg/sentry/arch/arch_state_x86.go +++ b/pkg/sentry/arch/arch_state_x86.go @@ -16,59 +16,7 @@ package arch -import ( - "gvisor.dev/gvisor/pkg/cpuid" - "gvisor.dev/gvisor/pkg/usermem" -) - -// XSTATE_BV does not exist if FXSAVE is used, but FXSAVE implicitly saves x87 -// and SSE state, so this is the equivalent XSTATE_BV value. -const fxsaveBV uint64 = cpuid.XSAVEFeatureX87 | cpuid.XSAVEFeatureSSE - // afterLoadFPState is invoked by afterLoad. func (s *State) afterLoadFPState() { - old := s.x86FPState - - // Recreate the slice. This is done to ensure that it is aligned - // appropriately in memory, and large enough to accommodate any new - // state that may be saved by the new CPU. Even if extraneous new state - // is saved, the state we care about is guaranteed to be a subset of - // new state. Later optimizations can use less space when using a - // smaller state component bitmap. Intel SDM Volume 1 Chapter 13 has - // more info. - s.x86FPState = newX86FPState() - - // x86FPState always contains all the FP state supported by the host. - // We may have come from a newer machine that supports additional state - // which we cannot restore. - // - // The x86 FP state areas are backwards compatible, so we can simply - // truncate the additional floating point state. - // - // Applications should not depend on the truncated state because it - // should relate only to features that were not exposed in the app - // FeatureSet. However, because we do not *prevent* them from using - // this state, we must verify here that there is no in-use state - // (according to XSTATE_BV) which we do not support. - if len(s.x86FPState) < len(old) { - // What do we support? - supportedBV := fxsaveBV - if fs := cpuid.HostFeatureSet(); fs.UseXsave() { - supportedBV = fs.ValidXCR0Mask() - } - - // What was in use? - savedBV := fxsaveBV - if len(old) >= xstateBVOffset+8 { - savedBV = usermem.ByteOrder.Uint64(old[xstateBVOffset:]) - } - - // Supported features must be a superset of saved features. - if savedBV&^supportedBV != 0 { - panic(ErrFloatingPoint{supported: supportedBV, saved: savedBV}) - } - } - - // Copy to the new, aligned location. - copy(s.x86FPState, old) + s.fpState.AfterLoad() } diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go index 91edf0703..e8e52d3a8 100644 --- a/pkg/sentry/arch/arch_x86.go +++ b/pkg/sentry/arch/arch_x86.go @@ -24,10 +24,9 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" - "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // Registers represents the CPU registers for this architecture. @@ -111,57 +110,6 @@ var ( X86TrapFlag uint64 = (1 << 8) ) -// x86FPState is x86 floating point state. -type x86FPState []byte - -// initX86FPState (defined in asm files) sets up initial state. -func initX86FPState(data *byte, useXsave bool) - -func newX86FPStateSlice() []byte { - size, align := cpuid.HostFeatureSet().ExtendedStateSize() - capacity := size - // Always use at least 4096 bytes. - // - // For the KVM platform, this state is a fixed 4096 bytes, so make sure - // that the underlying array is at _least_ that size otherwise we will - // corrupt random memory. This is not a pleasant thing to debug. - if capacity < 4096 { - capacity = 4096 - } - return alignedBytes(capacity, align)[:size] -} - -// newX86FPState returns an initialized floating point state. -// -// The returned state is large enough to store all floating point state -// supported by host, even if the app won't use much of it due to a restricted -// FeatureSet. Since they may still be able to see state not advertised by -// CPUID we must ensure it does not contain any sentry state. -func newX86FPState() x86FPState { - f := x86FPState(newX86FPStateSlice()) - initX86FPState(&f.FloatingPointData()[0], cpuid.HostFeatureSet().UseXsave()) - return f -} - -// fork creates and returns an identical copy of the x86 floating point state. -func (f x86FPState) fork() x86FPState { - n := x86FPState(newX86FPStateSlice()) - copy(n, f) - return n -} - -// FloatingPointData returns the raw data pointer. -func (f x86FPState) FloatingPointData() FloatingPointData { - return []byte(f) -} - -// NewFloatingPointData returns a new floating point data blob. -// -// This is primarily for use in tests. -func NewFloatingPointData() FloatingPointData { - return (FloatingPointData)(newX86FPState()) -} - // Proto returns a protobuf representation of the system registers in State. func (s State) Proto() *rpb.Registers { regs := &rpb.AMD64Registers{ @@ -200,7 +148,7 @@ func (s State) Proto() *rpb.Registers { func (s *State) Fork() State { return State{ Regs: s.Regs, - x86FPState: s.x86FPState.fork(), + fpState: s.fpState.Fork(), FeatureSet: s.FeatureSet, } } @@ -393,149 +341,6 @@ func isValidSegmentBase(reg uint64) bool { return reg < uint64(maxAddr64) } -// ptraceFPRegsSize is the size in bytes of Linux's user_i387_struct, the type -// manipulated by PTRACE_GETFPREGS and PTRACE_SETFPREGS on x86. Equivalently, -// ptraceFPRegsSize is the size in bytes of the x86 FXSAVE area. -const ptraceFPRegsSize = 512 - -// PtraceGetFPRegs implements Context.PtraceGetFPRegs. -func (s *State) PtraceGetFPRegs(dst io.Writer) (int, error) { - return dst.Write(s.x86FPState[:ptraceFPRegsSize]) -} - -// PtraceSetFPRegs implements Context.PtraceSetFPRegs. -func (s *State) PtraceSetFPRegs(src io.Reader) (int, error) { - var f [ptraceFPRegsSize]byte - n, err := io.ReadFull(src, f[:]) - if err != nil { - return 0, err - } - // Force reserved bits in MXCSR to 0. This is consistent with Linux. - sanitizeMXCSR(x86FPState(f[:])) - // N.B. this only copies the beginning of the FP state, which - // corresponds to the FXSAVE area. - copy(s.x86FPState, f[:]) - return n, nil -} - -const ( - // mxcsrOffset is the offset in bytes of the MXCSR field from the start of - // the FXSAVE area. (Intel SDM Vol. 1, Table 10-2 "Format of an FXSAVE - // Area") - mxcsrOffset = 24 - - // mxcsrMaskOffset is the offset in bytes of the MXCSR_MASK field from the - // start of the FXSAVE area. - mxcsrMaskOffset = 28 -) - -var ( - mxcsrMask uint32 - initMXCSRMask sync.Once -) - -// sanitizeMXCSR coerces reserved bits in the MXCSR field of f to 0. ("FXRSTOR -// generates a general-protection fault (#GP) in response to an attempt to set -// any of the reserved bits of the MXCSR register." - Intel SDM Vol. 1, Section -// 10.5.1.2 "SSE State") -func sanitizeMXCSR(f x86FPState) { - mxcsr := usermem.ByteOrder.Uint32(f[mxcsrOffset:]) - initMXCSRMask.Do(func() { - temp := x86FPState(alignedBytes(uint(ptraceFPRegsSize), 16)) - initX86FPState(&temp.FloatingPointData()[0], false /* useXsave */) - mxcsrMask = usermem.ByteOrder.Uint32(temp[mxcsrMaskOffset:]) - if mxcsrMask == 0 { - // "If the value of the MXCSR_MASK field is 00000000H, then the - // MXCSR_MASK value is the default value of 0000FFBFH." - Intel SDM - // Vol. 1, Section 11.6.6 "Guidelines for Writing to the MXCSR - // Register" - mxcsrMask = 0xffbf - } - }) - mxcsr &= mxcsrMask - usermem.ByteOrder.PutUint32(f[mxcsrOffset:], mxcsr) -} - -const ( - // minXstateBytes is the minimum size in bytes of an x86 XSAVE area, equal - // to the size of the XSAVE legacy area (512 bytes) plus the size of the - // XSAVE header (64 bytes). Equivalently, minXstateBytes is GDB's - // X86_XSTATE_SSE_SIZE. - minXstateBytes = 512 + 64 - - // userXstateXCR0Offset is the offset in bytes of the USER_XSTATE_XCR0_WORD - // field in Linux's struct user_xstateregs, which is the type manipulated - // by ptrace(PTRACE_GET/SETREGSET, NT_X86_XSTATE). Equivalently, - // userXstateXCR0Offset is GDB's I386_LINUX_XSAVE_XCR0_OFFSET. - userXstateXCR0Offset = 464 - - // xstateBVOffset is the offset in bytes of the XSTATE_BV field in an x86 - // XSAVE area. - xstateBVOffset = 512 - - // xsaveHeaderZeroedOffset and xsaveHeaderZeroedBytes indicate parts of the - // XSAVE header that we coerce to zero: "Bytes 15:8 of the XSAVE header is - // a state-component bitmap called XCOMP_BV. ... Bytes 63:16 of the XSAVE - // header are reserved." - Intel SDM Vol. 1, Section 13.4.2 "XSAVE Header". - // Linux ignores XCOMP_BV, but it's able to recover from XRSTOR #GP - // exceptions resulting from invalid values; we aren't. Linux also never - // uses the compacted format when doing XSAVE and doesn't even define the - // compaction extensions to XSAVE as a CPU feature, so for simplicity we - // assume no one is using them. - xsaveHeaderZeroedOffset = 512 + 8 - xsaveHeaderZeroedBytes = 64 - 8 -) - -func (s *State) ptraceGetXstateRegs(dst io.Writer, maxlen int) (int, error) { - // N.B. s.x86FPState may contain more state than the application - // expects. We only copy the subset that would be in their XSAVE area. - ess, _ := s.FeatureSet.ExtendedStateSize() - f := make([]byte, ess) - copy(f, s.x86FPState) - // "The XSAVE feature set does not use bytes 511:416; bytes 463:416 are - // reserved." - Intel SDM Vol 1., Section 13.4.1 "Legacy Region of an XSAVE - // Area". Linux uses the first 8 bytes of this area to store the OS XSTATE - // mask. GDB relies on this: see - // gdb/x86-linux-nat.c:x86_linux_read_description(). - usermem.ByteOrder.PutUint64(f[userXstateXCR0Offset:], s.FeatureSet.ValidXCR0Mask()) - if len(f) > maxlen { - f = f[:maxlen] - } - return dst.Write(f) -} - -func (s *State) ptraceSetXstateRegs(src io.Reader, maxlen int) (int, error) { - // Allow users to pass an xstate register set smaller than ours (they can - // mask bits out of XSTATE_BV), as long as it's at least minXstateBytes. - // Also allow users to pass a register set larger than ours; anything after - // their ExtendedStateSize will be ignored. (I think Linux technically - // permits setting a register set smaller than minXstateBytes, but it has - // the same silent truncation behavior in kernel/ptrace.c:ptrace_regset().) - if maxlen < minXstateBytes { - return 0, unix.EFAULT - } - ess, _ := s.FeatureSet.ExtendedStateSize() - if maxlen > int(ess) { - maxlen = int(ess) - } - f := make([]byte, maxlen) - if _, err := io.ReadFull(src, f); err != nil { - return 0, err - } - // Force reserved bits in MXCSR to 0. This is consistent with Linux. - sanitizeMXCSR(x86FPState(f)) - // Users can't enable *more* XCR0 bits than what we, and the CPU, support. - xstateBV := usermem.ByteOrder.Uint64(f[xstateBVOffset:]) - xstateBV &= s.FeatureSet.ValidXCR0Mask() - usermem.ByteOrder.PutUint64(f[xstateBVOffset:], xstateBV) - // Force XCOMP_BV and reserved bytes in the XSAVE header to 0. - reserved := f[xsaveHeaderZeroedOffset : xsaveHeaderZeroedOffset+xsaveHeaderZeroedBytes] - for i := range reserved { - reserved[i] = 0 - } - return copy(s.x86FPState, f), nil -} - // Register sets defined in include/uapi/linux/elf.h. const ( _NT_PRSTATUS = 1 @@ -552,12 +357,9 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, } return s.PtraceGetRegs(dst) case _NT_PRFPREG: - if maxlen < ptraceFPRegsSize { - return 0, syserror.EFAULT - } - return s.PtraceGetFPRegs(dst) + return s.fpState.PtraceGetFPRegs(dst, maxlen) case _NT_X86_XSTATE: - return s.ptraceGetXstateRegs(dst, maxlen) + return s.fpState.PtraceGetXstateRegs(dst, maxlen, s.FeatureSet) default: return 0, syserror.EINVAL } @@ -572,12 +374,9 @@ func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, } return s.PtraceSetRegs(src) case _NT_PRFPREG: - if maxlen < ptraceFPRegsSize { - return 0, syserror.EFAULT - } - return s.PtraceSetFPRegs(src) + return s.fpState.PtraceSetFPRegs(src, maxlen) case _NT_X86_XSTATE: - return s.ptraceSetXstateRegs(src, maxlen) + return s.fpState.PtraceSetXstateRegs(src, maxlen, s.FeatureSet) default: return 0, syserror.EINVAL } @@ -609,10 +408,10 @@ func New(arch Arch, fs *cpuid.FeatureSet) Context { case AMD64: return &context64{ State{ - x86FPState: newX86FPState(), + fpState: fpu.NewState(), FeatureSet: fs, }, - []x86FPState(nil), + []fpu.State(nil), } } panic(fmt.Sprintf("unknown architecture %v", arch)) diff --git a/pkg/sentry/arch/arch_x86_impl.go b/pkg/sentry/arch/arch_x86_impl.go index 0c73fcbfb..5d7b99bd9 100644 --- a/pkg/sentry/arch/arch_x86_impl.go +++ b/pkg/sentry/arch/arch_x86_impl.go @@ -18,6 +18,7 @@ package arch import ( "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" ) // State contains the common architecture bits for X86 (the build tag of this @@ -29,7 +30,7 @@ type State struct { Regs Registers // Our floating point state. - x86FPState `state:"wait"` + fpState fpu.State `state:"wait"` // FeatureSet is a pointer to the currently active feature set. FeatureSet *cpuid.FeatureSet diff --git a/pkg/sentry/arch/fpu/BUILD b/pkg/sentry/arch/fpu/BUILD new file mode 100644 index 000000000..0a5395267 --- /dev/null +++ b/pkg/sentry/arch/fpu/BUILD @@ -0,0 +1,21 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "fpu", + srcs = [ + "fpu.go", + "fpu_amd64.go", + "fpu_amd64.s", + "fpu_arm64.go", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/cpuid", + "//pkg/sync", + "//pkg/syserror", + "//pkg/usermem", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/sentry/arch/fpu/fpu.go b/pkg/sentry/arch/fpu/fpu.go new file mode 100644 index 000000000..867d309a3 --- /dev/null +++ b/pkg/sentry/arch/fpu/fpu.go @@ -0,0 +1,54 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package fpu provides basic floating point helpers. +package fpu + +import ( + "fmt" + "reflect" +) + +// State represents floating point state. +// +// This is a simple byte slice, but may have architecture-specific methods +// attached to it. +type State []byte + +// ErrLoadingState indicates a failed restore due to unusable floating point +// state. +type ErrLoadingState struct { + // supported is the supported floating point state. + supportedFeatures uint64 + + // saved is the saved floating point state. + savedFeatures uint64 +} + +// Error returns a sensible description of the restore error. +func (e ErrLoadingState) Error() string { + return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supportedFeatures, e.savedFeatures) +} + +// alignedBytes returns a slice of size bytes, aligned in memory to the given +// alignment. This is used because we require certain structures to be aligned +// in a specific way (for example, the X86 floating point data). +func alignedBytes(size, alignment uint) []byte { + data := make([]byte, size+alignment-1) + offset := uint(reflect.ValueOf(data).Index(0).Addr().Pointer() % uintptr(alignment)) + if offset == 0 { + return data[:size:size] + } + return data[alignment-offset:][:size:size] +} diff --git a/pkg/sentry/arch/fpu/fpu_amd64.go b/pkg/sentry/arch/fpu/fpu_amd64.go new file mode 100644 index 000000000..3a62f51be --- /dev/null +++ b/pkg/sentry/arch/fpu/fpu_amd64.go @@ -0,0 +1,280 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 i386 + +package fpu + +import ( + "io" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// initX86FPState (defined in asm files) sets up initial state. +func initX86FPState(data *byte, useXsave bool) + +func newX86FPStateSlice() State { + size, align := cpuid.HostFeatureSet().ExtendedStateSize() + capacity := size + // Always use at least 4096 bytes. + // + // For the KVM platform, this state is a fixed 4096 bytes, so make sure + // that the underlying array is at _least_ that size otherwise we will + // corrupt random memory. This is not a pleasant thing to debug. + if capacity < 4096 { + capacity = 4096 + } + return alignedBytes(capacity, align)[:size] +} + +// NewState returns an initialized floating point state. +// +// The returned state is large enough to store all floating point state +// supported by host, even if the app won't use much of it due to a restricted +// FeatureSet. Since they may still be able to see state not advertised by +// CPUID we must ensure it does not contain any sentry state. +func NewState() State { + f := newX86FPStateSlice() + initX86FPState(&f[0], cpuid.HostFeatureSet().UseXsave()) + return f +} + +// Fork creates and returns an identical copy of the x86 floating point state. +func (s *State) Fork() State { + n := newX86FPStateSlice() + copy(n, *s) + return n +} + +// ptraceFPRegsSize is the size in bytes of Linux's user_i387_struct, the type +// manipulated by PTRACE_GETFPREGS and PTRACE_SETFPREGS on x86. Equivalently, +// ptraceFPRegsSize is the size in bytes of the x86 FXSAVE area. +const ptraceFPRegsSize = 512 + +// PtraceGetFPRegs implements Context.PtraceGetFPRegs. +func (s *State) PtraceGetFPRegs(dst io.Writer, maxlen int) (int, error) { + if maxlen < ptraceFPRegsSize { + return 0, syserror.EFAULT + } + + return dst.Write((*s)[:ptraceFPRegsSize]) +} + +// PtraceSetFPRegs implements Context.PtraceSetFPRegs. +func (s *State) PtraceSetFPRegs(src io.Reader, maxlen int) (int, error) { + if maxlen < ptraceFPRegsSize { + return 0, syserror.EFAULT + } + + var f [ptraceFPRegsSize]byte + n, err := io.ReadFull(src, f[:]) + if err != nil { + return 0, err + } + // Force reserved bits in MXCSR to 0. This is consistent with Linux. + sanitizeMXCSR(State(f[:])) + // N.B. this only copies the beginning of the FP state, which + // corresponds to the FXSAVE area. + copy(*s, f[:]) + return n, nil +} + +const ( + // mxcsrOffset is the offset in bytes of the MXCSR field from the start of + // the FXSAVE area. (Intel SDM Vol. 1, Table 10-2 "Format of an FXSAVE + // Area") + mxcsrOffset = 24 + + // mxcsrMaskOffset is the offset in bytes of the MXCSR_MASK field from the + // start of the FXSAVE area. + mxcsrMaskOffset = 28 +) + +var ( + mxcsrMask uint32 + initMXCSRMask sync.Once +) + +const ( + // minXstateBytes is the minimum size in bytes of an x86 XSAVE area, equal + // to the size of the XSAVE legacy area (512 bytes) plus the size of the + // XSAVE header (64 bytes). Equivalently, minXstateBytes is GDB's + // X86_XSTATE_SSE_SIZE. + minXstateBytes = 512 + 64 + + // userXstateXCR0Offset is the offset in bytes of the USER_XSTATE_XCR0_WORD + // field in Linux's struct user_xstateregs, which is the type manipulated + // by ptrace(PTRACE_GET/SETREGSET, NT_X86_XSTATE). Equivalently, + // userXstateXCR0Offset is GDB's I386_LINUX_XSAVE_XCR0_OFFSET. + userXstateXCR0Offset = 464 + + // xstateBVOffset is the offset in bytes of the XSTATE_BV field in an x86 + // XSAVE area. + xstateBVOffset = 512 + + // xsaveHeaderZeroedOffset and xsaveHeaderZeroedBytes indicate parts of the + // XSAVE header that we coerce to zero: "Bytes 15:8 of the XSAVE header is + // a state-component bitmap called XCOMP_BV. ... Bytes 63:16 of the XSAVE + // header are reserved." - Intel SDM Vol. 1, Section 13.4.2 "XSAVE Header". + // Linux ignores XCOMP_BV, but it's able to recover from XRSTOR #GP + // exceptions resulting from invalid values; we aren't. Linux also never + // uses the compacted format when doing XSAVE and doesn't even define the + // compaction extensions to XSAVE as a CPU feature, so for simplicity we + // assume no one is using them. + xsaveHeaderZeroedOffset = 512 + 8 + xsaveHeaderZeroedBytes = 64 - 8 +) + +// sanitizeMXCSR coerces reserved bits in the MXCSR field of f to 0. ("FXRSTOR +// generates a general-protection fault (#GP) in response to an attempt to set +// any of the reserved bits of the MXCSR register." - Intel SDM Vol. 1, Section +// 10.5.1.2 "SSE State") +func sanitizeMXCSR(f State) { + mxcsr := usermem.ByteOrder.Uint32(f[mxcsrOffset:]) + initMXCSRMask.Do(func() { + temp := State(alignedBytes(uint(ptraceFPRegsSize), 16)) + initX86FPState(&temp[0], false /* useXsave */) + mxcsrMask = usermem.ByteOrder.Uint32(temp[mxcsrMaskOffset:]) + if mxcsrMask == 0 { + // "If the value of the MXCSR_MASK field is 00000000H, then the + // MXCSR_MASK value is the default value of 0000FFBFH." - Intel SDM + // Vol. 1, Section 11.6.6 "Guidelines for Writing to the MXCSR + // Register" + mxcsrMask = 0xffbf + } + }) + mxcsr &= mxcsrMask + usermem.ByteOrder.PutUint32(f[mxcsrOffset:], mxcsr) +} + +// PtraceGetXstateRegs implements ptrace(PTRACE_GETREGS, NT_X86_XSTATE) by +// writing the floating point registers from this state to dst and returning the +// number of bytes written, which must be less than or equal to maxlen. +func (s *State) PtraceGetXstateRegs(dst io.Writer, maxlen int, featureSet *cpuid.FeatureSet) (int, error) { + // N.B. s.x86FPState may contain more state than the application + // expects. We only copy the subset that would be in their XSAVE area. + ess, _ := featureSet.ExtendedStateSize() + f := make([]byte, ess) + copy(f, *s) + // "The XSAVE feature set does not use bytes 511:416; bytes 463:416 are + // reserved." - Intel SDM Vol 1., Section 13.4.1 "Legacy Region of an XSAVE + // Area". Linux uses the first 8 bytes of this area to store the OS XSTATE + // mask. GDB relies on this: see + // gdb/x86-linux-nat.c:x86_linux_read_description(). + usermem.ByteOrder.PutUint64(f[userXstateXCR0Offset:], featureSet.ValidXCR0Mask()) + if len(f) > maxlen { + f = f[:maxlen] + } + return dst.Write(f) +} + +// PtraceSetXstateRegs implements ptrace(PTRACE_SETREGS, NT_X86_XSTATE) by +// reading floating point registers from src and returning the number of bytes +// read, which must be less than or equal to maxlen. +func (s *State) PtraceSetXstateRegs(src io.Reader, maxlen int, featureSet *cpuid.FeatureSet) (int, error) { + // Allow users to pass an xstate register set smaller than ours (they can + // mask bits out of XSTATE_BV), as long as it's at least minXstateBytes. + // Also allow users to pass a register set larger than ours; anything after + // their ExtendedStateSize will be ignored. (I think Linux technically + // permits setting a register set smaller than minXstateBytes, but it has + // the same silent truncation behavior in kernel/ptrace.c:ptrace_regset().) + if maxlen < minXstateBytes { + return 0, unix.EFAULT + } + ess, _ := featureSet.ExtendedStateSize() + if maxlen > int(ess) { + maxlen = int(ess) + } + f := make([]byte, maxlen) + if _, err := io.ReadFull(src, f); err != nil { + return 0, err + } + // Force reserved bits in MXCSR to 0. This is consistent with Linux. + sanitizeMXCSR(State(f)) + // Users can't enable *more* XCR0 bits than what we, and the CPU, support. + xstateBV := usermem.ByteOrder.Uint64(f[xstateBVOffset:]) + xstateBV &= featureSet.ValidXCR0Mask() + usermem.ByteOrder.PutUint64(f[xstateBVOffset:], xstateBV) + // Force XCOMP_BV and reserved bytes in the XSAVE header to 0. + reserved := f[xsaveHeaderZeroedOffset : xsaveHeaderZeroedOffset+xsaveHeaderZeroedBytes] + for i := range reserved { + reserved[i] = 0 + } + return copy(*s, f), nil +} + +// BytePointer returns a pointer to the first byte of the state. +// +//go:nosplit +func (s *State) BytePointer() *byte { + return &(*s)[0] +} + +// XSTATE_BV does not exist if FXSAVE is used, but FXSAVE implicitly saves x87 +// and SSE state, so this is the equivalent XSTATE_BV value. +const fxsaveBV uint64 = cpuid.XSAVEFeatureX87 | cpuid.XSAVEFeatureSSE + +// AfterLoad converts the loaded state to the format that compatible with the +// current processor. +func (s *State) AfterLoad() { + old := *s + + // Recreate the slice. This is done to ensure that it is aligned + // appropriately in memory, and large enough to accommodate any new + // state that may be saved by the new CPU. Even if extraneous new state + // is saved, the state we care about is guaranteed to be a subset of + // new state. Later optimizations can use less space when using a + // smaller state component bitmap. Intel SDM Volume 1 Chapter 13 has + // more info. + *s = NewState() + + // x86FPState always contains all the FP state supported by the host. + // We may have come from a newer machine that supports additional state + // which we cannot restore. + // + // The x86 FP state areas are backwards compatible, so we can simply + // truncate the additional floating point state. + // + // Applications should not depend on the truncated state because it + // should relate only to features that were not exposed in the app + // FeatureSet. However, because we do not *prevent* them from using + // this state, we must verify here that there is no in-use state + // (according to XSTATE_BV) which we do not support. + if len(*s) < len(old) { + // What do we support? + supportedBV := fxsaveBV + if fs := cpuid.HostFeatureSet(); fs.UseXsave() { + supportedBV = fs.ValidXCR0Mask() + } + + // What was in use? + savedBV := fxsaveBV + if len(old) >= xstateBVOffset+8 { + savedBV = usermem.ByteOrder.Uint64(old[xstateBVOffset:]) + } + + // Supported features must be a superset of saved features. + if savedBV&^supportedBV != 0 { + panic(ErrLoadingState{supportedFeatures: supportedBV, savedFeatures: savedBV}) + } + } + + // Copy to the new, aligned location. + copy(*s, old) +} diff --git a/pkg/sentry/arch/arch_amd64.s b/pkg/sentry/arch/fpu/fpu_amd64.s index 6c10336e7..6c10336e7 100644 --- a/pkg/sentry/arch/arch_amd64.s +++ b/pkg/sentry/arch/fpu/fpu_amd64.s diff --git a/pkg/sentry/arch/fpu/fpu_arm64.go b/pkg/sentry/arch/fpu/fpu_arm64.go new file mode 100644 index 000000000..d2f62631d --- /dev/null +++ b/pkg/sentry/arch/fpu/fpu_arm64.go @@ -0,0 +1,63 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package fpu + +const ( + // fpsimdMagic is the magic number which is used in fpsimd_context. + fpsimdMagic = 0x46508001 + + // fpsimdContextSize is the size of fpsimd_context. + fpsimdContextSize = 0x210 +) + +// initAarch64FPState sets up initial state. +// +// Related code in Linux kernel: fpsimd_flush_thread(). +// FPCR = FPCR_RM_RN (0x0 << 22). +// +// Currently, aarch64FPState is only a space of 0x210 length for fpstate. +// The fp head is useless in sentry/ptrace/kvm. +// +func initAarch64FPState(data *State) { +} + +func newAarch64FPStateSlice() []byte { + return alignedBytes(4096, 16)[:fpsimdContextSize] +} + +// NewState returns an initialized floating point state. +// +// The returned state is large enough to store all floating point state +// supported by host, even if the app won't use much of it due to a restricted +// FeatureSet. +func NewState() State { + f := State(newAarch64FPStateSlice()) + initAarch64FPState(&f) + return f +} + +// Fork creates and returns an identical copy of the aarch64 floating point state. +func (s *State) Fork() State { + n := State(newAarch64FPStateSlice()) + copy(n, *s) + return n +} + +// BytePointer returns a pointer to the first byte of the state. +func (s *State) BytePointer() *byte { + return &(*s)[0] +} diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go index e6557cab6..ee3743483 100644 --- a/pkg/sentry/arch/signal_amd64.go +++ b/pkg/sentry/arch/signal_amd64.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/marshal/primitive" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/usermem" ) @@ -98,7 +99,7 @@ func (c *context64) NewSignalStack() NativeSignalStack { const _FP_XSTATE_MAGIC2_SIZE = 4 func (c *context64) fpuFrameSize() (size int, useXsave bool) { - size = len(c.x86FPState) + size = len(c.fpState) if size > 512 { // Make room for the magic cookie at the end of the xsave frame. size += _FP_XSTATE_MAGIC2_SIZE @@ -226,10 +227,10 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt c.Regs.Ss = userDS // Save the thread's floating point state. - c.sigFPState = append(c.sigFPState, c.x86FPState) + c.sigFPState = append(c.sigFPState, c.fpState) // Signal handler gets a clean floating point state. - c.x86FPState = newX86FPState() + c.fpState = fpu.NewState() return nil } @@ -273,7 +274,7 @@ func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalSt // Restore floating point state. l := len(c.sigFPState) if l > 0 { - c.x86FPState = c.sigFPState[l-1] + c.fpState = c.sigFPState[l-1] // NOTE(cl/133042258): State save requires that any slice // elements from '[len:cap]' to be zero value. c.sigFPState[l-1] = nil diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go index 4491008c2..53281dcba 100644 --- a/pkg/sentry/arch/signal_arm64.go +++ b/pkg/sentry/arch/signal_arm64.go @@ -20,6 +20,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/usermem" ) @@ -139,9 +140,9 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt c.Regs.Regs[30] = uint64(act.Restorer) // Save the thread's floating point state. - c.sigFPState = append(c.sigFPState, c.aarch64FPState) + c.sigFPState = append(c.sigFPState, c.fpState) // Signal handler gets a clean floating point state. - c.aarch64FPState = newAarch64FPState() + c.fpState = fpu.NewState() return nil } @@ -166,7 +167,7 @@ func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalSt // Restore floating point state. l := len(c.sigFPState) if l > 0 { - c.aarch64FPState = c.sigFPState[l-1] + c.fpState = c.sigFPState[l-1] // NOTE(cl/133042258): State save requires that any slice // elements from '[len:cap]' to be zero value. c.sigFPState[l-1] = nil diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go index 82eda3e43..0ed7aafa5 100644 --- a/pkg/sentry/fs/fsutil/inode_cached.go +++ b/pkg/sentry/fs/fsutil/inode_cached.go @@ -380,16 +380,17 @@ func (c *CachingInodeOperations) Allocate(ctx context.Context, offset, length in return nil } -// WriteOut implements fs.InodeOperations.WriteOut. -func (c *CachingInodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) error { +// WriteDirtyPagesAndAttrs will write the dirty pages and attributes to the +// gofer without calling Fsync on the remote file. +func (c *CachingInodeOperations) WriteDirtyPagesAndAttrs(ctx context.Context, inode *fs.Inode) error { c.attrMu.Lock() + defer c.attrMu.Unlock() + c.dataMu.Lock() + defer c.dataMu.Unlock() // Write dirty pages back. - c.dataMu.Lock() err := SyncDirtyAll(ctx, &c.cache, &c.dirty, uint64(c.attr.Size), c.mfp.MemoryFile(), c.backingFile.WriteFromBlocksAt) - c.dataMu.Unlock() if err != nil { - c.attrMu.Unlock() return err } @@ -399,12 +400,18 @@ func (c *CachingInodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) // Write out cached attributes. if err := c.backingFile.SetMaskedAttributes(ctx, c.dirtyAttr, c.attr, false); err != nil { - c.attrMu.Unlock() return err } c.dirtyAttr = fs.AttrMask{} - c.attrMu.Unlock() + return nil +} + +// WriteOut implements fs.InodeOperations.WriteOut. +func (c *CachingInodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) error { + if err := c.WriteDirtyPagesAndAttrs(ctx, inode); err != nil { + return err + } // Fsync the remote file. return c.backingFile.Sync(ctx) diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go index 06d450ba6..8f5a87120 100644 --- a/pkg/sentry/fs/gofer/file.go +++ b/pkg/sentry/fs/gofer/file.go @@ -204,20 +204,8 @@ func (f *fileOperations) readdirAll(ctx context.Context) (map[string]fs.DentAttr return entries, nil } -// maybeSync will call FSync on the file if either the cache policy or file -// flags require it. +// maybeSync will call FSync on the file if the file flags require it. func (f *fileOperations) maybeSync(ctx context.Context, file *fs.File, offset, n int64) error { - if n == 0 { - // Nothing to sync. - return nil - } - - if f.inodeOperations.session().cachePolicy.writeThrough(file.Dirent.Inode) { - // Call WriteOut directly, as some "writethrough" filesystems - // do not support sync. - return f.inodeOperations.cachingInodeOps.WriteOut(ctx, file.Dirent.Inode) - } - flags := file.Flags() var syncType fs.SyncType switch { @@ -254,6 +242,19 @@ func (f *fileOperations) Write(ctx context.Context, file *fs.File, src usermem.I n, err = src.CopyInTo(ctx, f.handles.readWriterAt(ctx, offset)) } + if n == 0 { + // Nothing written. We are done. + return 0, err + } + + // Write the dirty pages and attributes if cache policy tells us to. + if f.inodeOperations.session().cachePolicy.writeThrough(file.Dirent.Inode) { + if werr := f.inodeOperations.cachingInodeOps.WriteDirtyPagesAndAttrs(ctx, file.Dirent.Inode); werr != nil { + // Report no bytes written since the write faild. + return 0, werr + } + } + // We may need to sync the written bytes. if syncErr := f.maybeSync(ctx, file, offset, n); syncErr != nil { // Sync failed. Report 0 bytes written, since none of them are diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index c34451269..43c3c5a2d 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -783,7 +783,15 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { creds := rp.Credentials() return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string, _ **[]*dentry) error { - if _, err := parent.file.mkdir(ctx, name, (p9.FileMode)(opts.Mode), (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)); err != nil { + // If the parent is a setgid directory, use the parent's GID + // rather than the caller's and enable setgid. + kgid := creds.EffectiveKGID + mode := opts.Mode + if atomic.LoadUint32(&parent.mode)&linux.S_ISGID != 0 { + kgid = auth.KGID(atomic.LoadUint32(&parent.gid)) + mode |= linux.S_ISGID + } + if _, err := parent.file.mkdir(ctx, name, p9.FileMode(mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid)); err != nil { if !opts.ForSyntheticMountpoint || err == syserror.EEXIST { return err } @@ -1145,7 +1153,15 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving name := rp.Component() // We only want the access mode for creating the file. createFlags := p9.OpenFlags(opts.Flags) & p9.OpenFlagsModeMask - fdobj, openFile, createQID, _, err := dirfile.create(ctx, name, createFlags, (p9.FileMode)(opts.Mode), (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)) + + // If the parent is a setgid directory, use the parent's GID rather + // than the caller's. + kgid := creds.EffectiveKGID + if atomic.LoadUint32(&d.mode)&linux.S_ISGID != 0 { + kgid = auth.KGID(atomic.LoadUint32(&d.gid)) + } + + fdobj, openFile, createQID, _, err := dirfile.create(ctx, name, createFlags, p9.FileMode(opts.Mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid)) if err != nil { dirfile.close(ctx) return nil, err diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 71569dc65..692da02c1 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -1102,10 +1102,26 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs d.metadataMu.Lock() defer d.metadataMu.Unlock() + + // As with Linux, if the UID, GID, or file size is changing, we have to + // clear permission bits. Note that when set, clearSGID causes + // permissions to be updated, but does not modify stat.Mask, as + // modification would cause an extra inotify flag to be set. + clearSGID := stat.Mask&linux.STATX_UID != 0 && stat.UID != atomic.LoadUint32(&d.uid) || + stat.Mask&linux.STATX_GID != 0 && stat.GID != atomic.LoadUint32(&d.gid) || + stat.Mask&linux.STATX_SIZE != 0 + if clearSGID { + if stat.Mask&linux.STATX_MODE != 0 { + stat.Mode = uint16(vfs.ClearSUIDAndSGID(uint32(stat.Mode))) + } else { + stat.Mode = uint16(vfs.ClearSUIDAndSGID(atomic.LoadUint32(&d.mode))) + } + } + if !d.isSynthetic() { if stat.Mask != 0 { if err := d.file.setAttr(ctx, p9.SetAttrMask{ - Permissions: stat.Mask&linux.STATX_MODE != 0, + Permissions: stat.Mask&linux.STATX_MODE != 0 || clearSGID, UID: stat.Mask&linux.STATX_UID != 0, GID: stat.Mask&linux.STATX_GID != 0, Size: stat.Mask&linux.STATX_SIZE != 0, @@ -1140,7 +1156,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs return nil } } - if stat.Mask&linux.STATX_MODE != 0 { + if stat.Mask&linux.STATX_MODE != 0 || clearSGID { atomic.StoreUint32(&d.mode, d.fileType()|uint32(stat.Mode)) } if stat.Mask&linux.STATX_UID != 0 { diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index 283b220bb..4f1ad0c88 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -266,6 +266,20 @@ func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off return 0, offset, err } } + + // As with Linux, writing clears the setuid and setgid bits. + if n > 0 { + oldMode := atomic.LoadUint32(&d.mode) + // If setuid or setgid were set, update d.mode and propagate + // changes to the host. + if newMode := vfs.ClearSUIDAndSGID(oldMode); newMode != oldMode { + atomic.StoreUint32(&d.mode, newMode) + if err := d.file.setAttr(ctx, p9.SetAttrMask{Permissions: true}, p9.SetAttr{Permissions: p9.FileMode(newMode)}); err != nil { + return 0, offset, err + } + } + } + return n, offset + n, nil } diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index 84e37f793..46c500427 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -689,13 +689,9 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v } return err } - creds := rp.Credentials() + if err := vfsObj.SetStatAt(ctx, fs.creds, &pop, &vfs.SetStatOptions{ - Stat: linux.Statx{ - Mask: linux.STATX_UID | linux.STATX_GID, - UID: uint32(creds.EffectiveKUID), - GID: uint32(creds.EffectiveKGID), - }, + Stat: parent.newChildOwnerStat(opts.Mode, rp.Credentials()), }); err != nil { if cleanupErr := vfsObj.RmdirAt(ctx, fs.creds, &pop); cleanupErr != nil { panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer directory after MkdirAt metadata update failure: %v", cleanupErr)) @@ -750,11 +746,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v } creds := rp.Credentials() if err := vfsObj.SetStatAt(ctx, fs.creds, &pop, &vfs.SetStatOptions{ - Stat: linux.Statx{ - Mask: linux.STATX_UID | linux.STATX_GID, - UID: uint32(creds.EffectiveKUID), - GID: uint32(creds.EffectiveKGID), - }, + Stat: parent.newChildOwnerStat(opts.Mode, creds), }); err != nil { if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil { panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after MknodAt metadata update failure: %v", cleanupErr)) @@ -963,14 +955,11 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving } return nil, err } + // Change the file's owner to the caller. We can't use upperFD.SetStat() // because it will pick up creds from ctx. if err := vfsObj.SetStatAt(ctx, fs.creds, &pop, &vfs.SetStatOptions{ - Stat: linux.Statx{ - Mask: linux.STATX_UID | linux.STATX_GID, - UID: uint32(creds.EffectiveKUID), - GID: uint32(creds.EffectiveKGID), - }, + Stat: parent.newChildOwnerStat(opts.Mode, creds), }); err != nil { if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil { panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) metadata update failure: %v", cleanupErr)) diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go index 58680bc80..454c20d4f 100644 --- a/pkg/sentry/fsimpl/overlay/overlay.go +++ b/pkg/sentry/fsimpl/overlay/overlay.go @@ -749,6 +749,27 @@ func (d *dentry) mayDelete(creds *auth.Credentials, child *dentry) error { ) } +// newChildOwnerStat returns a Statx for configuring the UID, GID, and mode of +// children. +func (d *dentry) newChildOwnerStat(mode linux.FileMode, creds *auth.Credentials) linux.Statx { + stat := linux.Statx{ + Mask: uint32(linux.STATX_UID | linux.STATX_GID), + UID: uint32(creds.EffectiveKUID), + GID: uint32(creds.EffectiveKGID), + } + // Set GID and possibly the SGID bit if the parent is an SGID directory. + d.copyMu.RLock() + defer d.copyMu.RUnlock() + if atomic.LoadUint32(&d.mode)&linux.ModeSetGID == linux.ModeSetGID { + stat.GID = atomic.LoadUint32(&d.gid) + if stat.Mode&linux.ModeDirectory == linux.ModeDirectory { + stat.Mode = uint16(mode) | linux.ModeSetGID + stat.Mask |= linux.STATX_MODE + } + } + return stat +} + // fileDescription is embedded by overlay implementations of // vfs.FileDescriptionImpl. // diff --git a/pkg/sentry/fsimpl/overlay/regular_file.go b/pkg/sentry/fsimpl/overlay/regular_file.go index 25c785fd4..d791c06db 100644 --- a/pkg/sentry/fsimpl/overlay/regular_file.go +++ b/pkg/sentry/fsimpl/overlay/regular_file.go @@ -205,6 +205,20 @@ func (fd *regularFileFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) e if err := wrappedFD.SetStat(ctx, opts); err != nil { return err } + + // Changing owners may clear one or both of the setuid and setgid bits, + // so we may have to update opts before setting d.mode. + if opts.Stat.Mask&(linux.STATX_UID|linux.STATX_GID) != 0 { + stat, err := wrappedFD.Stat(ctx, vfs.StatOptions{ + Mask: linux.STATX_MODE, + }) + if err != nil { + return err + } + opts.Stat.Mode = stat.Mode + opts.Stat.Mask |= linux.STATX_MODE + } + d.updateAfterSetStatLocked(&opts) if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent) @@ -295,7 +309,11 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off return 0, err } defer wrappedFD.DecRef(ctx) - return wrappedFD.PWrite(ctx, src, offset, opts) + n, err := wrappedFD.PWrite(ctx, src, offset, opts) + if err != nil { + return n, err + } + return fd.updateSetUserGroupIDs(ctx, wrappedFD, n) } // Write implements vfs.FileDescriptionImpl.Write. @@ -307,7 +325,28 @@ func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts if err != nil { return 0, err } - return wrappedFD.Write(ctx, src, opts) + n, err := wrappedFD.Write(ctx, src, opts) + if err != nil { + return n, err + } + return fd.updateSetUserGroupIDs(ctx, wrappedFD, n) +} + +func (fd *regularFileFD) updateSetUserGroupIDs(ctx context.Context, wrappedFD *vfs.FileDescription, written int64) (int64, error) { + // Writing can clear the setuid and/or setgid bits. We only have to + // check this if something was written and one of those bits was set. + dentry := fd.dentry() + if written == 0 || atomic.LoadUint32(&dentry.mode)&(linux.S_ISUID|linux.S_ISGID) == 0 { + return written, nil + } + stat, err := wrappedFD.Stat(ctx, vfs.StatOptions{Mask: linux.STATX_MODE}) + if err != nil { + return written, err + } + dentry.copyMu.Lock() + defer dentry.copyMu.Unlock() + atomic.StoreUint32(&dentry.mode, uint32(stat.Mode)) + return written, nil } // Seek implements vfs.FileDescriptionImpl.Seek. diff --git a/pkg/sentry/kernel/ptrace_amd64.go b/pkg/sentry/kernel/ptrace_amd64.go index 609ad3941..7aea3dcd8 100644 --- a/pkg/sentry/kernel/ptrace_amd64.go +++ b/pkg/sentry/kernel/ptrace_amd64.go @@ -51,14 +51,15 @@ func (t *Task) ptraceArch(target *Task, req int64, addr, data usermem.Addr) erro return err case linux.PTRACE_GETFPREGS: - _, err := target.Arch().PtraceGetFPRegs(&usermem.IOReadWriter{ + s := target.Arch().FloatingPointData() + _, err := target.Arch().FloatingPointData().PtraceGetFPRegs(&usermem.IOReadWriter{ Ctx: t, IO: t.MemoryManager(), Addr: data, Opts: usermem.IOOpts{ AddressSpaceActive: true, }, - }) + }, len(*s)) return err case linux.PTRACE_SETREGS: @@ -73,14 +74,15 @@ func (t *Task) ptraceArch(target *Task, req int64, addr, data usermem.Addr) erro return err case linux.PTRACE_SETFPREGS: - _, err := target.Arch().PtraceSetFPRegs(&usermem.IOReadWriter{ + s := target.Arch().FloatingPointData() + _, err := s.PtraceSetFPRegs(&usermem.IOReadWriter{ Ctx: t, IO: t.MemoryManager(), Addr: data, Opts: usermem.IOOpts{ AddressSpaceActive: true, }, - }) + }, len(*s)) return err default: diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 4f9e781af..03a76eb9b 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -50,6 +50,7 @@ go_library( "//pkg/safecopy", "//pkg/seccomp", "//pkg/sentry/arch", + "//pkg/sentry/arch/fpu", "//pkg/sentry/memmap", "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", @@ -78,6 +79,7 @@ go_test( "//pkg/ring0", "//pkg/ring0/pagetables", "//pkg/sentry/arch", + "//pkg/sentry/arch/fpu", "//pkg/sentry/platform", "//pkg/sentry/platform/kvm/testutil", "//pkg/sentry/time", diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.go b/pkg/sentry/platform/kvm/bluepill_amd64.go index 308696efe..d761bbdee 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64.go @@ -73,7 +73,7 @@ func (c *vCPU) KernelSyscall() { // We only trigger a bluepill entry in the bluepill function, and can // therefore be guaranteed that there is no floating point state to be // loaded on resuming from halt. We only worry about saving on exit. - ring0.SaveFloatingPoint(&c.floatingPointState[0]) // escapes: no. + ring0.SaveFloatingPoint(c.floatingPointState.BytePointer()) // escapes: no. ring0.Halt() ring0.WriteFS(uintptr(regs.Fs_base)) // escapes: no, reload host segment. } @@ -92,7 +92,7 @@ func (c *vCPU) KernelException(vector ring0.Vector) { regs.Rip = 0 } // See above. - ring0.SaveFloatingPoint(&c.floatingPointState[0]) // escapes: no. + ring0.SaveFloatingPoint(c.floatingPointState.BytePointer()) // escapes: no. ring0.Halt() ring0.WriteFS(uintptr(regs.Fs_base)) // escapes: no; reload host segment. } @@ -124,5 +124,5 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) { // Set the context pointer to the saved floating point state. This is // where the guest data has been serialized, the kernel will restore // from this new pointer value. - context.Fpstate = uint64(uintptrValue(&c.floatingPointState[0])) + context.Fpstate = uint64(uintptrValue(c.floatingPointState.BytePointer())) } diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go index c317f1e99..578852c3f 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64.go @@ -92,7 +92,7 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) { lazyVfp := c.GetLazyVFP() if lazyVfp != 0 { - fpsimd := fpsimdPtr(&c.floatingPointState[0]) + fpsimd := fpsimdPtr(c.floatingPointState.BytePointer()) context.Fpsimd64.Fpsr = fpsimd.Fpsr context.Fpsimd64.Fpcr = fpsimd.Fpcr context.Fpsimd64.Vregs = fpsimd.Vregs @@ -112,12 +112,12 @@ func (c *vCPU) KernelSyscall() { fpDisableTrap := ring0.CPACREL1() if fpDisableTrap != 0 { - fpsimd := fpsimdPtr(&c.floatingPointState[0]) + fpsimd := fpsimdPtr(c.floatingPointState.BytePointer()) fpcr := ring0.GetFPCR() fpsr := ring0.GetFPSR() fpsimd.Fpcr = uint32(fpcr) fpsimd.Fpsr = uint32(fpsr) - ring0.SaveVRegs(&c.floatingPointState[0]) + ring0.SaveVRegs(c.floatingPointState.BytePointer()) } ring0.Halt() @@ -136,12 +136,12 @@ func (c *vCPU) KernelException(vector ring0.Vector) { fpDisableTrap := ring0.CPACREL1() if fpDisableTrap != 0 { - fpsimd := fpsimdPtr(&c.floatingPointState[0]) + fpsimd := fpsimdPtr(c.floatingPointState.BytePointer()) fpcr := ring0.GetFPCR() fpsr := ring0.GetFPSR() fpsimd.Fpcr = uint32(fpcr) fpsimd.Fpsr = uint32(fpsr) - ring0.SaveVRegs(&c.floatingPointState[0]) + ring0.SaveVRegs(c.floatingPointState.BytePointer()) } ring0.Halt() diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.go b/pkg/sentry/platform/kvm/kvm_amd64_test.go index 76fc594a0..e44e995a0 100644 --- a/pkg/sentry/platform/kvm/kvm_amd64_test.go +++ b/pkg/sentry/platform/kvm/kvm_amd64_test.go @@ -33,7 +33,7 @@ func TestSegments(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, FullRestore: true, }, &si); err == platform.ErrContextInterrupt { diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go index 6243b9a04..5bce16dde 100644 --- a/pkg/sentry/platform/kvm/kvm_test.go +++ b/pkg/sentry/platform/kvm/kvm_test.go @@ -25,13 +25,14 @@ import ( "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil" ktime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/usermem" ) -var dummyFPState = (*byte)(arch.NewFloatingPointData()) +var dummyFPState = fpu.NewState() type testHarness interface { Errorf(format string, args ...interface{}) @@ -159,7 +160,7 @@ func TestApplicationSyscall(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, FullRestore: true, }, &si); err == platform.ErrContextInterrupt { @@ -173,7 +174,7 @@ func TestApplicationSyscall(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, }, &si); err == platform.ErrContextInterrupt { return true // Retry. @@ -190,7 +191,7 @@ func TestApplicationFault(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, FullRestore: true, }, &si); err == platform.ErrContextInterrupt { @@ -205,7 +206,7 @@ func TestApplicationFault(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, }, &si); err == platform.ErrContextInterrupt { return true // Retry. @@ -223,7 +224,7 @@ func TestRegistersSyscall(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, }, &si); err == platform.ErrContextInterrupt { continue // Retry. @@ -246,7 +247,7 @@ func TestRegistersFault(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, FullRestore: true, }, &si); err == platform.ErrContextInterrupt { @@ -272,7 +273,7 @@ func TestBounce(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, }, &si); err != platform.ErrContextInterrupt { t.Errorf("application partial restore: got %v, wanted %v", err, platform.ErrContextInterrupt) @@ -287,7 +288,7 @@ func TestBounce(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, FullRestore: true, }, &si); err != platform.ErrContextInterrupt { @@ -319,7 +320,7 @@ func TestBounceStress(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, }, &si); err != platform.ErrContextInterrupt { t.Errorf("application partial restore: got %v, wanted %v", err, platform.ErrContextInterrupt) @@ -340,7 +341,7 @@ func TestInvalidate(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, }, &si); err == platform.ErrContextInterrupt { continue // Retry. @@ -355,7 +356,7 @@ func TestInvalidate(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, Flush: true, }, &si); err == platform.ErrContextInterrupt { @@ -379,7 +380,7 @@ func TestEmptyAddressSpace(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, }, &si); err == platform.ErrContextInterrupt { return true // Retry. @@ -393,7 +394,7 @@ func TestEmptyAddressSpace(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, FullRestore: true, }, &si); err == platform.ErrContextInterrupt { @@ -469,7 +470,7 @@ func BenchmarkApplicationSyscall(b *testing.B) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, }, &si); err == platform.ErrContextInterrupt { a++ @@ -506,7 +507,7 @@ func BenchmarkWorldSwitchToUserRoundtrip(b *testing.B) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, }, &si); err == platform.ErrContextInterrupt { a++ diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index 916903881..3af96c7e5 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -27,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" ktime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/usermem" @@ -70,7 +71,7 @@ type vCPUArchState struct { // floatingPointState is the floating point state buffer used in guest // to host transitions. See usage in bluepill_amd64.go. - floatingPointState arch.FloatingPointData + floatingPointState fpu.State } const ( @@ -151,7 +152,7 @@ func (c *vCPU) initArchState() error { // This will be saved prior to leaving the guest, and we restore from // this always. We cannot use the pointer in the context alone because // we don't know how large the area there is in reality. - c.floatingPointState = arch.NewFloatingPointData() + c.floatingPointState = fpu.NewState() // Set the time offset to the host native time. return c.setSystemTime() @@ -307,12 +308,12 @@ func loadByte(ptr *byte) byte { // emulate instructions like xsave and xrstor. // //go:nosplit -func prefaultFloatingPointState(data arch.FloatingPointData) { - size := len(data) +func prefaultFloatingPointState(data *fpu.State) { + size := len(*data) for i := 0; i < size; i += usermem.PageSize { - loadByte(&(data)[i]) + loadByte(&(*data)[i]) } - loadByte(&(data)[size-1]) + loadByte(&(*data)[size-1]) } // SwitchToUser unpacks architectural-details. diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index 3d715e570..2edc9d1b2 100644 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/usermem" ) @@ -32,7 +33,7 @@ type vCPUArchState struct { // floatingPointState is the floating point state buffer used in guest // to host transitions. See usage in bluepill_arm64.go. - floatingPointState arch.FloatingPointData + floatingPointState fpu.State } const ( diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 059aa43d0..e7d5f3193 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/usermem" ) @@ -150,7 +151,7 @@ func (c *vCPU) initArchState() error { c.PCIDs = pagetables.NewPCIDs(fixedKernelPCID+1, poolPCIDs) } - c.floatingPointState = arch.NewFloatingPointData() + c.floatingPointState = fpu.NewState() return c.setSystemTime() } diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD index fc43cc3c0..47efde6a2 100644 --- a/pkg/sentry/platform/ptrace/BUILD +++ b/pkg/sentry/platform/ptrace/BUILD @@ -30,6 +30,7 @@ go_library( "//pkg/safecopy", "//pkg/seccomp", "//pkg/sentry/arch", + "//pkg/sentry/arch/fpu", "//pkg/sentry/memmap", "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go index 6259350ec..01e73b019 100644 --- a/pkg/sentry/platform/ptrace/ptrace_unsafe.go +++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go @@ -20,6 +20,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/usermem" ) @@ -62,9 +63,9 @@ func (t *thread) setRegs(regs *arch.Registers) error { } // getFPRegs gets the floating-point data via the GETREGSET ptrace unix. -func (t *thread) getFPRegs(fpState arch.FloatingPointData, fpLen uint64, useXsave bool) error { +func (t *thread) getFPRegs(fpState *fpu.State, fpLen uint64, useXsave bool) error { iovec := unix.Iovec{ - Base: (*byte)(&fpState[0]), + Base: fpState.BytePointer(), Len: fpLen, } _, _, errno := unix.RawSyscall6( @@ -81,9 +82,9 @@ func (t *thread) getFPRegs(fpState arch.FloatingPointData, fpLen uint64, useXsav } // setFPRegs sets the floating-point data via the SETREGSET ptrace unix. -func (t *thread) setFPRegs(fpState arch.FloatingPointData, fpLen uint64, useXsave bool) error { +func (t *thread) setFPRegs(fpState *fpu.State, fpLen uint64, useXsave bool) error { iovec := unix.Iovec{ - Base: (*byte)(&fpState[0]), + Base: fpState.BytePointer(), Len: fpLen, } _, _, errno := unix.RawSyscall6( diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go index 5bd526b73..efec93f73 100644 --- a/pkg/sentry/syscalls/linux/error.go +++ b/pkg/sentry/syscalls/linux/error.go @@ -75,17 +75,25 @@ func handleIOError(ctx context.Context, partialResult bool, ioerr, intr error, o // errors, we may consume the error and return only the partial read/write. // // Returns false if error is unknown. -func handleIOErrorImpl(ctx context.Context, partialResult bool, err, intr error, op string) (bool, error) { - switch err { - case nil: +func handleIOErrorImpl(ctx context.Context, partialResult bool, errOrig, intr error, op string) (bool, error) { + if errOrig == nil { // Typical successful syscall. return true, nil + } + + // Translate error, if possible, to consolidate errors from other packages + // into a smaller set of errors from syserror package. + translatedErr := errOrig + if errno, ok := syserror.TranslateError(errOrig); ok { + translatedErr = errno + } + switch translatedErr { case io.EOF: // EOF is always consumed. If this is a partial read/write // (result != 0), the application will see that, otherwise // they will see 0. return true, nil - case syserror.ErrExceedsFileSizeLimit: + case syserror.EFBIG: t := kernel.TaskFromContext(ctx) if t == nil { panic("I/O error should only occur from a context associated with a Task") @@ -98,7 +106,7 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, err, intr error, // Simultaneously send a SIGXFSZ per setrlimit(2). t.SendSignal(kernel.SignalInfoNoInfo(linux.SIGXFSZ, t, t)) return true, syserror.EFBIG - case syserror.ErrInterrupted: + case syserror.EINTR: // The syscall was interrupted. Return nil if it completed // partially, otherwise return the error code that the syscall // needs (to indicate to the kernel what it should do). @@ -110,10 +118,10 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, err, intr error, if !partialResult { // Typical syscall error. - return true, err + return true, errOrig } - switch err { + switch translatedErr { case syserror.EINTR: // Syscall interrupted, but completed a partial // read/write. Like ErrWouldBlock, since we have a @@ -143,7 +151,7 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, err, intr error, // For TCP sendfile connections, we may have a reset or timeout. But we // should just return n as the result. return true, nil - case syserror.ErrWouldBlock: + case syserror.EWOULDBLOCK: // Syscall would block, but completed a partial read/write. // This case should only be returned by IssueIO for nonblocking // files. Since we have a partial read/write, we consume @@ -151,7 +159,7 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, err, intr error, return true, nil } - switch err.(type) { + switch errOrig.(type) { case syserror.SyscallRestartErrno: // Identical to the EINTR case. return true, nil diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go index 97de17afe..56b621357 100644 --- a/pkg/syserror/syserror.go +++ b/pkg/syserror/syserror.go @@ -130,17 +130,15 @@ func AddErrorUnwrapper(unwrap func(e error) (unix.Errno, bool)) { // TranslateError translates errors to errnos, it will return false if // the error was not registered. func TranslateError(from error) (unix.Errno, bool) { - err, ok := errorMap[from] - if ok { - return err, ok + if err, ok := errorMap[from]; ok { + return err, true } // Try to unwrap the error if we couldn't match an error // exactly. This might mean that a package has its own // error type. for _, unwrap := range errorUnwrappers { - err, ok := unwrap(from) - if ok { - return err, ok + if err, ok := unwrap(from); ok { + return err, true } } return 0, false diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index fc622b246..fef065b05 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -1287,6 +1287,13 @@ func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) } + case header.NDPNonceOption: + gotOpt, ok := opt.(header.NDPNonceOption) + if !ok { + t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) + } else if diff := cmp.Diff(wantOpt.Nonce(), gotOpt.Nonce()); diff != "" { + t.Errorf("nonce mismatch (-want +got):\n%s", diff) + } default: t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt) } diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index 554242f0c..3d1bccd15 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -26,29 +26,33 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" ) -// NDPOptionIdentifier is an NDP option type identifier. -type NDPOptionIdentifier uint8 +// ndpOptionIdentifier is an NDP option type identifier. +type ndpOptionIdentifier uint8 const ( - // NDPSourceLinkLayerAddressOptionType is the type of the Source Link Layer + // ndpSourceLinkLayerAddressOptionType is the type of the Source Link Layer // Address option, as per RFC 4861 section 4.6.1. - NDPSourceLinkLayerAddressOptionType NDPOptionIdentifier = 1 + ndpSourceLinkLayerAddressOptionType ndpOptionIdentifier = 1 - // NDPTargetLinkLayerAddressOptionType is the type of the Target Link Layer + // ndpTargetLinkLayerAddressOptionType is the type of the Target Link Layer // Address option, as per RFC 4861 section 4.6.1. - NDPTargetLinkLayerAddressOptionType NDPOptionIdentifier = 2 + ndpTargetLinkLayerAddressOptionType ndpOptionIdentifier = 2 - // NDPPrefixInformationType is the type of the Prefix Information + // ndpPrefixInformationType is the type of the Prefix Information // option, as per RFC 4861 section 4.6.2. - NDPPrefixInformationType NDPOptionIdentifier = 3 + ndpPrefixInformationType ndpOptionIdentifier = 3 - // NDPRecursiveDNSServerOptionType is the type of the Recursive DNS + // ndpNonceOptionType is the type of the Nonce option, as per + // RFC 3971 section 5.3.2. + ndpNonceOptionType ndpOptionIdentifier = 14 + + // ndpRecursiveDNSServerOptionType is the type of the Recursive DNS // Server option, as per RFC 8106 section 5.1. - NDPRecursiveDNSServerOptionType NDPOptionIdentifier = 25 + ndpRecursiveDNSServerOptionType ndpOptionIdentifier = 25 - // NDPDNSSearchListOptionType is the type of the DNS Search List option, + // ndpDNSSearchListOptionType is the type of the DNS Search List option, // as per RFC 8106 section 5.2. - NDPDNSSearchListOptionType = 31 + ndpDNSSearchListOptionType ndpOptionIdentifier = 31 ) const ( @@ -198,7 +202,7 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { // bytes for the whole option. return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Type field: %w", io.ErrUnexpectedEOF) } - kind := NDPOptionIdentifier(temp) + kind := ndpOptionIdentifier(temp) // Get the Length field. length, err := i.opts.ReadByte() @@ -225,13 +229,16 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { } switch kind { - case NDPSourceLinkLayerAddressOptionType: + case ndpSourceLinkLayerAddressOptionType: return NDPSourceLinkLayerAddressOption(body), false, nil - case NDPTargetLinkLayerAddressOptionType: + case ndpTargetLinkLayerAddressOptionType: return NDPTargetLinkLayerAddressOption(body), false, nil - case NDPPrefixInformationType: + case ndpNonceOptionType: + return NDPNonceOption(body), false, nil + + case ndpPrefixInformationType: // Make sure the length of a Prefix Information option // body is ndpPrefixInformationLength, as per RFC 4861 // section 4.6.2. @@ -241,7 +248,7 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { return NDPPrefixInformation(body), false, nil - case NDPRecursiveDNSServerOptionType: + case ndpRecursiveDNSServerOptionType: opt := NDPRecursiveDNSServer(body) if err := opt.checkAddresses(); err != nil { return nil, true, err @@ -249,7 +256,7 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { return opt, false, nil - case NDPDNSSearchListOptionType: + case ndpDNSSearchListOptionType: opt := NDPDNSSearchList(body) if err := opt.checkDomainNames(); err != nil { return nil, true, err @@ -316,7 +323,7 @@ func (b NDPOptions) Serialize(s NDPOptionsSerializer) int { continue } - b[0] = byte(o.Type()) + b[0] = byte(o.kind()) // We know this safe because paddedLength would have returned // 0 if o had an invalid length (> 255 * lengthByteUnits). @@ -341,11 +348,11 @@ func (b NDPOptions) Serialize(s NDPOptionsSerializer) int { type NDPOption interface { fmt.Stringer - // Type returns the type of the receiver. - Type() NDPOptionIdentifier + // kind returns the type of the receiver. + kind() ndpOptionIdentifier - // Length returns the length of the body of the receiver, in bytes. - Length() int + // length returns the length of the body of the receiver, in bytes. + length() int // serializeInto serializes the receiver into the provided byte // buffer. @@ -365,7 +372,7 @@ type NDPOption interface { // paddedLength returns the length of o, in bytes, with any padding bytes, if // required. func paddedLength(o NDPOption) int { - l := o.Length() + l := o.length() if l == 0 { return 0 @@ -416,6 +423,37 @@ func (b NDPOptionsSerializer) Length() int { return l } +// NDPNonceOption is the NDP Nonce Option as defined by RFC 3971 section 5.3.2. +// +// It is the first X bytes following the NDP option's Type and Length field +// where X is the value in Length multiplied by lengthByteUnits - 2 bytes. +type NDPNonceOption []byte + +// kind implements NDPOption. +func (o NDPNonceOption) kind() ndpOptionIdentifier { + return ndpNonceOptionType +} + +// length implements NDPOption. +func (o NDPNonceOption) length() int { + return len(o) +} + +// serializeInto implements NDPOption. +func (o NDPNonceOption) serializeInto(b []byte) int { + return copy(b, o) +} + +// String implements fmt.Stringer. +func (o NDPNonceOption) String() string { + return fmt.Sprintf("%T(%x)", o, []byte(o)) +} + +// Nonce returns the nonce value this option holds. +func (o NDPNonceOption) Nonce() []byte { + return []byte(o) +} + // NDPSourceLinkLayerAddressOption is the NDP Source Link Layer Option // as defined by RFC 4861 section 4.6.1. // @@ -423,22 +461,22 @@ func (b NDPOptionsSerializer) Length() int { // where X is the value in Length multiplied by lengthByteUnits - 2 bytes. type NDPSourceLinkLayerAddressOption tcpip.LinkAddress -// Type implements NDPOption.Type. -func (o NDPSourceLinkLayerAddressOption) Type() NDPOptionIdentifier { - return NDPSourceLinkLayerAddressOptionType +// kind implements NDPOption. +func (o NDPSourceLinkLayerAddressOption) kind() ndpOptionIdentifier { + return ndpSourceLinkLayerAddressOptionType } -// Length implements NDPOption.Length. -func (o NDPSourceLinkLayerAddressOption) Length() int { +// length implements NDPOption. +func (o NDPSourceLinkLayerAddressOption) length() int { return len(o) } -// serializeInto implements NDPOption.serializeInto. +// serializeInto implements NDPOption. func (o NDPSourceLinkLayerAddressOption) serializeInto(b []byte) int { return copy(b, o) } -// String implements fmt.Stringer.String. +// String implements fmt.Stringer. func (o NDPSourceLinkLayerAddressOption) String() string { return fmt.Sprintf("%T(%s)", o, tcpip.LinkAddress(o)) } @@ -463,22 +501,22 @@ func (o NDPSourceLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress { // where X is the value in Length multiplied by lengthByteUnits - 2 bytes. type NDPTargetLinkLayerAddressOption tcpip.LinkAddress -// Type implements NDPOption.Type. -func (o NDPTargetLinkLayerAddressOption) Type() NDPOptionIdentifier { - return NDPTargetLinkLayerAddressOptionType +// kind implements NDPOption. +func (o NDPTargetLinkLayerAddressOption) kind() ndpOptionIdentifier { + return ndpTargetLinkLayerAddressOptionType } -// Length implements NDPOption.Length. -func (o NDPTargetLinkLayerAddressOption) Length() int { +// length implements NDPOption. +func (o NDPTargetLinkLayerAddressOption) length() int { return len(o) } -// serializeInto implements NDPOption.serializeInto. +// serializeInto implements NDPOption. func (o NDPTargetLinkLayerAddressOption) serializeInto(b []byte) int { return copy(b, o) } -// String implements fmt.Stringer.String. +// String implements fmt.Stringer. func (o NDPTargetLinkLayerAddressOption) String() string { return fmt.Sprintf("%T(%s)", o, tcpip.LinkAddress(o)) } @@ -503,17 +541,17 @@ func (o NDPTargetLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress { // ndpPrefixInformationLength bytes. type NDPPrefixInformation []byte -// Type implements NDPOption.Type. -func (o NDPPrefixInformation) Type() NDPOptionIdentifier { - return NDPPrefixInformationType +// kind implements NDPOption. +func (o NDPPrefixInformation) kind() ndpOptionIdentifier { + return ndpPrefixInformationType } -// Length implements NDPOption.Length. -func (o NDPPrefixInformation) Length() int { +// length implements NDPOption. +func (o NDPPrefixInformation) length() int { return ndpPrefixInformationLength } -// serializeInto implements NDPOption.serializeInto. +// serializeInto implements NDPOption. func (o NDPPrefixInformation) serializeInto(b []byte) int { used := copy(b, o) @@ -529,7 +567,7 @@ func (o NDPPrefixInformation) serializeInto(b []byte) int { return used } -// String implements fmt.Stringer.String. +// String implements fmt.Stringer. func (o NDPPrefixInformation) String() string { return fmt.Sprintf("%T(O=%t, A=%t, PL=%s, VL=%s, Prefix=%s)", o, @@ -627,17 +665,17 @@ type NDPRecursiveDNSServer []byte // Type returns the type of an NDP Recursive DNS Server option. // -// Type implements NDPOption.Type. -func (NDPRecursiveDNSServer) Type() NDPOptionIdentifier { - return NDPRecursiveDNSServerOptionType +// kind implements NDPOption. +func (NDPRecursiveDNSServer) kind() ndpOptionIdentifier { + return ndpRecursiveDNSServerOptionType } -// Length implements NDPOption.Length. -func (o NDPRecursiveDNSServer) Length() int { +// length implements NDPOption. +func (o NDPRecursiveDNSServer) length() int { return len(o) } -// serializeInto implements NDPOption.serializeInto. +// serializeInto implements NDPOption. func (o NDPRecursiveDNSServer) serializeInto(b []byte) int { used := copy(b, o) @@ -649,7 +687,7 @@ func (o NDPRecursiveDNSServer) serializeInto(b []byte) int { return used } -// String implements fmt.Stringer.String. +// String implements fmt.Stringer. func (o NDPRecursiveDNSServer) String() string { lt := o.Lifetime() addrs, err := o.Addresses() @@ -722,17 +760,17 @@ func (o NDPRecursiveDNSServer) iterAddresses(fn func(tcpip.Address)) error { // RFC 8106 section 5.2. type NDPDNSSearchList []byte -// Type implements NDPOption.Type. -func (o NDPDNSSearchList) Type() NDPOptionIdentifier { - return NDPDNSSearchListOptionType +// kind implements NDPOption. +func (o NDPDNSSearchList) kind() ndpOptionIdentifier { + return ndpDNSSearchListOptionType } -// Length implements NDPOption.Length. -func (o NDPDNSSearchList) Length() int { +// length implements NDPOption. +func (o NDPDNSSearchList) length() int { return len(o) } -// serializeInto implements NDPOption.serializeInto. +// serializeInto implements NDPOption. func (o NDPDNSSearchList) serializeInto(b []byte) int { used := copy(b, o) @@ -744,7 +782,7 @@ func (o NDPDNSSearchList) serializeInto(b []byte) int { return used } -// String implements fmt.Stringer.String. +// String implements fmt.Stringer. func (o NDPDNSSearchList) String() string { lt := o.Lifetime() domainNames, err := o.DomainNames() diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go index dc4591253..d0a1a2492 100644 --- a/pkg/tcpip/header/ndp_test.go +++ b/pkg/tcpip/header/ndp_test.go @@ -16,6 +16,7 @@ package header import ( "bytes" + "encoding/binary" "errors" "fmt" "io" @@ -192,90 +193,6 @@ func TestNDPSourceLinkLayerAddressOptionEthernetAddress(t *testing.T) { } } -// TestNDPSourceLinkLayerAddressOptionSerialize tests serializing a -// NDPSourceLinkLayerAddressOption. -func TestNDPSourceLinkLayerAddressOptionSerialize(t *testing.T) { - tests := []struct { - name string - buf []byte - expectedBuf []byte - addr tcpip.LinkAddress - }{ - { - "Ethernet", - make([]byte, 8), - []byte{1, 1, 1, 2, 3, 4, 5, 6}, - "\x01\x02\x03\x04\x05\x06", - }, - { - "Padding", - []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, - []byte{1, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0}, - "\x01\x02\x03\x04\x05\x06\x07\x08", - }, - { - "Empty", - nil, - nil, - "", - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opts := NDPOptions(test.buf) - serializer := NDPOptionsSerializer{ - NDPSourceLinkLayerAddressOption(test.addr), - } - if got, want := int(serializer.Length()), len(test.expectedBuf); got != want { - t.Fatalf("got Length = %d, want = %d", got, want) - } - opts.Serialize(serializer) - if !bytes.Equal(test.buf, test.expectedBuf) { - t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf) - } - - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - if len(test.expectedBuf) > 0 { - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType { - t.Fatalf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType) - } - sll := next.(NDPSourceLinkLayerAddressOption) - if got, want := []byte(sll), test.expectedBuf[2:]; !bytes.Equal(got, want) { - t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want) - } - - if got, want := sll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want { - t.Errorf("got sll.EthernetAddress = %s, want = %s", got, want) - } - } - - // Iterator should not return anything else. - next, done, err := it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } - }) - } -} - // TestNDPTargetLinkLayerAddressOptionEthernetAddress tests getting the // Ethernet address from an NDPTargetLinkLayerAddressOption. func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) { @@ -311,32 +228,309 @@ func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) { } } -// TestNDPTargetLinkLayerAddressOptionSerialize tests serializing a -// NDPTargetLinkLayerAddressOption. -func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) { +func TestOpts(t *testing.T) { + const optionHeaderLen = 2 + + checkNonce := func(expectedNonce []byte) func(*testing.T, NDPOption) { + return func(t *testing.T, opt NDPOption) { + if got := opt.kind(); got != ndpNonceOptionType { + t.Errorf("got kind() = %d, want = %d", got, ndpNonceOptionType) + } + nonce, ok := opt.(NDPNonceOption) + if !ok { + t.Fatalf("got nonce = %T, want = NDPNonceOption", opt) + } + if diff := cmp.Diff(expectedNonce, nonce.Nonce()); diff != "" { + t.Errorf("nonce mismatch (-want +got):\n%s", diff) + } + } + } + + checkTLL := func(expectedAddr tcpip.LinkAddress) func(*testing.T, NDPOption) { + return func(t *testing.T, opt NDPOption) { + if got := opt.kind(); got != ndpTargetLinkLayerAddressOptionType { + t.Errorf("got kind() = %d, want = %d", got, ndpTargetLinkLayerAddressOptionType) + } + tll, ok := opt.(NDPTargetLinkLayerAddressOption) + if !ok { + t.Fatalf("got tll = %T, want = NDPTargetLinkLayerAddressOption", opt) + } + if got, want := tll.EthernetAddress(), expectedAddr; got != want { + t.Errorf("got tll.EthernetAddress = %s, want = %s", got, want) + } + } + } + + checkSLL := func(expectedAddr tcpip.LinkAddress) func(*testing.T, NDPOption) { + return func(t *testing.T, opt NDPOption) { + if got := opt.kind(); got != ndpSourceLinkLayerAddressOptionType { + t.Errorf("got kind() = %d, want = %d", got, ndpSourceLinkLayerAddressOptionType) + } + sll, ok := opt.(NDPSourceLinkLayerAddressOption) + if !ok { + t.Fatalf("got sll = %T, want = NDPSourceLinkLayerAddressOption", opt) + } + if got, want := sll.EthernetAddress(), expectedAddr; got != want { + t.Errorf("got sll.EthernetAddress = %s, want = %s", got, want) + } + } + } + + const validLifetimeSeconds = 16909060 + const address = tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18") + + expectedRDNSSBytes := [...]byte{ + // Type, Length + 25, 3, + + // Reserved + 0, 0, + + // Lifetime + 1, 2, 4, 8, + + // Address + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + } + binary.BigEndian.PutUint32(expectedRDNSSBytes[4:], validLifetimeSeconds) + if n := copy(expectedRDNSSBytes[8:], address); n != IPv6AddressSize { + t.Fatalf("got copy(...) = %d, want = %d", n, IPv6AddressSize) + } + // Update reserved fields to non zero values to make sure serializing sets + // them to zero. + rdnssBytes := expectedRDNSSBytes + rdnssBytes[1] = 1 + rdnssBytes[2] = 2 + + const searchListPaddingBytes = 3 + const domainName = "abc.abcd.e" + expectedSearchListBytes := [...]byte{ + // Type, Length + 31, 3, + + // Reserved + 0, 0, + + // Lifetime + 1, 0, 0, 0, + + // Domain names + 3, 'a', 'b', 'c', + 4, 'a', 'b', 'c', 'd', + 1, 'e', + 0, + 0, 0, 0, 0, + } + binary.BigEndian.PutUint32(expectedSearchListBytes[4:], validLifetimeSeconds) + // Update reserved fields to non zero values to make sure serializing sets + // them to zero. + searchListBytes := expectedSearchListBytes + searchListBytes[2] = 1 + searchListBytes[3] = 2 + + const prefixLength = 43 + const onLinkFlag = false + const slaacFlag = true + const preferredLifetimeSeconds = 84281096 + const onLinkFlagBit = 7 + const slaacFlagBit = 6 + boolToByte := func(v bool) byte { + if v { + return 1 + } + return 0 + } + flags := boolToByte(onLinkFlag)<<onLinkFlagBit | boolToByte(slaacFlag)<<slaacFlagBit + expectedPrefixInformationBytes := [...]byte{ + // Type, Length + 3, 4, + + prefixLength, flags, + + // Valid Lifetime + 1, 2, 3, 4, + + // Preferred Lifetime + 5, 6, 7, 8, + + // Reserved2 + 0, 0, 0, 0, + + // Address + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + } + binary.BigEndian.PutUint32(expectedPrefixInformationBytes[4:], validLifetimeSeconds) + binary.BigEndian.PutUint32(expectedPrefixInformationBytes[8:], preferredLifetimeSeconds) + if n := copy(expectedPrefixInformationBytes[16:], address); n != IPv6AddressSize { + t.Fatalf("got copy(...) = %d, want = %d", n, IPv6AddressSize) + } + // Update reserved fields to non zero values to make sure serializing sets + // them to zero. + prefixInformationBytes := expectedPrefixInformationBytes + prefixInformationBytes[3] |= (1 << slaacFlagBit) - 1 + binary.BigEndian.PutUint32(prefixInformationBytes[12:], validLifetimeSeconds+1) tests := []struct { name string buf []byte + opt NDPOption expectedBuf []byte - addr tcpip.LinkAddress + check func(*testing.T, NDPOption) }{ { - "Ethernet", - make([]byte, 8), - []byte{2, 1, 1, 2, 3, 4, 5, 6}, - "\x01\x02\x03\x04\x05\x06", + name: "Nonce", + buf: make([]byte, 8), + opt: NDPNonceOption([]byte{1, 2, 3, 4, 5, 6}), + expectedBuf: []byte{14, 1, 1, 2, 3, 4, 5, 6}, + check: checkNonce([]byte{1, 2, 3, 4, 5, 6}), + }, + { + name: "Nonce with padding", + buf: []byte{1, 1, 1, 1, 1, 1, 1, 1}, + opt: NDPNonceOption([]byte{1, 2, 3, 4, 5}), + expectedBuf: []byte{14, 1, 1, 2, 3, 4, 5, 0}, + check: checkNonce([]byte{1, 2, 3, 4, 5, 0}), + }, + + { + name: "TLL Ethernet", + buf: make([]byte, 8), + opt: NDPTargetLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06"), + expectedBuf: []byte{2, 1, 1, 2, 3, 4, 5, 6}, + check: checkTLL("\x01\x02\x03\x04\x05\x06"), + }, + { + name: "TLL Padding", + buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + opt: NDPTargetLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06\x07\x08"), + expectedBuf: []byte{2, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0}, + check: checkTLL("\x01\x02\x03\x04\x05\x06"), + }, + { + name: "TLL Empty", + buf: nil, + opt: NDPTargetLinkLayerAddressOption(""), + expectedBuf: nil, }, + { - "Padding", - []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, - []byte{2, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0}, - "\x01\x02\x03\x04\x05\x06\x07\x08", + name: "SLL Ethernet", + buf: make([]byte, 8), + opt: NDPSourceLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06"), + expectedBuf: []byte{1, 1, 1, 2, 3, 4, 5, 6}, + check: checkSLL("\x01\x02\x03\x04\x05\x06"), }, { - "Empty", - nil, - nil, - "", + name: "SLL Padding", + buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + opt: NDPSourceLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06\x07\x08"), + expectedBuf: []byte{1, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0}, + check: checkSLL("\x01\x02\x03\x04\x05\x06"), + }, + { + name: "SLL Empty", + buf: nil, + opt: NDPSourceLinkLayerAddressOption(""), + expectedBuf: nil, + }, + + { + name: "RDNSS", + buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + // NDPRecursiveDNSServer holds the option after the header bytes. + opt: NDPRecursiveDNSServer(rdnssBytes[optionHeaderLen:]), + expectedBuf: expectedRDNSSBytes[:], + check: func(t *testing.T, opt NDPOption) { + if got := opt.kind(); got != ndpRecursiveDNSServerOptionType { + t.Errorf("got kind() = %d, want = %d", got, ndpRecursiveDNSServerOptionType) + } + rdnss, ok := opt.(NDPRecursiveDNSServer) + if !ok { + t.Fatalf("got opt = %T, want = NDPRecursiveDNSServer", opt) + } + if got, want := rdnss.length(), len(expectedRDNSSBytes[optionHeaderLen:]); got != want { + t.Errorf("got length() = %d, want = %d", got, want) + } + if got, want := rdnss.Lifetime(), validLifetimeSeconds*time.Second; got != want { + t.Errorf("got Lifetime() = %s, want = %s", got, want) + } + if addrs, err := rdnss.Addresses(); err != nil { + t.Errorf("Addresses(): %s", err) + } else if diff := cmp.Diff([]tcpip.Address{address}, addrs); diff != "" { + t.Errorf("mismatched addresses (-want +got):\n%s", diff) + } + }, + }, + + { + name: "Search list", + buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + opt: NDPDNSSearchList(searchListBytes[optionHeaderLen:]), + expectedBuf: expectedSearchListBytes[:], + check: func(t *testing.T, opt NDPOption) { + if got := opt.kind(); got != ndpDNSSearchListOptionType { + t.Errorf("got kind() = %d, want = %d", got, ndpDNSSearchListOptionType) + } + + dnssl, ok := opt.(NDPDNSSearchList) + if !ok { + t.Fatalf("got opt = %T, want = NDPDNSSearchList", opt) + } + if got, want := dnssl.length(), len(expectedRDNSSBytes[optionHeaderLen:]); got != want { + t.Errorf("got length() = %d, want = %d", got, want) + } + if got, want := dnssl.Lifetime(), validLifetimeSeconds*time.Second; got != want { + t.Errorf("got Lifetime() = %s, want = %s", got, want) + } + + if domainNames, err := dnssl.DomainNames(); err != nil { + t.Errorf("DomainNames(): %s", err) + } else if diff := cmp.Diff([]string{domainName}, domainNames); diff != "" { + t.Errorf("domain names mismatch (-want +got):\n%s", diff) + } + }, + }, + + { + name: "Prefix Information", + buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + // NDPPrefixInformation holds the option after the header bytes. + opt: NDPPrefixInformation(prefixInformationBytes[optionHeaderLen:]), + expectedBuf: expectedPrefixInformationBytes[:], + check: func(t *testing.T, opt NDPOption) { + if got := opt.kind(); got != ndpPrefixInformationType { + t.Errorf("got kind() = %d, want = %d", got, ndpPrefixInformationType) + } + + pi, ok := opt.(NDPPrefixInformation) + if !ok { + t.Fatalf("got opt = %T, want = NDPPrefixInformation", opt) + } + + if got, want := pi.length(), len(expectedPrefixInformationBytes[optionHeaderLen:]); got != want { + t.Errorf("got length() = %d, want = %d", got, want) + } + if got := pi.PrefixLength(); got != prefixLength { + t.Errorf("got PrefixLength() = %d, want = %d", got, prefixLength) + } + if got := pi.OnLinkFlag(); got != onLinkFlag { + t.Errorf("got OnLinkFlag() = %t, want = %t", got, onLinkFlag) + } + if got := pi.AutonomousAddressConfigurationFlag(); got != slaacFlag { + t.Errorf("got AutonomousAddressConfigurationFlag() = %t, want = %t", got, slaacFlag) + } + if got, want := pi.ValidLifetime(), validLifetimeSeconds*time.Second; got != want { + t.Errorf("got ValidLifetime() = %s, want = %s", got, want) + } + if got, want := pi.PreferredLifetime(), preferredLifetimeSeconds*time.Second; got != want { + t.Errorf("got PreferredLifetime() = %s, want = %s", got, want) + } + if got := pi.Prefix(); got != address { + t.Errorf("got Prefix() = %s, want = %s", got, address) + } + }, }, } @@ -344,230 +538,47 @@ func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) { t.Run(test.name, func(t *testing.T) { opts := NDPOptions(test.buf) serializer := NDPOptionsSerializer{ - NDPTargetLinkLayerAddressOption(test.addr), + test.opt, } if got, want := int(serializer.Length()), len(test.expectedBuf); got != want { - t.Fatalf("got Length = %d, want = %d", got, want) + t.Fatalf("got Length() = %d, want = %d", got, want) } opts.Serialize(serializer) - if !bytes.Equal(test.buf, test.expectedBuf) { - t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf) + if diff := cmp.Diff(test.expectedBuf, test.buf); diff != "" { + t.Fatalf("serialized buffer mismatch (-want +got):\n%s", diff) } it, err := opts.Iter(true) if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + t.Fatalf("got Iter(true) = (_, %s), want = (_, nil)", err) } if len(test.expectedBuf) > 0 { next, done, err := it.Next() if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + t.Fatalf("got Next() = (_, _, %s), want = (_, _, nil)", err) } if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType { - t.Fatalf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType) - } - tll := next.(NDPTargetLinkLayerAddressOption) - if got, want := []byte(tll), test.expectedBuf[2:]; !bytes.Equal(got, want) { - t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want) - } - - if got, want := tll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want { - t.Errorf("got tll.EthernetAddress = %s, want = %s", got, want) + t.Fatal("got Next() = (_, true, _), want = (_, false, _)") } + test.check(t, next) } // Iterator should not return anything else. next, done, err := it.Next() if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + t.Errorf("got Next() = (_, _, %s), want = (_, _, nil)", err) } if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") + t.Error("got Next() = (_, false, _), want = (_, true, _)") } if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + t.Errorf("got Next() = (%x, _, _), want = (nil, _, _)", next) } }) } } -// TestNDPPrefixInformationOption tests the field getters and serialization of a -// NDPPrefixInformation. -func TestNDPPrefixInformationOption(t *testing.T) { - b := []byte{ - 43, 127, - 1, 2, 3, 4, - 5, 6, 7, 8, - 5, 5, 5, 5, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - } - - targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} - opts := NDPOptions(targetBuf) - serializer := NDPOptionsSerializer{ - NDPPrefixInformation(b), - } - opts.Serialize(serializer) - expectedBuf := []byte{ - 3, 4, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - } - if !bytes.Equal(targetBuf, expectedBuf) { - t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expectedBuf) - } - - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPPrefixInformationType { - t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType) - } - - pi := next.(NDPPrefixInformation) - - if got := pi.Type(); got != 3 { - t.Errorf("got Type = %d, want = 3", got) - } - - if got := pi.Length(); got != 30 { - t.Errorf("got Length = %d, want = 30", got) - } - - if got := pi.PrefixLength(); got != 43 { - t.Errorf("got PrefixLength = %d, want = 43", got) - } - - if pi.OnLinkFlag() { - t.Error("got OnLinkFlag = true, want = false") - } - - if !pi.AutonomousAddressConfigurationFlag() { - t.Error("got AutonomousAddressConfigurationFlag = false, want = true") - } - - if got, want := pi.ValidLifetime(), 16909060*time.Second; got != want { - t.Errorf("got ValidLifetime = %d, want = %d", got, want) - } - - if got, want := pi.PreferredLifetime(), 84281096*time.Second; got != want { - t.Errorf("got PreferredLifetime = %d, want = %d", got, want) - } - - if got, want := pi.Prefix(), tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18"); got != want { - t.Errorf("got Prefix = %s, want = %s", got, want) - } - - // Iterator should not return anything else. - next, done, err = it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } -} - -func TestNDPRecursiveDNSServerOptionSerialize(t *testing.T) { - b := []byte{ - 9, 8, - 1, 2, 4, 8, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - } - targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} - expected := []byte{ - 25, 3, 0, 0, - 1, 2, 4, 8, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - } - opts := NDPOptions(targetBuf) - serializer := NDPOptionsSerializer{ - NDPRecursiveDNSServer(b), - } - if got, want := opts.Serialize(serializer), len(expected); got != want { - t.Errorf("got Serialize = %d, want = %d", got, want) - } - if !bytes.Equal(targetBuf, expected) { - t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected) - } - - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPRecursiveDNSServerOptionType { - t.Errorf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType) - } - - opt, ok := next.(NDPRecursiveDNSServer) - if !ok { - t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next) - } - if got := opt.Type(); got != 25 { - t.Errorf("got Type = %d, want = 31", got) - } - if got := opt.Length(); got != 22 { - t.Errorf("got Length = %d, want = 22", got) - } - if got, want := opt.Lifetime(), 16909320*time.Second; got != want { - t.Errorf("got Lifetime = %s, want = %s", got, want) - } - want := []tcpip.Address{ - "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", - } - addrs, err := opt.Addresses() - if err != nil { - t.Errorf("opt.Addresses() = %s", err) - } - if diff := cmp.Diff(addrs, want); diff != "" { - t.Errorf("mismatched addresses (-want +got):\n%s", diff) - } - - // Iterator should not return anything else. - next, done, err = it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } -} - func TestNDPRecursiveDNSServerOption(t *testing.T) { tests := []struct { name string @@ -635,8 +646,8 @@ func TestNDPRecursiveDNSServerOption(t *testing.T) { if done { t.Fatal("got Next = (_, true, _), want = (_, false, _)") } - if got := next.Type(); got != NDPRecursiveDNSServerOptionType { - t.Fatalf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType) + if got := next.kind(); got != ndpRecursiveDNSServerOptionType { + t.Fatalf("got Type = %d, want = %d", got, ndpRecursiveDNSServerOptionType) } opt, ok := next.(NDPRecursiveDNSServer) @@ -1060,86 +1071,6 @@ func TestNDPSearchListOptionDomainNameLabelInvalidSymbols(t *testing.T) { } } -func TestNDPDNSSearchListOptionSerialize(t *testing.T) { - b := []byte{ - 9, 8, - 1, 0, 0, 0, - 3, 'a', 'b', 'c', - 4, 'a', 'b', 'c', 'd', - 1, 'e', - 0, - } - targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} - expected := []byte{ - 31, 3, 0, 0, - 1, 0, 0, 0, - 3, 'a', 'b', 'c', - 4, 'a', 'b', 'c', 'd', - 1, 'e', - 0, - 0, 0, 0, 0, - } - opts := NDPOptions(targetBuf) - serializer := NDPOptionsSerializer{ - NDPDNSSearchList(b), - } - if got, want := opts.Serialize(serializer), len(expected); got != want { - t.Errorf("got Serialize = %d, want = %d", got, want) - } - if !bytes.Equal(targetBuf, expected) { - t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected) - } - - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPDNSSearchListOptionType { - t.Errorf("got Type = %d, want = %d", got, NDPDNSSearchListOptionType) - } - - opt, ok := next.(NDPDNSSearchList) - if !ok { - t.Fatalf("next (type = %T) cannot be casted to an NDPDNSSearchList", next) - } - if got := opt.Type(); got != 31 { - t.Errorf("got Type = %d, want = 31", got) - } - if got := opt.Length(); got != 22 { - t.Errorf("got Length = %d, want = 22", got) - } - if got, want := opt.Lifetime(), 16777216*time.Second; got != want { - t.Errorf("got Lifetime = %s, want = %s", got, want) - } - domainNames, err := opt.DomainNames() - if err != nil { - t.Errorf("opt.DomainNames() = %s", err) - } - if diff := cmp.Diff(domainNames, []string{"abc.abcd.e"}); diff != "" { - t.Errorf("domain names mismatch (-want +got):\n%s", diff) - } - - // Iterator should not return anything else. - next, done, err = it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } -} - // TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions // the iterator was returned for is malformed. func TestNDPOptionsIterCheck(t *testing.T) { @@ -1472,8 +1403,8 @@ func TestNDPOptionsIter(t *testing.T) { if got, want := []byte(next.(NDPSourceLinkLayerAddressOption)), buf[2:][:6]; !bytes.Equal(got, want) { t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) } - if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType { - t.Errorf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType) + if got := next.kind(); got != ndpSourceLinkLayerAddressOptionType { + t.Errorf("got Type = %d, want = %d", got, ndpSourceLinkLayerAddressOptionType) } // Test the next (Target Link-Layer) option. @@ -1487,8 +1418,8 @@ func TestNDPOptionsIter(t *testing.T) { if got, want := []byte(next.(NDPTargetLinkLayerAddressOption)), buf[10:][:6]; !bytes.Equal(got, want) { t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) } - if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType { - t.Errorf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType) + if got := next.kind(); got != ndpTargetLinkLayerAddressOptionType { + t.Errorf("got Type = %d, want = %d", got, ndpTargetLinkLayerAddressOptionType) } // Test the next (Prefix Information) option. @@ -1503,8 +1434,8 @@ func TestNDPOptionsIter(t *testing.T) { if got, want := next.(NDPPrefixInformation), buf[34:][:30]; !bytes.Equal(got, want) { t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) } - if got := next.Type(); got != NDPPrefixInformationType { - t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType) + if got := next.kind(); got != ndpPrefixInformationType { + t.Errorf("got Type = %d, want = %d", got, ndpPrefixInformationType) } // Iterator should not return anything else. diff --git a/pkg/tcpip/header/ndpoptionidentifier_string.go b/pkg/tcpip/header/ndpoptionidentifier_string.go index 6fe9a336b..55ab1d7cf 100644 --- a/pkg/tcpip/header/ndpoptionidentifier_string.go +++ b/pkg/tcpip/header/ndpoptionidentifier_string.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Code generated by "stringer -type NDPOptionIdentifier ."; DO NOT EDIT. +// Code generated by "stringer -type ndpOptionIdentifier"; DO NOT EDIT. package header @@ -22,29 +22,37 @@ func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. var x [1]struct{} - _ = x[NDPSourceLinkLayerAddressOptionType-1] - _ = x[NDPTargetLinkLayerAddressOptionType-2] - _ = x[NDPPrefixInformationType-3] - _ = x[NDPRecursiveDNSServerOptionType-25] + _ = x[ndpSourceLinkLayerAddressOptionType-1] + _ = x[ndpTargetLinkLayerAddressOptionType-2] + _ = x[ndpPrefixInformationType-3] + _ = x[ndpNonceOptionType-14] + _ = x[ndpRecursiveDNSServerOptionType-25] + _ = x[ndpDNSSearchListOptionType-31] } const ( - _NDPOptionIdentifier_name_0 = "NDPSourceLinkLayerAddressOptionTypeNDPTargetLinkLayerAddressOptionTypeNDPPrefixInformationType" - _NDPOptionIdentifier_name_1 = "NDPRecursiveDNSServerOptionType" + _ndpOptionIdentifier_name_0 = "ndpSourceLinkLayerAddressOptionTypendpTargetLinkLayerAddressOptionTypendpPrefixInformationType" + _ndpOptionIdentifier_name_1 = "ndpNonceOptionType" + _ndpOptionIdentifier_name_2 = "ndpRecursiveDNSServerOptionType" + _ndpOptionIdentifier_name_3 = "ndpDNSSearchListOptionType" ) var ( - _NDPOptionIdentifier_index_0 = [...]uint8{0, 35, 70, 94} + _ndpOptionIdentifier_index_0 = [...]uint8{0, 35, 70, 94} ) -func (i NDPOptionIdentifier) String() string { +func (i ndpOptionIdentifier) String() string { switch { case 1 <= i && i <= 3: i -= 1 - return _NDPOptionIdentifier_name_0[_NDPOptionIdentifier_index_0[i]:_NDPOptionIdentifier_index_0[i+1]] + return _ndpOptionIdentifier_name_0[_ndpOptionIdentifier_index_0[i]:_ndpOptionIdentifier_index_0[i+1]] + case i == 14: + return _ndpOptionIdentifier_name_1 case i == 25: - return _NDPOptionIdentifier_name_1 + return _ndpOptionIdentifier_name_2 + case i == 31: + return _ndpOptionIdentifier_name_3 default: - return "NDPOptionIdentifier(" + strconv.FormatInt(int64(i), 10) + ")" + return "ndpOptionIdentifier(" + strconv.FormatInt(int64(i), 10) + ")" } } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index ae0461a6d..7ae38d684 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -38,6 +38,7 @@ const ( var _ stack.DuplicateAddressDetector = (*endpoint)(nil) var _ stack.LinkAddressResolver = (*endpoint)(nil) +var _ ip.DADProtocol = (*endpoint)(nil) // ARP endpoints need to implement stack.NetworkEndpoint because the stack // considers the layer above the link-layer a network layer; the only @@ -82,7 +83,8 @@ func (*endpoint) DuplicateAddressProtocol() tcpip.NetworkProtocolNumber { return header.IPv4ProtocolNumber } -func (e *endpoint) SendDADMessage(addr tcpip.Address) tcpip.Error { +// SendDADMessage implements ip.DADProtocol. +func (e *endpoint) SendDADMessage(addr tcpip.Address, _ []byte) tcpip.Error { return e.sendARPRequest(header.IPv4Any, addr, header.EthernetBroadcastAddress) } @@ -284,9 +286,12 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran e.mu.Lock() e.mu.dad.Init(&e.mu, p.options.DADConfigs, ip.DADOptions{ - Clock: p.stack.Clock(), - Protocol: e, - NICID: nic.ID(), + Clock: p.stack.Clock(), + SecureRNG: p.stack.SecureRNG(), + // ARP does not support sending nonce values. + NonceSize: 0, + Protocol: e, + NICID: nic.ID(), }) e.mu.Unlock() @@ -305,8 +310,6 @@ func (*endpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber { // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { - nicID := e.nic.ID() - stats := e.stats.arp if len(remoteLinkAddr) == 0 { @@ -314,9 +317,9 @@ func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot } if len(localAddr) == 0 { - addr, ok := e.protocol.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) - if !ok { - return &tcpip.ErrUnknownNICID{} + addr, err := e.nic.PrimaryAddress(header.IPv4ProtocolNumber) + if err != nil { + return err } if len(addr.Address) == 0 { diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go index 0053646ee..eed49f5d2 100644 --- a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go +++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go @@ -16,14 +16,27 @@ package ip import ( + "bytes" "fmt" + "io" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) +type extendRequest int + +const ( + notRequested extendRequest = iota + requested + extended +) + type dadState struct { + nonce []byte + extendRequest extendRequest + done *bool timer tcpip.Timer @@ -33,14 +46,17 @@ type dadState struct { // DADProtocol is a protocol whose core state machine can be represented by DAD. type DADProtocol interface { // SendDADMessage attempts to send a DAD probe message. - SendDADMessage(tcpip.Address) tcpip.Error + SendDADMessage(tcpip.Address, []byte) tcpip.Error } // DADOptions holds options for DAD. type DADOptions struct { - Clock tcpip.Clock - Protocol DADProtocol - NICID tcpip.NICID + Clock tcpip.Clock + SecureRNG io.Reader + NonceSize uint8 + ExtendDADTransmits uint8 + Protocol DADProtocol + NICID tcpip.NICID } // DAD performs duplicate address detection for addresses. @@ -63,6 +79,10 @@ func (d *DAD) Init(protocolMU sync.Locker, configs stack.DADConfigurations, opts panic("attempted to initialize DAD state twice") } + if opts.NonceSize != 0 && opts.ExtendDADTransmits == 0 { + panic(fmt.Sprintf("given a non-zero value for NonceSize (%d) but zero for ExtendDADTransmits", opts.NonceSize)) + } + *d = DAD{ opts: opts, configs: configs, @@ -96,10 +116,55 @@ func (d *DAD) CheckDuplicateAddressLocked(addr tcpip.Address, h stack.DADComplet s = dadState{ done: &done, timer: d.opts.Clock.AfterFunc(0, func() { - var err tcpip.Error dadDone := remaining == 0 + + nonce, earlyReturn := func() ([]byte, bool) { + d.protocolMU.Lock() + defer d.protocolMU.Unlock() + + if done { + return nil, true + } + + s, ok := d.addresses[addr] + if !ok { + panic(fmt.Sprintf("dad: timer fired but missing state for %s on NIC(%d)", addr, d.opts.NICID)) + } + + // As per RFC 7527 section 4 + // + // If any probe is looped back within RetransTimer milliseconds + // after having sent DupAddrDetectTransmits NS(DAD) messages, the + // interface continues with another MAX_MULTICAST_SOLICIT number of + // NS(DAD) messages transmitted RetransTimer milliseconds apart. + if dadDone && s.extendRequest == requested { + dadDone = false + remaining = d.opts.ExtendDADTransmits + s.extendRequest = extended + } + + if !dadDone && d.opts.NonceSize != 0 { + if s.nonce == nil { + s.nonce = make([]byte, d.opts.NonceSize) + } + + if n, err := io.ReadFull(d.opts.SecureRNG, s.nonce); err != nil { + panic(fmt.Sprintf("SecureRNG.Read(...): %s", err)) + } else if n != len(s.nonce) { + panic(fmt.Sprintf("expected to read %d bytes from secure RNG, only read %d bytes", len(s.nonce), n)) + } + } + + d.addresses[addr] = s + return s.nonce, false + }() + if earlyReturn { + return + } + + var err tcpip.Error if !dadDone { - err = d.opts.Protocol.SendDADMessage(addr) + err = d.opts.Protocol.SendDADMessage(addr, nonce) } d.protocolMU.Lock() @@ -142,6 +207,68 @@ func (d *DAD) CheckDuplicateAddressLocked(addr tcpip.Address, h stack.DADComplet return ret } +// ExtendIfNonceEqualLockedDisposition enumerates the possible results from +// ExtendIfNonceEqualLocked. +type ExtendIfNonceEqualLockedDisposition int + +const ( + // Extended indicates that the DAD process was extended. + Extended ExtendIfNonceEqualLockedDisposition = iota + + // AlreadyExtended indicates that the DAD process was already extended. + AlreadyExtended + + // NoDADStateFound indicates that DAD state was not found for the address. + NoDADStateFound + + // NonceDisabled indicates that nonce values are not sent with DAD messages. + NonceDisabled + + // NonceNotEqual indicates that the nonce value passed and the nonce in the + // last send DAD message are not equal. + NonceNotEqual +) + +// ExtendIfNonceEqualLocked extends the DAD process if the provided nonce is the +// same as the nonce sent in the last DAD message. +// +// Precondition: d.protocolMU must be locked. +func (d *DAD) ExtendIfNonceEqualLocked(addr tcpip.Address, nonce []byte) ExtendIfNonceEqualLockedDisposition { + s, ok := d.addresses[addr] + if !ok { + return NoDADStateFound + } + + if d.opts.NonceSize == 0 { + return NonceDisabled + } + + if s.extendRequest != notRequested { + return AlreadyExtended + } + + // As per RFC 7527 section 4 + // + // If any probe is looped back within RetransTimer milliseconds after having + // sent DupAddrDetectTransmits NS(DAD) messages, the interface continues + // with another MAX_MULTICAST_SOLICIT number of NS(DAD) messages transmitted + // RetransTimer milliseconds apart. + // + // If a DAD message has already been sent and the nonce value we observed is + // the same as the nonce value we last sent, then we assume our probe was + // looped back and request an extension to the DAD process. + // + // Note, the first DAD message is sent asynchronously so we need to make sure + // that we sent a DAD message by checking if we have a nonce value set. + if s.nonce != nil && bytes.Equal(s.nonce, nonce) { + s.extendRequest = requested + d.addresses[addr] = s + return Extended + } + + return NonceNotEqual +} + // StopLocked stops a currently running DAD process. // // Precondition: d.protocolMU must be locked. diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go index e00aa4678..a22b712c6 100644 --- a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go +++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go @@ -15,6 +15,7 @@ package ip_test import ( + "bytes" "testing" "time" @@ -32,8 +33,8 @@ type mockDADProtocol struct { mu struct { sync.Mutex - dad ip.DAD - sendCount map[tcpip.Address]int + dad ip.DAD + sentNonces map[tcpip.Address][][]byte } } @@ -48,26 +49,30 @@ func (m *mockDADProtocol) init(t *testing.T, c stack.DADConfigurations, opts ip. } func (m *mockDADProtocol) initLocked() { - m.mu.sendCount = make(map[tcpip.Address]int) + m.mu.sentNonces = make(map[tcpip.Address][][]byte) } -func (m *mockDADProtocol) SendDADMessage(addr tcpip.Address) tcpip.Error { +func (m *mockDADProtocol) SendDADMessage(addr tcpip.Address, nonce []byte) tcpip.Error { m.mu.Lock() defer m.mu.Unlock() - m.mu.sendCount[addr]++ + m.mu.sentNonces[addr] = append(m.mu.sentNonces[addr], nonce) return nil } func (m *mockDADProtocol) check(addrs []tcpip.Address) string { - m.mu.Lock() - defer m.mu.Unlock() - - sendCount := make(map[tcpip.Address]int) + sentNonces := make(map[tcpip.Address][][]byte) for _, a := range addrs { - sendCount[a]++ + sentNonces[a] = append(sentNonces[a], nil) } - diff := cmp.Diff(sendCount, m.mu.sendCount) + return m.checkWithNonce(sentNonces) +} + +func (m *mockDADProtocol) checkWithNonce(expectedSentNonces map[tcpip.Address][][]byte) string { + m.mu.Lock() + defer m.mu.Unlock() + + diff := cmp.Diff(expectedSentNonces, m.mu.sentNonces) m.initLocked() return diff } @@ -84,6 +89,12 @@ func (m *mockDADProtocol) stop(addr tcpip.Address, reason stack.DADResult) { m.mu.dad.StopLocked(addr, reason) } +func (m *mockDADProtocol) extendIfNonceEqual(addr tcpip.Address, nonce []byte) ip.ExtendIfNonceEqualLockedDisposition { + m.mu.Lock() + defer m.mu.Unlock() + return m.mu.dad.ExtendIfNonceEqualLocked(addr, nonce) +} + func (m *mockDADProtocol) setConfigs(c stack.DADConfigurations) { m.mu.Lock() defer m.mu.Unlock() @@ -277,3 +288,94 @@ func TestDADStop(t *testing.T) { default: } } + +func TestNonce(t *testing.T) { + const ( + nonceSize = 2 + + extendRequestAttempts = 2 + + dupAddrDetectTransmits = 2 + extendTransmits = 5 + ) + + var secureRNGBytes [nonceSize * (dupAddrDetectTransmits + extendTransmits)]byte + for i := range secureRNGBytes { + secureRNGBytes[i] = byte(i) + } + + tests := []struct { + name string + mockedReceivedNonce []byte + expectedResults [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition + expectedTransmits int + }{ + { + name: "not matching", + mockedReceivedNonce: []byte{0, 0}, + expectedResults: [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition{ip.NonceNotEqual, ip.NonceNotEqual}, + expectedTransmits: dupAddrDetectTransmits, + }, + { + name: "matching nonce", + mockedReceivedNonce: secureRNGBytes[:nonceSize], + expectedResults: [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition{ip.Extended, ip.AlreadyExtended}, + expectedTransmits: dupAddrDetectTransmits + extendTransmits, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var dad mockDADProtocol + clock := faketime.NewManualClock() + dadConfigs := stack.DADConfigurations{ + DupAddrDetectTransmits: dupAddrDetectTransmits, + RetransmitTimer: time.Second, + } + + var secureRNG bytes.Reader + secureRNG.Reset(secureRNGBytes[:]) + dad.init(t, dadConfigs, ip.DADOptions{ + Clock: clock, + SecureRNG: &secureRNG, + NonceSize: nonceSize, + ExtendDADTransmits: extendTransmits, + }) + + ch := make(chan dadResult, 1) + if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { + t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) + } + + clock.Advance(0) + for i, want := range test.expectedResults { + if got := dad.extendIfNonceEqual(addr1, test.mockedReceivedNonce); got != want { + t.Errorf("(i=%d) got dad.extendIfNonceEqual(%s, _) = %d, want = %d", i, addr1, got, want) + } + } + + for i := 0; i < test.expectedTransmits; i++ { + if diff := dad.checkWithNonce(map[tcpip.Address][][]byte{ + addr1: { + secureRNGBytes[nonceSize*i:][:nonceSize], + }, + }); diff != "" { + t.Errorf("(i=%d) dad check mismatch (-want +got):\n%s", i, diff) + } + + clock.Advance(dadConfigs.RetransmitTimer) + } + + if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" { + t.Errorf("dad result mismatch (-want +got):\n%s", diff) + } + + // Should not have anymore updates. + select { + case r := <-ch: + t.Fatalf("unexpectedly got an extra DAD result; r = %#v", r) + default: + } + }) + } +} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index aee1652fa..a4edc69c7 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -335,6 +335,10 @@ func (*testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tc return nil } +func (*testInterface) PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) { + return tcpip.AddressWithPrefix{}, nil +} + func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool { return false } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 8a2140ebe..a1660e9a3 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -593,7 +593,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // Check if the destination is owned by the stack. if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { - ep.handlePacket(pkt) + ep.handleValidatedPacket(h, pkt) return nil } @@ -634,12 +634,25 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - if !e.protocol.parse(pkt) { + h, ok := e.protocol.parseAndValidate(pkt) + if !ok { stats.MalformedPacketsReceived.Increment() return } if !e.nic.IsLoopback() { + if !e.protocol.options.AllowExternalLoopbackTraffic { + if header.IsV4LoopbackAddress(h.SourceAddress()) { + stats.InvalidSourceAddressesReceived.Increment() + return + } + + if header.IsV4LoopbackAddress(h.DestinationAddress()) { + stats.InvalidDestinationAddressesReceived.Increment() + return + } + } + if e.protocol.stack.HandleLocal() { addressEndpoint := e.AcquireAssignedAddress(header.IPv4(pkt.NetworkHeader().View()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) if addressEndpoint != nil { @@ -662,62 +675,32 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } } - e.handlePacket(pkt) + e.handleValidatedPacket(h, pkt) } +// handleLocalPacket is like HandlePacket except it does not perform the +// prerouting iptables hook or check for loopback traffic that originated from +// outside of the netstack (i.e. martian loopback packets). func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum bool) { stats := e.stats.ip - stats.PacketsReceived.Increment() pkt = pkt.CloneToInbound() - if e.protocol.parse(pkt) { - pkt.RXTransportChecksumValidated = canSkipRXChecksum - e.handlePacket(pkt) + pkt.RXTransportChecksumValidated = canSkipRXChecksum + + h, ok := e.protocol.parseAndValidate(pkt) + if !ok { + stats.MalformedPacketsReceived.Increment() return } - stats.MalformedPacketsReceived.Increment() + e.handleValidatedPacket(h, pkt) } -// handlePacket is like HandlePacket except it does not perform the prerouting -// iptables hook. -func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { +func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() stats := e.stats - h := header.IPv4(pkt.NetworkHeader().View()) - if !h.IsValid(pkt.Data().Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { - stats.ip.MalformedPacketsReceived.Increment() - return - } - - // There has been some confusion regarding verifying checksums. We need - // just look for negative 0 (0xffff) as the checksum, as it's not possible to - // get positive 0 (0) for the checksum. Some bad implementations could get it - // when doing entry replacement in the early days of the Internet, - // however the lore that one needs to check for both persists. - // - // RFC 1624 section 1 describes the source of this confusion as: - // [the partial recalculation method described in RFC 1071] computes a - // result for certain cases that differs from the one obtained from - // scratch (one's complement of one's complement sum of the original - // fields). - // - // However RFC 1624 section 5 clarifies that if using the verification method - // "recommended by RFC 1071, it does not matter if an intermediate system - // generated a -0 instead of +0". - // - // RFC1071 page 1 specifies the verification method as: - // (3) To check a checksum, the 1's complement sum is computed over the - // same set of octets, including the checksum field. If the result - // is all 1 bits (-0 in 1's complement arithmetic), the check - // succeeds. - if h.CalculateChecksum() != 0xffff { - stats.ip.MalformedPacketsReceived.Increment() - return - } - srcAddr := h.SourceAddress() dstAddr := h.DestinationAddress() @@ -1114,13 +1097,46 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} -// parse is like Parse but also attempts to parse the transport layer. +// parseAndValidate parses the packet (including its transport layer header) and +// returns the parsed IP header. // -// Returns true if the network header was successfully parsed. -func (p *protocol) parse(pkt *stack.PacketBuffer) bool { +// Returns true if the IP header was successfully parsed. +func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv4, bool) { transProtoNum, hasTransportHdr, ok := p.Parse(pkt) if !ok { - return false + return nil, false + } + + h := header.IPv4(pkt.NetworkHeader().View()) + // Do not include the link header's size when calculating the size of the IP + // packet. + if !h.IsValid(pkt.Size() - pkt.LinkHeader().View().Size()) { + return nil, false + } + + // There has been some confusion regarding verifying checksums. We need + // just look for negative 0 (0xffff) as the checksum, as it's not possible to + // get positive 0 (0) for the checksum. Some bad implementations could get it + // when doing entry replacement in the early days of the Internet, + // however the lore that one needs to check for both persists. + // + // RFC 1624 section 1 describes the source of this confusion as: + // [the partial recalculation method described in RFC 1071] computes a + // result for certain cases that differs from the one obtained from + // scratch (one's complement of one's complement sum of the original + // fields). + // + // However RFC 1624 section 5 clarifies that if using the verification method + // "recommended by RFC 1071, it does not matter if an intermediate system + // generated a -0 instead of +0". + // + // RFC1071 page 1 specifies the verification method as: + // (3) To check a checksum, the 1's complement sum is computed over the + // same set of octets, including the checksum field. If the result + // is all 1 bits (-0 in 1's complement arithmetic), the check + // succeeds. + if h.CalculateChecksum() != 0xffff { + return nil, false } if hasTransportHdr { @@ -1134,7 +1150,7 @@ func (p *protocol) parse(pkt *stack.PacketBuffer) bool { } } - return true + return h, true } // Parse implements stack.NetworkProtocol.Parse. @@ -1213,6 +1229,10 @@ func hashRoute(srcAddr, dstAddr tcpip.Address, protocol tcpip.TransportProtocolN type Options struct { // IGMP holds options for IGMP. IGMP IGMPOptions + + // AllowExternalLoopbackTraffic indicates that inbound loopback packets (i.e. + // martian loopback packets) should be accepted. + AllowExternalLoopbackTraffic bool } // NewProtocolWithOptions returns an IPv4 network protocol. @@ -1599,9 +1619,8 @@ func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Opt // TODO(https://gvisor.dev/issue/4586): This will need tweaking when we start // really forwarding packets as we may need to get two addresses, for rx and // tx interfaces. We will also have to take usage into account. - prefixedAddress, ok := e.protocol.stack.GetMainNICAddress(e.nic.ID(), ProtocolNumber) - localAddress := prefixedAddress.Address - if !ok { + localAddress := e.MainAddress().Address + if len(localAddress) == 0 { h := header.IPv4(pkt.NetworkHeader().View()) dstAddr := h.DestinationAddress() if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(dstAddr) { diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 6344a3e09..2afa856dc 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -369,6 +369,18 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r return } + var it header.NDPOptionIterator + { + var err error + it, err = ns.Options().Iter(false /* check */) + if err != nil { + // Options are not valid as per the wire format, silently drop the + // packet. + received.invalid.Increment() + return + } + } + if e.hasTentativeAddr(targetAddr) { // If the target address is tentative and the source of the packet is a // unicast (specified) address, then the source of the packet is @@ -382,6 +394,22 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // stack know so it can handle such a scenario and do nothing further with // the NS. if srcAddr == header.IPv6Any { + var nonce []byte + for { + opt, done, err := it.Next() + if err != nil { + received.invalid.Increment() + return + } + if done { + break + } + if n, ok := opt.(header.NDPNonceOption); ok { + nonce = n.Nonce() + break + } + } + // Since this is a DAD message we know the sender does not actually hold // the target address so there is no "holder". var holderLinkAddress tcpip.LinkAddress @@ -397,7 +425,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate // address is detected for an assigned address. - switch err := e.dupTentativeAddrDetected(targetAddr, holderLinkAddress); err.(type) { + switch err := e.dupTentativeAddrDetected(targetAddr, holderLinkAddress, nonce); err.(type) { case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState: default: panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err)) @@ -418,21 +446,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r return } - var sourceLinkAddr tcpip.LinkAddress - { - it, err := ns.Options().Iter(false /* check */) - if err != nil { - // Options are not valid as per the wire format, silently drop the - // packet. - received.invalid.Increment() - return - } - - sourceLinkAddr, ok = getSourceLinkAddr(it) - if !ok { - received.invalid.Increment() - return - } + sourceLinkAddr, ok := getSourceLinkAddr(it) + if !ok { + received.invalid.Increment() + return } // As per RFC 4861 section 4.3, the Source Link-Layer Address Option MUST @@ -586,6 +603,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r e.dad.mu.Unlock() if e.hasTentativeAddr(targetAddr) { + // We only send a nonce value in DAD messages to check for loopedback + // messages so we use the empty nonce value here. + var nonce []byte + // We just got an NA from a node that owns an address we are performing // DAD on, implying the address is not unique. In this case we let the // stack know so it can handle such a scenario and do nothing furthur with @@ -602,7 +623,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate // address is detected for an assigned address. - switch err := e.dupTentativeAddrDetected(targetAddr, targetLinkAddr); err.(type) { + switch err := e.dupTentativeAddrDetected(targetAddr, targetLinkAddr, nonce); err.(type) { case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState: return default: @@ -899,13 +920,16 @@ func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot } if len(localAddr) == 0 { + // Find an address that we can use as our source address. addressEndpoint := e.AcquireOutgoingPrimaryAddress(remoteAddr, false /* allowExpired */) if addressEndpoint == nil { return &tcpip.ErrNetworkUnreachable{} } localAddr = addressEndpoint.AddressWithPrefix().Address - } else if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, localAddr) == 0 { + addressEndpoint.DecRef() + } else if !e.checkLocalAddress(localAddr) { + // The provided local address is not assigned to us. return &tcpip.ErrBadLocalAddress{} } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index d4e63710c..47d713f88 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -155,6 +155,10 @@ func (t *testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, return nil } +func (*testInterface) PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) { + return tcpip.AddressWithPrefix{}, nil +} + func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool { return false } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 46b6cc41a..83e98bab9 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -348,7 +348,7 @@ func (e *endpoint) hasTentativeAddr(addr tcpip.Address) bool { // dupTentativeAddrDetected removes the tentative address if it exists. If the // address was generated via SLAAC, an attempt is made to generate a new // address. -func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr tcpip.LinkAddress) tcpip.Error { +func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr tcpip.LinkAddress, nonce []byte) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -361,27 +361,48 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr t return &tcpip.ErrInvalidEndpointState{} } - // If the address is a SLAAC address, do not invalidate its SLAAC prefix as an - // attempt will be made to generate a new address for it. - if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */, &stack.DADDupAddrDetected{HolderLinkAddress: holderLinkAddr}); err != nil { - return err - } + switch result := e.mu.ndp.dad.ExtendIfNonceEqualLocked(addr, nonce); result { + case ip.Extended: + // The nonce we got back was the same we sent so we know the message + // indicating a duplicate address was likely ours so do not consider + // the address duplicate here. + return nil + case ip.AlreadyExtended: + // See Extended. + // + // Our DAD message was looped back already. + return nil + case ip.NoDADStateFound: + panic(fmt.Sprintf("expected DAD state for tentative address %s", addr)) + case ip.NonceDisabled: + // If nonce is disabled then we have no way to know if the packet was + // looped-back so we have to assume it indicates a duplicate address. + fallthrough + case ip.NonceNotEqual: + // If the address is a SLAAC address, do not invalidate its SLAAC prefix as an + // attempt will be made to generate a new address for it. + if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */, &stack.DADDupAddrDetected{HolderLinkAddress: holderLinkAddr}); err != nil { + return err + } - prefix := addressEndpoint.Subnet() + prefix := addressEndpoint.Subnet() - switch t := addressEndpoint.ConfigType(); t { - case stack.AddressConfigStatic: - case stack.AddressConfigSlaac: - e.mu.ndp.regenerateSLAACAddr(prefix) - case stack.AddressConfigSlaacTemp: - // Do not reset the generation attempts counter for the prefix as the - // temporary address is being regenerated in response to a DAD conflict. - e.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */) + switch t := addressEndpoint.ConfigType(); t { + case stack.AddressConfigStatic: + case stack.AddressConfigSlaac: + e.mu.ndp.regenerateSLAACAddr(prefix) + case stack.AddressConfigSlaacTemp: + // Do not reset the generation attempts counter for the prefix as the + // temporary address is being regenerated in response to a DAD conflict. + e.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */) + default: + panic(fmt.Sprintf("unrecognized address config type = %d", t)) + } + + return nil default: - panic(fmt.Sprintf("unrecognized address config type = %d", t)) + panic(fmt.Sprintf("unhandled result = %d", result)) } - - return nil } // transitionForwarding transitions the endpoint's forwarding status to @@ -863,9 +884,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { dstAddr := h.DestinationAddress() // Check if the destination is owned by the stack. - if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { - ep.handlePacket(pkt) + ep.handleValidatedPacket(h, pkt) return nil } @@ -904,12 +924,25 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - if !e.protocol.parse(pkt) { + h, ok := e.protocol.parseAndValidate(pkt) + if !ok { stats.MalformedPacketsReceived.Increment() return } if !e.nic.IsLoopback() { + if !e.protocol.options.AllowExternalLoopbackTraffic { + if header.IsV6LoopbackAddress(h.SourceAddress()) { + stats.InvalidSourceAddressesReceived.Increment() + return + } + + if header.IsV6LoopbackAddress(h.DestinationAddress()) { + stats.InvalidDestinationAddressesReceived.Increment() + return + } + } + if e.protocol.stack.HandleLocal() { addressEndpoint := e.AcquireAssignedAddress(header.IPv6(pkt.NetworkHeader().View()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) if addressEndpoint != nil { @@ -932,35 +965,31 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } } - e.handlePacket(pkt) + e.handleValidatedPacket(h, pkt) } +// handleLocalPacket is like HandlePacket except it does not perform the +// prerouting iptables hook or check for loopback traffic that originated from +// outside of the netstack (i.e. martian loopback packets). func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum bool) { stats := e.stats.ip - stats.PacketsReceived.Increment() pkt = pkt.CloneToInbound() - if e.protocol.parse(pkt) { - pkt.RXTransportChecksumValidated = canSkipRXChecksum - e.handlePacket(pkt) + pkt.RXTransportChecksumValidated = canSkipRXChecksum + + h, ok := e.protocol.parseAndValidate(pkt) + if !ok { + stats.MalformedPacketsReceived.Increment() return } - stats.MalformedPacketsReceived.Increment() + e.handleValidatedPacket(h, pkt) } -// handlePacket is like HandlePacket except it does not perform the prerouting -// iptables hook. -func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { +func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() stats := e.stats.ip - - h := header.IPv6(pkt.NetworkHeader().View()) - if !h.IsValid(pkt.Data().Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { - stats.MalformedPacketsReceived.Increment() - return - } srcAddr := h.SourceAddress() dstAddr := h.DestinationAddress() @@ -1797,16 +1826,36 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran dispatcher: dispatcher, protocol: p, } + + // NDP options must be 8 octet aligned and the first 2 bytes are used for + // the type and length fields leaving 6 octets as the minimum size for a + // nonce option without padding. + const nonceSize = 6 + + // As per RFC 7527 section 4.1, + // + // If any probe is looped back within RetransTimer milliseconds after + // having sent DupAddrDetectTransmits NS(DAD) messages, the interface + // continues with another MAX_MULTICAST_SOLICIT number of NS(DAD) + // messages transmitted RetransTimer milliseconds apart. + // + // Value taken from RFC 4861 section 10. + const maxMulticastSolicit = 3 + dadOptions := ip.DADOptions{ + Clock: p.stack.Clock(), + SecureRNG: p.stack.SecureRNG(), + NonceSize: nonceSize, + ExtendDADTransmits: maxMulticastSolicit, + Protocol: &e.mu.ndp, + NICID: nic.ID(), + } + e.mu.Lock() e.mu.addressableEndpointState.Init(e) - e.mu.ndp.init(e) + e.mu.ndp.init(e, dadOptions) e.mu.mld.init(e) e.dad.mu.Lock() - e.dad.mu.dad.Init(&e.dad.mu, p.options.DADConfigs, ip.DADOptions{ - Clock: p.stack.Clock(), - Protocol: &e.mu.ndp, - NICID: nic.ID(), - }) + e.dad.mu.dad.Init(&e.dad.mu, p.options.DADConfigs, dadOptions) e.dad.mu.Unlock() e.mu.Unlock() @@ -1879,13 +1928,21 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} -// parse is like Parse but also attempts to parse the transport layer. +// parseAndValidate parses the packet (including its transport layer header) and +// returns the parsed IP header. // -// Returns true if the network header was successfully parsed. -func (p *protocol) parse(pkt *stack.PacketBuffer) bool { +// Returns true if the IP header was successfully parsed. +func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv6, bool) { transProtoNum, hasTransportHdr, ok := p.Parse(pkt) if !ok { - return false + return nil, false + } + + h := header.IPv6(pkt.NetworkHeader().View()) + // Do not include the link header's size when calculating the size of the IP + // packet. + if !h.IsValid(pkt.Size() - pkt.LinkHeader().View().Size()) { + return nil, false } if hasTransportHdr { @@ -1899,7 +1956,7 @@ func (p *protocol) parse(pkt *stack.PacketBuffer) bool { } } - return true + return h, true } // Parse implements stack.NetworkProtocol.Parse. @@ -2013,6 +2070,10 @@ type Options struct { // DADConfigs holds the default DAD configurations used by IPv6 endpoints. DADConfigs stack.DADConfigurations + + // AllowExternalLoopbackTraffic indicates that inbound loopback packets (i.e. + // martian loopback packets) should be accepted. + AllowExternalLoopbackTraffic bool } // NewProtocolWithOptions returns an IPv6 network protocol. diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 266a53e3b..81f5f23c3 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -343,6 +343,8 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) { // TestAddIpv6Address tests adding IPv6 addresses. func TestAddIpv6Address(t *testing.T) { + const nicID = 1 + tests := []struct { name string addr tcpip.Address @@ -367,18 +369,18 @@ func TestAddIpv6Address(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) - if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(1, ProtocolNumber, test.addr); err != nil { - t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err) + if err := s.AddAddress(nicID, ProtocolNumber, test.addr); err != nil { + t.Fatalf("AddAddress(%d, %d, nil) = %s", nicID, ProtocolNumber, err) } - if addr, ok := s.GetMainNICAddress(1, header.IPv6ProtocolNumber); !ok { - t.Fatalf("got stack.GetMainNICAddress(1, %d) = (_, false), want = (_, true)", header.IPv6ProtocolNumber) + if addr, err := s.GetMainNICAddress(nicID, ProtocolNumber); err != nil { + t.Fatalf("stack.GetMainNICAddress(%d, %d): %s", nicID, ProtocolNumber, err) } else if addr.Address != test.addr { - t.Fatalf("got stack.GetMainNICAddress(1_, %d) = (%s, true), want = (%s, true)", header.IPv6ProtocolNumber, addr.Address, test.addr) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, ProtocolNumber, addr.Address, test.addr) } }) } diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go index 9a425e50a..85a8f9944 100644 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -15,6 +15,7 @@ package ipv6_test import ( + "bytes" "testing" "time" @@ -119,11 +120,26 @@ func TestSendQueuedMLDReports(t *testing.T) { }, } + nonce := [...]byte{ + 1, 2, 3, 4, 5, 6, + } + + const maxNSMessages = 2 + secureRNGBytes := make([]byte, len(nonce)*maxNSMessages) + for b := secureRNGBytes[:]; len(b) > 0; b = b[len(nonce):] { + if n := copy(b, nonce[:]); n != len(nonce) { + t.Fatalf("got copy(...) = %d, want = %d", n, len(nonce)) + } + } + for _, test := range tests { t.Run(test.name, func(t *testing.T) { dadResolutionTime := test.retransmitTimer * time.Duration(test.dadTransmits) clock := faketime.NewManualClock() + var secureRNG bytes.Reader + secureRNG.Reset(secureRNGBytes[:]) s := stack.New(stack.Options{ + SecureRNG: &secureRNG, NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ DADConfigs: stack.DADConfigurations{ DupAddrDetectTransmits: test.dadTransmits, @@ -154,7 +170,7 @@ func TestSendQueuedMLDReports(t *testing.T) { checker.TTL(header.NDPHopLimit), checker.NDPNS( checker.NDPNSTargetAddress(addr), - checker.NDPNSOptions(nil), + checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonce[:])}), )) } } diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index d9b728878..536493f87 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -1789,18 +1789,14 @@ func (ndp *ndpState) stopSolicitingRouters() { ndp.rtrSolicitTimer = timer{} } -func (ndp *ndpState) init(ep *endpoint) { +func (ndp *ndpState) init(ep *endpoint, dadOptions ip.DADOptions) { if ndp.defaultRouters != nil { panic("attempted to initialize NDP state twice") } ndp.ep = ep ndp.configs = ep.protocol.options.NDPConfigs - ndp.dad.Init(&ndp.ep.mu, ep.protocol.options.DADConfigs, ip.DADOptions{ - Clock: ep.protocol.stack.Clock(), - Protocol: ndp, - NICID: ep.nic.ID(), - }) + ndp.dad.Init(&ndp.ep.mu, ep.protocol.options.DADConfigs, dadOptions) ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState) ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState) ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState) @@ -1811,9 +1807,11 @@ func (ndp *ndpState) init(ep *endpoint) { } } -func (ndp *ndpState) SendDADMessage(addr tcpip.Address) tcpip.Error { +func (ndp *ndpState) SendDADMessage(addr tcpip.Address, nonce []byte) tcpip.Error { snmc := header.SolicitedNodeAddr(addr) - return ndp.ep.sendNDPNS(header.IPv6Any, snmc, addr, header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* opts */) + return ndp.ep.sendNDPNS(header.IPv6Any, snmc, addr, header.EthernetAddressFromMulticastIPv6Address(snmc), header.NDPOptionsSerializer{ + header.NDPNonceOption(nonce), + }) } func (e *endpoint) sendNDPNS(srcAddr, dstAddr, targetAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, opts header.NDPOptionsSerializer) tcpip.Error { diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 6e850fd46..52b9a200c 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -15,6 +15,7 @@ package ipv6 import ( + "bytes" "context" "strings" "testing" @@ -1264,8 +1265,21 @@ func TestCheckDuplicateAddress(t *testing.T) { DupAddrDetectTransmits: 1, RetransmitTimer: time.Second, } + + nonces := [...][]byte{ + {1, 2, 3, 4, 5, 6}, + {7, 8, 9, 10, 11, 12}, + } + + var secureRNGBytes []byte + for _, n := range nonces { + secureRNGBytes = append(secureRNGBytes, n...) + } + var secureRNG bytes.Reader + secureRNG.Reset(secureRNGBytes[:]) s := stack.New(stack.Options{ - Clock: clock, + SecureRNG: &secureRNG, + Clock: clock, NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{ DADConfigs: dadConfigs, })}, @@ -1278,10 +1292,36 @@ func TestCheckDuplicateAddress(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - dadPacketsSent := 1 + dadPacketsSent := 0 + snmc := header.SolicitedNodeAddr(lladdr0) + remoteLinkAddr := header.EthernetAddressFromMulticastIPv6Address(snmc) + checkDADMsg := func() { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("expected %d-th DAD message", dadPacketsSent) + } + + if p.Proto != header.IPv6ProtocolNumber { + t.Errorf("(i=%d) got p.Proto = %d, want = %d", dadPacketsSent, p.Proto, header.IPv6ProtocolNumber) + } + + if p.Route.RemoteLinkAddress != remoteLinkAddr { + t.Errorf("(i=%d) got p.Route.RemoteLinkAddress = %s, want = %s", dadPacketsSent, p.Route.RemoteLinkAddress, remoteLinkAddr) + } + + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(snmc), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(lladdr0), + checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonces[dadPacketsSent])}), + )) + } if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) } + checkDADMsg() // Start DAD for the address we just added. // @@ -1297,6 +1337,7 @@ func TestCheckDuplicateAddress(t *testing.T) { } else if res != stack.DADStarting { t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, ProtocolNumber, lladdr0, res, stack.DADStarting) } + checkDADMsg() // Remove the address and make sure our DAD request was not stopped. if err := s.RemoveAddress(nicID, lladdr0); err != nil { @@ -1328,33 +1369,6 @@ func TestCheckDuplicateAddress(t *testing.T) { default: } - snmc := header.SolicitedNodeAddr(lladdr0) - remoteLinkAddr := header.EthernetAddressFromMulticastIPv6Address(snmc) - - for i := 0; i < dadPacketsSent; i++ { - p, ok := e.Read() - if !ok { - t.Fatalf("expected %d-th DAD message", i) - } - - if p.Proto != header.IPv6ProtocolNumber { - t.Errorf("(i=%d) got p.Proto = %d, want = %d", i, p.Proto, header.IPv6ProtocolNumber) - } - - if p.Route.RemoteLinkAddress != remoteLinkAddr { - t.Errorf("(i=%d) got p.Route.RemoteLinkAddress = %s, want = %s", i, p.Route.RemoteLinkAddress, remoteLinkAddr) - } - - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(snmc), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(lladdr0), - checker.NDPNSOptions(nil), - )) - } - // Should have no more packets. if p, ok := e.Read(); ok { t.Errorf("got unexpected packet = %#v", p) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 47796a6ba..43e6d102c 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -15,6 +15,7 @@ package stack_test import ( + "bytes" "context" "encoding/binary" "fmt" @@ -29,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" @@ -358,6 +360,66 @@ func TestDADDisabled(t *testing.T) { } } +func TestDADResolveLoopback(t *testing.T) { + const nicID = 1 + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + } + + dadConfigs := stack.DADConfigurations{ + RetransmitTimer: time.Second, + DupAddrDetectTransmits: 1, + } + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + Clock: clock, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPDisp: &ndpDisp, + DADConfigs: dadConfigs, + })}, + }) + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: addr1, + PrefixLen: defaultPrefixLen, + } + if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) + } + + // Address should not be considered bound to the NIC yet (DAD ongoing). + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) + } + + // DAD should not resolve after the normal resolution time since our DAD + // message was looped back - we should extend our DAD process. + dadResolutionTime := time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer + clock.Advance(dadResolutionTime) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Error(err) + } + + // Make sure the address does not resolve before the extended resolution time + // has passed. + const delta = time.Nanosecond + // DAD will send extra NS probes if an NS message is looped back. + const extraTransmits = 3 + clock.Advance(dadResolutionTime*extraTransmits - delta) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Error(err) + } + + // DAD should now resolve. + clock.Advance(delta) + if diff := checkDADEvent(<-ndpDisp.dadC, nicID, addr1, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) + } +} + // TestDADResolve tests that an address successfully resolves after performing // DAD for various values of DupAddrDetectTransmits and RetransmitTimer. // Included in the subtests is a test to make sure that an invalid @@ -404,6 +466,16 @@ func TestDADResolve(t *testing.T) { }, } + nonces := [][]byte{ + {1, 2, 3, 4, 5, 6}, + {7, 8, 9, 10, 11, 12}, + } + + var secureRNGBytes []byte + for _, n := range nonces { + secureRNGBytes = append(secureRNGBytes, n...) + } + for _, test := range tests { test := test @@ -419,7 +491,12 @@ func TestDADResolve(t *testing.T) { headerLength: test.linkHeaderLen, } e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + var secureRNG bytes.Reader + secureRNG.Reset(secureRNGBytes) + s := stack.New(stack.Options{ + SecureRNG: &secureRNG, NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, DADConfigs: stack.DADConfigurations{ @@ -553,7 +630,7 @@ func TestDADResolve(t *testing.T) { checker.TTL(header.NDPHopLimit), checker.NDPNS( checker.NDPNSTargetAddress(addr1), - checker.NDPNSOptions(nil), + checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonces[i])}), )) if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 62f7c880e..ca15c0691 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -568,23 +568,19 @@ func (n *nic) primaryAddresses() []tcpip.ProtocolAddress { return addrs } -// primaryAddress returns the primary address associated with this NIC. -// -// primaryAddress will return the first non-deprecated address if such an -// address exists. If no non-deprecated address exists, the first deprecated -// address will be returned. -func (n *nic) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix { +// PrimaryAddress implements NetworkInterface. +func (n *nic) PrimaryAddress(proto tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) { ep, ok := n.networkEndpoints[proto] if !ok { - return tcpip.AddressWithPrefix{} + return tcpip.AddressWithPrefix{}, &tcpip.ErrUnknownProtocol{} } addressableEndpoint, ok := ep.(AddressableEndpoint) if !ok { - return tcpip.AddressWithPrefix{} + return tcpip.AddressWithPrefix{}, &tcpip.ErrNotSupported{} } - return addressableEndpoint.MainAddress() + return addressableEndpoint.MainAddress(), nil } // removeAddress removes an address from n. diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 85f0f471a..ff3a385e1 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -525,6 +525,14 @@ type NetworkInterface interface { // assigned to it. Spoofing() bool + // PrimaryAddress returns the primary address associated with the interface. + // + // PrimaryAddress will return the first non-deprecated address if such an + // address exists. If no non-deprecated addresses exist, the first deprecated + // address will be returned. If no deprecated addresses exist, the zero value + // will be returned. + PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) + // CheckLocalAddress returns true if the address exists on the interface. CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 53370c354..931a97ddc 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -23,6 +23,7 @@ import ( "bytes" "encoding/binary" "fmt" + "io" mathrand "math/rand" "sync/atomic" "time" @@ -445,6 +446,9 @@ type Stack struct { // used when a random number is required. randomGenerator *mathrand.Rand + // secureRNG is a cryptographically secure random number generator. + secureRNG io.Reader + // sendBufferSize holds the min/default/max send buffer sizes for // endpoints other than TCP. sendBufferSize tcpip.SendBufferSizeOption @@ -528,6 +532,9 @@ type Options struct { // IPTables are the initial iptables rules. If nil, iptables will allow // all traffic. IPTables *IPTables + + // SecureRNG is a cryptographically secure random number generator. + SecureRNG io.Reader } // TransportEndpointInfo holds useful information about a transport endpoint @@ -636,6 +643,10 @@ func New(opts Options) *Stack { opts.NUDConfigs.resetInvalidFields() + if opts.SecureRNG == nil { + opts.SecureRNG = rand.Reader + } + s := &Stack{ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), @@ -652,6 +663,7 @@ func New(opts Options) *Stack { uniqueIDGenerator: opts.UniqueID, nudDisp: opts.NUDDisp, randomGenerator: mathrand.New(randSrc), + secureRNG: opts.SecureRNG, sendBufferSize: tcpip.SendBufferSizeOption{ Min: MinBufferSize, Default: DefaultBufferSize, @@ -1211,20 +1223,19 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress { } // GetMainNICAddress returns the first non-deprecated primary address and prefix -// for the given NIC and protocol. If no non-deprecated primary address exists, -// a deprecated primary address and prefix will be returned. Returns false if -// the NIC doesn't exist and an empty value if the NIC doesn't have a primary -// address for the given protocol. -func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, bool) { +// for the given NIC and protocol. If no non-deprecated primary addresses exist, +// a deprecated address will be returned. If no deprecated addresses exist, the +// zero value will be returned. +func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() nic, ok := s.nics[id] if !ok { - return tcpip.AddressWithPrefix{}, false + return tcpip.AddressWithPrefix{}, &tcpip.ErrUnknownNICID{} } - return nic.primaryAddress(protocol), true + return nic.PrimaryAddress(protocol) } func (s *Stack) getAddressEP(nic *nic, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) AssignableAddressEndpoint { @@ -2047,6 +2058,12 @@ func (s *Stack) Rand() *mathrand.Rand { return s.randomGenerator } +// SecureRNG returns the stack's cryptographically secure random number +// generator. +func (s *Stack) SecureRNG() io.Reader { + return s.secureRNG +} + func generateRandUint32() uint32 { b := make([]byte, 4) if _, err := rand.Read(b); err != nil { diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 880219007..7ddf7a083 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -62,10 +62,10 @@ const ( ) func checkGetMainNICAddress(s *stack.Stack, nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber, want tcpip.AddressWithPrefix) error { - if addr, ok := s.GetMainNICAddress(nicID, proto); !ok { - return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, false), want = (_, true)", nicID, proto) + if addr, err := s.GetMainNICAddress(nicID, proto); err != nil { + return fmt.Errorf("stack.GetMainNICAddress(%d, %d): %s", nicID, proto, err) } else if addr != want { - return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, true), want = (%s, true)", nicID, proto, addr, want) + return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, proto, addr, want) } return nil } @@ -1854,6 +1854,8 @@ func TestNetworkOption(t *testing.T) { } func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { + const nicID = 1 + for _, addrLen := range []int{4, 16} { t.Run(fmt.Sprintf("addrLen=%d", addrLen), func(t *testing.T) { for canBe := 0; canBe < 3; canBe++ { @@ -1864,8 +1866,8 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) + if err := s.CreateNIC(nicID, ep); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } // Insert <canBe> primary and <never> never-primary addresses. // Each one will add a network endpoint to the NIC. @@ -1888,34 +1890,34 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { PrefixLen: addrLen * 8, }, } - if err := s.AddProtocolAddressWithOptions(1, protocolAddress, behavior); err != nil { - t.Fatal("AddProtocolAddressWithOptions failed:", err) + if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil { + t.Fatalf("AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, protocolAddress, behavior, err) } // Remember the address/prefix. primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{} } else { - if err := s.AddAddressWithOptions(1, fakeNetNumber, address, behavior); err != nil { - t.Fatal("AddAddressWithOptions failed:", err) + if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil { + t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s:", nicID, fakeNetNumber, address, behavior, err) } } } // Check that GetMainNICAddress returns an address if at least // one primary address was added. In that case make sure the // address/prefixLen matches what we added. - gotAddr, ok := s.GetMainNICAddress(1, fakeNetNumber) - if !ok { - t.Fatalf("got GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber) + gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber) + if err != nil { + t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) } if len(primaryAddrAdded) == 0 { // No primary addresses present. if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { - t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, gotAddr, wantAddr) + t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, wantAddr) } } else { // At least one primary address was added, verify the returned // address is in the list of primary addresses we added. if _, ok := primaryAddrAdded[gotAddr]; !ok { - t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, gotAddr, primaryAddrAdded) + t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, primaryAddrAdded) } } }) @@ -1926,6 +1928,45 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { } } +func TestGetMainNICAddressErrors(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, + }) + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + // Sanity check with a successful call. + if addr, err := s.GetMainNICAddress(nicID, ipv4.ProtocolNumber); err != nil { + t.Errorf("s.GetMainNICAddress(%d, %d): %s", nicID, ipv4.ProtocolNumber, err) + } else if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, ipv4.ProtocolNumber, addr, want) + } + + const unknownNICID = nicID + 1 + switch addr, err := s.GetMainNICAddress(unknownNICID, ipv4.ProtocolNumber); err.(type) { + case *tcpip.ErrUnknownNICID: + default: + t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, %T), want = (_, tcpip.ErrUnknownNICID)", unknownNICID, ipv4.ProtocolNumber, addr, err) + } + + // ARP is not an addressable network endpoint. + switch addr, err := s.GetMainNICAddress(nicID, arp.ProtocolNumber); err.(type) { + case *tcpip.ErrNotSupported: + default: + t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, %T), want = (_, tcpip.ErrNotSupported)", nicID, arp.ProtocolNumber, addr, err) + } + + const unknownProtocolNumber = 1234 + switch addr, err := s.GetMainNICAddress(nicID, unknownProtocolNumber); err.(type) { + case *tcpip.ErrUnknownProtocol: + default: + t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, %T), want = (_, tcpip.ErrUnknownProtocol)", nicID, unknownProtocolNumber, addr, err) + } +} + func TestGetMainNICAddressAddRemove(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, @@ -2507,11 +2548,15 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { } } - // Check that we get no address after removal. - if err := checkGetMainNICAddress(s, 1, fakeNetNumber, tcpip.AddressWithPrefix{}); err != nil { + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, expectedMainAddr); err != nil { t.Fatal(err) } - if err := checkGetMainNICAddress(s, 1, header.IPv6ProtocolNumber, expectedMainAddr); err != nil { + + // Disabling the NIC should remove the auto-generated address. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { t.Fatal(err) } }) @@ -2617,6 +2662,8 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { // TestNewPEB tests that a new PrimaryEndpointBehavior value (peb) is respected // when an address's kind gets "promoted" to permanent from permanentExpired. func TestNewPEBOnPromotionToPermanent(t *testing.T) { + const nicID = 1 + pebs := []stack.PrimaryEndpointBehavior{ stack.NeverPrimaryEndpoint, stack.CanBePrimaryEndpoint, @@ -2630,8 +2677,8 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed:", err) + if err := s.CreateNIC(nicID, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } // Add a permanent address with initial @@ -2639,20 +2686,21 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // NeverPrimaryEndpoint, the address should not // be returned by a call to GetMainNICAddress; // else, it should. - if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", pi); err != nil { - t.Fatal("AddAddressWithOptions failed:", err) + const address1 = tcpip.Address("\x01") + if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, pi); err != nil { + t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err) } - addr, ok := s.GetMainNICAddress(1, fakeNetNumber) - if !ok { - t.Fatalf("GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber) + addr, err := s.GetMainNICAddress(nicID, fakeNetNumber) + if err != nil { + t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) } if pi == stack.NeverPrimaryEndpoint { if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, addr, want) + t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr, want) } - } else if addr.Address != "\x01" { - t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (1, true)", fakeNetNumber, addr.Address) + } else if addr.Address != address1 { + t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr.Address, address1) } { @@ -2670,13 +2718,14 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // new peb is respected when an address gets // "promoted" to permanent from a // permanentExpired kind. - r, err := s.FindRoute(1, "\x01", "\x02", fakeNetNumber, false) + const address2 = tcpip.Address("\x02") + r, err := s.FindRoute(nicID, address1, address2, fakeNetNumber, false) if err != nil { - t.Fatalf("FindRoute failed: %v", err) + t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, address1, address2, fakeNetNumber, err) } defer r.Release() - if err := s.RemoveAddress(1, "\x01"); err != nil { - t.Fatalf("RemoveAddress failed: %v", err) + if err := s.RemoveAddress(nicID, address1); err != nil { + t.Fatalf("RemoveAddress(%d, %s): %s", nicID, address1, err) } // @@ -2687,19 +2736,20 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // Add some other address with peb set to // FirstPrimaryEndpoint. - if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x03", stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddAddressWithOptions failed: %v", err) + const address3 = tcpip.Address("\x03") + if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint); err != nil { + t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint, err) } // Add back the address we removed earlier and // make sure the new peb was respected. // (The address should just be promoted now). - if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", ps); err != nil { - t.Fatalf("AddAddressWithOptions failed: %v", err) + if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, ps); err != nil { + t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err) } var primaryAddrs []tcpip.Address - for _, pa := range s.NICInfo()[1].ProtocolAddresses { + for _, pa := range s.NICInfo()[nicID].ProtocolAddresses { primaryAddrs = append(primaryAddrs, pa.AddressWithPrefix.Address) } var expectedList []tcpip.Address @@ -2728,20 +2778,20 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // should be returned by a call to // GetMainNICAddress; else, our original address // should be returned. - if err := s.RemoveAddress(1, "\x03"); err != nil { - t.Fatalf("RemoveAddress failed: %v", err) + if err := s.RemoveAddress(nicID, address3); err != nil { + t.Fatalf("RemoveAddress(%d, %s): %s", nicID, address3, err) } - addr, ok = s.GetMainNICAddress(1, fakeNetNumber) - if !ok { - t.Fatalf("got GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber) + addr, err = s.GetMainNICAddress(nicID, fakeNetNumber) + if err != nil { + t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) } if ps == stack.NeverPrimaryEndpoint { if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, addr, want) + t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr, want) } } else { - if addr.Address != "\x01" { - t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (1, true)", fakeNetNumber, addr.Address) + if addr.Address != address1 { + t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr.Address, address1) } } }) diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 10cbbe589..c1c6cbccd 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -33,8 +33,8 @@ import ( ) const ( - testSrcAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testDstAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + testSrcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + testDstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" testSrcAddrV4 = "\x0a\x00\x00\x01" testDstAddrV4 = "\x0a\x00\x00\x02" diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 58aabe547..3cc8c36f1 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -72,11 +72,13 @@ go_test( "//pkg/tcpip/buffer", "//pkg/tcpip/checker", "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", + "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 80afc2825..6462e9d42 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -24,11 +24,13 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -502,3 +504,262 @@ func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) { }) } } + +func TestExternalLoopbackTraffic(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + + ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01") + + numPackets = 1 + ) + + loopbackSourcedICMPv4 := func(e *channel.Endpoint) { + utils.RxICMPv4EchoRequest(e, ipv4Loopback, utils.Ipv4Addr.Address) + } + + loopbackSourcedICMPv6 := func(e *channel.Endpoint) { + utils.RxICMPv6EchoRequest(e, header.IPv6Loopback, utils.Ipv6Addr.Address) + } + + loopbackDestinedICMPv4 := func(e *channel.Endpoint) { + utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, ipv4Loopback) + } + + loopbackDestinedICMPv6 := func(e *channel.Endpoint) { + utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, header.IPv6Loopback) + } + + invalidSrcAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter { + return s.InvalidSourceAddressesReceived + } + + invalidDestAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter { + return s.InvalidDestinationAddressesReceived + } + + tests := []struct { + name string + allowExternalLoopback bool + forwarding bool + rxICMP func(*channel.Endpoint) + invalidAddressStat func(tcpip.IPStats) *tcpip.StatCounter + shouldAccept bool + }{ + { + name: "IPv4 external loopback sourced traffic without forwarding and drop external loopback disabled", + allowExternalLoopback: true, + forwarding: false, + rxICMP: loopbackSourcedICMPv4, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: true, + }, + { + name: "IPv4 external loopback sourced traffic without forwarding and drop external loopback enabled", + allowExternalLoopback: false, + forwarding: false, + rxICMP: loopbackSourcedICMPv4, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: false, + }, + { + name: "IPv4 external loopback sourced traffic with forwarding and drop external loopback disabled", + allowExternalLoopback: true, + forwarding: true, + rxICMP: loopbackSourcedICMPv4, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: true, + }, + { + name: "IPv4 external loopback sourced traffic with forwarding and drop external loopback enabled", + allowExternalLoopback: false, + forwarding: true, + rxICMP: loopbackSourcedICMPv4, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: false, + }, + { + name: "IPv4 external loopback destined traffic without forwarding and drop external loopback disabled", + allowExternalLoopback: true, + forwarding: false, + rxICMP: loopbackDestinedICMPv4, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: false, + }, + { + name: "IPv4 external loopback destined traffic without forwarding and drop external loopback enabled", + allowExternalLoopback: false, + forwarding: false, + rxICMP: loopbackDestinedICMPv4, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: false, + }, + { + name: "IPv4 external loopback destined traffic with forwarding and drop external loopback disabled", + allowExternalLoopback: true, + forwarding: true, + rxICMP: loopbackDestinedICMPv4, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: true, + }, + { + name: "IPv4 external loopback destined traffic with forwarding and drop external loopback enabled", + allowExternalLoopback: false, + forwarding: true, + rxICMP: loopbackDestinedICMPv4, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: false, + }, + + { + name: "IPv6 external loopback sourced traffic without forwarding and drop external loopback disabled", + allowExternalLoopback: true, + forwarding: false, + rxICMP: loopbackSourcedICMPv6, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: true, + }, + { + name: "IPv6 external loopback sourced traffic without forwarding and drop external loopback enabled", + allowExternalLoopback: false, + forwarding: false, + rxICMP: loopbackSourcedICMPv6, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: false, + }, + { + name: "IPv6 external loopback sourced traffic with forwarding and drop external loopback disabled", + allowExternalLoopback: true, + forwarding: true, + rxICMP: loopbackSourcedICMPv6, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: true, + }, + { + name: "IPv6 external loopback sourced traffic with forwarding and drop external loopback enabled", + allowExternalLoopback: false, + forwarding: true, + rxICMP: loopbackSourcedICMPv6, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: false, + }, + { + name: "IPv6 external loopback destined traffic without forwarding and drop external loopback disabled", + allowExternalLoopback: true, + forwarding: false, + rxICMP: loopbackDestinedICMPv6, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: false, + }, + { + name: "IPv6 external loopback destined traffic without forwarding and drop external loopback enabled", + allowExternalLoopback: false, + forwarding: false, + rxICMP: loopbackDestinedICMPv6, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: false, + }, + { + name: "IPv6 external loopback destined traffic with forwarding and drop external loopback disabled", + allowExternalLoopback: true, + forwarding: true, + rxICMP: loopbackDestinedICMPv6, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: true, + }, + { + name: "IPv6 external loopback destined traffic with forwarding and drop external loopback enabled", + allowExternalLoopback: false, + forwarding: true, + rxICMP: loopbackDestinedICMPv6, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocolWithOptions(ipv4.Options{ + AllowExternalLoopbackTraffic: test.allowExternalLoopback, + }), + ipv6.NewProtocolWithOptions(ipv6.Options{ + AllowExternalLoopbackTraffic: test.allowExternalLoopback, + }), + }, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, + }) + e := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID1, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + if err := s.AddAddressWithPrefix(nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr, err) + } + if err := s.AddAddressWithPrefix(nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr, err) + } + + if err := s.CreateNIC(nicID2, loopback.New()); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, ipv4Loopback); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, ipv4Loopback, err) + } + if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, header.IPv6Loopback); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, header.IPv6Loopback, err) + } + + if test.forwarding { + if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err) + } + if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + } + } + + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: nicID1, + }, + tcpip.Route{ + Destination: header.IPv6EmptySubnet, + NIC: nicID1, + }, + tcpip.Route{ + Destination: ipv4Loopback.WithPrefix().Subnet(), + NIC: nicID2, + }, + tcpip.Route{ + Destination: header.IPv6Loopback.WithPrefix().Subnet(), + NIC: nicID2, + }, + }) + + stats := s.Stats().IP + invalidAddressStat := test.invalidAddressStat(stats) + deliveredPacketsStat := stats.PacketsDelivered + if got := invalidAddressStat.Value(); got != 0 { + t.Fatalf("got invalidAddressStat.Value() = %d, want = 0", got) + } + if got := deliveredPacketsStat.Value(); got != 0 { + t.Fatalf("got deliveredPacketsStat.Value() = %d, want = 0", got) + } + test.rxICMP(e) + var expectedInvalidPackets uint64 + if !test.shouldAccept { + expectedInvalidPackets = numPackets + } + if got := invalidAddressStat.Value(); got != expectedInvalidPackets { + t.Fatalf("got invalidAddressStat.Value() = %d, want = %d", got, expectedInvalidPackets) + } + if got, want := deliveredPacketsStat.Value(), numPackets-expectedInvalidPackets; got != want { + t.Fatalf("got deliveredPacketsStat.Value() = %d, want = %d", got, want) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 29266a4fc..77f4a88ec 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -45,82 +45,61 @@ const ( func TestPingMulticastBroadcast(t *testing.T) { const nicID = 1 - rxIPv4ICMP := func(e *channel.Endpoint, dst tcpip.Address) { - totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize - hdr := buffer.NewPrependable(totalLen) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - pkt.SetType(header.ICMPv4Echo) - pkt.SetCode(0) - pkt.SetChecksum(0) - pkt.SetChecksum(^header.Checksum(pkt, 0)) - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - Protocol: uint8(icmp.ProtocolNumber4), - TTL: ttl, - SrcAddr: utils.RemoteIPv4Addr, - DstAddr: dst, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - - rxIPv6ICMP := func(e *channel.Endpoint, dst tcpip.Address) { - totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize - hdr := buffer.NewPrependable(totalLen) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) - pkt.SetType(header.ICMPv6EchoRequest) - pkt.SetCode(0) - pkt.SetChecksum(0) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: utils.RemoteIPv6Addr, - Dst: dst, - })) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: ttl, - SrcAddr: utils.RemoteIPv6Addr, - DstAddr: dst, - }) - - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - tests := []struct { - name string - dstAddr tcpip.Address + name string + protoNum tcpip.NetworkProtocolNumber + rxICMP func(*channel.Endpoint, tcpip.Address, tcpip.Address) + srcAddr tcpip.Address + dstAddr tcpip.Address + expectedSrc tcpip.Address }{ { - name: "IPv4 unicast", - dstAddr: utils.Ipv4Addr.Address, + name: "IPv4 unicast", + protoNum: header.IPv4ProtocolNumber, + dstAddr: utils.Ipv4Addr.Address, + srcAddr: utils.RemoteIPv4Addr, + rxICMP: utils.RxICMPv4EchoRequest, + expectedSrc: utils.Ipv4Addr.Address, }, { - name: "IPv4 directed broadcast", - dstAddr: utils.Ipv4SubnetBcast, + name: "IPv4 directed broadcast", + protoNum: header.IPv4ProtocolNumber, + rxICMP: utils.RxICMPv4EchoRequest, + srcAddr: utils.RemoteIPv4Addr, + dstAddr: utils.Ipv4SubnetBcast, + expectedSrc: utils.Ipv4Addr.Address, }, { - name: "IPv4 broadcast", - dstAddr: header.IPv4Broadcast, + name: "IPv4 broadcast", + protoNum: header.IPv4ProtocolNumber, + rxICMP: utils.RxICMPv4EchoRequest, + srcAddr: utils.RemoteIPv4Addr, + dstAddr: header.IPv4Broadcast, + expectedSrc: utils.Ipv4Addr.Address, }, { - name: "IPv4 all-systems multicast", - dstAddr: header.IPv4AllSystems, + name: "IPv4 all-systems multicast", + protoNum: header.IPv4ProtocolNumber, + rxICMP: utils.RxICMPv4EchoRequest, + srcAddr: utils.RemoteIPv4Addr, + dstAddr: header.IPv4AllSystems, + expectedSrc: utils.Ipv4Addr.Address, }, { - name: "IPv6 unicast", - dstAddr: utils.Ipv6Addr.Address, + name: "IPv6 unicast", + protoNum: header.IPv6ProtocolNumber, + rxICMP: utils.RxICMPv6EchoRequest, + srcAddr: utils.RemoteIPv6Addr, + dstAddr: utils.Ipv6Addr.Address, + expectedSrc: utils.Ipv6Addr.Address, }, { - name: "IPv6 all-nodes multicast", - dstAddr: header.IPv6AllNodesMulticastAddress, + name: "IPv6 all-nodes multicast", + protoNum: header.IPv6ProtocolNumber, + rxICMP: utils.RxICMPv6EchoRequest, + srcAddr: utils.RemoteIPv6Addr, + dstAddr: header.IPv6AllNodesMulticastAddress, + expectedSrc: utils.Ipv6Addr.Address, }, } @@ -157,44 +136,29 @@ func TestPingMulticastBroadcast(t *testing.T) { }, }) - var rxICMP func(*channel.Endpoint, tcpip.Address) - var expectedSrc tcpip.Address - var expectedDst tcpip.Address - var protoNum tcpip.NetworkProtocolNumber - switch l := len(test.dstAddr); l { - case header.IPv4AddressSize: - rxICMP = rxIPv4ICMP - expectedSrc = utils.Ipv4Addr.Address - expectedDst = utils.RemoteIPv4Addr - protoNum = header.IPv4ProtocolNumber - case header.IPv6AddressSize: - rxICMP = rxIPv6ICMP - expectedSrc = utils.Ipv6Addr.Address - expectedDst = utils.RemoteIPv6Addr - protoNum = header.IPv6ProtocolNumber - default: - t.Fatalf("got unexpected address length = %d bytes", l) - } - - rxICMP(e, test.dstAddr) + test.rxICMP(e, test.srcAddr, test.dstAddr) pkt, ok := e.Read() if !ok { t.Fatal("expected ICMP response") } - if pkt.Route.LocalAddress != expectedSrc { - t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, expectedSrc) + if pkt.Route.LocalAddress != test.expectedSrc { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.expectedSrc) } - if pkt.Route.RemoteAddress != expectedDst { - t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst) + // The destination of the response packet should be the source of the + // original packet. + if pkt.Route.RemoteAddress != test.srcAddr { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.srcAddr) } - src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader())) - if src != expectedSrc { - t.Errorf("got pkt source = %s, want = %s", src, expectedSrc) + src, dst := s.NetworkProtocolInstance(test.protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader())) + if src != test.expectedSrc { + t.Errorf("got pkt source = %s, want = %s", src, test.expectedSrc) } - if dst != expectedDst { - t.Errorf("got pkt destination = %s, want = %s", dst, expectedDst) + // The destination of the response packet should be the source of the + // original packet. + if dst != test.srcAddr { + t.Errorf("got pkt destination = %s, want = %s", dst, test.srcAddr) } }) } diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go index 4455f6dd7..ed499179f 100644 --- a/pkg/tcpip/tests/integration/route_test.go +++ b/pkg/tcpip/tests/integration/route_test.go @@ -16,6 +16,7 @@ package route_test import ( "bytes" + "fmt" "testing" "github.com/google/go-cmp/cmp" @@ -161,78 +162,79 @@ func TestLocalPing(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, - HandleLocal: true, - }) - e := test.linkEndpoint() - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } + for _, allowExternalLoopback := range []bool{true, false} { + t.Run(fmt.Sprintf("AllowExternalLoopback=%t", allowExternalLoopback), func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocolWithOptions(ipv4.Options{ + AllowExternalLoopbackTraffic: allowExternalLoopback, + }), + ipv6.NewProtocolWithOptions(ipv6.Options{ + AllowExternalLoopbackTraffic: allowExternalLoopback, + }), + }, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, + HandleLocal: true, + }) + e := test.linkEndpoint() + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } - if len(test.localAddr) != 0 { - if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err) - } - } + if len(test.localAddr) != 0 { + if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err) + } + } - var wq waiter.Queue - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err) - } - defer ep.Close() - - connAddr := tcpip.FullAddress{Addr: test.localAddr} - { - err := ep.Connect(connAddr) - if diff := cmp.Diff(test.expectedConnectErr, err); diff != "" { - t.Fatalf("unexpected error from ep.Connect(%#v), (-want, +got):\n%s", connAddr, diff) - } - } + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err) + } + defer ep.Close() - if test.expectedConnectErr != nil { - return - } + connAddr := tcpip.FullAddress{Addr: test.localAddr} + if err := ep.Connect(connAddr); err != test.expectedConnectErr { + t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr) + } - payload := test.icmpBuf(t) - var r bytes.Reader - r.Reset(payload) - var wOpts tcpip.WriteOptions - if n, err := ep.Write(&r, wOpts); err != nil { - t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err) - } else if n != int64(len(payload)) { - t.Fatalf("got ep.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", payload, wOpts, n, len(payload)) - } + if test.expectedConnectErr != nil { + return + } - // Wait for the endpoint to become readable. - <-ch + var r bytes.Reader + payload := test.icmpBuf(t) + r.Reset(payload) + var wOpts tcpip.WriteOptions + if n, err := ep.Write(&r, wOpts); err != nil { + t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err) + } else if n != int64(len(payload)) { + t.Fatalf("got ep.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", payload, wOpts, n, len(payload)) + } - var buf bytes.Buffer - opts := tcpip.ReadOptions{NeedRemoteAddr: true} - res, err := ep.Read(&buf, opts) - if err != nil { - t.Fatalf("ep.Read(_, %#v): %s", opts, err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - RemoteAddr: tcpip.FullAddress{Addr: test.localAddr}, - }, res, checker.IgnoreCmpPath( - "ControlMessages", - "RemoteAddr.NIC", - "RemoteAddr.Port", - )); diff != "" { - t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(buf.Bytes()[icmpDataOffset:], []byte(payload[icmpDataOffset:])); diff != "" { - t.Errorf("received data mismatch (-want +got):\n%s", diff) - } + // Wait for the endpoint to become readable. + <-ch - test.checkLinkEndpoint(t, e) + var w bytes.Buffer + rr, err := ep.Read(&w, tcpip.ReadOptions{ + NeedRemoteAddr: true, + }) + if err != nil { + t.Fatalf("ep.Read(...): %s", err) + } + if diff := cmp.Diff(buffer.View(w.Bytes()[icmpDataOffset:]), payload[icmpDataOffset:]); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) + } + if rr.RemoteAddr.Addr != test.localAddr { + t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr) + } + + test.checkLinkEndpoint(t, e) + }) + } }) } } diff --git a/pkg/tcpip/tests/utils/BUILD b/pkg/tcpip/tests/utils/BUILD index 433004148..a9699a367 100644 --- a/pkg/tcpip/tests/utils/BUILD +++ b/pkg/tcpip/tests/utils/BUILD @@ -8,12 +8,15 @@ go_library( visibility = ["//pkg/tcpip/tests:__subpackages__"], deps = [ "//pkg/tcpip", + "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", "//pkg/tcpip/link/ethernet", "//pkg/tcpip/link/nested", "//pkg/tcpip/link/pipe", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/icmp", ], ) diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go index f414a2234..d1c9f3a94 100644 --- a/pkg/tcpip/tests/utils/utils.go +++ b/pkg/tcpip/tests/utils/utils.go @@ -20,13 +20,16 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/ethernet" "gvisor.dev/gvisor/pkg/tcpip/link/nested" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" ) // Common NIC IDs used by tests. @@ -45,6 +48,10 @@ const ( LinkAddr4 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") ) +const ( + ttl = 255 +) + // Common IP addresses used by tests. var ( Ipv4Addr = tcpip.AddressWithPrefix{ @@ -312,3 +319,56 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. }, }) } + +// RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on +// the provided endpoint. +func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { + totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + pkt.SetType(header.ICMPv4Echo) + pkt.SetCode(header.ICMPv4UnusedCode) + pkt.SetChecksum(0) + pkt.SetChecksum(^header.Checksum(pkt, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLen), + Protocol: uint8(icmp.ProtocolNumber4), + TTL: ttl, + SrcAddr: src, + DstAddr: dst, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) +} + +// RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on +// the provided endpoint. +func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { + totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + pkt.SetType(header.ICMPv6EchoRequest) + pkt.SetCode(header.ICMPv6UnusedCode) + pkt.SetChecksum(0) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: src, + Dst: dst, + })) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: ttl, + SrcAddr: src, + DstAddr: dst, + }) + + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) +} diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 3b574837c..0a2f3291c 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -20,6 +20,7 @@ import ( "fmt" "hash" "io" + "sync/atomic" "time" "gvisor.dev/gvisor/pkg/rand" @@ -390,7 +391,7 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) { // deliverAccepted delivers the newly-accepted endpoint to the listener. If the // endpoint has transitioned out of the listen state (acceptedChan is nil), // the new endpoint is closed instead. -func (e *endpoint) deliverAccepted(n *endpoint) { +func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) { e.mu.Lock() e.pendingAccepted.Add(1) e.mu.Unlock() @@ -405,6 +406,9 @@ func (e *endpoint) deliverAccepted(n *endpoint) { } select { case e.acceptedChan <- n: + if !withSynCookie { + atomic.AddInt32(&e.synRcvdCount, -1) + } e.acceptMu.Unlock() e.waiterQueue.Notify(waiter.EventIn) return @@ -476,7 +480,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() - e.synRcvdCount-- + atomic.AddInt32(&e.synRcvdCount, -1) return err } @@ -486,18 +490,13 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() ctx.cleanupFailedHandshake(h) - e.mu.Lock() - e.synRcvdCount-- - e.mu.Unlock() + atomic.AddInt32(&e.synRcvdCount, -1) return } ctx.cleanupCompletedHandshake(h) - e.mu.Lock() - e.synRcvdCount-- - e.mu.Unlock() h.ep.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - e.deliverAccepted(h.ep) + e.deliverAccepted(h.ep, false /*withSynCookie*/) }() // S/R-SAFE: synRcvdCount is the barrier. return nil @@ -505,17 +504,17 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header func (e *endpoint) incSynRcvdCount() bool { e.acceptMu.Lock() - canInc := e.synRcvdCount < cap(e.acceptedChan) + canInc := int(atomic.LoadInt32(&e.synRcvdCount)) < cap(e.acceptedChan) e.acceptMu.Unlock() if canInc { - e.synRcvdCount++ + atomic.AddInt32(&e.synRcvdCount, 1) } return canInc } func (e *endpoint) acceptQueueIsFull() bool { e.acceptMu.Lock() - full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan) + full := len(e.acceptedChan)+int(atomic.LoadInt32(&e.synRcvdCount)) >= cap(e.acceptedChan) e.acceptMu.Unlock() return full } @@ -737,7 +736,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err // Start the protocol goroutine. n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - go e.deliverAccepted(n) + go e.deliverAccepted(n, true /*withSynCookie*/) return nil default: diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 129f36d11..43d344350 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -532,8 +532,8 @@ type endpoint struct { segmentQueue segmentQueue `state:"wait"` // synRcvdCount is the number of connections for this endpoint that are - // in SYN-RCVD state. - synRcvdCount int + // in SYN-RCVD state; this is only accessed atomically. + synRcvdCount int32 // userMSS if non-zero is the MSS value explicitly set by the user // for this endpoint using the TCP_MAXSEG setsockopt. diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index a5c82b8fa..bc6793fc6 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -260,7 +260,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum case StateEstablished: r.ep.setEndpointState(StateCloseWait) case StateFinWait1: - if s.flagIsSet(header.TCPFlagAck) { + if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { // FIN-ACK, transition to TIME-WAIT. r.ep.setEndpointState(StateTimeWait) } else { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index fd499a47b..6c86ae1ae 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -165,7 +165,7 @@ func TestConnectIncrementActiveConnection(t *testing.T) { stats := c.Stack().Stats() want := stats.TCP.ActiveConnectionOpenings.Value() + 1 - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want { t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %d, want = %d", got, want) } @@ -178,7 +178,7 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) { stats := c.Stack().Stats() want := stats.TCP.FailedConnectionAttempts.Value() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want) } @@ -239,7 +239,7 @@ func TestTCPSegmentsSentIncrement(t *testing.T) { stats := c.Stack().Stats() // SYN and ACK want := stats.TCP.SegmentsSent.Value() + 2 - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) if got := stats.TCP.SegmentsSent.Value(); got != want { t.Errorf("got stats.TCP.SegmentsSent.Value() = %d, want = %d", got, want) @@ -269,7 +269,7 @@ func TestTCPResetsSentIncrement(t *testing.T) { } // Send a SYN request. - iss := seqnum.Value(789) + iss := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -318,7 +318,7 @@ func TestTCPResetsSentNoICMP(t *testing.T) { // Send a SYN request for a closed port. This should elicit an RST // but NOT an ICMPv4 DstUnreachable packet. - iss := seqnum.Value(789) + iss := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -362,7 +362,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { } // Send a SYN request. - iss := seqnum.Value(789) + iss := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -459,7 +459,7 @@ func TestTCPResetsReceivedIncrement(t *testing.T) { stats := c.Stack().Stats() want := stats.TCP.ResetsReceived.Value() + 1 - iss := seqnum.Value(789) + iss := seqnum.Value(context.TestInitialSequenceNumber) rcvWnd := seqnum.Size(30000) c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) @@ -483,7 +483,7 @@ func TestTCPResetsDoNotGenerateResets(t *testing.T) { stats := c.Stack().Stats() want := stats.TCP.ResetsReceived.Value() + 1 - iss := seqnum.Value(789) + iss := seqnum.Value(context.TestInitialSequenceNumber) rcvWnd := seqnum.Size(30000) c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) @@ -506,14 +506,14 @@ func TestActiveHandshake(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) } func TestNonBlockingClose(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil @@ -537,18 +537,19 @@ func TestConnectResetAfterClose(t *testing.T) { t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil // Close the endpoint, make sure we get a FIN segment, then acknowledge // to complete closure of sender, but don't send our own FIN. ep.Close() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -556,7 +557,7 @@ func TestConnectResetAfterClose(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(2), RcvWnd: 30000, }) @@ -570,7 +571,7 @@ func TestConnectResetAfterClose(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(2), RcvWnd: 30000, }) @@ -612,7 +613,7 @@ func TestCurrentConnectedIncrement(t *testing.T) { t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil @@ -625,12 +626,12 @@ func TestCurrentConnectedIncrement(t *testing.T) { } ep.Close() - + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -638,7 +639,7 @@ func TestCurrentConnectedIncrement(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(2), RcvWnd: 30000, }) @@ -655,7 +656,7 @@ func TestCurrentConnectedIncrement(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(2), RcvWnd: 30000, }) @@ -666,7 +667,7 @@ func TestCurrentConnectedIncrement(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(791), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -690,7 +691,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil @@ -699,11 +700,12 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { } // Send a FIN for ESTABLISHED --> CLOSED-WAIT + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagFin | header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -713,7 +715,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(791), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -734,7 +736,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(791), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -753,7 +755,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 791, + SeqNum: iss.Add(1), AckNum: c.IRS.Add(2), RcvWnd: 30000, }) @@ -764,7 +766,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 792, + SeqNum: iss.Add(2), AckNum: c.IRS.Add(2), RcvWnd: 30000, }) @@ -804,7 +806,7 @@ func TestSimpleReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -813,11 +815,12 @@ func TestSimpleReceive(t *testing.T) { ept := endpointTester{c.EP} data := []byte{1, 2, 3} + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -840,7 +843,7 @@ func TestSimpleReceive(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+len(data))), + checker.TCPAckNum(uint32(iss)+uint32(len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1366,12 +1369,13 @@ func TestTOSV4(t *testing.T) { // Check that data is received. b := c.GetPacket() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, b, checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), // Acknum is initial sequence number + 1 + checker.TCPAckNum(uint32(iss)), // Acknum is initial sequence number + 1 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), checker.TOS(tos, 0), @@ -1414,12 +1418,13 @@ func TestTrafficClassV6(t *testing.T) { // Check that data is received. b := c.GetV6Packet() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv6(t, b, checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), checker.TOS(tos, 0), @@ -1472,7 +1477,7 @@ func TestConnectBindToDevice(t *testing.T) { tcpHdr := header.TCP(header.IPv4(b).Payload()) c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - iss := seqnum.Value(789) + iss := seqnum.Value(context.TestInitialSequenceNumber) rcvWnd := seqnum.Size(30000) c.SendPacket(nil, &context.Headers{ SrcPort: tcpHdr.DestinationPort(), @@ -1537,7 +1542,7 @@ func TestSynSent(t *testing.T) { if test.reset { // Send a packet with a proper ACK and a RST flag to cause the socket // to error and close out. - iss := seqnum.Value(789) + iss := seqnum.Value(context.TestInitialSequenceNumber) rcvWnd := seqnum.Size(30000) c.SendPacket(nil, &context.Headers{ SrcPort: tcpHdr.DestinationPort(), @@ -1582,7 +1587,7 @@ func TestOutOfOrderReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -1593,11 +1598,12 @@ func TestOutOfOrderReceive(t *testing.T) { // Send second half of data first, with seqnum 3 ahead of expected. data := []byte{1, 2, 3, 4, 5, 6} + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) c.SendPacket(data[3:], &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 793, + SeqNum: iss.Add(3), AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -1607,7 +1613,7 @@ func TestOutOfOrderReceive(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1621,7 +1627,7 @@ func TestOutOfOrderReceive(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -1639,7 +1645,7 @@ func TestOutOfOrderReceive(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+len(data))), + checker.TCPAckNum(uint32(iss)+uint32(len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1650,19 +1656,20 @@ func TestOutOfOrderFlood(t *testing.T) { defer c.Cleanup() rcvBufSz := math.MaxUint16 - c.CreateConnected(789, 30000, rcvBufSz) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBufSz) ept := endpointTester{c.EP} ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Send 100 packets before the actual one that is expected. data := []byte{1, 2, 3, 4, 5, 6} + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) for i := 0; i < 100; i++ { c.SendPacket(data[3:], &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 796, + SeqNum: iss.Add(6), AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -1671,19 +1678,19 @@ func TestOutOfOrderFlood(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), ), ) } - // Send packet with seqnum 793. It must be discarded because the + // Send packet with seqnum as initial + 3. It must be discarded because the // out-of-order buffer was filled by the previous packets. c.SendPacket(data[3:], &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 793, + SeqNum: iss.Add(3), AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -1692,27 +1699,27 @@ func TestOutOfOrderFlood(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), ), ) - // Now send the expected packet, seqnum 790. + // Now send the expected packet with initial sequence number. c.SendPacket(data[:3], &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - // Check that only packet 790 is acknowledged. + // Check that only packet with initial sequence number is acknowledged. checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(793), + checker.TCPAckNum(uint32(iss)+3), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1722,7 +1729,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -1732,11 +1739,12 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) data := []byte{1, 2, 3} + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -1753,7 +1761,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+len(data))), + checker.TCPAckNum(uint32(iss)+uint32(len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1780,7 +1788,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + len(data)), + SeqNum: iss.Add(seqnum.Size(len(data))), AckNum: c.IRS.Add(seqnum.Size(2)), RcvWnd: 30000, }) @@ -1790,7 +1798,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -1800,11 +1808,12 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) data := []byte{1, 2, 3} + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -1821,7 +1830,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+len(data))), + checker.TCPAckNum(uint32(iss)+uint32(len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1866,7 +1875,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + len(data)), + SeqNum: iss.Add(seqnum.Size(len(data))), AckNum: c.IRS.Add(seqnum.Size(2)), RcvWnd: 30000, }) @@ -1876,7 +1885,7 @@ func TestShutdownRead(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) ept := endpointTester{c.EP} ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) @@ -1897,7 +1906,7 @@ func TestFullWindowReceive(t *testing.T) { defer c.Cleanup() const rcvBufSz = 10 - c.CreateConnected(789, 30000, rcvBufSz) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBufSz) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -1913,11 +1922,12 @@ func TestFullWindowReceive(t *testing.T) { for i := range data { data[i] = byte(i % 255) } + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -1934,7 +1944,7 @@ func TestFullWindowReceive(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+len(data))), + checker.TCPAckNum(uint32(iss)+uint32(len(data))), checker.TCPFlags(header.TCPFlagAck), checker.TCPWindow(0), ), @@ -1956,7 +1966,7 @@ func TestFullWindowReceive(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+len(data))), + checker.TCPAckNum(uint32(iss)+uint32(len(data))), checker.TCPFlags(header.TCPFlagAck), checker.TCPWindow(10), ), @@ -1996,8 +2006,7 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { } payload := generateRandomPayload(t, payloadSize) payloadLen := seqnum.Size(len(payload)) - iss := seqnum.Value(789) - seqNum := iss.Add(1) + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) // Send payload to the endpoint and return the advertised receive window // from the endpoint. @@ -2005,12 +2014,12 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { c.SendPacket(payload, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, - SeqNum: seqNum, + SeqNum: iss, AckNum: c.IRS.Add(1), Flags: header.TCPFlagAck, RcvWnd: 30000, }) - seqNum = seqNum.Add(payloadLen) + iss = iss.Add(payloadLen) pkt := c.GetPacket() return uint32(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.RcvdWindowScale @@ -2054,9 +2063,8 @@ func TestNoWindowShrinking(t *testing.T) { // the right edge of the window does not shrink. // NOTE: Netstack doubles the value specified here. rcvBufSize := 65536 - iss := seqnum.Value(789) // Enable window scaling with a scale of zero from our end. - c.CreateConnectedWithRawOptions(iss, 30000, rcvBufSize, []byte{ + c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, rcvBufSize, []byte{ header.TCPOptionWS, 3, 0, header.TCPOptionNOP, }) @@ -2069,13 +2077,13 @@ func TestNoWindowShrinking(t *testing.T) { // Send a 1 byte payload so that we can record the current receive window. // Send a payload of half the size of rcvBufSize. - seqNum := iss.Add(1) + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) payload := []byte{1} c.SendPacket(payload, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqNum, + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -2092,20 +2100,20 @@ func TestNoWindowShrinking(t *testing.T) { t.Fatalf("got data: %v, want: %v", got, want) } - seqNum = seqNum.Add(1) // Verify that the ACK does not shrink the window. pkt := c.GetPacket() + iss = iss.Add(1) checker.IPv4(t, pkt, checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(seqNum)), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), ), ) // Stash the initial window. initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale - initialLastAcceptableSeq := seqNum.Add(seqnum.Size(initialWnd)) + initialLastAcceptableSeq := iss.Add(seqnum.Size(initialWnd)) // Now shrink the receive buffer to half its original size. if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufSize/2); err != nil { t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err) @@ -2117,11 +2125,11 @@ func TestNoWindowShrinking(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqNum, + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2)) + iss = iss.Add(seqnum.Size(rcvBufSize / 2)) // Verify that the ACK does not shrink the window. pkt = c.GetPacket() @@ -2129,12 +2137,12 @@ func TestNoWindowShrinking(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(seqNum)), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), ), ) newWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale - newLastAcceptableSeq := seqNum.Add(seqnum.Size(newWnd)) + newLastAcceptableSeq := iss.Add(seqnum.Size(newWnd)) if newLastAcceptableSeq.LessThan(initialLastAcceptableSeq) { t.Fatalf("receive window shrunk unexpectedly got: %d, want >= %d", newLastAcceptableSeq, initialLastAcceptableSeq) } @@ -2145,17 +2153,17 @@ func TestNoWindowShrinking(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqNum, + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2)) + iss = iss.Add(seqnum.Size(rcvBufSize / 2)) checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(seqNum)), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), checker.TCPWindow(0), ), @@ -2173,7 +2181,7 @@ func TestNoWindowShrinking(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(seqNum)), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), checker.TCPWindow(uint16(rcvBufSize/2)>>c.RcvdWindowScale), ), @@ -2184,7 +2192,7 @@ func TestSimpleSend(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) data := []byte{1, 2, 3} var r bytes.Reader @@ -2195,12 +2203,13 @@ func TestSimpleSend(t *testing.T) { // Check that data is received. b := c.GetPacket() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, b, checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2214,7 +2223,7 @@ func TestSimpleSend(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), RcvWnd: 30000, }) @@ -2224,7 +2233,7 @@ func TestZeroWindowSend(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789 /* iss */, 0 /* rcvWnd */, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 0 /* rcvWnd */, -1 /* epRcvBuf */) data := []byte{1, 2, 3} var r bytes.Reader @@ -2235,12 +2244,13 @@ func TestZeroWindowSend(t *testing.T) { // Check if we got a zero-window probe. b := c.GetPacket() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, b, checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2250,7 +2260,7 @@ func TestZeroWindowSend(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -2262,7 +2272,7 @@ func TestZeroWindowSend(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2276,7 +2286,7 @@ func TestZeroWindowSend(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), RcvWnd: 30000, }) @@ -2289,7 +2299,7 @@ func TestScaledWindowConnect(t *testing.T) { defer c.Cleanup() // Set the window size greater than the maximum non-scaled window. - c.CreateConnectedWithRawOptions(789, 30000, 65535*3, []byte{ + c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, 65535*3, []byte{ header.TCPOptionWS, 3, 0, header.TCPOptionNOP, }) @@ -2303,12 +2313,13 @@ func TestScaledWindowConnect(t *testing.T) { // Check that data is received, and that advertised window is 0x5fff, // that is, that it is scaled. b := c.GetPacket() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, b, checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPWindow(0x5fff), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), @@ -2322,7 +2333,7 @@ func TestNonScaledWindowConnect(t *testing.T) { defer c.Cleanup() // Set the window size greater than the maximum non-scaled window. - c.CreateConnected(789, 30000, 65535*3) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, 65535*3) data := []byte{1, 2, 3} var r bytes.Reader @@ -2334,12 +2345,13 @@ func TestNonScaledWindowConnect(t *testing.T) { // Check that data is received, and that advertised window is 0xffff, // that is, that it's not scaled. b := c.GetPacket() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, b, checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPWindow(0xffff), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), @@ -2407,12 +2419,13 @@ func TestScaledWindowAccept(t *testing.T) { // Check that data is received, and that advertised window is 0x5fff, // that is, that it is scaled. b := c.GetPacket() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, b, checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPWindow(0x5fff), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), @@ -2480,12 +2493,13 @@ func TestNonScaledWindowAccept(t *testing.T) { // Check that data is received, and that advertised window is 0xffff, // that is, that it's not scaled. b := c.GetPacket() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, b, checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPWindow(0xffff), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), @@ -2502,7 +2516,7 @@ func TestZeroScaledWindowReceive(t *testing.T) { // Set the buffer size such that a window scale of 5 will be used. const bufSz = 65535 * 10 const ws = uint32(5) - c.CreateConnectedWithRawOptions(789, 30000, bufSz, []byte{ + c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, bufSz, []byte{ header.TCPOptionWS, 3, 0, header.TCPOptionNOP, }) @@ -2510,13 +2524,14 @@ func TestZeroScaledWindowReceive(t *testing.T) { remain := 0 sent := 0 data := make([]byte, 50000) + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) // Keep writing till the window drops below len(data). for { c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + sent), + SeqNum: iss.Add(seqnum.Size(sent)), AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -2527,7 +2542,7 @@ func TestZeroScaledWindowReceive(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+sent)), + checker.TCPAckNum(uint32(iss)+uint32(sent)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -2545,7 +2560,7 @@ func TestZeroScaledWindowReceive(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + sent), + SeqNum: iss.Add(seqnum.Size(sent)), AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -2556,7 +2571,7 @@ func TestZeroScaledWindowReceive(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+sent)), + checker.TCPAckNum(uint32(iss)+uint32(sent)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -2594,7 +2609,7 @@ func TestZeroScaledWindowReceive(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+sent)), + checker.TCPAckNum(uint32(iss)+uint32(sent)), checker.TCPWindowGreaterThanEq(uint16(defaultMTU>>ws)), checker.TCPFlags(header.TCPFlagAck), ), @@ -2632,7 +2647,7 @@ func TestSegmentMerging(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) // Send tcp.InitialCwnd number of segments to fill up // InitialWindow but don't ACK. That should prevent @@ -2657,6 +2672,7 @@ func TestSegmentMerging(t *testing.T) { } // Check that we get tcp.InitialCwnd packets. + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) for i := 0; i < tcp.InitialCwnd; i++ { b := c.GetPacket() checker.IPv4(t, b, @@ -2664,7 +2680,7 @@ func TestSegmentMerging(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2675,7 +2691,7 @@ func TestSegmentMerging(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1 + 10), // 10 for the 10 bytes of payload. RcvWnd: 30000, }) @@ -2687,7 +2703,7 @@ func TestSegmentMerging(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+11), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2701,7 +2717,7 @@ func TestSegmentMerging(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(11 + seqnum.Size(len(allData))), RcvWnd: 30000, }) @@ -2713,7 +2729,7 @@ func TestDelay(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) c.EP.SocketOptions().SetDelayOption(true) @@ -2728,6 +2744,7 @@ func TestDelay(t *testing.T) { } seq := c.IRS.Add(1) + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) for _, want := range [][]byte{allData[:1], allData[1:]} { // Check that data is received. b := c.GetPacket() @@ -2736,7 +2753,7 @@ func TestDelay(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(seq)), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2751,7 +2768,7 @@ func TestDelay(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: seq, RcvWnd: 30000, }) @@ -2762,7 +2779,7 @@ func TestUndelay(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) c.EP.SocketOptions().SetDelayOption(true) @@ -2776,7 +2793,7 @@ func TestUndelay(t *testing.T) { } seq := c.IRS.Add(1) - + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) // Check that data is received. first := c.GetPacket() checker.IPv4(t, first, @@ -2784,7 +2801,7 @@ func TestUndelay(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(seq)), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2807,7 +2824,7 @@ func TestUndelay(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(seq)), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2823,7 +2840,7 @@ func TestUndelay(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: seq, RcvWnd: 30000, }) @@ -2845,7 +2862,7 @@ func TestMSSNotDelayed(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ + c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), }) @@ -2861,7 +2878,7 @@ func TestMSSNotDelayed(t *testing.T) { } seq := c.IRS.Add(1) - + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) for i, data := range allData { // Check that data is received. packet := c.GetPacket() @@ -2870,7 +2887,7 @@ func TestMSSNotDelayed(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(seq)), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2887,7 +2904,7 @@ func TestMSSNotDelayed(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: seq, RcvWnd: 30000, }) @@ -2912,6 +2929,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { // Check that data is received in chunks. bytesReceived := 0 numPackets := 0 + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) for bytesReceived != dataLen { b := c.GetPacket() numPackets++ @@ -2921,7 +2939,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1+uint32(bytesReceived)), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2945,7 +2963,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)), RcvWnd: 30000, TCPOpts: options, @@ -2961,7 +2979,7 @@ func TestSendGreaterThanMTU(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) testBrokenUpWrite(t, c, maxPayload) } @@ -3001,7 +3019,7 @@ func TestActiveSendMSSLessThanMTU(t *testing.T) { c := context.New(t, 65535) defer c.Cleanup() - c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ + c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), }) testBrokenUpWrite(t, c, maxPayload) @@ -3210,7 +3228,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { ) // Send SYN-ACK. - iss := seqnum.Value(789) + iss := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: tcpHdr.DestinationPort(), DstPort: tcpHdr.SourcePort(), @@ -3272,14 +3290,15 @@ func TestReceiveOnResetConnection(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) // Send RST segment. c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagRst, - SeqNum: 790, + SeqNum: iss, RcvWnd: 30000, }) @@ -3328,14 +3347,15 @@ func TestSendOnResetConnection(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) // Send RST segment. + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagRst, - SeqNum: 790, + SeqNum: iss, RcvWnd: 30000, }) @@ -3363,7 +3383,7 @@ func TestMaxRetransmitsTimeout(t *testing.T) { t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } - c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) waitEntry, notifyCh := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&waitEntry, waiter.EventHUp) @@ -3426,7 +3446,7 @@ func TestMaxRTO(t *testing.T) { t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } - c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) var r bytes.Reader r.Reset(make([]byte, 1)) @@ -3469,7 +3489,7 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) // Disabling PMTU discovery causes all packets sent from this socket to // have DF=0. This needs to be done because the IPv4 ID uniqueness @@ -3518,19 +3538,20 @@ func TestFinImmediately(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) // Shutdown immediately, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { t.Fatalf("Shutdown failed: %s", err) } + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3540,7 +3561,7 @@ func TestFinImmediately(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(2), RcvWnd: 30000, }) @@ -3551,7 +3572,7 @@ func TestFinImmediately(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(791), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3561,19 +3582,20 @@ func TestFinRetransmit(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) // Shutdown immediately, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { t.Fatalf("Shutdown failed: %s", err) } + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3584,7 +3606,7 @@ func TestFinRetransmit(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3594,7 +3616,7 @@ func TestFinRetransmit(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(2), RcvWnd: 30000, }) @@ -3605,7 +3627,7 @@ func TestFinRetransmit(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(791), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3615,7 +3637,7 @@ func TestFinWithNoPendingData(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) // Write something out, and have it acknowledged. view := make([]byte, 10) @@ -3626,12 +3648,13 @@ func TestFinWithNoPendingData(t *testing.T) { } next := uint32(c.IRS) + 1 + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, c.GetPacket(), checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3641,7 +3664,7 @@ func TestFinWithNoPendingData(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: seqnum.Value(next), RcvWnd: 30000, }) @@ -3656,7 +3679,7 @@ func TestFinWithNoPendingData(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3667,7 +3690,7 @@ func TestFinWithNoPendingData(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, + SeqNum: iss, AckNum: seqnum.Value(next), RcvWnd: 30000, }) @@ -3678,7 +3701,7 @@ func TestFinWithNoPendingData(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(791), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3688,7 +3711,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) // Write enough segments to fill the congestion window before ACK'ing // any of them. @@ -3702,13 +3725,14 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { } next := uint32(c.IRS) + 1 + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) for i := tcp.InitialCwnd; i > 0; i-- { checker.IPv4(t, c.GetPacket(), checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3727,7 +3751,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3737,7 +3761,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: seqnum.Value(next), RcvWnd: 30000, }) @@ -3747,7 +3771,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3758,7 +3782,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, + SeqNum: iss, AckNum: seqnum.Value(next), RcvWnd: 30000, }) @@ -3768,7 +3792,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(791), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3778,7 +3802,7 @@ func TestFinWithPendingData(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) // Write something out, and acknowledge it to get cwnd to 2. view := make([]byte, 10) @@ -3789,12 +3813,13 @@ func TestFinWithPendingData(t *testing.T) { } next := uint32(c.IRS) + 1 + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, c.GetPacket(), checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3804,7 +3829,7 @@ func TestFinWithPendingData(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: seqnum.Value(next), RcvWnd: 30000, }) @@ -3820,7 +3845,7 @@ func TestFinWithPendingData(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3836,7 +3861,7 @@ func TestFinWithPendingData(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3847,7 +3872,7 @@ func TestFinWithPendingData(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, + SeqNum: iss, AckNum: seqnum.Value(next), RcvWnd: 30000, }) @@ -3857,7 +3882,7 @@ func TestFinWithPendingData(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(791), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3867,7 +3892,7 @@ func TestFinWithPartialAck(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) // Write something out, and acknowledge it to get cwnd to 2. Also send // FIN from the test side. @@ -3879,12 +3904,13 @@ func TestFinWithPartialAck(t *testing.T) { } next := uint32(c.IRS) + 1 + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, c.GetPacket(), checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3894,7 +3920,7 @@ func TestFinWithPartialAck(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, + SeqNum: iss, AckNum: seqnum.Value(next), RcvWnd: 30000, }) @@ -3905,7 +3931,7 @@ func TestFinWithPartialAck(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(791), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3921,7 +3947,7 @@ func TestFinWithPartialAck(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(791), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3937,7 +3963,7 @@ func TestFinWithPartialAck(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(791), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3948,7 +3974,7 @@ func TestFinWithPartialAck(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 791, + SeqNum: iss.Add(1), AckNum: seqnum.Value(next - 1), RcvWnd: 30000, }) @@ -3961,7 +3987,7 @@ func TestFinWithPartialAck(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 791, + SeqNum: iss.Add(1), AckNum: seqnum.Value(next), RcvWnd: 30000, }) @@ -4002,17 +4028,18 @@ func scaledSendWindow(t *testing.T, scale uint8) { defer c.Cleanup() maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize - c.CreateConnectedWithRawOptions(789, 0, -1 /* epRcvBuf */, []byte{ + c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 0, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), header.TCPOptionWS, 3, scale, header.TCPOptionNOP, }) // Open up the window with a scaled value. + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 1, }) @@ -4031,7 +4058,7 @@ func scaledSendWindow(t *testing.T, scale uint8) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -4041,7 +4068,7 @@ func scaledSendWindow(t *testing.T, scale uint8) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagRst, - SeqNum: 790, + SeqNum: iss, }) } @@ -4054,15 +4081,16 @@ func TestScaledSendWindow(t *testing.T) { func TestReceivedValidSegmentCountIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) stats := c.Stack().Stats() want := stats.TCP.ValidSegmentsReceived.Value() + 1 + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790), + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -4083,14 +4111,15 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) { func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) stats := c.Stack().Stats() want := stats.TCP.InvalidSegmentsReceived.Value() + 1 + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) vv := c.BuildSegment(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790), + SeqNum: seqnum.Value(iss), AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -4110,14 +4139,15 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { func TestReceivedIncorrectChecksumIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) stats := c.Stack().Stats() want := stats.TCP.ChecksumErrors.Value() + 1 + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790), + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -4144,23 +4174,24 @@ func TestReceivedSegmentQueuing(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) // Send 200 segments. + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) data := []byte{1, 2, 3} for i := 0; i < 200; i++ { c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + i*len(data)), + SeqNum: iss.Add(seqnum.Size(i * len(data))), AckNum: c.IRS.Add(1), RcvWnd: 30000, }) } // Receive ACKs for all segments. - last := seqnum.Value(790 + 200*len(data)) + last := iss.Add(seqnum.Size(200 * len(data))) for { b := c.GetPacket() checker.IPv4(t, b, @@ -4198,7 +4229,7 @@ func TestReadAfterClosedState(t *testing.T) { t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -4212,12 +4243,13 @@ func TestReadAfterClosedState(t *testing.T) { t.Fatalf("Shutdown failed: %s", err) } + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -4232,7 +4264,7 @@ func TestReadAfterClosedState(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(2), RcvWnd: 30000, }) @@ -4242,7 +4274,7 @@ func TestReadAfterClosedState(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(uint32(791+len(data))), + checker.TCPAckNum(uint32(iss)+uint32(len(data))+1), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -4815,7 +4847,7 @@ func TestPathMTUDiscovery(t *testing.T) { // Create new connection with MSS of 1460. const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize - c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ + c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), }) @@ -4833,6 +4865,7 @@ func TestPathMTUDiscovery(t *testing.T) { receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte { var ret []byte + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) for i, size := range sizes { p := c.GetPacket() if i == which { @@ -4843,7 +4876,7 @@ func TestPathMTUDiscovery(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(seqNum), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -4893,14 +4926,15 @@ func TestTCPEndpointProbe(t *testing.T) { invoked <- struct{}{} }) - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) data := []byte{1, 2, 3} + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -5027,7 +5061,7 @@ func TestEndpointSetCongestionControl(t *testing.T) { } if connected { - c.Connect(789 /* iss */, 32768 /* rcvWnd */, nil) + c.Connect(context.TestInitialSequenceNumber, 32768 /* rcvWnd */, nil) } if err := c.EP.SetSockOpt(&tc.cc); err != tc.err { @@ -5067,7 +5101,7 @@ func TestKeepalive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) const keepAliveIdle = 100 * time.Millisecond const keepAliveInterval = 3 * time.Second @@ -5087,13 +5121,14 @@ func TestKeepalive(t *testing.T) { // 5 unacked keepalives are sent. ACK each one, and check that the // connection stays alive after 5. + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) for i := 0; i < 10; i++ { b := c.GetPacket() checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)), - checker.TCPAckNum(uint32(790)), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -5103,7 +5138,7 @@ func TestKeepalive(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS, RcvWnd: 30000, }) @@ -5128,7 +5163,7 @@ func TestKeepalive(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -5140,7 +5175,7 @@ func TestKeepalive(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), ), ) @@ -5153,7 +5188,7 @@ func TestKeepalive(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: seqnum.Value(next), RcvWnd: 30000, }) @@ -5166,7 +5201,7 @@ func TestKeepalive(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(next-1)), - checker.TCPAckNum(uint32(790)), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -5184,7 +5219,7 @@ func TestKeepalive(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: seqnum.Value(next), RcvWnd: 30000, }) @@ -5215,7 +5250,7 @@ func TestKeepalive(t *testing.T) { func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { t.Helper() // Send a SYN request. - irs = seqnum.Value(789) + irs = seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: srcPort, DstPort: context.StackPort, @@ -5260,7 +5295,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { t.Helper() // Send a SYN request. - irs = seqnum.Value(789) + irs = seqnum.Value(context.TestInitialSequenceNumber) c.SendV6Packet(nil, &context.Headers{ SrcPort: srcPort, DstPort: context.StackPort, @@ -5340,7 +5375,7 @@ func TestListenBacklogFull(t *testing.T) { SrcPort: context.TestPort + uint16(lastPortOffset), DstPort: context.StackPort, Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(789), + SeqNum: seqnum.Value(context.TestInitialSequenceNumber), RcvWnd: 30000, }) c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) @@ -5491,7 +5526,7 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) { t.Fatalf("Listen failed: %s", err) } - irs := seqnum.Value(789) + irs := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacketWithAddrs(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -5591,7 +5626,7 @@ func TestListenNoAcceptNonUnicastV6(t *testing.T) { t.Fatalf("Listen failed: %s", err) } - irs := seqnum.Value(789) + irs := seqnum.Value(context.TestInitialSequenceNumber) c.SendV6PacketWithAddrs(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -5645,7 +5680,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { // Send two SYN's the first one should get a SYN-ACK, the // second one should not get any response and is dropped as // the synRcvd count will be equal to backlog. - irs := seqnum.Value(789) + irs := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -5758,7 +5793,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { time.Sleep(50 * time.Millisecond) // Send a SYN request. - irs := seqnum.Value(789) + irs := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ // pick a different src port for new SYN. SrcPort: context.TestPort + 1, @@ -5824,7 +5859,7 @@ func TestSYNRetransmit(t *testing.T) { // Send the same SYN packet multiple times. We should still get a valid SYN-ACK // reply. - irs := seqnum.Value(789) + irs := seqnum.Value(context.TestInitialSequenceNumber) for i := 0; i < 5; i++ { c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, @@ -5867,7 +5902,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { } // Send a SYN to get a SYN-ACK. This should put the ep into SYN-RCVD state - irs := seqnum.Value(789) + irs := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -6051,7 +6086,7 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { SrcPort: srcPort + 2, DstPort: context.StackPort, Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(789), + SeqNum: seqnum.Value(context.TestInitialSequenceNumber), RcvWnd: 30000, }) @@ -6501,7 +6536,7 @@ func TestTCPLingerTimeout(t *testing.T) { c := context.New(t, 1500 /* mtu */) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) testCases := []struct { name string @@ -6552,7 +6587,7 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { } // Send a SYN request. - iss := seqnum.Value(789) + iss := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -6671,7 +6706,7 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { } // Send a SYN request. - iss := seqnum.Value(789) + iss := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -6778,7 +6813,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { } // Send a SYN request. - iss := seqnum.Value(789) + iss := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -6852,7 +6887,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { // Send a SYN request w/ sequence number lower than // the highest sequence number sent. We just reuse // the same number. - iss = seqnum.Value(789) + iss = seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -6872,7 +6907,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { // Send a SYN request w/ sequence number higher than // the highest sequence number sent. - iss = seqnum.Value(792) + iss = iss.Add(3) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -6942,7 +6977,7 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { } // Send a SYN request. - iss := seqnum.Value(789) + iss := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -7091,7 +7126,7 @@ func TestTCPCloseWithData(t *testing.T) { } // Send a SYN request. - iss := seqnum.Value(789) + iss := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -7242,7 +7277,7 @@ func TestTCPUserTimeout(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) waitEntry, notifyCh := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&waitEntry, waiter.EventHUp) @@ -7268,12 +7303,13 @@ func TestTCPUserTimeout(t *testing.T) { } next := uint32(c.IRS) + 1 + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, c.GetPacket(), checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -7299,7 +7335,7 @@ func TestTCPUserTimeout(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: seqnum.Value(next), RcvWnd: 30000, }) @@ -7328,7 +7364,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() @@ -7362,11 +7398,12 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { // Now receive 1 keepalives, but don't ACK it. b := c.GetPacket() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)), - checker.TCPAckNum(uint32(790)), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -7383,7 +7420,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: seqnum.Value(c.IRS + 1), RcvWnd: 30000, }) @@ -7413,20 +7450,20 @@ func TestIncreaseWindowOnRead(t *testing.T) { defer c.Cleanup() const rcvBuf = 65535 * 10 - c.CreateConnected(789, 30000, rcvBuf) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBuf) // Write chunks of ~30000 bytes. It's important that two // payloads make it equal or longer than MSS. remain := rcvBuf * 2 sent := 0 data := make([]byte, defaultMTU/2) - + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) for remain > len(data) { c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + sent), + SeqNum: iss.Add(seqnum.Size(sent)), AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -7438,7 +7475,7 @@ func TestIncreaseWindowOnRead(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+sent)), + checker.TCPAckNum(uint32(iss)+uint32(sent)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -7469,7 +7506,7 @@ func TestIncreaseWindowOnRead(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+sent)), + checker.TCPAckNum(uint32(iss)+uint32(sent)), checker.TCPWindow(uint16(0xffff)), checker.TCPFlags(header.TCPFlagAck), ), @@ -7483,20 +7520,20 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { defer c.Cleanup() const rcvBuf = 65535 * 10 - c.CreateConnected(789, 30000, rcvBuf) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBuf) // Write chunks of ~30000 bytes. It's important that two // payloads make it equal or longer than MSS. remain := rcvBuf sent := 0 data := make([]byte, defaultMTU/2) - + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) for remain > len(data) { c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + sent), + SeqNum: iss.Add(seqnum.Size(sent)), AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -7507,7 +7544,7 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+sent)), + checker.TCPAckNum(uint32(iss)+uint32(sent)), checker.TCPWindowLessThanEq(0xffff), checker.TCPFlags(header.TCPFlagAck), ), @@ -7523,7 +7560,7 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+sent)), + checker.TCPAckNum(uint32(iss)+uint32(sent)), checker.TCPWindow(uint16(0xffff)), checker.TCPFlags(header.TCPFlagAck), ), @@ -7664,16 +7701,16 @@ func TestResetDuringClose(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - iss := seqnum.Value(789) - c.CreateConnected(iss, 30000, -1 /* epRecvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRecvBuf */) // Send some data to make sure there is some unread // data to trigger a reset on c.Close. irs := c.IRS + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: iss.Add(1), + SeqNum: iss, AckNum: irs.Add(1), RcvWnd: 30000, }) @@ -7683,7 +7720,7 @@ func TestResetDuringClose(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), checker.TCPSeqNum(uint32(irs.Add(1))), - checker.TCPAckNum(uint32(iss.Add(5))))) + checker.TCPAckNum(uint32(iss)+4))) // Close in a separate goroutine so that we can trigger // a race with the RST we send below. This should not @@ -7698,7 +7735,7 @@ func TestResetDuringClose(t *testing.T) { c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, - SeqNum: iss.Add(5), + SeqNum: iss.Add(4), AckNum: c.IRS.Add(5), RcvWnd: 30000, Flags: header.TCPFlagRst, diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 5d81dbb94..c8126b51b 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -45,8 +45,8 @@ import ( // represents the remote endpoint. const ( v4MappedAddrPrefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" - stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + stackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + testV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" stackV4MappedAddr = v4MappedAddrPrefix + stackAddr testV4MappedAddr = v4MappedAddrPrefix + testAddr multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr diff --git a/pkg/test/dockerutil/network.go b/pkg/test/dockerutil/network.go index 047091e75..dbe17fa5e 100644 --- a/pkg/test/dockerutil/network.go +++ b/pkg/test/dockerutil/network.go @@ -102,11 +102,8 @@ func (n *Network) Inspect(ctx context.Context) (types.NetworkResource, error) { return n.client.NetworkInspect(ctx, n.id, types.NetworkInspectOptions{Verbose: true}) } -// Cleanup cleans up the docker network and all the containers attached to it. +// Cleanup cleans up the docker network. func (n *Network) Cleanup(ctx context.Context) error { - for _, c := range n.containers { - c.CleanUp(ctx) - } n.containers = nil return n.client.NetworkRemove(ctx, n.id) diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go index 1cd5fba5c..1ae76d7d7 100644 --- a/runsc/boot/controller.go +++ b/runsc/boot/controller.go @@ -400,7 +400,7 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { // Set up the restore environment. ctx := k.SupervisorContext() - mntr := newContainerMounter(cm.l.root.spec, cm.l.root.goferFDs, cm.l.k, cm.l.mountHints) + mntr := newContainerMounter(cm.l.root.spec, cm.l.root.goferFDs, cm.l.k, cm.l.mountHints, kernel.VFS2Enabled) if kernel.VFS2Enabled { ctx, err = mntr.configureRestore(ctx, cm.l.root.conf) if err != nil { diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go index 77f632bb9..32adde643 100644 --- a/runsc/boot/fs.go +++ b/runsc/boot/fs.go @@ -103,14 +103,14 @@ func addOverlay(ctx context.Context, conf *config.Config, lower *fs.Inode, name // compileMounts returns the supported mounts from the mount spec, adding any // mandatory mounts that are required by the OCI specification. -func compileMounts(spec *specs.Spec) []specs.Mount { +func compileMounts(spec *specs.Spec, vfs2Enabled bool) []specs.Mount { // Keep track of whether proc and sys were mounted. var procMounted, sysMounted, devMounted, devptsMounted bool var mounts []specs.Mount // Mount all submounts from the spec. for _, m := range spec.Mounts { - if !specutils.IsSupportedDevMount(m) { + if !vfs2Enabled && !specutils.IsVFS1SupportedDevMount(m) { log.Warningf("ignoring dev mount at %q", m.Destination) continue } @@ -572,10 +572,10 @@ type containerMounter struct { hints *podMountHints } -func newContainerMounter(spec *specs.Spec, goferFDs []*fd.FD, k *kernel.Kernel, hints *podMountHints) *containerMounter { +func newContainerMounter(spec *specs.Spec, goferFDs []*fd.FD, k *kernel.Kernel, hints *podMountHints, vfs2Enabled bool) *containerMounter { return &containerMounter{ root: spec.Root, - mounts: compileMounts(spec), + mounts: compileMounts(spec, vfs2Enabled), fds: fdDispenser{fds: goferFDs}, k: k, hints: hints, @@ -792,7 +792,7 @@ func (c *containerMounter) getMountNameAndOptions(conf *config.Config, m specs.M case bind: fd := c.fds.remove() fsName = gofervfs2.Name - opts = p9MountData(fd, c.getMountAccessType(m), conf.VFS2) + opts = p9MountData(fd, c.getMountAccessType(conf, m), conf.VFS2) // If configured, add overlay to all writable mounts. useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly @@ -802,12 +802,11 @@ func (c *containerMounter) getMountNameAndOptions(conf *config.Config, m specs.M return fsName, opts, useOverlay, nil } -func (c *containerMounter) getMountAccessType(mount specs.Mount) config.FileAccessType { +func (c *containerMounter) getMountAccessType(conf *config.Config, mount specs.Mount) config.FileAccessType { if hint := c.hints.findMount(mount); hint != nil { return hint.fileAccessType() } - // Non-root bind mounts are always shared if no hints were provided. - return config.FileAccessShared + return conf.FileAccessMounts } // mountSubmount mounts volumes inside the container's root. Because mounts may diff --git a/runsc/boot/fs_test.go b/runsc/boot/fs_test.go index e986231e5..b4f12d034 100644 --- a/runsc/boot/fs_test.go +++ b/runsc/boot/fs_test.go @@ -243,7 +243,8 @@ func TestGetMountAccessType(t *testing.T) { t.Fatalf("newPodMountHints failed: %v", err) } mounter := containerMounter{hints: podHints} - if got := mounter.getMountAccessType(specs.Mount{Source: source}); got != tst.want { + conf := &config.Config{FileAccessMounts: config.FileAccessShared} + if got := mounter.getMountAccessType(conf, specs.Mount{Source: source}); got != tst.want { t.Errorf("getMountAccessType(), want: %v, got: %v", tst.want, got) } }) diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index 5afce232d..774621970 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -752,7 +752,7 @@ func (l *Loader) createContainerProcess(root bool, cid string, info *containerIn // Setup the child container file system. l.startGoferMonitor(cid, info.goferFDs) - mntr := newContainerMounter(info.spec, info.goferFDs, l.k, l.mountHints) + mntr := newContainerMounter(info.spec, info.goferFDs, l.k, l.mountHints, kernel.VFS2Enabled) if root { if err := mntr.processHints(info.conf, info.procArgs.Credentials); err != nil { return nil, nil, nil, err diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go index 3121ca6eb..8b39bc59a 100644 --- a/runsc/boot/loader_test.go +++ b/runsc/boot/loader_test.go @@ -439,7 +439,7 @@ func TestCreateMountNamespace(t *testing.T) { } defer cleanup() - mntr := newContainerMounter(&tc.spec, []*fd.FD{fd.New(sandEnd)}, nil, &podMountHints{}) + mntr := newContainerMounter(&tc.spec, []*fd.FD{fd.New(sandEnd)}, nil, &podMountHints{}, false /* vfs2Enabled */) mns, err := mntr.createMountNamespace(ctx, conf) if err != nil { t.Fatalf("failed to create mount namespace: %v", err) @@ -479,7 +479,7 @@ func TestCreateMountNamespaceVFS2(t *testing.T) { defer l.Destroy() defer loaderCleanup() - mntr := newContainerMounter(l.root.spec, l.root.goferFDs, l.k, l.mountHints) + mntr := newContainerMounter(l.root.spec, l.root.goferFDs, l.k, l.mountHints, true /* vfs2Enabled */) if err := mntr.processHints(l.root.conf, l.root.procArgs.Credentials); err != nil { t.Fatalf("failed process hints: %v", err) } @@ -702,7 +702,7 @@ func TestRestoreEnvironment(t *testing.T) { for _, ioFD := range tc.ioFDs { ioFDs = append(ioFDs, fd.New(ioFD)) } - mntr := newContainerMounter(tc.spec, ioFDs, nil, &podMountHints{}) + mntr := newContainerMounter(tc.spec, ioFDs, nil, &podMountHints{}, false /* vfs2Enabled */) actualRenv, err := mntr.createRestoreEnvironment(conf) if !tc.errorExpected && err != nil { t.Fatalf("could not create restore environment for test:%s", tc.name) diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go index 3fd28e516..9b3dacf46 100644 --- a/runsc/boot/vfs.go +++ b/runsc/boot/vfs.go @@ -494,7 +494,7 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo // but unlikely to be correct in this context. return "", nil, false, fmt.Errorf("9P mount requires a connection FD") } - data = p9MountData(m.fd, c.getMountAccessType(m.Mount), true /* vfs2 */) + data = p9MountData(m.fd, c.getMountAccessType(conf, m.Mount), true /* vfs2 */) iopts = gofer.InternalFilesystemOptions{ UniqueID: m.Destination, } diff --git a/runsc/cmd/do.go b/runsc/cmd/do.go index 22c1dfeb8..455c57692 100644 --- a/runsc/cmd/do.go +++ b/runsc/cmd/do.go @@ -42,10 +42,11 @@ var errNoDefaultInterface = errors.New("no default interface found") // Do implements subcommands.Command for the "do" command. It sets up a simple // sandbox and executes the command inside it. See Usage() for more details. type Do struct { - root string - cwd string - ip string - quiet bool + root string + cwd string + ip string + quiet bool + overlay bool } // Name implements subcommands.Command.Name. @@ -76,6 +77,7 @@ func (c *Do) SetFlags(f *flag.FlagSet) { f.StringVar(&c.cwd, "cwd", ".", "path to the current directory, defaults to the current directory") f.StringVar(&c.ip, "ip", "192.168.10.2", "IPv4 address for the sandbox") f.BoolVar(&c.quiet, "quiet", false, "suppress runsc messages to stdout. Application output is still sent to stdout and stderr") + f.BoolVar(&c.overlay, "force-overlay", true, "use an overlay. WARNING: disabling gives the command write access to the host") } // Execute implements subcommands.Command.Execute. @@ -100,9 +102,8 @@ func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) su return Errorf("Error to retrieve hostname: %v", err) } - // Map the entire host file system, but make it readonly with a writable - // overlay on top (ignore --overlay option). - conf.Overlay = true + // Map the entire host file system, optionally using an overlay. + conf.Overlay = c.overlay absRoot, err := resolvePath(c.root) if err != nil { return Errorf("Error resolving root: %v", err) diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go index 639b2219c..4cb0164dd 100644 --- a/runsc/cmd/gofer.go +++ b/runsc/cmd/gofer.go @@ -165,8 +165,8 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) // Start with root mount, then add any other additional mount as needed. ats := make([]p9.Attacher, 0, len(spec.Mounts)+1) ap, err := fsgofer.NewAttachPoint("/", fsgofer.Config{ - ROMount: spec.Root.Readonly || conf.Overlay, - EnableXattr: conf.Verity, + ROMount: spec.Root.Readonly || conf.Overlay, + EnableVerityXattr: conf.Verity, }) if err != nil { Fatalf("creating attach point: %v", err) @@ -178,9 +178,9 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) for _, m := range spec.Mounts { if specutils.Is9PMount(m) { cfg := fsgofer.Config{ - ROMount: isReadonlyMount(m.Options) || conf.Overlay, - HostUDS: conf.FSGoferHostUDS, - EnableXattr: conf.Verity, + ROMount: isReadonlyMount(m.Options) || conf.Overlay, + HostUDS: conf.FSGoferHostUDS, + EnableVerityXattr: conf.Verity, } ap, err := fsgofer.NewAttachPoint(m.Destination, cfg) if err != nil { @@ -203,6 +203,10 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) filter.InstallUDSFilters() } + if conf.Verity { + filter.InstallXattrFilters() + } + if err := filter.Install(); err != nil { Fatalf("installing seccomp filters: %v", err) } @@ -346,7 +350,7 @@ func setupRootFS(spec *specs.Spec, conf *config.Config) error { // creates directories as needed. func setupMounts(conf *config.Config, mounts []specs.Mount, root string) error { for _, m := range mounts { - if m.Type != "bind" || !specutils.IsSupportedDevMount(m) { + if m.Type != "bind" || !specutils.IsVFS1SupportedDevMount(m) { continue } @@ -386,7 +390,7 @@ func setupMounts(conf *config.Config, mounts []specs.Mount, root string) error { func resolveMounts(conf *config.Config, mounts []specs.Mount, root string) ([]specs.Mount, error) { cleanMounts := make([]specs.Mount, 0, len(mounts)) for _, m := range mounts { - if m.Type != "bind" || !specutils.IsSupportedDevMount(m) { + if m.Type != "bind" || !specutils.IsVFS1SupportedDevMount(m) { cleanMounts = append(cleanMounts, m) continue } diff --git a/runsc/config/config.go b/runsc/config/config.go index 34ef48825..1e5858837 100644 --- a/runsc/config/config.go +++ b/runsc/config/config.go @@ -58,9 +58,12 @@ type Config struct { // DebugLogFormat is the log format for debug. DebugLogFormat string `flag:"debug-log-format"` - // FileAccess indicates how the filesystem is accessed. + // FileAccess indicates how the root filesystem is accessed. FileAccess FileAccessType `flag:"file-access"` + // FileAccessMounts indicates how non-root volumes are accessed. + FileAccessMounts FileAccessType `flag:"file-access-mounts"` + // Overlay is whether to wrap the root filesystem in an overlay. Overlay bool `flag:"overlay"` @@ -197,13 +200,19 @@ func (c *Config) validate() error { type FileAccessType int const ( - // FileAccessExclusive is the same as FileAccessShared, but enables - // extra caching for improved performance. It should only be used if - // the sandbox has exclusive access to the filesystem. + // FileAccessExclusive gives the sandbox exclusive access over files and + // directories in the filesystem. No external modifications are permitted and + // can lead to undefined behavior. + // + // Exclusive filesystem access enables more aggressive caching and offers + // significantly better performance. This is the default mode for the root + // volume. FileAccessExclusive FileAccessType = iota - // FileAccessShared sends IO requests to a Gofer process that validates the - // requests and forwards them to the host. + // FileAccessShared is used for volumes that can have external changes. It + // requires revalidation on every filesystem access to detect external + // changes, and reduces the amount of caching that can be done. This is the + // default mode for non-root volumes. FileAccessShared ) diff --git a/runsc/config/flags.go b/runsc/config/flags.go index adbee506c..1d996c841 100644 --- a/runsc/config/flags.go +++ b/runsc/config/flags.go @@ -67,7 +67,8 @@ func RegisterFlags() { flag.Bool("oci-seccomp", false, "Enables loading OCI seccomp filters inside the sandbox.") // Flags that control sandbox runtime behavior: FS related. - flag.Var(fileAccessTypePtr(FileAccessExclusive), "file-access", "specifies which filesystem to use for the root mount: exclusive (default), shared. Volume mounts are always shared.") + flag.Var(fileAccessTypePtr(FileAccessExclusive), "file-access", "specifies which filesystem validation to use for the root mount: exclusive (default), shared.") + flag.Var(fileAccessTypePtr(FileAccessShared), "file-access-mounts", "specifies which filesystem validation to use for volumes other than the root mount: shared (default), exclusive.") flag.Bool("overlay", false, "wrap filesystem mounts with writable overlay. All modifications are stored in memory inside the sandbox.") flag.Bool("verity", false, "specifies whether a verity file system will be mounted.") flag.Bool("overlayfs-stale-read", true, "assume root mount is an overlay filesystem") diff --git a/runsc/fsgofer/filter/config.go b/runsc/fsgofer/filter/config.go index fd72414ce..246b7ed3c 100644 --- a/runsc/fsgofer/filter/config.go +++ b/runsc/fsgofer/filter/config.go @@ -247,3 +247,8 @@ var udsSyscalls = seccomp.SyscallRules{ }, }, } + +var xattrSyscalls = seccomp.SyscallRules{ + unix.SYS_FGETXATTR: {}, + unix.SYS_FSETXATTR: {}, +} diff --git a/runsc/fsgofer/filter/filter.go b/runsc/fsgofer/filter/filter.go index 289886720..6c67ee288 100644 --- a/runsc/fsgofer/filter/filter.go +++ b/runsc/fsgofer/filter/filter.go @@ -36,3 +36,9 @@ func InstallUDSFilters() { // Add additional filters required for connecting to the host's sockets. allowedSyscalls.Merge(udsSyscalls) } + +// InstallXattrFilters extends the allowed syscalls to include xattr calls that +// are necessary for Verity enabled file systems. +func InstallXattrFilters() { + allowedSyscalls.Merge(xattrSyscalls) +} diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go index 1e80a634d..e04ddda47 100644 --- a/runsc/fsgofer/fsgofer.go +++ b/runsc/fsgofer/fsgofer.go @@ -48,6 +48,14 @@ const ( allowedOpenFlags = unix.O_TRUNC ) +// verityXattrs are the extended attributes used by verity file system. +var verityXattrs = map[string]struct{}{ + "user.merkle.offset": struct{}{}, + "user.merkle.size": struct{}{}, + "user.merkle.childrenOffset": struct{}{}, + "user.merkle.childrenSize": struct{}{}, +} + // join is equivalent to path.Join() but skips path.Clean() which is expensive. func join(parent, child string) string { if child == "." || child == ".." { @@ -67,8 +75,9 @@ type Config struct { // HostUDS signals whether the gofer can mount a host's UDS. HostUDS bool - // enableXattr allows Get/SetXattr for the mounted file systems. - EnableXattr bool + // EnableVerityXattr allows access to extended attributes used by the + // verity file system. + EnableVerityXattr bool } type attachPoint struct { @@ -799,7 +808,10 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error { } func (l *localFile) GetXattr(name string, size uint64) (string, error) { - if !l.attachPoint.conf.EnableXattr { + if !l.attachPoint.conf.EnableVerityXattr { + return "", unix.EOPNOTSUPP + } + if _, ok := verityXattrs[name]; !ok { return "", unix.EOPNOTSUPP } buffer := make([]byte, size) @@ -810,7 +822,10 @@ func (l *localFile) GetXattr(name string, size uint64) (string, error) { } func (l *localFile) SetXattr(name string, value string, flags uint32) error { - if !l.attachPoint.conf.EnableXattr { + if !l.attachPoint.conf.EnableVerityXattr { + return unix.EOPNOTSUPP + } + if _, ok := verityXattrs[name]; !ok { return unix.EOPNOTSUPP } return unix.Fsetxattr(l.file.FD(), name, []byte(value), int(flags)) diff --git a/runsc/fsgofer/fsgofer_test.go b/runsc/fsgofer/fsgofer_test.go index a5f09f88f..d7e141476 100644 --- a/runsc/fsgofer/fsgofer_test.go +++ b/runsc/fsgofer/fsgofer_test.go @@ -579,20 +579,24 @@ func SetGetXattr(l *localFile, name string, value string) error { return nil } +func TestSetGetDisabledXattr(t *testing.T) { + runCustom(t, []uint32{unix.S_IFREG}, rwConfs, func(t *testing.T, s state) { + name := "user.merkle.offset" + value := "tmp" + err := SetGetXattr(s.file, name, value) + if err == nil { + t.Fatalf("%v: SetGetXattr should have failed", s) + } + }) +} + func TestSetGetXattr(t *testing.T) { - xattrConfs := []Config{{ROMount: false, EnableXattr: false}, {ROMount: false, EnableXattr: true}} - runCustom(t, []uint32{unix.S_IFREG}, xattrConfs, func(t *testing.T, s state) { - name := "user.test" + runCustom(t, []uint32{unix.S_IFREG}, []Config{{ROMount: false, EnableVerityXattr: true}}, func(t *testing.T, s state) { + name := "user.merkle.offset" value := "tmp" err := SetGetXattr(s.file, name, value) - if s.conf.EnableXattr { - if err != nil { - t.Fatalf("%v: SetGetXattr failed, err: %v", s, err) - } - } else { - if err == nil { - t.Fatalf("%v: SetGetXattr should have failed", s) - } + if err != nil { + t.Fatalf("%v: SetGetXattr failed, err: %v", s, err) } }) } diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go index 5ba38bfe4..45856fd58 100644 --- a/runsc/specutils/specutils.go +++ b/runsc/specutils/specutils.go @@ -334,14 +334,13 @@ func capsFromNames(names []string, skipSet map[linux.Capability]struct{}) (auth. // Is9PMount returns true if the given mount can be mounted as an external gofer. func Is9PMount(m specs.Mount) bool { - return m.Type == "bind" && m.Source != "" && IsSupportedDevMount(m) + return m.Type == "bind" && m.Source != "" && IsVFS1SupportedDevMount(m) } -// IsSupportedDevMount returns true if the mount is a supported /dev mount. -// Only mount that does not conflict with runsc default /dev mount is -// supported. -func IsSupportedDevMount(m specs.Mount) bool { - // These are devices exist inside sentry. See pkg/sentry/fs/dev/dev.go +// IsVFS1SupportedDevMount returns true if m.Destination does not specify a +// path that is hardcoded by VFS1's implementation of /dev. +func IsVFS1SupportedDevMount(m specs.Mount) bool { + // See pkg/sentry/fs/dev/dev.go. var existingDevices = []string{ "/dev/fd", "/dev/stdin", "/dev/stdout", "/dev/stderr", "/dev/null", "/dev/zero", "/dev/full", "/dev/random", diff --git a/test/benchmarks/fs/BUILD b/test/benchmarks/fs/BUILD index c94caab60..dc82e63b2 100644 --- a/test/benchmarks/fs/BUILD +++ b/test/benchmarks/fs/BUILD @@ -8,6 +8,7 @@ benchmark_test( srcs = ["bazel_test.go"], visibility = ["//:sandbox"], deps = [ + "//pkg/cleanup", "//pkg/test/dockerutil", "//test/benchmarks/harness", "//test/benchmarks/tools", @@ -21,6 +22,7 @@ benchmark_test( srcs = ["fio_test.go"], visibility = ["//:sandbox"], deps = [ + "//pkg/cleanup", "//pkg/test/dockerutil", "//test/benchmarks/harness", "//test/benchmarks/tools", diff --git a/test/benchmarks/fs/bazel_test.go b/test/benchmarks/fs/bazel_test.go index 7ced963f6..797b1952d 100644 --- a/test/benchmarks/fs/bazel_test.go +++ b/test/benchmarks/fs/bazel_test.go @@ -20,6 +20,7 @@ import ( "strings" "testing" + "gvisor.dev/gvisor/pkg/cleanup" "gvisor.dev/gvisor/pkg/test/dockerutil" "gvisor.dev/gvisor/test/benchmarks/harness" "gvisor.dev/gvisor/test/benchmarks/tools" @@ -28,8 +29,8 @@ import ( // Dimensions here are clean/dirty cache (do or don't drop caches) // and if the mount on which we are compiling is a tmpfs/bind mount. type benchmark struct { - clearCache bool // clearCache drops caches before running. - fstype string // type of filesystem to use. + clearCache bool // clearCache drops caches before running. + fstype harness.FileSystemType // type of filesystem to use. } // Note: CleanCache versions of this test require running with root permissions. @@ -48,12 +49,12 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) { // Get a machine from the Harness on which to run. machine, err := harness.GetMachine() if err != nil { - b.Fatalf("failed to get machine: %v", err) + b.Fatalf("Failed to get machine: %v", err) } defer machine.CleanUp() benchmarks := make([]benchmark, 0, 6) - for _, filesys := range []string{harness.BindFS, harness.TmpFS, harness.RootFS} { + for _, filesys := range []harness.FileSystemType{harness.BindFS, harness.TmpFS, harness.RootFS} { benchmarks = append(benchmarks, benchmark{ clearCache: true, fstype: filesys, @@ -75,7 +76,7 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) { filesystem := tools.Parameter{ Name: "filesystem", - Value: bm.fstype, + Value: string(bm.fstype), } name, err := tools.ParametersToName(pageCache, filesystem) if err != nil { @@ -86,13 +87,14 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) { // Grab a container. ctx := context.Background() container := machine.GetContainer(ctx, b) - defer container.CleanUp(ctx) - - mts, prefix, cleanup, err := harness.MakeMount(machine, bm.fstype) + cu := cleanup.Make(func() { + container.CleanUp(ctx) + }) + defer cu.Clean() + mts, prefix, err := harness.MakeMount(machine, bm.fstype, &cu) if err != nil { b.Fatalf("Failed to make mount: %v", err) } - defer cleanup() runOpts := dockerutil.RunOpts{ Image: image, @@ -104,8 +106,9 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) { b.Fatalf("run failed with: %v", err) } + cpCmd := fmt.Sprintf("mkdir -p %s && cp -r %s %s/.", prefix, workdir, prefix) if out, err := container.Exec(ctx, dockerutil.ExecOpts{}, - "cp", "-rf", workdir, prefix+"/."); err != nil { + "/bin/sh", "-c", cpCmd); err != nil { b.Fatalf("failed to copy directory: %v (%s)", err, out) } diff --git a/test/benchmarks/fs/fio_test.go b/test/benchmarks/fs/fio_test.go index f783a2b33..1482466f4 100644 --- a/test/benchmarks/fs/fio_test.go +++ b/test/benchmarks/fs/fio_test.go @@ -21,6 +21,7 @@ import ( "strings" "testing" + "gvisor.dev/gvisor/pkg/cleanup" "gvisor.dev/gvisor/pkg/test/dockerutil" "gvisor.dev/gvisor/test/benchmarks/harness" "gvisor.dev/gvisor/test/benchmarks/tools" @@ -69,7 +70,7 @@ func BenchmarkFio(b *testing.B) { } defer machine.CleanUp() - for _, fsType := range []string{harness.BindFS, harness.TmpFS, harness.RootFS} { + for _, fsType := range []harness.FileSystemType{harness.BindFS, harness.TmpFS, harness.RootFS} { for _, tc := range testCases { operation := tools.Parameter{ Name: "operation", @@ -81,7 +82,7 @@ func BenchmarkFio(b *testing.B) { } filesystem := tools.Parameter{ Name: "filesystem", - Value: fsType, + Value: string(fsType), } name, err := tools.ParametersToName(operation, blockSize, filesystem) if err != nil { @@ -90,15 +91,18 @@ func BenchmarkFio(b *testing.B) { b.Run(name, func(b *testing.B) { b.StopTimer() tc.Size = b.N + ctx := context.Background() container := machine.GetContainer(ctx, b) - defer container.CleanUp(ctx) + cu := cleanup.Make(func() { + container.CleanUp(ctx) + }) + defer cu.Clean() - mnts, outdir, mountCleanup, err := harness.MakeMount(machine, fsType) + mnts, outdir, err := harness.MakeMount(machine, fsType, &cu) if err != nil { b.Fatalf("failed to make mount: %v", err) } - defer mountCleanup() // Start the container with the mount. if err := container.Spawn( @@ -112,6 +116,11 @@ func BenchmarkFio(b *testing.B) { b.Fatalf("failed to start fio container with: %v", err) } + if out, err := container.Exec(ctx, dockerutil.ExecOpts{}, + "mkdir", "-p", outdir); err != nil { + b.Fatalf("failed to copy directory: %v (%s)", err, out) + } + // Directory and filename inside container where fio will read/write. outfile := filepath.Join(outdir, "test.txt") @@ -130,7 +139,6 @@ func BenchmarkFio(b *testing.B) { } cmd := tc.MakeCmd(outfile) - if err := harness.DropCaches(machine); err != nil { b.Fatalf("failed to drop caches: %v", err) } diff --git a/test/benchmarks/harness/BUILD b/test/benchmarks/harness/BUILD index 116610938..367316661 100644 --- a/test/benchmarks/harness/BUILD +++ b/test/benchmarks/harness/BUILD @@ -12,6 +12,7 @@ go_library( ], visibility = ["//:sandbox"], deps = [ + "//pkg/cleanup", "//pkg/test/dockerutil", "//pkg/test/testutil", "@com_github_docker_docker//api/types/mount:go_default_library", diff --git a/test/benchmarks/harness/util.go b/test/benchmarks/harness/util.go index 36abe1069..f7e569751 100644 --- a/test/benchmarks/harness/util.go +++ b/test/benchmarks/harness/util.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/docker/docker/api/types/mount" + "gvisor.dev/gvisor/pkg/cleanup" "gvisor.dev/gvisor/pkg/test/dockerutil" "gvisor.dev/gvisor/pkg/test/testutil" ) @@ -58,52 +59,55 @@ func DebugLog(b *testing.B, msg string, args ...interface{}) { } } +// FileSystemType represents a type container mount. +type FileSystemType string + const ( // BindFS indicates a bind mount should be created. - BindFS = "bindfs" + BindFS FileSystemType = "bindfs" // TmpFS indicates a tmpfs mount should be created. - TmpFS = "tmpfs" + TmpFS FileSystemType = "tmpfs" // RootFS indicates no mount should be created and the root mount should be used. - RootFS = "rootfs" + RootFS FileSystemType = "rootfs" ) // MakeMount makes a mount and cleanup based on the requested type. Bind // and volume mounts are backed by a temp directory made with mktemp. // tmpfs mounts require no such backing and are just made. // rootfs mounts do not make a mount, but instead return a target direectory at root. -// It is up to the caller to call the returned cleanup. -func MakeMount(machine Machine, fsType string) ([]mount.Mount, string, func(), error) { +// It is up to the caller to call Clean on the passed *cleanup.Cleanup +func MakeMount(machine Machine, fsType FileSystemType, cu *cleanup.Cleanup) ([]mount.Mount, string, error) { mounts := make([]mount.Mount, 0, 1) + target := "/data" switch fsType { case BindFS: dir, err := machine.RunCommand("mktemp", "-d") if err != nil { - return mounts, "", func() {}, fmt.Errorf("failed to create tempdir: %v", err) + return mounts, "", fmt.Errorf("failed to create tempdir: %v", err) } dir = strings.TrimSuffix(dir, "\n") - + cu.Add(func() { + machine.RunCommand("rm", "-rf", dir) + }) out, err := machine.RunCommand("chmod", "777", dir) if err != nil { - machine.RunCommand("rm", "-rf", dir) - return mounts, "", func() {}, fmt.Errorf("failed modify directory: %v %s", err, out) + return mounts, "", fmt.Errorf("failed modify directory: %v %s", err, out) } - target := "/data" mounts = append(mounts, mount.Mount{ Target: target, Source: dir, Type: mount.TypeBind, }) - return mounts, target, func() { machine.RunCommand("rm", "-rf", dir) }, nil + return mounts, target, nil case RootFS: - return mounts, "/", func() {}, nil + return mounts, target, nil case TmpFS: - target := "/data" mounts = append(mounts, mount.Mount{ Target: target, Type: mount.TypeTmpfs, }) - return mounts, target, func() {}, nil + return mounts, target, nil default: - return mounts, "", func() {}, fmt.Errorf("illegal mount type not supported: %v", fsType) + return mounts, "", fmt.Errorf("illegal mount type not supported: %v", fsType) } } diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl index 567f64c41..34e83ec49 100644 --- a/test/packetimpact/runner/defs.bzl +++ b/test/packetimpact/runner/defs.bzl @@ -203,11 +203,6 @@ ALL_TESTS = [ name = "tcp_outside_the_window", ), PacketimpactTestInfo( - name = "tcp_outside_the_window_closing", - # TODO(b/181625316): Fix netstack then merge into tcp_outside_the_window. - expect_netstack_failure = True, - ), - PacketimpactTestInfo( name = "tcp_noaccept_close_rst", ), PacketimpactTestInfo( @@ -217,11 +212,6 @@ ALL_TESTS = [ name = "tcp_unacc_seq_ack", ), PacketimpactTestInfo( - name = "tcp_unacc_seq_ack_closing", - # TODO(b/181625316): Fix netstack then merge into tcp_unacc_seq_ack. - expect_netstack_failure = True, - ), - PacketimpactTestInfo( name = "tcp_paws_mechanism", # TODO(b/156682000): Fix netstack then remove the line below. expect_netstack_failure = True, @@ -289,8 +279,6 @@ ALL_TESTS = [ ), PacketimpactTestInfo( name = "tcp_fin_retransmission", - # TODO(b/181625316): Fix netstack then remove the line below. - expect_netstack_failure = True, ), ] diff --git a/test/packetimpact/runner/dut.go b/test/packetimpact/runner/dut.go index 1064ca976..b271bd47e 100644 --- a/test/packetimpact/runner/dut.go +++ b/test/packetimpact/runner/dut.go @@ -137,7 +137,7 @@ func setUpDUT(ctx context.Context, t *testing.T, id int, mkDevice func(*dockerut dn := dn t.Cleanup(func() { if err := dn.Cleanup(ctx); err != nil { - t.Errorf("unable to cleanup container %s: %s", dn.Name, err) + t.Errorf("failed to cleanup network %s: %s", dn.Name, err) } }) // Sanity check. @@ -151,13 +151,15 @@ func setUpDUT(ctx context.Context, t *testing.T, id int, mkDevice func(*dockerut info.testNet = testNet // Create the Docker container for the DUT. - var dut DUT + makeContainer := dockerutil.MakeContainer if native { - dut = mkDevice(dockerutil.MakeNativeContainer(ctx, logger(fmt.Sprintf("dut-%d", id)))) - } else { - dut = mkDevice(dockerutil.MakeContainer(ctx, logger(fmt.Sprintf("dut-%d", id)))) + makeContainer = dockerutil.MakeNativeContainer } - info.dut = dut + dutContainer := makeContainer(ctx, logger(fmt.Sprintf("dut-%d", id))) + t.Cleanup(func() { + dutContainer.CleanUp(ctx) + }) + info.dut = mkDevice(dutContainer) runOpts := dockerutil.RunOpts{ Image: "packetimpact", @@ -168,7 +170,7 @@ func setUpDUT(ctx context.Context, t *testing.T, id int, mkDevice func(*dockerut } ipv4PrefixLength, _ := testNet.Subnet.Mask.Size() - remoteIPv6, remoteMAC, dutDeviceID, dutTestNetDev, err := dut.Prepare(ctx, t, runOpts, ctrlNet, testNet) + remoteIPv6, remoteMAC, dutDeviceID, dutTestNetDev, err := info.dut.Prepare(ctx, t, runOpts, ctrlNet, testNet) if err != nil { return dutInfo{}, err } @@ -183,7 +185,7 @@ func setUpDUT(ctx context.Context, t *testing.T, id int, mkDevice func(*dockerut POSIXServerIP: AddressInSubnet(DUTAddr, *ctrlNet.Subnet), POSIXServerPort: CtrlPort, } - info.uname, err = dut.Uname(ctx) + info.uname, err = info.dut.Uname(ctx) if err != nil { return dutInfo{}, fmt.Errorf("failed to get uname information on DUT: %w", err) } @@ -231,6 +233,9 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co // Create the Docker container for the testbench. testbenchContainer := dockerutil.MakeNativeContainer(ctx, logger("testbench")) + t.Cleanup(func() { + testbenchContainer.CleanUp(ctx) + }) runOpts := dockerutil.RunOpts{ Image: "packetimpact", @@ -598,7 +603,6 @@ func createDockerNetwork(ctx context.Context, n *dockerutil.Network) error { func StartContainer(ctx context.Context, runOpts dockerutil.RunOpts, c *dockerutil.Container, containerAddr net.IP, ns []*dockerutil.Network, sysctls map[string]string, cmd ...string) error { conf, hostconf, netconf := c.ConfigsFrom(runOpts, cmd...) _ = netconf - hostconf.AutoRemove = true hostconf.Sysctls = map[string]string{"net.ipv6.conf.all.disable_ipv6": "0"} for k, v := range sysctls { hostconf.Sysctls[k] = v diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index d5cb0ae06..c0deb33e5 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -124,17 +124,6 @@ packetimpact_testbench( ) packetimpact_testbench( - name = "tcp_outside_the_window_closing", - srcs = ["tcp_outside_the_window_closing_test.go"], - deps = [ - "//pkg/tcpip/header", - "//pkg/tcpip/seqnum", - "//test/packetimpact/testbench", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -packetimpact_testbench( name = "tcp_noaccept_close_rst", srcs = ["tcp_noaccept_close_rst_test.go"], deps = [ @@ -166,17 +155,6 @@ packetimpact_testbench( ) packetimpact_testbench( - name = "tcp_unacc_seq_ack_closing", - srcs = ["tcp_unacc_seq_ack_closing_test.go"], - deps = [ - "//pkg/tcpip/header", - "//pkg/tcpip/seqnum", - "//test/packetimpact/testbench", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -packetimpact_testbench( name = "tcp_paws_mechanism", srcs = ["tcp_paws_mechanism_test.go"], deps = [ diff --git a/test/packetimpact/tests/tcp_outside_the_window_closing_test.go b/test/packetimpact/tests/tcp_outside_the_window_closing_test.go deleted file mode 100644 index 1097746c7..000000000 --- a/test/packetimpact/tests/tcp_outside_the_window_closing_test.go +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_outside_the_window_closing_test - -import ( - "flag" - "fmt" - "testing" - "time" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/test/packetimpact/testbench" -) - -func init() { - testbench.Initialize(flag.CommandLine) -} - -// TestAckOTWSeqInClosing tests that the DUT should send an ACK with -// the right ACK number when receiving a packet with OTW Seq number -// in CLOSING state. https://tools.ietf.org/html/rfc793#page-69 -func TestAckOTWSeqInClosing(t *testing.T) { - for seqNumOffset := seqnum.Size(0); seqNumOffset < 3; seqNumOffset++ { - for _, tt := range []struct { - description string - flags header.TCPFlags - payloads testbench.Layers - }{ - {"SYN", header.TCPFlagSyn, nil}, - {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil}, - {"ACK", header.TCPFlagAck, nil}, - {"FINACK", header.TCPFlagFin | header.TCPFlagAck, nil}, - {"Data", header.TCPFlagAck, []testbench.Layer{&testbench.Payload{Bytes: []byte("abc123")}}}, - } { - t.Run(fmt.Sprintf("%s%d", tt.description, seqNumOffset), func(t *testing.T) { - dut := testbench.NewDUT(t) - listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(t, listenFD) - conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close(t) - conn.Connect(t) - acceptFD, _ := dut.Accept(t, listenFD) - defer dut.Close(t, acceptFD) - - dut.Shutdown(t, acceptFD, unix.SHUT_WR) - - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { - t.Fatalf("expected FINACK from DUT, but got none: %s", err) - } - - // Do not ack the FIN from DUT so that the TCP state on DUT is CLOSING instead of CLOSED. - seqNumForTheirFIN := testbench.Uint32(uint32(*conn.RemoteSeqNum(t)) - 1) - conn.Send(t, testbench.TCP{AckNum: seqNumForTheirFIN, Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}) - - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil { - t.Errorf("expected an ACK to our FIN, but got none: %s", err) - } - - windowSize := seqnum.Size(*conn.SynAck(t).WindowSize) + seqNumOffset - conn.SendFrameStateless(t, conn.CreateFrame(t, testbench.Layers{&testbench.TCP{ - SeqNum: testbench.Uint32(uint32(conn.LocalSeqNum(t).Add(windowSize))), - AckNum: seqNumForTheirFIN, - Flags: testbench.TCPFlags(tt.flags), - }}, tt.payloads...)) - - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil { - t.Errorf("expected an ACK but got none: %s", err) - } - }) - } - } -} diff --git a/test/packetimpact/tests/tcp_outside_the_window_test.go b/test/packetimpact/tests/tcp_outside_the_window_test.go index 7cd7ff703..0523887d9 100644 --- a/test/packetimpact/tests/tcp_outside_the_window_test.go +++ b/test/packetimpact/tests/tcp_outside_the_window_test.go @@ -108,3 +108,75 @@ func TestTCPOutsideTheWindow(t *testing.T) { }) } } + +// TestAckOTWSeqInClosing tests that the DUT should send an ACK with +// the right ACK number when receiving a packet with OTW Seq number +// in CLOSING state. https://tools.ietf.org/html/rfc793#page-69 +func TestAckOTWSeqInClosing(t *testing.T) { + for _, tt := range []struct { + description string + flags header.TCPFlags + payloads testbench.Layers + seqNumOffset seqnum.Size + expectACK bool + }{ + {"SYN", header.TCPFlagSyn, nil, 0, true}, + {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 0, true}, + {"ACK", header.TCPFlagAck, nil, 0, false}, + {"FINACK", header.TCPFlagFin | header.TCPFlagAck, nil, 0, false}, + {"Data", header.TCPFlagAck, []testbench.Layer{&testbench.Payload{Bytes: []byte("Sample Data")}}, 0, false}, + + {"SYN", header.TCPFlagSyn, nil, 1, true}, + {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 1, true}, + {"ACK", header.TCPFlagAck, nil, 1, true}, + {"FINACK", header.TCPFlagFin | header.TCPFlagAck, nil, 1, true}, + {"Data", header.TCPFlagAck, []testbench.Layer{&testbench.Payload{Bytes: []byte("Sample Data")}}, 1, true}, + + {"SYN", header.TCPFlagSyn, nil, 2, true}, + {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 2, true}, + {"ACK", header.TCPFlagAck, nil, 2, true}, + {"FINACK", header.TCPFlagFin | header.TCPFlagAck, nil, 2, true}, + {"Data", header.TCPFlagAck, []testbench.Layer{&testbench.Payload{Bytes: []byte("Sample Data")}}, 2, true}, + } { + t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) { + dut := testbench.NewDUT(t) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) + conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + conn.Connect(t) + acceptFD, _ := dut.Accept(t, listenFD) + defer dut.Close(t, acceptFD) + + dut.Shutdown(t, acceptFD, unix.SHUT_WR) + + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected FINACK from DUT, but got none: %s", err) + } + + // Do not ack the FIN from DUT so that the TCP state on DUT is CLOSING instead of CLOSED. + seqNumForTheirFIN := testbench.Uint32(uint32(*conn.RemoteSeqNum(t)) - 1) + conn.Send(t, testbench.TCP{AckNum: seqNumForTheirFIN, Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}) + + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("expected an ACK to our FIN, but got none: %s", err) + } + + windowSize := seqnum.Size(*gotTCP.WindowSize) + tt.seqNumOffset + conn.SendFrameStateless(t, conn.CreateFrame(t, testbench.Layers{&testbench.TCP{ + SeqNum: testbench.Uint32(uint32(conn.LocalSeqNum(t).Add(windowSize))), + AckNum: seqNumForTheirFIN, + Flags: testbench.TCPFlags(tt.flags), + }}, tt.payloads...)) + + gotACK, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second) + if tt.expectACK && err != nil { + t.Errorf("expected an ACK but got none: %s", err) + } + if !tt.expectACK && gotACK != nil { + t.Errorf("expected no ACK but got one: %s", gotACK) + } + }) + } +} diff --git a/test/packetimpact/tests/tcp_unacc_seq_ack_closing_test.go b/test/packetimpact/tests/tcp_unacc_seq_ack_closing_test.go deleted file mode 100644 index a208210ac..000000000 --- a/test/packetimpact/tests/tcp_unacc_seq_ack_closing_test.go +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_unacc_seq_ack_closing_test - -import ( - "flag" - "fmt" - "testing" - "time" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/test/packetimpact/testbench" -) - -func init() { - testbench.Initialize(flag.CommandLine) -} - -func TestSimultaneousCloseUnaccSeqAck(t *testing.T) { - for _, tt := range []struct { - description string - makeTestingTCP func(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset, windowSize seqnum.Size) testbench.TCP - seqNumOffset seqnum.Size - expectAck bool - }{ - {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 0, expectAck: true}, - {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 1, expectAck: true}, - {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 2, expectAck: true}, - {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 0, expectAck: false}, - {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 1, expectAck: true}, - {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 2, expectAck: true}, - } { - t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) { - dut := testbench.NewDUT(t) - listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/) - defer dut.Close(t, listenFD) - conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close(t) - - conn.Connect(t) - acceptFD, _ := dut.Accept(t, listenFD) - - // Trigger active close. - dut.Shutdown(t, acceptFD, unix.SHUT_WR) - - gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second) - if err != nil { - t.Fatalf("expected a FIN: %s", err) - } - // Do not ack the FIN from DUT so that we get to CLOSING. - seqNumForTheirFIN := testbench.Uint32(uint32(*conn.RemoteSeqNum(t)) - 1) - conn.Send(t, testbench.TCP{AckNum: seqNumForTheirFIN, Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}) - - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil { - t.Errorf("expected an ACK to our FIN, but got none: %s", err) - } - - sampleData := []byte("Sample Data") - samplePayload := &testbench.Payload{Bytes: sampleData} - - origSeq := uint32(*conn.LocalSeqNum(t)) - // Send a segment with OTW Seq / unacc ACK. - tcp := tt.makeTestingTCP(t, &conn, tt.seqNumOffset, seqnum.Size(*gotTCP.WindowSize)) - if tt.description == "OTWSeq" { - // If we generate an OTW Seq segment, make sure we don't acknowledge their FIN so that - // we stay in CLOSING. - tcp.AckNum = seqNumForTheirFIN - } - conn.Send(t, tcp, samplePayload) - - got, err := conn.Expect(t, testbench.TCP{AckNum: testbench.Uint32(origSeq), Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second) - if tt.expectAck && err != nil { - t.Errorf("expected an ack in CLOSING state, but got none: %s", err) - } - if !tt.expectAck && got != nil { - t.Errorf("expected no ack in CLOSING state, but got one: %s", got) - } - }) - } -} diff --git a/test/packetimpact/tests/tcp_unacc_seq_ack_test.go b/test/packetimpact/tests/tcp_unacc_seq_ack_test.go index ce0a26171..389bfc629 100644 --- a/test/packetimpact/tests/tcp_unacc_seq_ack_test.go +++ b/test/packetimpact/tests/tcp_unacc_seq_ack_test.go @@ -209,3 +209,66 @@ func TestActiveCloseUnaccpSeqAck(t *testing.T) { }) } } + +func TestSimultaneousCloseUnaccSeqAck(t *testing.T) { + for _, tt := range []struct { + description string + makeTestingTCP func(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset, windowSize seqnum.Size) testbench.TCP + seqNumOffset seqnum.Size + expectAck bool + }{ + {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 0, expectAck: false}, + {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 1, expectAck: true}, + {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 2, expectAck: true}, + {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 0, expectAck: false}, + {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 1, expectAck: true}, + {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 2, expectAck: true}, + } { + t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) { + dut := testbench.NewDUT(t) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/) + defer dut.Close(t, listenFD) + conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + conn.Connect(t) + acceptFD, _ := dut.Accept(t, listenFD) + + // Trigger active close. + dut.Shutdown(t, acceptFD, unix.SHUT_WR) + + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected a FIN: %s", err) + } + // Do not ack the FIN from DUT so that we get to CLOSING. + seqNumForTheirFIN := testbench.Uint32(uint32(*conn.RemoteSeqNum(t)) - 1) + conn.Send(t, testbench.TCP{AckNum: seqNumForTheirFIN, Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}) + + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second) + if err != nil { + t.Errorf("expected an ACK to our FIN, but got none: %s", err) + } + + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + + origSeq := uint32(*conn.LocalSeqNum(t)) + // Send a segment with OTW Seq / unacc ACK. + tcp := tt.makeTestingTCP(t, &conn, tt.seqNumOffset, seqnum.Size(*gotTCP.WindowSize)) + if tt.description == "OTWSeq" { + // If we generate an OTW Seq segment, make sure we don't acknowledge their FIN so that + // we stay in CLOSING. + tcp.AckNum = seqNumForTheirFIN + } + conn.Send(t, tcp, samplePayload) + + got, err := conn.Expect(t, testbench.TCP{AckNum: testbench.Uint32(origSeq), Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second) + if tt.expectAck && err != nil { + t.Errorf("expected an ack in CLOSING state, but got none: %s", err) + } + if !tt.expectAck && got != nil { + t.Errorf("expected no ack in CLOSING state, but got one: %s", got) + } + }) + } +} diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 4509b5e55..043ada583 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1922,7 +1922,9 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:file_descriptor", + "@com_google_absl//absl/base:core_headers", gtest, + "//test/util:cleanup", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", @@ -2162,6 +2164,7 @@ cc_binary( "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings", gtest, ], @@ -3990,6 +3993,7 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:cleanup", + "@com_google_absl//absl/base:core_headers", gtest, "//test/util:temp_path", "//test/util:test_main", diff --git a/test/syscalls/linux/read.cc b/test/syscalls/linux/read.cc index 98d5e432d..087262535 100644 --- a/test/syscalls/linux/read.cc +++ b/test/syscalls/linux/read.cc @@ -13,11 +13,14 @@ // limitations under the License. #include <fcntl.h> +#include <sys/mman.h> #include <unistd.h> #include <vector> #include "gtest/gtest.h" +#include "absl/base/macros.h" +#include "test/util/cleanup.h" #include "test/util/file_descriptor.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -121,6 +124,43 @@ TEST_F(ReadTest, ReadWithOpath) { EXPECT_THAT(ReadFd(fd.get(), buf.data(), 1), SyscallFailsWithErrno(EBADF)); } +// Test that partial writes that hit SIGSEGV are correctly handled and return +// partial write. +TEST_F(ReadTest, PartialReadSIGSEGV) { + // Allocate 2 pages and remove permission from the second. + const size_t size = 2 * kPageSize; + void* addr = + mmap(0, size, PROT_WRITE | PROT_READ, MAP_ANONYMOUS | MAP_PRIVATE, 0, 0); + ASSERT_NE(addr, MAP_FAILED); + auto cleanup = Cleanup( + [addr, size] { EXPECT_THAT(munmap(addr, size), SyscallSucceeds()); }); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(name_.c_str(), O_RDWR, 0666)); + for (size_t i = 0; i < 2; i++) { + EXPECT_THAT(pwrite(fd.get(), addr, size, 0), + SyscallSucceedsWithValue(size)); + } + + void* badAddr = reinterpret_cast<char*>(addr) + kPageSize; + ASSERT_THAT(mprotect(badAddr, kPageSize, PROT_NONE), SyscallSucceeds()); + + // Attempt to read to both pages. Create a non-contiguous iovec pair to + // ensure operation is done in 2 steps. + struct iovec iov[] = { + { + .iov_base = addr, + .iov_len = kPageSize, + }, + { + .iov_base = addr, + .iov_len = size, + }, + }; + EXPECT_THAT(preadv(fd.get(), iov, ABSL_ARRAYSIZE(iov), 0), + SyscallSucceedsWithValue(size)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/setgid.cc b/test/syscalls/linux/setgid.cc index cd030b094..98f8f3dfe 100644 --- a/test/syscalls/linux/setgid.cc +++ b/test/syscalls/linux/setgid.cc @@ -17,6 +17,7 @@ #include <unistd.h> #include "gtest/gtest.h" +#include "absl/flags/flag.h" #include "test/util/capability_util.h" #include "test/util/cleanup.h" #include "test/util/fs_util.h" @@ -24,6 +25,11 @@ #include "test/util/temp_path.h" #include "test/util/test_util.h" +ABSL_FLAG(std::vector<std::string>, groups, std::vector<std::string>({}), + "groups the test can use"); + +constexpr gid_t kNobody = 65534; + namespace gvisor { namespace testing { @@ -46,6 +52,18 @@ PosixErrorOr<Cleanup> Setegid(gid_t egid) { // Returns a pair of groups that the user is a member of. PosixErrorOr<std::pair<gid_t, gid_t>> Groups() { + // Were we explicitly passed GIDs? + std::vector<std::string> flagged_groups = absl::GetFlag(FLAGS_groups); + if (flagged_groups.size() >= 2) { + int group1; + int group2; + if (!absl::SimpleAtoi(flagged_groups[0], &group1) || + !absl::SimpleAtoi(flagged_groups[1], &group2)) { + return PosixError(EINVAL, "failed converting group flags to ints"); + } + return std::pair<gid_t, gid_t>(group1, group2); + } + // See whether the user is a member of at least 2 groups. std::vector<gid_t> groups(64); for (; groups.size() <= NGROUPS_MAX; groups.resize(groups.size() * 2)) { @@ -58,26 +76,47 @@ PosixErrorOr<std::pair<gid_t, gid_t>> Groups() { return PosixError(errno, absl::StrFormat("getgroups(%d, %p)", groups.size(), groups.data())); } - if (ngroups >= 2) { - return std::pair<gid_t, gid_t>(groups[0], groups[1]); + + if (ngroups < 2) { + // There aren't enough groups. + break; + } + + // TODO(b/181878080): Read /proc/sys/fs/overflowgid once it is supported in + // gVisor. + if (groups[0] == kNobody || groups[1] == kNobody) { + // These groups aren't mapped into our user namespace, so we can't use + // them. + break; } - // There aren't enough groups. - break; + return std::pair<gid_t, gid_t>(groups[0], groups[1]); } - // If we're root in the root user namespace, we can set our GID to whatever we - // want. Try that before giving up. - constexpr gid_t kGID1 = 1111; - constexpr gid_t kGID2 = 2222; - auto cleanup1 = Setegid(kGID1); + // If we're running in gVisor and are root in the root user namespace, we can + // set our GID to whatever we want. Try that before giving up. + // + // This won't work in native tests, as despite having CAP_SETGID, the gofer + // process will be sandboxed and unable to change file GIDs. + if (!IsRunningOnGvisor()) { + return PosixError(EPERM, "no valid groups for native testing"); + } + PosixErrorOr<bool> capable = HaveCapability(CAP_SETGID); + if (!capable.ok()) { + return capable.error(); + } + if (!capable.ValueOrDie()) { + return PosixError(EPERM, "missing CAP_SETGID"); + } + gid_t gid = getegid(); + auto cleanup1 = Setegid(gid); if (!cleanup1.ok()) { return cleanup1.error(); } - auto cleanup2 = Setegid(kGID2); + auto cleanup2 = Setegid(kNobody); if (!cleanup2.ok()) { return cleanup2.error(); } - return std::pair<gid_t, gid_t>(kGID1, kGID2); + return std::pair<gid_t, gid_t>(gid, kNobody); } class SetgidDirTest : public ::testing::Test { @@ -85,17 +124,21 @@ class SetgidDirTest : public ::testing::Test { void SetUp() override { original_gid_ = getegid(); - // TODO(b/175325250): Enable when setgid directories are supported. SKIP_IF(IsRunningWithVFS1()); - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETGID))); + // If we can't find two usable groups, we're in an unsupporting environment. + // Skip the test. + PosixErrorOr<std::pair<gid_t, gid_t>> groups = Groups(); + SKIP_IF(!groups.ok()); + groups_ = groups.ValueOrDie(); + + auto cleanup = Setegid(groups_.first); temp_dir_ = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0777 /* mode */)); - groups_ = ASSERT_NO_ERRNO_AND_VALUE(Groups()); } void TearDown() override { - ASSERT_THAT(setegid(original_gid_), SyscallSucceeds()); + EXPECT_THAT(setegid(original_gid_), SyscallSucceeds()); } void MkdirAsGid(gid_t gid, const std::string& path, mode_t mode) { @@ -131,7 +174,7 @@ TEST_F(SetgidDirTest, Control) { ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.first, g1owned, 0777)); // Set group to G2, create a file in g1owned, and confirm that G2 owns it. - ASSERT_THAT(setegid(groups_.second), SyscallSucceeds()); + auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(Setegid(groups_.second)); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( Open(JoinPath(g1owned, "g2owned").c_str(), O_CREAT | O_RDWR, 0777)); struct stat stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(fd)); @@ -146,7 +189,7 @@ TEST_F(SetgidDirTest, CreateFile) { ASSERT_THAT(chmod(g1owned.c_str(), kDirmodeSgid), SyscallSucceeds()); // Set group to G2, create a file, and confirm that G1 owns it. - ASSERT_THAT(setegid(groups_.second), SyscallSucceeds()); + auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(Setegid(groups_.second)); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( Open(JoinPath(g1owned, "g2created").c_str(), O_CREAT | O_RDWR, 0666)); struct stat stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(fd)); @@ -194,7 +237,7 @@ TEST_F(SetgidDirTest, OldFile) { ASSERT_THAT(chmod(g1owned.c_str(), kDirmodeNoSgid), SyscallSucceeds()); // Set group to G2, create a file, confirm that G2 owns it. - ASSERT_THAT(setegid(groups_.second), SyscallSucceeds()); + auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(Setegid(groups_.second)); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( Open(JoinPath(g1owned, "g2created").c_str(), O_CREAT | O_RDWR, 0666)); struct stat stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(fd)); @@ -217,7 +260,7 @@ TEST_F(SetgidDirTest, OldDir) { ASSERT_THAT(chmod(g1owned.c_str(), kDirmodeNoSgid), SyscallSucceeds()); // Set group to G2, create a directory, confirm that G2 owns it. - ASSERT_THAT(setegid(groups_.second), SyscallSucceeds()); + auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(Setegid(groups_.second)); auto g2created = JoinPath(g1owned, "g2created"); ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.second, g2created, 0666)); struct stat stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(g2created)); @@ -306,6 +349,10 @@ class FileModeTest : public ::testing::TestWithParam<FileModeTestcase> {}; TEST_P(FileModeTest, WriteToFile) { SKIP_IF(IsRunningWithVFS1()); + PosixErrorOr<std::pair<gid_t, gid_t>> groups = Groups(); + SKIP_IF(!groups.ok()); + + auto cleanup = Setegid(groups.ValueOrDie().first); auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0777 /* mode */)); auto path = JoinPath(temp_dir.path(), GetParam().name); @@ -329,26 +376,28 @@ TEST_P(FileModeTest, WriteToFile) { TEST_P(FileModeTest, TruncateFile) { SKIP_IF(IsRunningWithVFS1()); + PosixErrorOr<std::pair<gid_t, gid_t>> groups = Groups(); + SKIP_IF(!groups.ok()); + + auto cleanup = Setegid(groups.ValueOrDie().first); auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0777 /* mode */)); auto path = JoinPath(temp_dir.path(), GetParam().name); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.c_str(), O_CREAT | O_RDWR, 0666)); - ASSERT_THAT(fchmod(fd.get(), GetParam().mode), SyscallSucceeds()); - struct stat stats; - ASSERT_THAT(fstat(fd.get(), &stats), SyscallSucceeds()); - EXPECT_EQ(stats.st_mode & kDirmodeMask, GetParam().mode); // Write something to the file, as truncating an empty file is a no-op. constexpr char c = 'M'; ASSERT_THAT(write(fd.get(), &c, sizeof(c)), SyscallSucceedsWithValue(sizeof(c))); + ASSERT_THAT(fchmod(fd.get(), GetParam().mode), SyscallSucceeds()); // For security reasons, truncating the file clears the SUID bit, and clears // the SGID bit when the group executable bit is unset (which is not a true // SGID binary). ASSERT_THAT(ftruncate(fd.get(), 0), SyscallSucceeds()); + struct stat stats; ASSERT_THAT(fstat(fd.get(), &stats), SyscallSucceeds()); EXPECT_EQ(stats.st_mode & kDirmodeMask, GetParam().result_mode); } diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 54b45b075..597b5bcb1 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -490,7 +490,11 @@ void TestListenWhileConnect(const TestParam& param, TestAddress const& connector = param.connector; constexpr int kBacklog = 2; - constexpr int kClients = kBacklog + 1; + // Linux completes one more connection than the listen backlog argument. + // To ensure that there is at least one client connection that stays in + // connecting state, keep 2 more client connections than the listen backlog. + // gVisor differs in this behavior though, gvisor.dev/issue/3153. + constexpr int kClients = kBacklog + 2; // Create the listening socket. FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( @@ -527,7 +531,7 @@ void TestListenWhileConnect(const TestParam& param, for (auto& client : clients) { constexpr int kTimeout = 10000; - struct pollfd pfd = { + pollfd pfd = { .fd = client.get(), .events = POLLIN, }; @@ -543,6 +547,10 @@ void TestListenWhileConnect(const TestParam& param, ASSERT_THAT(read(client.get(), &c, sizeof(c)), AnyOf(SyscallFailsWithErrno(ECONNRESET), SyscallFailsWithErrno(ECONNREFUSED))); + // The last client connection would be in connecting (SYN_SENT) state. + if (client.get() == clients[kClients - 1].get()) { + ASSERT_EQ(errno, ECONNREFUSED) << strerror(errno); + } } } @@ -598,7 +606,7 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) { connector.addr_len); if (ret != 0) { EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); - struct pollfd pfd = { + pollfd pfd = { .fd = conn_fd.get(), .events = POLLOUT, }; @@ -623,6 +631,95 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) { } } +// Test if the stack completes atmost listen backlog number of client +// connections. It exercises the path of the stack that enqueues completed +// connections to accept queue vs new incoming SYNs. +TEST_P(SocketInetLoopbackTest, TCPConnectBacklog_NoRandomSave) { + const auto& param = GetParam(); + const TestAddress& listener = param.listener; + const TestAddress& connector = param.connector; + + constexpr int kBacklog = 1; + // Keep the number of client connections more than the listen backlog. + // Linux completes one more connection than the listen backlog argument. + // gVisor differs in this behavior though, gvisor.dev/issue/3153. + int kClients = kBacklog + 2; + if (IsRunningOnGvisor()) { + kClients--; + } + + // Run the following test for few iterations to test race between accept queue + // getting filled with incoming SYNs. + for (int num = 0; num < 10; num++) { + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); + + socklen_t addrlen = listener.addr_len; + ASSERT_THAT( + getsockname(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + + std::vector<FileDescriptor> clients; + // Issue multiple non-blocking client connects. + for (int i = 0; i < kClients; i++) { + FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len); + if (ret != 0) { + EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); + } + clients.push_back(std::move(client)); + } + + // Now that client connects are issued, wait for the accept queue to get + // filled and ensure no new client connection is completed. + for (int i = 0; i < kClients; i++) { + pollfd pfd = { + .fd = clients[i].get(), + .events = POLLOUT, + }; + if (i < kClients - 1) { + // Poll for client side connection completions with a large timeout. + // We cannot poll on the listener side without calling accept as poll + // stays level triggered with non-zero accept queue length. + // + // Client side poll would not guarantee that the completed connection + // has been enqueued in to the acccept queue, but the fact that the + // listener ACKd the SYN, means that it cannot complete any new incoming + // SYNs when it has already ACKd for > backlog number of SYNs. + ASSERT_THAT(poll(&pfd, 1, 10000), SyscallSucceedsWithValue(1)) + << "num=" << num << " i=" << i << " kClients=" << kClients; + ASSERT_EQ(pfd.revents, POLLOUT) << "num=" << num << " i=" << i; + } else { + // Now that we expect accept queue filled up, ensure that the last + // client connection never completes with a smaller poll timeout. + ASSERT_THAT(poll(&pfd, 1, 1000), SyscallSucceedsWithValue(0)) + << "num=" << num << " i=" << i; + } + + ASSERT_THAT(close(clients[i].release()), SyscallSucceedsWithValue(0)) + << "num=" << num << " i=" << i; + } + clients.clear(); + // We close the listening side and open a new listener. We could instead + // drain the accept queue by calling accept() and reuse the listener, but + // that is racy as the retransmitted SYNs could get ACKd as we make room in + // the accept queue. + ASSERT_THAT(close(listen_fd.release()), SyscallSucceedsWithValue(0)); + } +} + // TCPFinWait2Test creates a pair of connected sockets then closes one end to // trigger FIN_WAIT2 state for the closed endpoint. Then it binds the same local // IP/port on a new socket and tries to connect. The connect should fail w/ @@ -937,7 +1034,7 @@ void setupTimeWaitClose(const TestAddress* listener, ASSERT_THAT(shutdown(active_closefd.get(), SHUT_WR), SyscallSucceeds()); { constexpr int kTimeout = 10000; - struct pollfd pfd = { + pollfd pfd = { .fd = passive_closefd.get(), .events = POLLIN, }; @@ -948,7 +1045,7 @@ void setupTimeWaitClose(const TestAddress* listener, { constexpr int kTimeout = 10000; constexpr int16_t want_events = POLLHUP; - struct pollfd pfd = { + pollfd pfd = { .fd = active_closefd.get(), .events = want_events, }; @@ -1181,7 +1278,7 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) { // Wait for accept_fd to process the RST. constexpr int kTimeout = 10000; - struct pollfd pfd = { + pollfd pfd = { .fd = accept_fd.get(), .events = POLLIN, }; @@ -1705,7 +1802,7 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort_NoRandomSave) { SyscallSucceedsWithValue(sizeof(i))); } - struct pollfd pollfds[kThreadCount]; + pollfd pollfds[kThreadCount]; for (int i = 0; i < kThreadCount; i++) { pollfds[i].fd = listener_fds[i].get(); pollfds[i].events = POLLIN; diff --git a/test/syscalls/linux/write.cc b/test/syscalls/linux/write.cc index 740992d0a..3373ba72b 100644 --- a/test/syscalls/linux/write.cc +++ b/test/syscalls/linux/write.cc @@ -15,6 +15,7 @@ #include <errno.h> #include <fcntl.h> #include <signal.h> +#include <sys/mman.h> #include <sys/resource.h> #include <sys/stat.h> #include <sys/types.h> @@ -23,6 +24,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/base/macros.h" #include "test/util/cleanup.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -256,6 +258,82 @@ TEST_F(WriteTest, PwriteWithOpath) { SyscallFailsWithErrno(EBADF)); } +// Test that partial writes that hit SIGSEGV are correctly handled and return +// partial write. +TEST_F(WriteTest, PartialWriteSIGSEGV) { + // Allocate 2 pages and remove permission from the second. + const size_t size = 2 * kPageSize; + void* addr = mmap(0, size, PROT_READ, MAP_ANONYMOUS | MAP_PRIVATE, 0, 0); + ASSERT_NE(addr, MAP_FAILED); + auto cleanup = Cleanup( + [addr, size] { EXPECT_THAT(munmap(addr, size), SyscallSucceeds()); }); + + void* badAddr = reinterpret_cast<char*>(addr) + kPageSize; + ASSERT_THAT(mprotect(badAddr, kPageSize, PROT_NONE), SyscallSucceeds()); + + TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path().c_str(), O_WRONLY)); + + // Attempt to write both pages to the file. Create a non-contiguous iovec pair + // to ensure operation is done in 2 steps. + struct iovec iov[] = { + { + .iov_base = addr, + .iov_len = kPageSize, + }, + { + .iov_base = addr, + .iov_len = size, + }, + }; + // Write should succeed for the first iovec and half of the second (=2 pages). + EXPECT_THAT(pwritev(fd.get(), iov, ABSL_ARRAYSIZE(iov), 0), + SyscallSucceedsWithValue(2 * kPageSize)); +} + +// Test that partial writes that hit SIGBUS are correctly handled and return +// partial write. +TEST_F(WriteTest, PartialWriteSIGBUS) { + SKIP_IF(getenv("GVISOR_GOFER_UNCACHED")); // Can't mmap from uncached files. + + TempPath mapfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd_map = + ASSERT_NO_ERRNO_AND_VALUE(Open(mapfile.path().c_str(), O_RDWR)); + + // Let the first page be read to force a partial write. + ASSERT_THAT(ftruncate(fd_map.get(), kPageSize), SyscallSucceeds()); + + // Map 2 pages, one of which is not allocated in the backing file. Reading + // from it will trigger a SIGBUS. + const size_t size = 2 * kPageSize; + void* addr = + mmap(NULL, size, PROT_READ, MAP_FILE | MAP_PRIVATE, fd_map.get(), 0); + ASSERT_NE(addr, MAP_FAILED); + auto cleanup = Cleanup( + [addr, size] { EXPECT_THAT(munmap(addr, size), SyscallSucceeds()); }); + + TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path().c_str(), O_WRONLY)); + + // Attempt to write both pages to the file. Create a non-contiguous iovec pair + // to ensure operation is done in 2 steps. + struct iovec iov[] = { + { + .iov_base = addr, + .iov_len = kPageSize, + }, + { + .iov_base = addr, + .iov_len = size, + }, + }; + // Write should succeed for the first iovec and half of the second (=2 pages). + ASSERT_THAT(pwritev(fd.get(), iov, ABSL_ARRAYSIZE(iov), 0), + SyscallSucceedsWithValue(2 * kPageSize)); +} + } // namespace } // namespace testing diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD index 44cb33ae4..c2747d94c 100644 --- a/tools/go_marshal/gomarshal/BUILD +++ b/tools/go_marshal/gomarshal/BUILD @@ -8,6 +8,7 @@ go_library( "generator.go", "generator_interfaces.go", "generator_interfaces_array_newtype.go", + "generator_interfaces_dynamic.go", "generator_interfaces_primitive_newtype.go", "generator_interfaces_struct.go", "generator_tests.go", diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index 634abd1af..39394d2a7 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -126,6 +126,12 @@ func (g *Generator) writeHeader() error { b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n") // Emit build tags. + b.emit("// If there are issues with build tag aggregation, see\n") + b.emit("// tools/go_marshal/gomarshal/generator.go:writeHeader(). The build tags here\n") + b.emit("// come from the input set of files used to generate this file. This input set\n") + b.emit("// is filtered based on pre-defined file suffixes related to build tags, see \n") + b.emit("// tools/defs.bzl:calculate_sets().\n\n") + if t := tags.Aggregate(g.inputs); len(t) > 0 { b.emit(strings.Join(t.Lines(), "\n")) b.emit("\n\n") @@ -381,36 +387,29 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp func (g *Generator) generateOne(t *marshallableType, fset *token.FileSet) *interfaceGenerator { i := newInterfaceGenerator(t.spec, t.recv, fset) + if t.dynamic { + if t.slice != nil { + abortAt(fset.Position(t.slice.comment.Slash), "Slice API is not supported for dynamic types because it assumes that each slice element is statically sized.") + } + // No validation needed, assume the user knows what they are doing. + i.emitMarshallableForDynamicType() + return i + } switch ty := t.spec.Type.(type) { case *ast.StructType: - if t.dynamic { - // Don't validate because this type is dynamically sized and probably - // contains some funky slices which the validation does not allow. - i.emitMarshallableForStruct(ty, t.dynamic) - if t.slice != nil { - abortAt(fset.Position(t.slice.comment.Slash), "Slice API is not supported for dynamic types because it assumes that each slice element is statically sized.") - } - break - } i.validateStruct(t.spec, ty) - i.emitMarshallableForStruct(ty, t.dynamic) + i.emitMarshallableForStruct(ty) if t.slice != nil { i.emitMarshallableSliceForStruct(ty, t.slice) } case *ast.Ident: i.validatePrimitiveNewtype(ty) - if t.dynamic { - abortAt(fset.Position(t.slice.comment.Slash), "Primitive type marked as '+marshal dynamic', but primitive types can not be dynamic.") - } i.emitMarshallableForPrimitiveNewtype(ty) if t.slice != nil { i.emitMarshallableSliceForPrimitiveNewtype(ty, t.slice) } case *ast.ArrayType: i.validateArrayNewtype(t.spec.Name, ty) - if t.dynamic { - abortAt(fset.Position(t.slice.comment.Slash), "Marking array types as `dynamic` is currently not supported.") - } // After validate, we can safely call arrayLen. i.emitMarshallableForArrayNewtype(t.spec.Name, ty, ty.Elt.(*ast.Ident)) if t.slice != nil { diff --git a/tools/go_marshal/gomarshal/generator_interfaces_dynamic.go b/tools/go_marshal/gomarshal/generator_interfaces_dynamic.go new file mode 100644 index 000000000..b1a8622cd --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_interfaces_dynamic.go @@ -0,0 +1,96 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gomarshal + +func (g *interfaceGenerator) emitMarshallableForDynamicType() { + // The user writes their own MarshalBytes, UnmarshalBytes and SizeBytes for + // dynamic types. Generate the rest using these definitions. + + g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Type %s is dynamic so it might have slice/string headers. Hence, it is not packed.\n", g.typeName()) + g.emit("return false\n") + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName()) + g.emit("%s.MarshalBytes(dst)\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Type %s doesn't have a packed layout in memory, fallback to UnmarshalBytes.\n", g.typeName()) + g.emit("%s.UnmarshalBytes(src)\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n") + g.emit("//go:nosplit\n") + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("buf := cc.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r) + g.emit("%s.MarshalBytes(buf) // escapes: fallback.\n", g.r) + g.emit("return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") + }) + g.emit("}\n\n") + + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.emit("//go:nosplit\n") + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("return %s.CopyOutN(cc, addr, %s.SizeBytes())\n", g.r, g.r) + }) + g.emit("}\n\n") + + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.emit("//go:nosplit\n") + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) + g.emit("buf := cc.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r) + g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emit("// Unmarshal unconditionally. If we had a short copy-in, this results in a\n") + g.emit("// partially unmarshalled struct.\n") + g.emit("%s.UnmarshalBytes(buf) // escapes: fallback.\n", g.r) + g.emit("return length, err\n") + }) + g.emit("}\n\n") + + g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") + g.recordUsedImport("io") + g.emit("func (%s *%s) WriteTo(writer io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r) + g.emit("%s.MarshalBytes(buf)\n", g.r) + g.emit("length, err := writer.Write(buf)\n") + g.emit("return int64(length), err\n") + }) + g.emit("}\n\n") +} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go index f98e41ed7..5f6306b8f 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_struct.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go @@ -69,11 +69,7 @@ func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType }) } -func (g *interfaceGenerator) isStructPacked(st *ast.StructType, isDynamic bool) bool { - if isDynamic { - // Dynamic types are not packed because a slice header might be present. - return false - } +func (g *interfaceGenerator) isStructPacked(st *ast.StructType) bool { packed := true forEachStructField(st, func(f *ast.Field) { if f.Tag != nil { @@ -89,17 +85,165 @@ func (g *interfaceGenerator) isStructPacked(st *ast.StructType, isDynamic bool) return packed } -func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType, isDynamic bool) { - thisPacked := g.isStructPacked(st, isDynamic) +func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { + thisPacked := g.isStructPacked(st) - // Dynamic types are supposed to manually implement SizeBytes, MarshalBytes - // and UnmarshalBytes. The rest of the methos are autogenerated and depend on - // the implementation of these three. - if !isDynamic { - g.emitSizeBytesForStruct(st) - g.emitMarshalBytesForStruct(st) - g.emitUnmarshalBytesForStruct(st) - } + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) + g.inIndent(func() { + primitiveSize := 0 + var dynamicSizeTerms []string + + forEachStructField(st, fieldDispatcher{ + primitive: func(_, t *ast.Ident) { + if size, dynamic := g.scalarSize(t); !dynamic { + primitiveSize += size + } else { + g.recordUsedMarshallable(t.Name) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name)) + } + }, + selector: func(_, tX, tSel *ast.Ident) { + tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name) + g.recordUsedImport(tX.Name) + g.recordUsedMarshallable(tName) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName)) + }, + array: func(_ *ast.Ident, a *ast.ArrayType, t *ast.Ident) { + lenExpr := g.arrayLenExpr(a) + if size, dynamic := g.scalarSize(t); !dynamic { + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%d*%s", size, lenExpr)) + } else { + g.recordUsedMarshallable(t.Name) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%s", t.Name, lenExpr)) + } + }, + }.dispatch) + g.emit("return %d", primitiveSize) + if len(dynamicSizeTerms) > 0 { + g.incIndent() + } + { + for _, d := range dynamicSizeTerms { + g.emitNoIndent(" +\n") + g.emit(d) + } + } + if len(dynamicSizeTerms) > 0 { + g.decIndent() + } + }) + g.emit("\n}\n\n") + + g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") + g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + forEachStructField(st, fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("dst", len) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can reference here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name) + } + return + } + g.marshalScalar(g.fieldAccessor(n), t.Name, "dst") + }, + selector: func(n, tX, tSel *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", tX.Name, tSel.Name) + g.emit("dst = dst[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name) + return + } + g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") + }, + array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) { + lenExpr := g.arrayLenExpr(a) + if n.Name == "_" { + g.emit("// Padding: dst[:sizeof(%s)*%s] ~= [%s]%s{0}\n", t.Name, lenExpr, lenExpr, t.Name) + if size, dynamic := g.scalarSize(t); !dynamic { + g.emit("dst = dst[%d*(%s):]\n", size, lenExpr) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can reference here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("dst = dst[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr) + } + return + } + + g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) + g.inIndent(func() { + g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst") + }) + g.emit("}\n") + }, + }.dispatch) + }) + g.emit("}\n\n") + + g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") + g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + forEachStructField(st, fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("src", len) + } else { + // We don't have an instance of the dynamic type we can + // reference here (since the version in this struct is + // anonymous). Use a typed nil pointer to call + // SizeBytes() instead. + g.shiftDynamic("src", fmt.Sprintf("(*%s)(nil)", t.Name)) + g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name)) + } + return + } + g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src") + }, + selector: func(n, tX, tSel *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: %s ~= src[:sizeof(%s.%s)]\n", g.fieldAccessor(n), tX.Name, tSel.Name) + g.emit("src = src[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name) + g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s.%s)(nil)", tX.Name, tSel.Name)) + return + } + g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") + }, + array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) { + lenExpr := g.arrayLenExpr(a) + if n.Name == "_" { + g.emit("// Padding: ~ copy([%s]%s(%s), src[:sizeof(%s)*%s])\n", lenExpr, t.Name, g.fieldAccessor(n), t.Name, lenExpr) + if size, dynamic := g.scalarSize(t); !dynamic { + g.emit("src = src[%d*(%s):]\n", size, lenExpr) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can referece here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("src = src[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr) + } + return + } + + g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) + g.inIndent(func() { + g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src") + }) + g.emit("}\n") + }, + }.dispatch) + }) + g.emit("}\n\n") g.emit("// Packed implements marshal.Marshallable.Packed.\n") g.emit("//go:nosplit\n") @@ -284,171 +428,8 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType, isDyn g.emit("}\n\n") } -func (g *interfaceGenerator) emitSizeBytesForStruct(st *ast.StructType) { - g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") - g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) - g.inIndent(func() { - primitiveSize := 0 - var dynamicSizeTerms []string - - forEachStructField(st, fieldDispatcher{ - primitive: func(_, t *ast.Ident) { - if size, dynamic := g.scalarSize(t); !dynamic { - primitiveSize += size - } else { - g.recordUsedMarshallable(t.Name) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name)) - } - }, - selector: func(_, tX, tSel *ast.Ident) { - tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name) - g.recordUsedImport(tX.Name) - g.recordUsedMarshallable(tName) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName)) - }, - array: func(_ *ast.Ident, a *ast.ArrayType, t *ast.Ident) { - lenExpr := g.arrayLenExpr(a) - if size, dynamic := g.scalarSize(t); !dynamic { - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%d*%s", size, lenExpr)) - } else { - g.recordUsedMarshallable(t.Name) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%s", t.Name, lenExpr)) - } - }, - }.dispatch) - g.emit("return %d", primitiveSize) - if len(dynamicSizeTerms) > 0 { - g.incIndent() - } - { - for _, d := range dynamicSizeTerms { - g.emitNoIndent(" +\n") - g.emit(d) - } - } - if len(dynamicSizeTerms) > 0 { - g.decIndent() - } - }) - g.emit("\n}\n\n") -} - -func (g *interfaceGenerator) emitMarshalBytesForStruct(st *ast.StructType) { - g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") - g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - forEachStructField(st, fieldDispatcher{ - primitive: func(n, t *ast.Ident) { - if n.Name == "_" { - g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("dst", len) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can reference here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name) - } - return - } - g.marshalScalar(g.fieldAccessor(n), t.Name, "dst") - }, - selector: func(n, tX, tSel *ast.Ident) { - if n.Name == "_" { - g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", tX.Name, tSel.Name) - g.emit("dst = dst[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name) - return - } - g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") - }, - array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) { - lenExpr := g.arrayLenExpr(a) - if n.Name == "_" { - g.emit("// Padding: dst[:sizeof(%s)*%s] ~= [%s]%s{0}\n", t.Name, lenExpr, lenExpr, t.Name) - if size, dynamic := g.scalarSize(t); !dynamic { - g.emit("dst = dst[%d*(%s):]\n", size, lenExpr) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can reference here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("dst = dst[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr) - } - return - } - - g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) - g.inIndent(func() { - g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst") - }) - g.emit("}\n") - }, - }.dispatch) - }) - g.emit("}\n\n") -} - -func (g *interfaceGenerator) emitUnmarshalBytesForStruct(st *ast.StructType) { - g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") - g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - forEachStructField(st, fieldDispatcher{ - primitive: func(n, t *ast.Ident) { - if n.Name == "_" { - g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("src", len) - } else { - // We don't have an instance of the dynamic type we can - // reference here (since the version in this struct is - // anonymous). Use a typed nil pointer to call - // SizeBytes() instead. - g.shiftDynamic("src", fmt.Sprintf("(*%s)(nil)", t.Name)) - g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name)) - } - return - } - g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src") - }, - selector: func(n, tX, tSel *ast.Ident) { - if n.Name == "_" { - g.emit("// Padding: %s ~= src[:sizeof(%s.%s)]\n", g.fieldAccessor(n), tX.Name, tSel.Name) - g.emit("src = src[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name) - g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s.%s)(nil)", tX.Name, tSel.Name)) - return - } - g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") - }, - array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) { - lenExpr := g.arrayLenExpr(a) - if n.Name == "_" { - g.emit("// Padding: ~ copy([%s]%s(%s), src[:sizeof(%s)*%s])\n", lenExpr, t.Name, g.fieldAccessor(n), t.Name, lenExpr) - if size, dynamic := g.scalarSize(t); !dynamic { - g.emit("src = src[%d*(%s):]\n", size, lenExpr) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can referece here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("src = src[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr) - } - return - } - - g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) - g.inIndent(func() { - g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src") - }) - g.emit("}\n") - }, - }.dispatch) - }) - g.emit("}\n\n") -} - func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, slice *sliceAPI) { - thisPacked := g.isStructPacked(st, false /* isDynamic */) + thisPacked := g.isStructPacked(st) if slice.inner { abortAt(g.f.Position(slice.comment.Slash), fmt.Sprintf("The ':inner' argument to '+marshal slice:%s:inner' is only applicable to newtypes on primitives. Remove it from this struct declaration.", slice.ident)) diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD index cb2d4e6e3..5bceacd32 100644 --- a/tools/go_marshal/test/BUILD +++ b/tools/go_marshal/test/BUILD @@ -23,7 +23,10 @@ go_test( go_library( name = "test", testonly = 1, - srcs = ["test.go"], + srcs = [ + "dynamic.go", + "test.go", + ], marshal = True, visibility = ["//tools/go_marshal/test:__subpackages__"], deps = [ diff --git a/tools/go_marshal/test/dynamic.go b/tools/go_marshal/test/dynamic.go new file mode 100644 index 000000000..9a812efe9 --- /dev/null +++ b/tools/go_marshal/test/dynamic.go @@ -0,0 +1,83 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import "gvisor.dev/gvisor/pkg/marshal/primitive" + +// Type12Dynamic is a dynamically sized struct which depends on the +// autogenerator to generate some Marshallable methods for it. +// +// +marshal dynamic +type Type12Dynamic struct { + X primitive.Int64 + Y []primitive.Int64 +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (t *Type12Dynamic) SizeBytes() int { + return (len(t.Y) * 8) + t.X.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (t *Type12Dynamic) MarshalBytes(dst []byte) { + t.X.MarshalBytes(dst) + dst = dst[t.X.SizeBytes():] + for i, x := range t.Y { + x.MarshalBytes(dst[i*8 : (i+1)*8]) + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (t *Type12Dynamic) UnmarshalBytes(src []byte) { + t.X.UnmarshalBytes(src) + if t.Y != nil { + t.Y = t.Y[:0] + } + for i := t.X.SizeBytes(); i < len(src); i += 8 { + var x primitive.Int64 + x.UnmarshalBytes(src[i:]) + t.Y = append(t.Y, x) + } +} + +// Type13Dynamic is a dynamically sized struct which depends on the +// autogenerator to generate some Marshallable methods for it. +// +// It represents a string in memory which is preceded by a uint32 indicating +// the string size. +// +// +marshal dynamic +type Type13Dynamic string + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (t *Type13Dynamic) SizeBytes() int { + return (*primitive.Uint32)(nil).SizeBytes() + len(*t) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (t *Type13Dynamic) MarshalBytes(dst []byte) { + strLen := primitive.Uint32(len(*t)) + strLen.MarshalBytes(dst) + dst = dst[strLen.SizeBytes():] + copy(dst[:strLen], *t) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (t *Type13Dynamic) UnmarshalBytes(src []byte) { + var strLen primitive.Uint32 + strLen.UnmarshalBytes(src) + src = src[strLen.SizeBytes():] + *t = Type13Dynamic(src[:strLen]) +} diff --git a/tools/go_marshal/test/marshal_test.go b/tools/go_marshal/test/marshal_test.go index b0091dc64..733689c79 100644 --- a/tools/go_marshal/test/marshal_test.go +++ b/tools/go_marshal/test/marshal_test.go @@ -515,20 +515,39 @@ func TestLimitedSliceMarshalling(t *testing.T) { } } -func TestDynamicType(t *testing.T) { +func TestDynamicTypeStruct(t *testing.T) { t12 := test.Type12Dynamic{ X: 32, Y: []primitive.Int64{5, 6, 7}, } + var cc mockCopyContext + cc.setLimit(t12.SizeBytes()) - var m marshal.Marshallable - m = &t12 // Ensure that all methods were generated. - b := make([]byte, m.SizeBytes()) - m.MarshalBytes(b) + if _, err := t12.CopyOut(&cc, usermem.Addr(0)); err != nil { + t.Fatalf("cc.CopyOut faile: %v", err) + } - var res test.Type12Dynamic - res.UnmarshalBytes(b) + res := test.Type12Dynamic{ + Y: make([]primitive.Int64, len(t12.Y)), + } + res.CopyIn(&cc, usermem.Addr(0)) if !reflect.DeepEqual(t12, res) { t.Errorf("dynamic type is not same after marshalling and unmarshalling: before = %+v, after = %+v", t12, res) } } + +func TestDynamicTypeIdentifier(t *testing.T) { + s := test.Type13Dynamic("go_marshal") + var cc mockCopyContext + cc.setLimit(s.SizeBytes()) + + if _, err := s.CopyOut(&cc, usermem.Addr(0)); err != nil { + t.Fatalf("cc.CopyOut faile: %v", err) + } + + res := test.Type13Dynamic(make([]byte, len(s))) + res.CopyIn(&cc, usermem.Addr(0)) + if res != s { + t.Errorf("dynamic type is not same after marshalling and unmarshalling: before = %s, after = %s", s, res) + } +} diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go index b8eb989d9..e7e3ed74a 100644 --- a/tools/go_marshal/test/test.go +++ b/tools/go_marshal/test/test.go @@ -16,8 +16,6 @@ package test import ( - "gvisor.dev/gvisor/pkg/marshal/primitive" - // We're intentionally using a package name alias here even though it's not // necessary to test the code generator's ability to handle package aliases. ex "gvisor.dev/gvisor/tools/go_marshal/test/external" @@ -200,36 +198,3 @@ type Type11 struct { ex.External y int64 } - -// Type12Dynamic is a dynamically sized struct which depends on the autogenerator -// to generate some Marshallable methods for it. -// -// +marshal dynamic -type Type12Dynamic struct { - X primitive.Int64 - Y []primitive.Int64 -} - -// SizeBytes implements marshal.Marshallable.SizeBytes. -func (t *Type12Dynamic) SizeBytes() int { - return (len(t.Y) * 8) + t.X.SizeBytes() -} - -// MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (t *Type12Dynamic) MarshalBytes(dst []byte) { - t.X.MarshalBytes(dst) - dst = dst[t.X.SizeBytes():] - for i, x := range t.Y { - x.MarshalBytes(dst[i*8 : (i+1)*8]) - } -} - -// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (t *Type12Dynamic) UnmarshalBytes(src []byte) { - t.X.UnmarshalBytes(src) - for i := t.X.SizeBytes(); i < len(src); i += 8 { - var x primitive.Int64 - x.UnmarshalBytes(src[i:]) - t.Y = append(t.Y, x) - } -} diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go index e1de12e25..93022f504 100644 --- a/tools/go_stateify/main.go +++ b/tools/go_stateify/main.go @@ -403,6 +403,7 @@ func main() { // on this specific behavior, but the ability to specify slots // allows a manual implementation to be order-dependent. if generateSaverLoader { + fmt.Fprintf(outputFile, "// +checklocksignore\n") fmt.Fprintf(outputFile, "func (%s *%s) StateSave(stateSinkObject %sSink) {\n", recv, ts.Name.Name, statePrefix) fmt.Fprintf(outputFile, " %s.beforeSave()\n", recv) scanFields(x, "", scanFunctions{zerovalue: emitZeroCheck}) @@ -425,6 +426,7 @@ func main() { // // N.B. See the comment above for the save method. if generateSaverLoader { + fmt.Fprintf(outputFile, "// +checklocksignore\n") fmt.Fprintf(outputFile, "func (%s *%s) StateLoad(stateSourceObject %sSource) {\n", recv, ts.Name.Name, statePrefix) scanFields(x, "", scanFunctions{normal: emitLoad, wait: emitLoadWait}) scanFields(x, "", scanFunctions{value: emitLoadValue}) |