diff options
Diffstat (limited to 'pkg')
76 files changed, 2654 insertions, 1600 deletions
diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go index 2235f8968..648cf4b49 100644 --- a/pkg/p9/p9.go +++ b/pkg/p9/p9.go @@ -151,9 +151,16 @@ const ( // Sticky is a mode bit indicating sticky directories. Sticky FileMode = 01000 + // SetGID is the set group ID bit. + SetGID FileMode = 02000 + + // SetUID is the set user ID bit. + SetUID FileMode = 04000 + // permissionsMask is the mask to apply to FileModes for permissions. It - // includes rwx bits for user, group and others, and sticky bit. - permissionsMask FileMode = 01777 + // includes rwx bits for user, group, and others, as well as the sticky + // bit, setuid bit, and setgid bit. + permissionsMask FileMode = 07777 ) // QIDType is the most significant byte of the FileMode word, to be used as the diff --git a/pkg/ring0/BUILD b/pkg/ring0/BUILD index d1b14efdb..885958456 100644 --- a/pkg/ring0/BUILD +++ b/pkg/ring0/BUILD @@ -80,6 +80,7 @@ go_library( "//pkg/ring0/pagetables", "//pkg/safecopy", "//pkg/sentry/arch", + "//pkg/sentry/arch/fpu", "//pkg/usermem", ], ) diff --git a/pkg/ring0/defs.go b/pkg/ring0/defs.go index e8ce608ba..b6e2012e8 100644 --- a/pkg/ring0/defs.go +++ b/pkg/ring0/defs.go @@ -17,6 +17,7 @@ package ring0 import ( "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" ) // Kernel is a global kernel object. @@ -96,7 +97,7 @@ type SwitchOpts struct { // FloatingPointState is a byte pointer where floating point state is // saved and restored. - FloatingPointState arch.FloatingPointData + FloatingPointState *fpu.State // PageTables are the application page tables. PageTables *pagetables.PageTables diff --git a/pkg/ring0/gen_offsets/BUILD b/pkg/ring0/gen_offsets/BUILD index 15b93d61c..f421e1687 100644 --- a/pkg/ring0/gen_offsets/BUILD +++ b/pkg/ring0/gen_offsets/BUILD @@ -35,6 +35,7 @@ go_binary( "//pkg/cpuid", "//pkg/ring0/pagetables", "//pkg/sentry/arch", + "//pkg/sentry/arch/fpu", "//pkg/usermem", ], ) diff --git a/pkg/ring0/kernel_amd64.go b/pkg/ring0/kernel_amd64.go index e9e706716..33c259757 100644 --- a/pkg/ring0/kernel_amd64.go +++ b/pkg/ring0/kernel_amd64.go @@ -239,17 +239,17 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { regs.Ss = uint64(Udata) // Ditto. // Perform the switch. - swapgs() // GS will be swapped on return. - WriteFS(uintptr(regs.Fs_base)) // escapes: no. Set application FS. - WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS. - LoadFloatingPoint(&switchOpts.FloatingPointState[0]) // escapes: no. Copy in floating point. + swapgs() // GS will be swapped on return. + WriteFS(uintptr(regs.Fs_base)) // escapes: no. Set application FS. + WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS. + LoadFloatingPoint(switchOpts.FloatingPointState.BytePointer()) // escapes: no. Copy in floating point. if switchOpts.FullRestore { vector = iret(c, regs, uintptr(userCR3)) } else { vector = sysret(c, regs, uintptr(userCR3)) } - SaveFloatingPoint(&switchOpts.FloatingPointState[0]) // escapes: no. Copy out floating point. - WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS. + SaveFloatingPoint(switchOpts.FloatingPointState.BytePointer()) // escapes: no. Copy out floating point. + WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS. return } diff --git a/pkg/ring0/kernel_arm64.go b/pkg/ring0/kernel_arm64.go index c9a120952..7975e5f92 100644 --- a/pkg/ring0/kernel_arm64.go +++ b/pkg/ring0/kernel_arm64.go @@ -62,7 +62,7 @@ func IsCanonical(addr uint64) bool { //go:nosplit func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { storeAppASID(uintptr(switchOpts.UserASID)) - storeEl0Fpstate(&switchOpts.FloatingPointState[0]) + storeEl0Fpstate(switchOpts.FloatingPointState.BytePointer()) if switchOpts.Flush { FlushTlbByASID(uintptr(switchOpts.UserASID)) @@ -82,7 +82,7 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { fpDisableTrap = CPACREL1() if fpDisableTrap != 0 { - SaveFloatingPoint(&switchOpts.FloatingPointState[0]) + SaveFloatingPoint(switchOpts.FloatingPointState.BytePointer()) } vector = c.vecCode diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD index 85278b389..f660f1614 100644 --- a/pkg/sentry/arch/BUILD +++ b/pkg/sentry/arch/BUILD @@ -9,7 +9,6 @@ go_library( "arch.go", "arch_aarch64.go", "arch_amd64.go", - "arch_amd64.s", "arch_arm64.go", "arch_state_x86.go", "arch_x86.go", @@ -36,8 +35,8 @@ go_library( "//pkg/log", "//pkg/marshal", "//pkg/marshal/primitive", + "//pkg/sentry/arch/fpu", "//pkg/sentry/limits", - "//pkg/sync", "//pkg/syserror", "//pkg/usermem", "@org_golang_x_sys//unix:go_default_library", diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go index 3443b9e1b..921151137 100644 --- a/pkg/sentry/arch/arch.go +++ b/pkg/sentry/arch/arch.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/usermem" ) @@ -50,12 +51,6 @@ func (a Arch) String() string { } } -// FloatingPointData is a generic type, and will always be passed as a pointer. -// We rely on the individual arch implementations to meet all the necessary -// requirements. For example, on x86 the region must be 16-byte aligned and 512 -// bytes in size. -type FloatingPointData []byte - // Context provides architecture-dependent information for a specific thread. // // NOTE(b/34169503): Currently we use uintptr here to refer to a generic native @@ -187,7 +182,7 @@ type Context interface { ClearSingleStep() // FloatingPointData will be passed to underlying save routines. - FloatingPointData() FloatingPointData + FloatingPointData() *fpu.State // NewMmapLayout returns a layout for a new MM, where MinAddr for the // returned layout must be no lower than min, and MaxAddr for the returned @@ -221,16 +216,6 @@ type Context interface { // number of bytes read. PtraceSetRegs(src io.Reader) (int, error) - // PtraceGetFPRegs implements ptrace(PTRACE_GETFPREGS) by writing the - // floating-point registers represented by this Context to addr in dst and - // returning the number of bytes written. - PtraceGetFPRegs(dst io.Writer) (int, error) - - // PtraceSetFPRegs implements ptrace(PTRACE_SETFPREGS) by reading - // floating-point registers from src into this Context and returning the - // number of bytes read. - PtraceSetFPRegs(src io.Reader) (int, error) - // PtraceGetRegSet implements ptrace(PTRACE_GETREGSET) by writing the // register set given by architecture-defined value regset from this // Context to dst and returning the number of bytes written, which must be @@ -365,18 +350,3 @@ func (a SyscallArgument) SizeT() uint { func (a SyscallArgument) ModeT() uint { return uint(uint16(a.Value)) } - -// ErrFloatingPoint indicates a failed restore due to unusable floating point -// state. -type ErrFloatingPoint struct { - // supported is the supported floating point state. - supported uint64 - - // saved is the saved floating point state. - saved uint64 -} - -// Error returns a sensible description of the restore error. -func (e ErrFloatingPoint) Error() string { - return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supported, e.saved) -} diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go index 6b81e9708..08789f517 100644 --- a/pkg/sentry/arch/arch_aarch64.go +++ b/pkg/sentry/arch/arch_aarch64.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" "gvisor.dev/gvisor/pkg/syserror" ) @@ -40,65 +41,11 @@ type Registers struct { const ( // SyscallWidth is the width of insturctions. SyscallWidth = 4 - - // fpsimdMagic is the magic number which is used in fpsimd_context. - fpsimdMagic = 0x46508001 - - // fpsimdContextSize is the size of fpsimd_context. - fpsimdContextSize = 0x210 ) // ARMTrapFlag is the mask for the trap flag. const ARMTrapFlag = uint64(1) << 21 -// aarch64FPState is aarch64 floating point state. -type aarch64FPState []byte - -// initAarch64FPState sets up initial state. -// -// Related code in Linux kernel: fpsimd_flush_thread(). -// FPCR = FPCR_RM_RN (0x0 << 22). -// -// Currently, aarch64FPState is only a space of 0x210 length for fpstate. -// The fp head is useless in sentry/ptrace/kvm. -// -func initAarch64FPState(data aarch64FPState) { -} - -func newAarch64FPStateSlice() []byte { - return alignedBytes(4096, 16)[:fpsimdContextSize] -} - -// newAarch64FPState returns an initialized floating point state. -// -// The returned state is large enough to store all floating point state -// supported by host, even if the app won't use much of it due to a restricted -// FeatureSet. -func newAarch64FPState() aarch64FPState { - f := aarch64FPState(newAarch64FPStateSlice()) - initAarch64FPState(f) - return f -} - -// fork creates and returns an identical copy of the aarch64 floating point state. -func (f aarch64FPState) fork() aarch64FPState { - n := aarch64FPState(newAarch64FPStateSlice()) - copy(n, f) - return n -} - -// FloatingPointData returns the raw data pointer. -func (f aarch64FPState) FloatingPointData() FloatingPointData { - return ([]byte)(f) -} - -// NewFloatingPointData returns a new floating point data blob. -// -// This is primarily for use in tests. -func NewFloatingPointData() FloatingPointData { - return ([]byte)(newAarch64FPState()) -} - // State contains the common architecture bits for aarch64 (the build tag of this // file ensures it's only built on aarch64). // @@ -108,7 +55,7 @@ type State struct { Regs Registers // Our floating point state. - aarch64FPState `state:"wait"` + fpState fpu.State `state:"wait"` // FeatureSet is a pointer to the currently active feature set. FeatureSet *cpuid.FeatureSet @@ -162,10 +109,10 @@ func (s State) Proto() *rpb.Registers { // Fork creates and returns an identical copy of the state. func (s *State) Fork() State { return State{ - Regs: s.Regs, - aarch64FPState: s.aarch64FPState.fork(), - FeatureSet: s.FeatureSet, - OrigR0: s.OrigR0, + Regs: s.Regs, + fpState: s.fpState.Fork(), + FeatureSet: s.FeatureSet, + OrigR0: s.OrigR0, } } @@ -318,10 +265,10 @@ func New(arch Arch, fs *cpuid.FeatureSet) Context { case ARM64: return &context64{ State{ - aarch64FPState: newAarch64FPState(), - FeatureSet: fs, + fpState: fpu.NewState(), + FeatureSet: fs, }, - []aarch64FPState(nil), + []fpu.State(nil), } } panic(fmt.Sprintf("unknown architecture %v", arch)) diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go index 15d8ddb40..2571be60f 100644 --- a/pkg/sentry/arch/arch_amd64.go +++ b/pkg/sentry/arch/arch_amd64.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/usermem" ) @@ -105,7 +106,7 @@ const ( // +stateify savable type context64 struct { State - sigFPState []x86FPState // fpstate to be restored on sigreturn. + sigFPState []fpu.State // fpstate to be restored on sigreturn. } // Arch implements Context.Arch. @@ -113,14 +114,18 @@ func (c *context64) Arch() Arch { return AMD64 } -func (c *context64) copySigFPState() []x86FPState { - var sigfps []x86FPState +func (c *context64) copySigFPState() []fpu.State { + var sigfps []fpu.State for _, s := range c.sigFPState { - sigfps = append(sigfps, s.fork()) + sigfps = append(sigfps, s.Fork()) } return sigfps } +func (c *context64) FloatingPointData() *fpu.State { + return &c.State.fpState +} + // Fork returns an exact copy of this context. func (c *context64) Fork() Context { return &context64{ diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go index 0c61a3ff7..14ad9483b 100644 --- a/pkg/sentry/arch/arch_arm64.go +++ b/pkg/sentry/arch/arch_arm64.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/usermem" ) @@ -79,7 +80,7 @@ const ( // +stateify savable type context64 struct { State - sigFPState []aarch64FPState // fpstate to be restored on sigreturn. + sigFPState []fpu.State // fpstate to be restored on sigreturn. } // Arch implements Context.Arch. @@ -87,10 +88,10 @@ func (c *context64) Arch() Arch { return ARM64 } -func (c *context64) copySigFPState() []aarch64FPState { - var sigfps []aarch64FPState +func (c *context64) copySigFPState() []fpu.State { + var sigfps []fpu.State for _, s := range c.sigFPState { - sigfps = append(sigfps, s.fork()) + sigfps = append(sigfps, s.Fork()) } return sigfps } @@ -286,3 +287,7 @@ func (c *context64) PtracePokeUser(addr, data uintptr) error { // TODO(gvisor.dev/issue/1239): Full ptrace supporting for Arm64. return nil } + +func (c *context64) FloatingPointData() *fpu.State { + return &c.State.fpState +} diff --git a/pkg/sentry/arch/arch_state_x86.go b/pkg/sentry/arch/arch_state_x86.go index 840e53d33..b2b94c304 100644 --- a/pkg/sentry/arch/arch_state_x86.go +++ b/pkg/sentry/arch/arch_state_x86.go @@ -16,59 +16,7 @@ package arch -import ( - "gvisor.dev/gvisor/pkg/cpuid" - "gvisor.dev/gvisor/pkg/usermem" -) - -// XSTATE_BV does not exist if FXSAVE is used, but FXSAVE implicitly saves x87 -// and SSE state, so this is the equivalent XSTATE_BV value. -const fxsaveBV uint64 = cpuid.XSAVEFeatureX87 | cpuid.XSAVEFeatureSSE - // afterLoadFPState is invoked by afterLoad. func (s *State) afterLoadFPState() { - old := s.x86FPState - - // Recreate the slice. This is done to ensure that it is aligned - // appropriately in memory, and large enough to accommodate any new - // state that may be saved by the new CPU. Even if extraneous new state - // is saved, the state we care about is guaranteed to be a subset of - // new state. Later optimizations can use less space when using a - // smaller state component bitmap. Intel SDM Volume 1 Chapter 13 has - // more info. - s.x86FPState = newX86FPState() - - // x86FPState always contains all the FP state supported by the host. - // We may have come from a newer machine that supports additional state - // which we cannot restore. - // - // The x86 FP state areas are backwards compatible, so we can simply - // truncate the additional floating point state. - // - // Applications should not depend on the truncated state because it - // should relate only to features that were not exposed in the app - // FeatureSet. However, because we do not *prevent* them from using - // this state, we must verify here that there is no in-use state - // (according to XSTATE_BV) which we do not support. - if len(s.x86FPState) < len(old) { - // What do we support? - supportedBV := fxsaveBV - if fs := cpuid.HostFeatureSet(); fs.UseXsave() { - supportedBV = fs.ValidXCR0Mask() - } - - // What was in use? - savedBV := fxsaveBV - if len(old) >= xstateBVOffset+8 { - savedBV = usermem.ByteOrder.Uint64(old[xstateBVOffset:]) - } - - // Supported features must be a superset of saved features. - if savedBV&^supportedBV != 0 { - panic(ErrFloatingPoint{supported: supportedBV, saved: savedBV}) - } - } - - // Copy to the new, aligned location. - copy(s.x86FPState, old) + s.fpState.AfterLoad() } diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go index 91edf0703..e8e52d3a8 100644 --- a/pkg/sentry/arch/arch_x86.go +++ b/pkg/sentry/arch/arch_x86.go @@ -24,10 +24,9 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" - "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // Registers represents the CPU registers for this architecture. @@ -111,57 +110,6 @@ var ( X86TrapFlag uint64 = (1 << 8) ) -// x86FPState is x86 floating point state. -type x86FPState []byte - -// initX86FPState (defined in asm files) sets up initial state. -func initX86FPState(data *byte, useXsave bool) - -func newX86FPStateSlice() []byte { - size, align := cpuid.HostFeatureSet().ExtendedStateSize() - capacity := size - // Always use at least 4096 bytes. - // - // For the KVM platform, this state is a fixed 4096 bytes, so make sure - // that the underlying array is at _least_ that size otherwise we will - // corrupt random memory. This is not a pleasant thing to debug. - if capacity < 4096 { - capacity = 4096 - } - return alignedBytes(capacity, align)[:size] -} - -// newX86FPState returns an initialized floating point state. -// -// The returned state is large enough to store all floating point state -// supported by host, even if the app won't use much of it due to a restricted -// FeatureSet. Since they may still be able to see state not advertised by -// CPUID we must ensure it does not contain any sentry state. -func newX86FPState() x86FPState { - f := x86FPState(newX86FPStateSlice()) - initX86FPState(&f.FloatingPointData()[0], cpuid.HostFeatureSet().UseXsave()) - return f -} - -// fork creates and returns an identical copy of the x86 floating point state. -func (f x86FPState) fork() x86FPState { - n := x86FPState(newX86FPStateSlice()) - copy(n, f) - return n -} - -// FloatingPointData returns the raw data pointer. -func (f x86FPState) FloatingPointData() FloatingPointData { - return []byte(f) -} - -// NewFloatingPointData returns a new floating point data blob. -// -// This is primarily for use in tests. -func NewFloatingPointData() FloatingPointData { - return (FloatingPointData)(newX86FPState()) -} - // Proto returns a protobuf representation of the system registers in State. func (s State) Proto() *rpb.Registers { regs := &rpb.AMD64Registers{ @@ -200,7 +148,7 @@ func (s State) Proto() *rpb.Registers { func (s *State) Fork() State { return State{ Regs: s.Regs, - x86FPState: s.x86FPState.fork(), + fpState: s.fpState.Fork(), FeatureSet: s.FeatureSet, } } @@ -393,149 +341,6 @@ func isValidSegmentBase(reg uint64) bool { return reg < uint64(maxAddr64) } -// ptraceFPRegsSize is the size in bytes of Linux's user_i387_struct, the type -// manipulated by PTRACE_GETFPREGS and PTRACE_SETFPREGS on x86. Equivalently, -// ptraceFPRegsSize is the size in bytes of the x86 FXSAVE area. -const ptraceFPRegsSize = 512 - -// PtraceGetFPRegs implements Context.PtraceGetFPRegs. -func (s *State) PtraceGetFPRegs(dst io.Writer) (int, error) { - return dst.Write(s.x86FPState[:ptraceFPRegsSize]) -} - -// PtraceSetFPRegs implements Context.PtraceSetFPRegs. -func (s *State) PtraceSetFPRegs(src io.Reader) (int, error) { - var f [ptraceFPRegsSize]byte - n, err := io.ReadFull(src, f[:]) - if err != nil { - return 0, err - } - // Force reserved bits in MXCSR to 0. This is consistent with Linux. - sanitizeMXCSR(x86FPState(f[:])) - // N.B. this only copies the beginning of the FP state, which - // corresponds to the FXSAVE area. - copy(s.x86FPState, f[:]) - return n, nil -} - -const ( - // mxcsrOffset is the offset in bytes of the MXCSR field from the start of - // the FXSAVE area. (Intel SDM Vol. 1, Table 10-2 "Format of an FXSAVE - // Area") - mxcsrOffset = 24 - - // mxcsrMaskOffset is the offset in bytes of the MXCSR_MASK field from the - // start of the FXSAVE area. - mxcsrMaskOffset = 28 -) - -var ( - mxcsrMask uint32 - initMXCSRMask sync.Once -) - -// sanitizeMXCSR coerces reserved bits in the MXCSR field of f to 0. ("FXRSTOR -// generates a general-protection fault (#GP) in response to an attempt to set -// any of the reserved bits of the MXCSR register." - Intel SDM Vol. 1, Section -// 10.5.1.2 "SSE State") -func sanitizeMXCSR(f x86FPState) { - mxcsr := usermem.ByteOrder.Uint32(f[mxcsrOffset:]) - initMXCSRMask.Do(func() { - temp := x86FPState(alignedBytes(uint(ptraceFPRegsSize), 16)) - initX86FPState(&temp.FloatingPointData()[0], false /* useXsave */) - mxcsrMask = usermem.ByteOrder.Uint32(temp[mxcsrMaskOffset:]) - if mxcsrMask == 0 { - // "If the value of the MXCSR_MASK field is 00000000H, then the - // MXCSR_MASK value is the default value of 0000FFBFH." - Intel SDM - // Vol. 1, Section 11.6.6 "Guidelines for Writing to the MXCSR - // Register" - mxcsrMask = 0xffbf - } - }) - mxcsr &= mxcsrMask - usermem.ByteOrder.PutUint32(f[mxcsrOffset:], mxcsr) -} - -const ( - // minXstateBytes is the minimum size in bytes of an x86 XSAVE area, equal - // to the size of the XSAVE legacy area (512 bytes) plus the size of the - // XSAVE header (64 bytes). Equivalently, minXstateBytes is GDB's - // X86_XSTATE_SSE_SIZE. - minXstateBytes = 512 + 64 - - // userXstateXCR0Offset is the offset in bytes of the USER_XSTATE_XCR0_WORD - // field in Linux's struct user_xstateregs, which is the type manipulated - // by ptrace(PTRACE_GET/SETREGSET, NT_X86_XSTATE). Equivalently, - // userXstateXCR0Offset is GDB's I386_LINUX_XSAVE_XCR0_OFFSET. - userXstateXCR0Offset = 464 - - // xstateBVOffset is the offset in bytes of the XSTATE_BV field in an x86 - // XSAVE area. - xstateBVOffset = 512 - - // xsaveHeaderZeroedOffset and xsaveHeaderZeroedBytes indicate parts of the - // XSAVE header that we coerce to zero: "Bytes 15:8 of the XSAVE header is - // a state-component bitmap called XCOMP_BV. ... Bytes 63:16 of the XSAVE - // header are reserved." - Intel SDM Vol. 1, Section 13.4.2 "XSAVE Header". - // Linux ignores XCOMP_BV, but it's able to recover from XRSTOR #GP - // exceptions resulting from invalid values; we aren't. Linux also never - // uses the compacted format when doing XSAVE and doesn't even define the - // compaction extensions to XSAVE as a CPU feature, so for simplicity we - // assume no one is using them. - xsaveHeaderZeroedOffset = 512 + 8 - xsaveHeaderZeroedBytes = 64 - 8 -) - -func (s *State) ptraceGetXstateRegs(dst io.Writer, maxlen int) (int, error) { - // N.B. s.x86FPState may contain more state than the application - // expects. We only copy the subset that would be in their XSAVE area. - ess, _ := s.FeatureSet.ExtendedStateSize() - f := make([]byte, ess) - copy(f, s.x86FPState) - // "The XSAVE feature set does not use bytes 511:416; bytes 463:416 are - // reserved." - Intel SDM Vol 1., Section 13.4.1 "Legacy Region of an XSAVE - // Area". Linux uses the first 8 bytes of this area to store the OS XSTATE - // mask. GDB relies on this: see - // gdb/x86-linux-nat.c:x86_linux_read_description(). - usermem.ByteOrder.PutUint64(f[userXstateXCR0Offset:], s.FeatureSet.ValidXCR0Mask()) - if len(f) > maxlen { - f = f[:maxlen] - } - return dst.Write(f) -} - -func (s *State) ptraceSetXstateRegs(src io.Reader, maxlen int) (int, error) { - // Allow users to pass an xstate register set smaller than ours (they can - // mask bits out of XSTATE_BV), as long as it's at least minXstateBytes. - // Also allow users to pass a register set larger than ours; anything after - // their ExtendedStateSize will be ignored. (I think Linux technically - // permits setting a register set smaller than minXstateBytes, but it has - // the same silent truncation behavior in kernel/ptrace.c:ptrace_regset().) - if maxlen < minXstateBytes { - return 0, unix.EFAULT - } - ess, _ := s.FeatureSet.ExtendedStateSize() - if maxlen > int(ess) { - maxlen = int(ess) - } - f := make([]byte, maxlen) - if _, err := io.ReadFull(src, f); err != nil { - return 0, err - } - // Force reserved bits in MXCSR to 0. This is consistent with Linux. - sanitizeMXCSR(x86FPState(f)) - // Users can't enable *more* XCR0 bits than what we, and the CPU, support. - xstateBV := usermem.ByteOrder.Uint64(f[xstateBVOffset:]) - xstateBV &= s.FeatureSet.ValidXCR0Mask() - usermem.ByteOrder.PutUint64(f[xstateBVOffset:], xstateBV) - // Force XCOMP_BV and reserved bytes in the XSAVE header to 0. - reserved := f[xsaveHeaderZeroedOffset : xsaveHeaderZeroedOffset+xsaveHeaderZeroedBytes] - for i := range reserved { - reserved[i] = 0 - } - return copy(s.x86FPState, f), nil -} - // Register sets defined in include/uapi/linux/elf.h. const ( _NT_PRSTATUS = 1 @@ -552,12 +357,9 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, } return s.PtraceGetRegs(dst) case _NT_PRFPREG: - if maxlen < ptraceFPRegsSize { - return 0, syserror.EFAULT - } - return s.PtraceGetFPRegs(dst) + return s.fpState.PtraceGetFPRegs(dst, maxlen) case _NT_X86_XSTATE: - return s.ptraceGetXstateRegs(dst, maxlen) + return s.fpState.PtraceGetXstateRegs(dst, maxlen, s.FeatureSet) default: return 0, syserror.EINVAL } @@ -572,12 +374,9 @@ func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, } return s.PtraceSetRegs(src) case _NT_PRFPREG: - if maxlen < ptraceFPRegsSize { - return 0, syserror.EFAULT - } - return s.PtraceSetFPRegs(src) + return s.fpState.PtraceSetFPRegs(src, maxlen) case _NT_X86_XSTATE: - return s.ptraceSetXstateRegs(src, maxlen) + return s.fpState.PtraceSetXstateRegs(src, maxlen, s.FeatureSet) default: return 0, syserror.EINVAL } @@ -609,10 +408,10 @@ func New(arch Arch, fs *cpuid.FeatureSet) Context { case AMD64: return &context64{ State{ - x86FPState: newX86FPState(), + fpState: fpu.NewState(), FeatureSet: fs, }, - []x86FPState(nil), + []fpu.State(nil), } } panic(fmt.Sprintf("unknown architecture %v", arch)) diff --git a/pkg/sentry/arch/arch_x86_impl.go b/pkg/sentry/arch/arch_x86_impl.go index 0c73fcbfb..5d7b99bd9 100644 --- a/pkg/sentry/arch/arch_x86_impl.go +++ b/pkg/sentry/arch/arch_x86_impl.go @@ -18,6 +18,7 @@ package arch import ( "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" ) // State contains the common architecture bits for X86 (the build tag of this @@ -29,7 +30,7 @@ type State struct { Regs Registers // Our floating point state. - x86FPState `state:"wait"` + fpState fpu.State `state:"wait"` // FeatureSet is a pointer to the currently active feature set. FeatureSet *cpuid.FeatureSet diff --git a/pkg/sentry/arch/fpu/BUILD b/pkg/sentry/arch/fpu/BUILD new file mode 100644 index 000000000..0a5395267 --- /dev/null +++ b/pkg/sentry/arch/fpu/BUILD @@ -0,0 +1,21 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "fpu", + srcs = [ + "fpu.go", + "fpu_amd64.go", + "fpu_amd64.s", + "fpu_arm64.go", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/cpuid", + "//pkg/sync", + "//pkg/syserror", + "//pkg/usermem", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/sentry/arch/fpu/fpu.go b/pkg/sentry/arch/fpu/fpu.go new file mode 100644 index 000000000..867d309a3 --- /dev/null +++ b/pkg/sentry/arch/fpu/fpu.go @@ -0,0 +1,54 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package fpu provides basic floating point helpers. +package fpu + +import ( + "fmt" + "reflect" +) + +// State represents floating point state. +// +// This is a simple byte slice, but may have architecture-specific methods +// attached to it. +type State []byte + +// ErrLoadingState indicates a failed restore due to unusable floating point +// state. +type ErrLoadingState struct { + // supported is the supported floating point state. + supportedFeatures uint64 + + // saved is the saved floating point state. + savedFeatures uint64 +} + +// Error returns a sensible description of the restore error. +func (e ErrLoadingState) Error() string { + return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supportedFeatures, e.savedFeatures) +} + +// alignedBytes returns a slice of size bytes, aligned in memory to the given +// alignment. This is used because we require certain structures to be aligned +// in a specific way (for example, the X86 floating point data). +func alignedBytes(size, alignment uint) []byte { + data := make([]byte, size+alignment-1) + offset := uint(reflect.ValueOf(data).Index(0).Addr().Pointer() % uintptr(alignment)) + if offset == 0 { + return data[:size:size] + } + return data[alignment-offset:][:size:size] +} diff --git a/pkg/sentry/arch/fpu/fpu_amd64.go b/pkg/sentry/arch/fpu/fpu_amd64.go new file mode 100644 index 000000000..3a62f51be --- /dev/null +++ b/pkg/sentry/arch/fpu/fpu_amd64.go @@ -0,0 +1,280 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 i386 + +package fpu + +import ( + "io" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// initX86FPState (defined in asm files) sets up initial state. +func initX86FPState(data *byte, useXsave bool) + +func newX86FPStateSlice() State { + size, align := cpuid.HostFeatureSet().ExtendedStateSize() + capacity := size + // Always use at least 4096 bytes. + // + // For the KVM platform, this state is a fixed 4096 bytes, so make sure + // that the underlying array is at _least_ that size otherwise we will + // corrupt random memory. This is not a pleasant thing to debug. + if capacity < 4096 { + capacity = 4096 + } + return alignedBytes(capacity, align)[:size] +} + +// NewState returns an initialized floating point state. +// +// The returned state is large enough to store all floating point state +// supported by host, even if the app won't use much of it due to a restricted +// FeatureSet. Since they may still be able to see state not advertised by +// CPUID we must ensure it does not contain any sentry state. +func NewState() State { + f := newX86FPStateSlice() + initX86FPState(&f[0], cpuid.HostFeatureSet().UseXsave()) + return f +} + +// Fork creates and returns an identical copy of the x86 floating point state. +func (s *State) Fork() State { + n := newX86FPStateSlice() + copy(n, *s) + return n +} + +// ptraceFPRegsSize is the size in bytes of Linux's user_i387_struct, the type +// manipulated by PTRACE_GETFPREGS and PTRACE_SETFPREGS on x86. Equivalently, +// ptraceFPRegsSize is the size in bytes of the x86 FXSAVE area. +const ptraceFPRegsSize = 512 + +// PtraceGetFPRegs implements Context.PtraceGetFPRegs. +func (s *State) PtraceGetFPRegs(dst io.Writer, maxlen int) (int, error) { + if maxlen < ptraceFPRegsSize { + return 0, syserror.EFAULT + } + + return dst.Write((*s)[:ptraceFPRegsSize]) +} + +// PtraceSetFPRegs implements Context.PtraceSetFPRegs. +func (s *State) PtraceSetFPRegs(src io.Reader, maxlen int) (int, error) { + if maxlen < ptraceFPRegsSize { + return 0, syserror.EFAULT + } + + var f [ptraceFPRegsSize]byte + n, err := io.ReadFull(src, f[:]) + if err != nil { + return 0, err + } + // Force reserved bits in MXCSR to 0. This is consistent with Linux. + sanitizeMXCSR(State(f[:])) + // N.B. this only copies the beginning of the FP state, which + // corresponds to the FXSAVE area. + copy(*s, f[:]) + return n, nil +} + +const ( + // mxcsrOffset is the offset in bytes of the MXCSR field from the start of + // the FXSAVE area. (Intel SDM Vol. 1, Table 10-2 "Format of an FXSAVE + // Area") + mxcsrOffset = 24 + + // mxcsrMaskOffset is the offset in bytes of the MXCSR_MASK field from the + // start of the FXSAVE area. + mxcsrMaskOffset = 28 +) + +var ( + mxcsrMask uint32 + initMXCSRMask sync.Once +) + +const ( + // minXstateBytes is the minimum size in bytes of an x86 XSAVE area, equal + // to the size of the XSAVE legacy area (512 bytes) plus the size of the + // XSAVE header (64 bytes). Equivalently, minXstateBytes is GDB's + // X86_XSTATE_SSE_SIZE. + minXstateBytes = 512 + 64 + + // userXstateXCR0Offset is the offset in bytes of the USER_XSTATE_XCR0_WORD + // field in Linux's struct user_xstateregs, which is the type manipulated + // by ptrace(PTRACE_GET/SETREGSET, NT_X86_XSTATE). Equivalently, + // userXstateXCR0Offset is GDB's I386_LINUX_XSAVE_XCR0_OFFSET. + userXstateXCR0Offset = 464 + + // xstateBVOffset is the offset in bytes of the XSTATE_BV field in an x86 + // XSAVE area. + xstateBVOffset = 512 + + // xsaveHeaderZeroedOffset and xsaveHeaderZeroedBytes indicate parts of the + // XSAVE header that we coerce to zero: "Bytes 15:8 of the XSAVE header is + // a state-component bitmap called XCOMP_BV. ... Bytes 63:16 of the XSAVE + // header are reserved." - Intel SDM Vol. 1, Section 13.4.2 "XSAVE Header". + // Linux ignores XCOMP_BV, but it's able to recover from XRSTOR #GP + // exceptions resulting from invalid values; we aren't. Linux also never + // uses the compacted format when doing XSAVE and doesn't even define the + // compaction extensions to XSAVE as a CPU feature, so for simplicity we + // assume no one is using them. + xsaveHeaderZeroedOffset = 512 + 8 + xsaveHeaderZeroedBytes = 64 - 8 +) + +// sanitizeMXCSR coerces reserved bits in the MXCSR field of f to 0. ("FXRSTOR +// generates a general-protection fault (#GP) in response to an attempt to set +// any of the reserved bits of the MXCSR register." - Intel SDM Vol. 1, Section +// 10.5.1.2 "SSE State") +func sanitizeMXCSR(f State) { + mxcsr := usermem.ByteOrder.Uint32(f[mxcsrOffset:]) + initMXCSRMask.Do(func() { + temp := State(alignedBytes(uint(ptraceFPRegsSize), 16)) + initX86FPState(&temp[0], false /* useXsave */) + mxcsrMask = usermem.ByteOrder.Uint32(temp[mxcsrMaskOffset:]) + if mxcsrMask == 0 { + // "If the value of the MXCSR_MASK field is 00000000H, then the + // MXCSR_MASK value is the default value of 0000FFBFH." - Intel SDM + // Vol. 1, Section 11.6.6 "Guidelines for Writing to the MXCSR + // Register" + mxcsrMask = 0xffbf + } + }) + mxcsr &= mxcsrMask + usermem.ByteOrder.PutUint32(f[mxcsrOffset:], mxcsr) +} + +// PtraceGetXstateRegs implements ptrace(PTRACE_GETREGS, NT_X86_XSTATE) by +// writing the floating point registers from this state to dst and returning the +// number of bytes written, which must be less than or equal to maxlen. +func (s *State) PtraceGetXstateRegs(dst io.Writer, maxlen int, featureSet *cpuid.FeatureSet) (int, error) { + // N.B. s.x86FPState may contain more state than the application + // expects. We only copy the subset that would be in their XSAVE area. + ess, _ := featureSet.ExtendedStateSize() + f := make([]byte, ess) + copy(f, *s) + // "The XSAVE feature set does not use bytes 511:416; bytes 463:416 are + // reserved." - Intel SDM Vol 1., Section 13.4.1 "Legacy Region of an XSAVE + // Area". Linux uses the first 8 bytes of this area to store the OS XSTATE + // mask. GDB relies on this: see + // gdb/x86-linux-nat.c:x86_linux_read_description(). + usermem.ByteOrder.PutUint64(f[userXstateXCR0Offset:], featureSet.ValidXCR0Mask()) + if len(f) > maxlen { + f = f[:maxlen] + } + return dst.Write(f) +} + +// PtraceSetXstateRegs implements ptrace(PTRACE_SETREGS, NT_X86_XSTATE) by +// reading floating point registers from src and returning the number of bytes +// read, which must be less than or equal to maxlen. +func (s *State) PtraceSetXstateRegs(src io.Reader, maxlen int, featureSet *cpuid.FeatureSet) (int, error) { + // Allow users to pass an xstate register set smaller than ours (they can + // mask bits out of XSTATE_BV), as long as it's at least minXstateBytes. + // Also allow users to pass a register set larger than ours; anything after + // their ExtendedStateSize will be ignored. (I think Linux technically + // permits setting a register set smaller than minXstateBytes, but it has + // the same silent truncation behavior in kernel/ptrace.c:ptrace_regset().) + if maxlen < minXstateBytes { + return 0, unix.EFAULT + } + ess, _ := featureSet.ExtendedStateSize() + if maxlen > int(ess) { + maxlen = int(ess) + } + f := make([]byte, maxlen) + if _, err := io.ReadFull(src, f); err != nil { + return 0, err + } + // Force reserved bits in MXCSR to 0. This is consistent with Linux. + sanitizeMXCSR(State(f)) + // Users can't enable *more* XCR0 bits than what we, and the CPU, support. + xstateBV := usermem.ByteOrder.Uint64(f[xstateBVOffset:]) + xstateBV &= featureSet.ValidXCR0Mask() + usermem.ByteOrder.PutUint64(f[xstateBVOffset:], xstateBV) + // Force XCOMP_BV and reserved bytes in the XSAVE header to 0. + reserved := f[xsaveHeaderZeroedOffset : xsaveHeaderZeroedOffset+xsaveHeaderZeroedBytes] + for i := range reserved { + reserved[i] = 0 + } + return copy(*s, f), nil +} + +// BytePointer returns a pointer to the first byte of the state. +// +//go:nosplit +func (s *State) BytePointer() *byte { + return &(*s)[0] +} + +// XSTATE_BV does not exist if FXSAVE is used, but FXSAVE implicitly saves x87 +// and SSE state, so this is the equivalent XSTATE_BV value. +const fxsaveBV uint64 = cpuid.XSAVEFeatureX87 | cpuid.XSAVEFeatureSSE + +// AfterLoad converts the loaded state to the format that compatible with the +// current processor. +func (s *State) AfterLoad() { + old := *s + + // Recreate the slice. This is done to ensure that it is aligned + // appropriately in memory, and large enough to accommodate any new + // state that may be saved by the new CPU. Even if extraneous new state + // is saved, the state we care about is guaranteed to be a subset of + // new state. Later optimizations can use less space when using a + // smaller state component bitmap. Intel SDM Volume 1 Chapter 13 has + // more info. + *s = NewState() + + // x86FPState always contains all the FP state supported by the host. + // We may have come from a newer machine that supports additional state + // which we cannot restore. + // + // The x86 FP state areas are backwards compatible, so we can simply + // truncate the additional floating point state. + // + // Applications should not depend on the truncated state because it + // should relate only to features that were not exposed in the app + // FeatureSet. However, because we do not *prevent* them from using + // this state, we must verify here that there is no in-use state + // (according to XSTATE_BV) which we do not support. + if len(*s) < len(old) { + // What do we support? + supportedBV := fxsaveBV + if fs := cpuid.HostFeatureSet(); fs.UseXsave() { + supportedBV = fs.ValidXCR0Mask() + } + + // What was in use? + savedBV := fxsaveBV + if len(old) >= xstateBVOffset+8 { + savedBV = usermem.ByteOrder.Uint64(old[xstateBVOffset:]) + } + + // Supported features must be a superset of saved features. + if savedBV&^supportedBV != 0 { + panic(ErrLoadingState{supportedFeatures: supportedBV, savedFeatures: savedBV}) + } + } + + // Copy to the new, aligned location. + copy(*s, old) +} diff --git a/pkg/sentry/arch/arch_amd64.s b/pkg/sentry/arch/fpu/fpu_amd64.s index 6c10336e7..6c10336e7 100644 --- a/pkg/sentry/arch/arch_amd64.s +++ b/pkg/sentry/arch/fpu/fpu_amd64.s diff --git a/pkg/sentry/arch/fpu/fpu_arm64.go b/pkg/sentry/arch/fpu/fpu_arm64.go new file mode 100644 index 000000000..d2f62631d --- /dev/null +++ b/pkg/sentry/arch/fpu/fpu_arm64.go @@ -0,0 +1,63 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package fpu + +const ( + // fpsimdMagic is the magic number which is used in fpsimd_context. + fpsimdMagic = 0x46508001 + + // fpsimdContextSize is the size of fpsimd_context. + fpsimdContextSize = 0x210 +) + +// initAarch64FPState sets up initial state. +// +// Related code in Linux kernel: fpsimd_flush_thread(). +// FPCR = FPCR_RM_RN (0x0 << 22). +// +// Currently, aarch64FPState is only a space of 0x210 length for fpstate. +// The fp head is useless in sentry/ptrace/kvm. +// +func initAarch64FPState(data *State) { +} + +func newAarch64FPStateSlice() []byte { + return alignedBytes(4096, 16)[:fpsimdContextSize] +} + +// NewState returns an initialized floating point state. +// +// The returned state is large enough to store all floating point state +// supported by host, even if the app won't use much of it due to a restricted +// FeatureSet. +func NewState() State { + f := State(newAarch64FPStateSlice()) + initAarch64FPState(&f) + return f +} + +// Fork creates and returns an identical copy of the aarch64 floating point state. +func (s *State) Fork() State { + n := State(newAarch64FPStateSlice()) + copy(n, *s) + return n +} + +// BytePointer returns a pointer to the first byte of the state. +func (s *State) BytePointer() *byte { + return &(*s)[0] +} diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go index e6557cab6..ee3743483 100644 --- a/pkg/sentry/arch/signal_amd64.go +++ b/pkg/sentry/arch/signal_amd64.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/marshal/primitive" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/usermem" ) @@ -98,7 +99,7 @@ func (c *context64) NewSignalStack() NativeSignalStack { const _FP_XSTATE_MAGIC2_SIZE = 4 func (c *context64) fpuFrameSize() (size int, useXsave bool) { - size = len(c.x86FPState) + size = len(c.fpState) if size > 512 { // Make room for the magic cookie at the end of the xsave frame. size += _FP_XSTATE_MAGIC2_SIZE @@ -226,10 +227,10 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt c.Regs.Ss = userDS // Save the thread's floating point state. - c.sigFPState = append(c.sigFPState, c.x86FPState) + c.sigFPState = append(c.sigFPState, c.fpState) // Signal handler gets a clean floating point state. - c.x86FPState = newX86FPState() + c.fpState = fpu.NewState() return nil } @@ -273,7 +274,7 @@ func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalSt // Restore floating point state. l := len(c.sigFPState) if l > 0 { - c.x86FPState = c.sigFPState[l-1] + c.fpState = c.sigFPState[l-1] // NOTE(cl/133042258): State save requires that any slice // elements from '[len:cap]' to be zero value. c.sigFPState[l-1] = nil diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go index 4491008c2..53281dcba 100644 --- a/pkg/sentry/arch/signal_arm64.go +++ b/pkg/sentry/arch/signal_arm64.go @@ -20,6 +20,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/usermem" ) @@ -139,9 +140,9 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt c.Regs.Regs[30] = uint64(act.Restorer) // Save the thread's floating point state. - c.sigFPState = append(c.sigFPState, c.aarch64FPState) + c.sigFPState = append(c.sigFPState, c.fpState) // Signal handler gets a clean floating point state. - c.aarch64FPState = newAarch64FPState() + c.fpState = fpu.NewState() return nil } @@ -166,7 +167,7 @@ func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalSt // Restore floating point state. l := len(c.sigFPState) if l > 0 { - c.aarch64FPState = c.sigFPState[l-1] + c.fpState = c.sigFPState[l-1] // NOTE(cl/133042258): State save requires that any slice // elements from '[len:cap]' to be zero value. c.sigFPState[l-1] = nil diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go index 82eda3e43..0ed7aafa5 100644 --- a/pkg/sentry/fs/fsutil/inode_cached.go +++ b/pkg/sentry/fs/fsutil/inode_cached.go @@ -380,16 +380,17 @@ func (c *CachingInodeOperations) Allocate(ctx context.Context, offset, length in return nil } -// WriteOut implements fs.InodeOperations.WriteOut. -func (c *CachingInodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) error { +// WriteDirtyPagesAndAttrs will write the dirty pages and attributes to the +// gofer without calling Fsync on the remote file. +func (c *CachingInodeOperations) WriteDirtyPagesAndAttrs(ctx context.Context, inode *fs.Inode) error { c.attrMu.Lock() + defer c.attrMu.Unlock() + c.dataMu.Lock() + defer c.dataMu.Unlock() // Write dirty pages back. - c.dataMu.Lock() err := SyncDirtyAll(ctx, &c.cache, &c.dirty, uint64(c.attr.Size), c.mfp.MemoryFile(), c.backingFile.WriteFromBlocksAt) - c.dataMu.Unlock() if err != nil { - c.attrMu.Unlock() return err } @@ -399,12 +400,18 @@ func (c *CachingInodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) // Write out cached attributes. if err := c.backingFile.SetMaskedAttributes(ctx, c.dirtyAttr, c.attr, false); err != nil { - c.attrMu.Unlock() return err } c.dirtyAttr = fs.AttrMask{} - c.attrMu.Unlock() + return nil +} + +// WriteOut implements fs.InodeOperations.WriteOut. +func (c *CachingInodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) error { + if err := c.WriteDirtyPagesAndAttrs(ctx, inode); err != nil { + return err + } // Fsync the remote file. return c.backingFile.Sync(ctx) diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go index 06d450ba6..8f5a87120 100644 --- a/pkg/sentry/fs/gofer/file.go +++ b/pkg/sentry/fs/gofer/file.go @@ -204,20 +204,8 @@ func (f *fileOperations) readdirAll(ctx context.Context) (map[string]fs.DentAttr return entries, nil } -// maybeSync will call FSync on the file if either the cache policy or file -// flags require it. +// maybeSync will call FSync on the file if the file flags require it. func (f *fileOperations) maybeSync(ctx context.Context, file *fs.File, offset, n int64) error { - if n == 0 { - // Nothing to sync. - return nil - } - - if f.inodeOperations.session().cachePolicy.writeThrough(file.Dirent.Inode) { - // Call WriteOut directly, as some "writethrough" filesystems - // do not support sync. - return f.inodeOperations.cachingInodeOps.WriteOut(ctx, file.Dirent.Inode) - } - flags := file.Flags() var syncType fs.SyncType switch { @@ -254,6 +242,19 @@ func (f *fileOperations) Write(ctx context.Context, file *fs.File, src usermem.I n, err = src.CopyInTo(ctx, f.handles.readWriterAt(ctx, offset)) } + if n == 0 { + // Nothing written. We are done. + return 0, err + } + + // Write the dirty pages and attributes if cache policy tells us to. + if f.inodeOperations.session().cachePolicy.writeThrough(file.Dirent.Inode) { + if werr := f.inodeOperations.cachingInodeOps.WriteDirtyPagesAndAttrs(ctx, file.Dirent.Inode); werr != nil { + // Report no bytes written since the write faild. + return 0, werr + } + } + // We may need to sync the written bytes. if syncErr := f.maybeSync(ctx, file, offset, n); syncErr != nil { // Sync failed. Report 0 bytes written, since none of them are diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index c34451269..43c3c5a2d 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -783,7 +783,15 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { creds := rp.Credentials() return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string, _ **[]*dentry) error { - if _, err := parent.file.mkdir(ctx, name, (p9.FileMode)(opts.Mode), (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)); err != nil { + // If the parent is a setgid directory, use the parent's GID + // rather than the caller's and enable setgid. + kgid := creds.EffectiveKGID + mode := opts.Mode + if atomic.LoadUint32(&parent.mode)&linux.S_ISGID != 0 { + kgid = auth.KGID(atomic.LoadUint32(&parent.gid)) + mode |= linux.S_ISGID + } + if _, err := parent.file.mkdir(ctx, name, p9.FileMode(mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid)); err != nil { if !opts.ForSyntheticMountpoint || err == syserror.EEXIST { return err } @@ -1145,7 +1153,15 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving name := rp.Component() // We only want the access mode for creating the file. createFlags := p9.OpenFlags(opts.Flags) & p9.OpenFlagsModeMask - fdobj, openFile, createQID, _, err := dirfile.create(ctx, name, createFlags, (p9.FileMode)(opts.Mode), (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)) + + // If the parent is a setgid directory, use the parent's GID rather + // than the caller's. + kgid := creds.EffectiveKGID + if atomic.LoadUint32(&d.mode)&linux.S_ISGID != 0 { + kgid = auth.KGID(atomic.LoadUint32(&d.gid)) + } + + fdobj, openFile, createQID, _, err := dirfile.create(ctx, name, createFlags, p9.FileMode(opts.Mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid)) if err != nil { dirfile.close(ctx) return nil, err diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 71569dc65..692da02c1 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -1102,10 +1102,26 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs d.metadataMu.Lock() defer d.metadataMu.Unlock() + + // As with Linux, if the UID, GID, or file size is changing, we have to + // clear permission bits. Note that when set, clearSGID causes + // permissions to be updated, but does not modify stat.Mask, as + // modification would cause an extra inotify flag to be set. + clearSGID := stat.Mask&linux.STATX_UID != 0 && stat.UID != atomic.LoadUint32(&d.uid) || + stat.Mask&linux.STATX_GID != 0 && stat.GID != atomic.LoadUint32(&d.gid) || + stat.Mask&linux.STATX_SIZE != 0 + if clearSGID { + if stat.Mask&linux.STATX_MODE != 0 { + stat.Mode = uint16(vfs.ClearSUIDAndSGID(uint32(stat.Mode))) + } else { + stat.Mode = uint16(vfs.ClearSUIDAndSGID(atomic.LoadUint32(&d.mode))) + } + } + if !d.isSynthetic() { if stat.Mask != 0 { if err := d.file.setAttr(ctx, p9.SetAttrMask{ - Permissions: stat.Mask&linux.STATX_MODE != 0, + Permissions: stat.Mask&linux.STATX_MODE != 0 || clearSGID, UID: stat.Mask&linux.STATX_UID != 0, GID: stat.Mask&linux.STATX_GID != 0, Size: stat.Mask&linux.STATX_SIZE != 0, @@ -1140,7 +1156,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs return nil } } - if stat.Mask&linux.STATX_MODE != 0 { + if stat.Mask&linux.STATX_MODE != 0 || clearSGID { atomic.StoreUint32(&d.mode, d.fileType()|uint32(stat.Mode)) } if stat.Mask&linux.STATX_UID != 0 { diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index 283b220bb..4f1ad0c88 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -266,6 +266,20 @@ func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off return 0, offset, err } } + + // As with Linux, writing clears the setuid and setgid bits. + if n > 0 { + oldMode := atomic.LoadUint32(&d.mode) + // If setuid or setgid were set, update d.mode and propagate + // changes to the host. + if newMode := vfs.ClearSUIDAndSGID(oldMode); newMode != oldMode { + atomic.StoreUint32(&d.mode, newMode) + if err := d.file.setAttr(ctx, p9.SetAttrMask{Permissions: true}, p9.SetAttr{Permissions: p9.FileMode(newMode)}); err != nil { + return 0, offset, err + } + } + } + return n, offset + n, nil } diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index 84e37f793..46c500427 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -689,13 +689,9 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v } return err } - creds := rp.Credentials() + if err := vfsObj.SetStatAt(ctx, fs.creds, &pop, &vfs.SetStatOptions{ - Stat: linux.Statx{ - Mask: linux.STATX_UID | linux.STATX_GID, - UID: uint32(creds.EffectiveKUID), - GID: uint32(creds.EffectiveKGID), - }, + Stat: parent.newChildOwnerStat(opts.Mode, rp.Credentials()), }); err != nil { if cleanupErr := vfsObj.RmdirAt(ctx, fs.creds, &pop); cleanupErr != nil { panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer directory after MkdirAt metadata update failure: %v", cleanupErr)) @@ -750,11 +746,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v } creds := rp.Credentials() if err := vfsObj.SetStatAt(ctx, fs.creds, &pop, &vfs.SetStatOptions{ - Stat: linux.Statx{ - Mask: linux.STATX_UID | linux.STATX_GID, - UID: uint32(creds.EffectiveKUID), - GID: uint32(creds.EffectiveKGID), - }, + Stat: parent.newChildOwnerStat(opts.Mode, creds), }); err != nil { if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil { panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after MknodAt metadata update failure: %v", cleanupErr)) @@ -963,14 +955,11 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving } return nil, err } + // Change the file's owner to the caller. We can't use upperFD.SetStat() // because it will pick up creds from ctx. if err := vfsObj.SetStatAt(ctx, fs.creds, &pop, &vfs.SetStatOptions{ - Stat: linux.Statx{ - Mask: linux.STATX_UID | linux.STATX_GID, - UID: uint32(creds.EffectiveKUID), - GID: uint32(creds.EffectiveKGID), - }, + Stat: parent.newChildOwnerStat(opts.Mode, creds), }); err != nil { if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil { panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) metadata update failure: %v", cleanupErr)) diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go index 58680bc80..454c20d4f 100644 --- a/pkg/sentry/fsimpl/overlay/overlay.go +++ b/pkg/sentry/fsimpl/overlay/overlay.go @@ -749,6 +749,27 @@ func (d *dentry) mayDelete(creds *auth.Credentials, child *dentry) error { ) } +// newChildOwnerStat returns a Statx for configuring the UID, GID, and mode of +// children. +func (d *dentry) newChildOwnerStat(mode linux.FileMode, creds *auth.Credentials) linux.Statx { + stat := linux.Statx{ + Mask: uint32(linux.STATX_UID | linux.STATX_GID), + UID: uint32(creds.EffectiveKUID), + GID: uint32(creds.EffectiveKGID), + } + // Set GID and possibly the SGID bit if the parent is an SGID directory. + d.copyMu.RLock() + defer d.copyMu.RUnlock() + if atomic.LoadUint32(&d.mode)&linux.ModeSetGID == linux.ModeSetGID { + stat.GID = atomic.LoadUint32(&d.gid) + if stat.Mode&linux.ModeDirectory == linux.ModeDirectory { + stat.Mode = uint16(mode) | linux.ModeSetGID + stat.Mask |= linux.STATX_MODE + } + } + return stat +} + // fileDescription is embedded by overlay implementations of // vfs.FileDescriptionImpl. // diff --git a/pkg/sentry/fsimpl/overlay/regular_file.go b/pkg/sentry/fsimpl/overlay/regular_file.go index 25c785fd4..d791c06db 100644 --- a/pkg/sentry/fsimpl/overlay/regular_file.go +++ b/pkg/sentry/fsimpl/overlay/regular_file.go @@ -205,6 +205,20 @@ func (fd *regularFileFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) e if err := wrappedFD.SetStat(ctx, opts); err != nil { return err } + + // Changing owners may clear one or both of the setuid and setgid bits, + // so we may have to update opts before setting d.mode. + if opts.Stat.Mask&(linux.STATX_UID|linux.STATX_GID) != 0 { + stat, err := wrappedFD.Stat(ctx, vfs.StatOptions{ + Mask: linux.STATX_MODE, + }) + if err != nil { + return err + } + opts.Stat.Mode = stat.Mode + opts.Stat.Mask |= linux.STATX_MODE + } + d.updateAfterSetStatLocked(&opts) if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent) @@ -295,7 +309,11 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off return 0, err } defer wrappedFD.DecRef(ctx) - return wrappedFD.PWrite(ctx, src, offset, opts) + n, err := wrappedFD.PWrite(ctx, src, offset, opts) + if err != nil { + return n, err + } + return fd.updateSetUserGroupIDs(ctx, wrappedFD, n) } // Write implements vfs.FileDescriptionImpl.Write. @@ -307,7 +325,28 @@ func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts if err != nil { return 0, err } - return wrappedFD.Write(ctx, src, opts) + n, err := wrappedFD.Write(ctx, src, opts) + if err != nil { + return n, err + } + return fd.updateSetUserGroupIDs(ctx, wrappedFD, n) +} + +func (fd *regularFileFD) updateSetUserGroupIDs(ctx context.Context, wrappedFD *vfs.FileDescription, written int64) (int64, error) { + // Writing can clear the setuid and/or setgid bits. We only have to + // check this if something was written and one of those bits was set. + dentry := fd.dentry() + if written == 0 || atomic.LoadUint32(&dentry.mode)&(linux.S_ISUID|linux.S_ISGID) == 0 { + return written, nil + } + stat, err := wrappedFD.Stat(ctx, vfs.StatOptions{Mask: linux.STATX_MODE}) + if err != nil { + return written, err + } + dentry.copyMu.Lock() + defer dentry.copyMu.Unlock() + atomic.StoreUint32(&dentry.mode, uint32(stat.Mode)) + return written, nil } // Seek implements vfs.FileDescriptionImpl.Seek. diff --git a/pkg/sentry/kernel/ptrace_amd64.go b/pkg/sentry/kernel/ptrace_amd64.go index 609ad3941..7aea3dcd8 100644 --- a/pkg/sentry/kernel/ptrace_amd64.go +++ b/pkg/sentry/kernel/ptrace_amd64.go @@ -51,14 +51,15 @@ func (t *Task) ptraceArch(target *Task, req int64, addr, data usermem.Addr) erro return err case linux.PTRACE_GETFPREGS: - _, err := target.Arch().PtraceGetFPRegs(&usermem.IOReadWriter{ + s := target.Arch().FloatingPointData() + _, err := target.Arch().FloatingPointData().PtraceGetFPRegs(&usermem.IOReadWriter{ Ctx: t, IO: t.MemoryManager(), Addr: data, Opts: usermem.IOOpts{ AddressSpaceActive: true, }, - }) + }, len(*s)) return err case linux.PTRACE_SETREGS: @@ -73,14 +74,15 @@ func (t *Task) ptraceArch(target *Task, req int64, addr, data usermem.Addr) erro return err case linux.PTRACE_SETFPREGS: - _, err := target.Arch().PtraceSetFPRegs(&usermem.IOReadWriter{ + s := target.Arch().FloatingPointData() + _, err := s.PtraceSetFPRegs(&usermem.IOReadWriter{ Ctx: t, IO: t.MemoryManager(), Addr: data, Opts: usermem.IOOpts{ AddressSpaceActive: true, }, - }) + }, len(*s)) return err default: diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 4f9e781af..03a76eb9b 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -50,6 +50,7 @@ go_library( "//pkg/safecopy", "//pkg/seccomp", "//pkg/sentry/arch", + "//pkg/sentry/arch/fpu", "//pkg/sentry/memmap", "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", @@ -78,6 +79,7 @@ go_test( "//pkg/ring0", "//pkg/ring0/pagetables", "//pkg/sentry/arch", + "//pkg/sentry/arch/fpu", "//pkg/sentry/platform", "//pkg/sentry/platform/kvm/testutil", "//pkg/sentry/time", diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.go b/pkg/sentry/platform/kvm/bluepill_amd64.go index 308696efe..d761bbdee 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64.go @@ -73,7 +73,7 @@ func (c *vCPU) KernelSyscall() { // We only trigger a bluepill entry in the bluepill function, and can // therefore be guaranteed that there is no floating point state to be // loaded on resuming from halt. We only worry about saving on exit. - ring0.SaveFloatingPoint(&c.floatingPointState[0]) // escapes: no. + ring0.SaveFloatingPoint(c.floatingPointState.BytePointer()) // escapes: no. ring0.Halt() ring0.WriteFS(uintptr(regs.Fs_base)) // escapes: no, reload host segment. } @@ -92,7 +92,7 @@ func (c *vCPU) KernelException(vector ring0.Vector) { regs.Rip = 0 } // See above. - ring0.SaveFloatingPoint(&c.floatingPointState[0]) // escapes: no. + ring0.SaveFloatingPoint(c.floatingPointState.BytePointer()) // escapes: no. ring0.Halt() ring0.WriteFS(uintptr(regs.Fs_base)) // escapes: no; reload host segment. } @@ -124,5 +124,5 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) { // Set the context pointer to the saved floating point state. This is // where the guest data has been serialized, the kernel will restore // from this new pointer value. - context.Fpstate = uint64(uintptrValue(&c.floatingPointState[0])) + context.Fpstate = uint64(uintptrValue(c.floatingPointState.BytePointer())) } diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go index c317f1e99..578852c3f 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64.go @@ -92,7 +92,7 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) { lazyVfp := c.GetLazyVFP() if lazyVfp != 0 { - fpsimd := fpsimdPtr(&c.floatingPointState[0]) + fpsimd := fpsimdPtr(c.floatingPointState.BytePointer()) context.Fpsimd64.Fpsr = fpsimd.Fpsr context.Fpsimd64.Fpcr = fpsimd.Fpcr context.Fpsimd64.Vregs = fpsimd.Vregs @@ -112,12 +112,12 @@ func (c *vCPU) KernelSyscall() { fpDisableTrap := ring0.CPACREL1() if fpDisableTrap != 0 { - fpsimd := fpsimdPtr(&c.floatingPointState[0]) + fpsimd := fpsimdPtr(c.floatingPointState.BytePointer()) fpcr := ring0.GetFPCR() fpsr := ring0.GetFPSR() fpsimd.Fpcr = uint32(fpcr) fpsimd.Fpsr = uint32(fpsr) - ring0.SaveVRegs(&c.floatingPointState[0]) + ring0.SaveVRegs(c.floatingPointState.BytePointer()) } ring0.Halt() @@ -136,12 +136,12 @@ func (c *vCPU) KernelException(vector ring0.Vector) { fpDisableTrap := ring0.CPACREL1() if fpDisableTrap != 0 { - fpsimd := fpsimdPtr(&c.floatingPointState[0]) + fpsimd := fpsimdPtr(c.floatingPointState.BytePointer()) fpcr := ring0.GetFPCR() fpsr := ring0.GetFPSR() fpsimd.Fpcr = uint32(fpcr) fpsimd.Fpsr = uint32(fpsr) - ring0.SaveVRegs(&c.floatingPointState[0]) + ring0.SaveVRegs(c.floatingPointState.BytePointer()) } ring0.Halt() diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.go b/pkg/sentry/platform/kvm/kvm_amd64_test.go index 76fc594a0..e44e995a0 100644 --- a/pkg/sentry/platform/kvm/kvm_amd64_test.go +++ b/pkg/sentry/platform/kvm/kvm_amd64_test.go @@ -33,7 +33,7 @@ func TestSegments(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, FullRestore: true, }, &si); err == platform.ErrContextInterrupt { diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go index 6243b9a04..5bce16dde 100644 --- a/pkg/sentry/platform/kvm/kvm_test.go +++ b/pkg/sentry/platform/kvm/kvm_test.go @@ -25,13 +25,14 @@ import ( "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil" ktime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/usermem" ) -var dummyFPState = (*byte)(arch.NewFloatingPointData()) +var dummyFPState = fpu.NewState() type testHarness interface { Errorf(format string, args ...interface{}) @@ -159,7 +160,7 @@ func TestApplicationSyscall(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, FullRestore: true, }, &si); err == platform.ErrContextInterrupt { @@ -173,7 +174,7 @@ func TestApplicationSyscall(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, }, &si); err == platform.ErrContextInterrupt { return true // Retry. @@ -190,7 +191,7 @@ func TestApplicationFault(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, FullRestore: true, }, &si); err == platform.ErrContextInterrupt { @@ -205,7 +206,7 @@ func TestApplicationFault(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, }, &si); err == platform.ErrContextInterrupt { return true // Retry. @@ -223,7 +224,7 @@ func TestRegistersSyscall(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, }, &si); err == platform.ErrContextInterrupt { continue // Retry. @@ -246,7 +247,7 @@ func TestRegistersFault(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, FullRestore: true, }, &si); err == platform.ErrContextInterrupt { @@ -272,7 +273,7 @@ func TestBounce(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, }, &si); err != platform.ErrContextInterrupt { t.Errorf("application partial restore: got %v, wanted %v", err, platform.ErrContextInterrupt) @@ -287,7 +288,7 @@ func TestBounce(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, FullRestore: true, }, &si); err != platform.ErrContextInterrupt { @@ -319,7 +320,7 @@ func TestBounceStress(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, }, &si); err != platform.ErrContextInterrupt { t.Errorf("application partial restore: got %v, wanted %v", err, platform.ErrContextInterrupt) @@ -340,7 +341,7 @@ func TestInvalidate(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, }, &si); err == platform.ErrContextInterrupt { continue // Retry. @@ -355,7 +356,7 @@ func TestInvalidate(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, Flush: true, }, &si); err == platform.ErrContextInterrupt { @@ -379,7 +380,7 @@ func TestEmptyAddressSpace(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, }, &si); err == platform.ErrContextInterrupt { return true // Retry. @@ -393,7 +394,7 @@ func TestEmptyAddressSpace(t *testing.T) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, FullRestore: true, }, &si); err == platform.ErrContextInterrupt { @@ -469,7 +470,7 @@ func BenchmarkApplicationSyscall(b *testing.B) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, }, &si); err == platform.ErrContextInterrupt { a++ @@ -506,7 +507,7 @@ func BenchmarkWorldSwitchToUserRoundtrip(b *testing.B) { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, - FloatingPointState: dummyFPState, + FloatingPointState: &dummyFPState, PageTables: pt, }, &si); err == platform.ErrContextInterrupt { a++ diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index 916903881..3af96c7e5 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -27,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" ktime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/usermem" @@ -70,7 +71,7 @@ type vCPUArchState struct { // floatingPointState is the floating point state buffer used in guest // to host transitions. See usage in bluepill_amd64.go. - floatingPointState arch.FloatingPointData + floatingPointState fpu.State } const ( @@ -151,7 +152,7 @@ func (c *vCPU) initArchState() error { // This will be saved prior to leaving the guest, and we restore from // this always. We cannot use the pointer in the context alone because // we don't know how large the area there is in reality. - c.floatingPointState = arch.NewFloatingPointData() + c.floatingPointState = fpu.NewState() // Set the time offset to the host native time. return c.setSystemTime() @@ -307,12 +308,12 @@ func loadByte(ptr *byte) byte { // emulate instructions like xsave and xrstor. // //go:nosplit -func prefaultFloatingPointState(data arch.FloatingPointData) { - size := len(data) +func prefaultFloatingPointState(data *fpu.State) { + size := len(*data) for i := 0; i < size; i += usermem.PageSize { - loadByte(&(data)[i]) + loadByte(&(*data)[i]) } - loadByte(&(data)[size-1]) + loadByte(&(*data)[size-1]) } // SwitchToUser unpacks architectural-details. diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index 3d715e570..2edc9d1b2 100644 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/usermem" ) @@ -32,7 +33,7 @@ type vCPUArchState struct { // floatingPointState is the floating point state buffer used in guest // to host transitions. See usage in bluepill_arm64.go. - floatingPointState arch.FloatingPointData + floatingPointState fpu.State } const ( diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 059aa43d0..e7d5f3193 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/usermem" ) @@ -150,7 +151,7 @@ func (c *vCPU) initArchState() error { c.PCIDs = pagetables.NewPCIDs(fixedKernelPCID+1, poolPCIDs) } - c.floatingPointState = arch.NewFloatingPointData() + c.floatingPointState = fpu.NewState() return c.setSystemTime() } diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD index fc43cc3c0..47efde6a2 100644 --- a/pkg/sentry/platform/ptrace/BUILD +++ b/pkg/sentry/platform/ptrace/BUILD @@ -30,6 +30,7 @@ go_library( "//pkg/safecopy", "//pkg/seccomp", "//pkg/sentry/arch", + "//pkg/sentry/arch/fpu", "//pkg/sentry/memmap", "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go index 6259350ec..01e73b019 100644 --- a/pkg/sentry/platform/ptrace/ptrace_unsafe.go +++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go @@ -20,6 +20,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/usermem" ) @@ -62,9 +63,9 @@ func (t *thread) setRegs(regs *arch.Registers) error { } // getFPRegs gets the floating-point data via the GETREGSET ptrace unix. -func (t *thread) getFPRegs(fpState arch.FloatingPointData, fpLen uint64, useXsave bool) error { +func (t *thread) getFPRegs(fpState *fpu.State, fpLen uint64, useXsave bool) error { iovec := unix.Iovec{ - Base: (*byte)(&fpState[0]), + Base: fpState.BytePointer(), Len: fpLen, } _, _, errno := unix.RawSyscall6( @@ -81,9 +82,9 @@ func (t *thread) getFPRegs(fpState arch.FloatingPointData, fpLen uint64, useXsav } // setFPRegs sets the floating-point data via the SETREGSET ptrace unix. -func (t *thread) setFPRegs(fpState arch.FloatingPointData, fpLen uint64, useXsave bool) error { +func (t *thread) setFPRegs(fpState *fpu.State, fpLen uint64, useXsave bool) error { iovec := unix.Iovec{ - Base: (*byte)(&fpState[0]), + Base: fpState.BytePointer(), Len: fpLen, } _, _, errno := unix.RawSyscall6( diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go index 5bd526b73..efec93f73 100644 --- a/pkg/sentry/syscalls/linux/error.go +++ b/pkg/sentry/syscalls/linux/error.go @@ -75,17 +75,25 @@ func handleIOError(ctx context.Context, partialResult bool, ioerr, intr error, o // errors, we may consume the error and return only the partial read/write. // // Returns false if error is unknown. -func handleIOErrorImpl(ctx context.Context, partialResult bool, err, intr error, op string) (bool, error) { - switch err { - case nil: +func handleIOErrorImpl(ctx context.Context, partialResult bool, errOrig, intr error, op string) (bool, error) { + if errOrig == nil { // Typical successful syscall. return true, nil + } + + // Translate error, if possible, to consolidate errors from other packages + // into a smaller set of errors from syserror package. + translatedErr := errOrig + if errno, ok := syserror.TranslateError(errOrig); ok { + translatedErr = errno + } + switch translatedErr { case io.EOF: // EOF is always consumed. If this is a partial read/write // (result != 0), the application will see that, otherwise // they will see 0. return true, nil - case syserror.ErrExceedsFileSizeLimit: + case syserror.EFBIG: t := kernel.TaskFromContext(ctx) if t == nil { panic("I/O error should only occur from a context associated with a Task") @@ -98,7 +106,7 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, err, intr error, // Simultaneously send a SIGXFSZ per setrlimit(2). t.SendSignal(kernel.SignalInfoNoInfo(linux.SIGXFSZ, t, t)) return true, syserror.EFBIG - case syserror.ErrInterrupted: + case syserror.EINTR: // The syscall was interrupted. Return nil if it completed // partially, otherwise return the error code that the syscall // needs (to indicate to the kernel what it should do). @@ -110,10 +118,10 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, err, intr error, if !partialResult { // Typical syscall error. - return true, err + return true, errOrig } - switch err { + switch translatedErr { case syserror.EINTR: // Syscall interrupted, but completed a partial // read/write. Like ErrWouldBlock, since we have a @@ -143,7 +151,7 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, err, intr error, // For TCP sendfile connections, we may have a reset or timeout. But we // should just return n as the result. return true, nil - case syserror.ErrWouldBlock: + case syserror.EWOULDBLOCK: // Syscall would block, but completed a partial read/write. // This case should only be returned by IssueIO for nonblocking // files. Since we have a partial read/write, we consume @@ -151,7 +159,7 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, err, intr error, return true, nil } - switch err.(type) { + switch errOrig.(type) { case syserror.SyscallRestartErrno: // Identical to the EINTR case. return true, nil diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go index 97de17afe..56b621357 100644 --- a/pkg/syserror/syserror.go +++ b/pkg/syserror/syserror.go @@ -130,17 +130,15 @@ func AddErrorUnwrapper(unwrap func(e error) (unix.Errno, bool)) { // TranslateError translates errors to errnos, it will return false if // the error was not registered. func TranslateError(from error) (unix.Errno, bool) { - err, ok := errorMap[from] - if ok { - return err, ok + if err, ok := errorMap[from]; ok { + return err, true } // Try to unwrap the error if we couldn't match an error // exactly. This might mean that a package has its own // error type. for _, unwrap := range errorUnwrappers { - err, ok := unwrap(from) - if ok { - return err, ok + if err, ok := unwrap(from); ok { + return err, true } } return 0, false diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index fc622b246..fef065b05 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -1287,6 +1287,13 @@ func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) } + case header.NDPNonceOption: + gotOpt, ok := opt.(header.NDPNonceOption) + if !ok { + t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) + } else if diff := cmp.Diff(wantOpt.Nonce(), gotOpt.Nonce()); diff != "" { + t.Errorf("nonce mismatch (-want +got):\n%s", diff) + } default: t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt) } diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index 554242f0c..3d1bccd15 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -26,29 +26,33 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" ) -// NDPOptionIdentifier is an NDP option type identifier. -type NDPOptionIdentifier uint8 +// ndpOptionIdentifier is an NDP option type identifier. +type ndpOptionIdentifier uint8 const ( - // NDPSourceLinkLayerAddressOptionType is the type of the Source Link Layer + // ndpSourceLinkLayerAddressOptionType is the type of the Source Link Layer // Address option, as per RFC 4861 section 4.6.1. - NDPSourceLinkLayerAddressOptionType NDPOptionIdentifier = 1 + ndpSourceLinkLayerAddressOptionType ndpOptionIdentifier = 1 - // NDPTargetLinkLayerAddressOptionType is the type of the Target Link Layer + // ndpTargetLinkLayerAddressOptionType is the type of the Target Link Layer // Address option, as per RFC 4861 section 4.6.1. - NDPTargetLinkLayerAddressOptionType NDPOptionIdentifier = 2 + ndpTargetLinkLayerAddressOptionType ndpOptionIdentifier = 2 - // NDPPrefixInformationType is the type of the Prefix Information + // ndpPrefixInformationType is the type of the Prefix Information // option, as per RFC 4861 section 4.6.2. - NDPPrefixInformationType NDPOptionIdentifier = 3 + ndpPrefixInformationType ndpOptionIdentifier = 3 - // NDPRecursiveDNSServerOptionType is the type of the Recursive DNS + // ndpNonceOptionType is the type of the Nonce option, as per + // RFC 3971 section 5.3.2. + ndpNonceOptionType ndpOptionIdentifier = 14 + + // ndpRecursiveDNSServerOptionType is the type of the Recursive DNS // Server option, as per RFC 8106 section 5.1. - NDPRecursiveDNSServerOptionType NDPOptionIdentifier = 25 + ndpRecursiveDNSServerOptionType ndpOptionIdentifier = 25 - // NDPDNSSearchListOptionType is the type of the DNS Search List option, + // ndpDNSSearchListOptionType is the type of the DNS Search List option, // as per RFC 8106 section 5.2. - NDPDNSSearchListOptionType = 31 + ndpDNSSearchListOptionType ndpOptionIdentifier = 31 ) const ( @@ -198,7 +202,7 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { // bytes for the whole option. return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Type field: %w", io.ErrUnexpectedEOF) } - kind := NDPOptionIdentifier(temp) + kind := ndpOptionIdentifier(temp) // Get the Length field. length, err := i.opts.ReadByte() @@ -225,13 +229,16 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { } switch kind { - case NDPSourceLinkLayerAddressOptionType: + case ndpSourceLinkLayerAddressOptionType: return NDPSourceLinkLayerAddressOption(body), false, nil - case NDPTargetLinkLayerAddressOptionType: + case ndpTargetLinkLayerAddressOptionType: return NDPTargetLinkLayerAddressOption(body), false, nil - case NDPPrefixInformationType: + case ndpNonceOptionType: + return NDPNonceOption(body), false, nil + + case ndpPrefixInformationType: // Make sure the length of a Prefix Information option // body is ndpPrefixInformationLength, as per RFC 4861 // section 4.6.2. @@ -241,7 +248,7 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { return NDPPrefixInformation(body), false, nil - case NDPRecursiveDNSServerOptionType: + case ndpRecursiveDNSServerOptionType: opt := NDPRecursiveDNSServer(body) if err := opt.checkAddresses(); err != nil { return nil, true, err @@ -249,7 +256,7 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { return opt, false, nil - case NDPDNSSearchListOptionType: + case ndpDNSSearchListOptionType: opt := NDPDNSSearchList(body) if err := opt.checkDomainNames(); err != nil { return nil, true, err @@ -316,7 +323,7 @@ func (b NDPOptions) Serialize(s NDPOptionsSerializer) int { continue } - b[0] = byte(o.Type()) + b[0] = byte(o.kind()) // We know this safe because paddedLength would have returned // 0 if o had an invalid length (> 255 * lengthByteUnits). @@ -341,11 +348,11 @@ func (b NDPOptions) Serialize(s NDPOptionsSerializer) int { type NDPOption interface { fmt.Stringer - // Type returns the type of the receiver. - Type() NDPOptionIdentifier + // kind returns the type of the receiver. + kind() ndpOptionIdentifier - // Length returns the length of the body of the receiver, in bytes. - Length() int + // length returns the length of the body of the receiver, in bytes. + length() int // serializeInto serializes the receiver into the provided byte // buffer. @@ -365,7 +372,7 @@ type NDPOption interface { // paddedLength returns the length of o, in bytes, with any padding bytes, if // required. func paddedLength(o NDPOption) int { - l := o.Length() + l := o.length() if l == 0 { return 0 @@ -416,6 +423,37 @@ func (b NDPOptionsSerializer) Length() int { return l } +// NDPNonceOption is the NDP Nonce Option as defined by RFC 3971 section 5.3.2. +// +// It is the first X bytes following the NDP option's Type and Length field +// where X is the value in Length multiplied by lengthByteUnits - 2 bytes. +type NDPNonceOption []byte + +// kind implements NDPOption. +func (o NDPNonceOption) kind() ndpOptionIdentifier { + return ndpNonceOptionType +} + +// length implements NDPOption. +func (o NDPNonceOption) length() int { + return len(o) +} + +// serializeInto implements NDPOption. +func (o NDPNonceOption) serializeInto(b []byte) int { + return copy(b, o) +} + +// String implements fmt.Stringer. +func (o NDPNonceOption) String() string { + return fmt.Sprintf("%T(%x)", o, []byte(o)) +} + +// Nonce returns the nonce value this option holds. +func (o NDPNonceOption) Nonce() []byte { + return []byte(o) +} + // NDPSourceLinkLayerAddressOption is the NDP Source Link Layer Option // as defined by RFC 4861 section 4.6.1. // @@ -423,22 +461,22 @@ func (b NDPOptionsSerializer) Length() int { // where X is the value in Length multiplied by lengthByteUnits - 2 bytes. type NDPSourceLinkLayerAddressOption tcpip.LinkAddress -// Type implements NDPOption.Type. -func (o NDPSourceLinkLayerAddressOption) Type() NDPOptionIdentifier { - return NDPSourceLinkLayerAddressOptionType +// kind implements NDPOption. +func (o NDPSourceLinkLayerAddressOption) kind() ndpOptionIdentifier { + return ndpSourceLinkLayerAddressOptionType } -// Length implements NDPOption.Length. -func (o NDPSourceLinkLayerAddressOption) Length() int { +// length implements NDPOption. +func (o NDPSourceLinkLayerAddressOption) length() int { return len(o) } -// serializeInto implements NDPOption.serializeInto. +// serializeInto implements NDPOption. func (o NDPSourceLinkLayerAddressOption) serializeInto(b []byte) int { return copy(b, o) } -// String implements fmt.Stringer.String. +// String implements fmt.Stringer. func (o NDPSourceLinkLayerAddressOption) String() string { return fmt.Sprintf("%T(%s)", o, tcpip.LinkAddress(o)) } @@ -463,22 +501,22 @@ func (o NDPSourceLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress { // where X is the value in Length multiplied by lengthByteUnits - 2 bytes. type NDPTargetLinkLayerAddressOption tcpip.LinkAddress -// Type implements NDPOption.Type. -func (o NDPTargetLinkLayerAddressOption) Type() NDPOptionIdentifier { - return NDPTargetLinkLayerAddressOptionType +// kind implements NDPOption. +func (o NDPTargetLinkLayerAddressOption) kind() ndpOptionIdentifier { + return ndpTargetLinkLayerAddressOptionType } -// Length implements NDPOption.Length. -func (o NDPTargetLinkLayerAddressOption) Length() int { +// length implements NDPOption. +func (o NDPTargetLinkLayerAddressOption) length() int { return len(o) } -// serializeInto implements NDPOption.serializeInto. +// serializeInto implements NDPOption. func (o NDPTargetLinkLayerAddressOption) serializeInto(b []byte) int { return copy(b, o) } -// String implements fmt.Stringer.String. +// String implements fmt.Stringer. func (o NDPTargetLinkLayerAddressOption) String() string { return fmt.Sprintf("%T(%s)", o, tcpip.LinkAddress(o)) } @@ -503,17 +541,17 @@ func (o NDPTargetLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress { // ndpPrefixInformationLength bytes. type NDPPrefixInformation []byte -// Type implements NDPOption.Type. -func (o NDPPrefixInformation) Type() NDPOptionIdentifier { - return NDPPrefixInformationType +// kind implements NDPOption. +func (o NDPPrefixInformation) kind() ndpOptionIdentifier { + return ndpPrefixInformationType } -// Length implements NDPOption.Length. -func (o NDPPrefixInformation) Length() int { +// length implements NDPOption. +func (o NDPPrefixInformation) length() int { return ndpPrefixInformationLength } -// serializeInto implements NDPOption.serializeInto. +// serializeInto implements NDPOption. func (o NDPPrefixInformation) serializeInto(b []byte) int { used := copy(b, o) @@ -529,7 +567,7 @@ func (o NDPPrefixInformation) serializeInto(b []byte) int { return used } -// String implements fmt.Stringer.String. +// String implements fmt.Stringer. func (o NDPPrefixInformation) String() string { return fmt.Sprintf("%T(O=%t, A=%t, PL=%s, VL=%s, Prefix=%s)", o, @@ -627,17 +665,17 @@ type NDPRecursiveDNSServer []byte // Type returns the type of an NDP Recursive DNS Server option. // -// Type implements NDPOption.Type. -func (NDPRecursiveDNSServer) Type() NDPOptionIdentifier { - return NDPRecursiveDNSServerOptionType +// kind implements NDPOption. +func (NDPRecursiveDNSServer) kind() ndpOptionIdentifier { + return ndpRecursiveDNSServerOptionType } -// Length implements NDPOption.Length. -func (o NDPRecursiveDNSServer) Length() int { +// length implements NDPOption. +func (o NDPRecursiveDNSServer) length() int { return len(o) } -// serializeInto implements NDPOption.serializeInto. +// serializeInto implements NDPOption. func (o NDPRecursiveDNSServer) serializeInto(b []byte) int { used := copy(b, o) @@ -649,7 +687,7 @@ func (o NDPRecursiveDNSServer) serializeInto(b []byte) int { return used } -// String implements fmt.Stringer.String. +// String implements fmt.Stringer. func (o NDPRecursiveDNSServer) String() string { lt := o.Lifetime() addrs, err := o.Addresses() @@ -722,17 +760,17 @@ func (o NDPRecursiveDNSServer) iterAddresses(fn func(tcpip.Address)) error { // RFC 8106 section 5.2. type NDPDNSSearchList []byte -// Type implements NDPOption.Type. -func (o NDPDNSSearchList) Type() NDPOptionIdentifier { - return NDPDNSSearchListOptionType +// kind implements NDPOption. +func (o NDPDNSSearchList) kind() ndpOptionIdentifier { + return ndpDNSSearchListOptionType } -// Length implements NDPOption.Length. -func (o NDPDNSSearchList) Length() int { +// length implements NDPOption. +func (o NDPDNSSearchList) length() int { return len(o) } -// serializeInto implements NDPOption.serializeInto. +// serializeInto implements NDPOption. func (o NDPDNSSearchList) serializeInto(b []byte) int { used := copy(b, o) @@ -744,7 +782,7 @@ func (o NDPDNSSearchList) serializeInto(b []byte) int { return used } -// String implements fmt.Stringer.String. +// String implements fmt.Stringer. func (o NDPDNSSearchList) String() string { lt := o.Lifetime() domainNames, err := o.DomainNames() diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go index dc4591253..d0a1a2492 100644 --- a/pkg/tcpip/header/ndp_test.go +++ b/pkg/tcpip/header/ndp_test.go @@ -16,6 +16,7 @@ package header import ( "bytes" + "encoding/binary" "errors" "fmt" "io" @@ -192,90 +193,6 @@ func TestNDPSourceLinkLayerAddressOptionEthernetAddress(t *testing.T) { } } -// TestNDPSourceLinkLayerAddressOptionSerialize tests serializing a -// NDPSourceLinkLayerAddressOption. -func TestNDPSourceLinkLayerAddressOptionSerialize(t *testing.T) { - tests := []struct { - name string - buf []byte - expectedBuf []byte - addr tcpip.LinkAddress - }{ - { - "Ethernet", - make([]byte, 8), - []byte{1, 1, 1, 2, 3, 4, 5, 6}, - "\x01\x02\x03\x04\x05\x06", - }, - { - "Padding", - []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, - []byte{1, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0}, - "\x01\x02\x03\x04\x05\x06\x07\x08", - }, - { - "Empty", - nil, - nil, - "", - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opts := NDPOptions(test.buf) - serializer := NDPOptionsSerializer{ - NDPSourceLinkLayerAddressOption(test.addr), - } - if got, want := int(serializer.Length()), len(test.expectedBuf); got != want { - t.Fatalf("got Length = %d, want = %d", got, want) - } - opts.Serialize(serializer) - if !bytes.Equal(test.buf, test.expectedBuf) { - t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf) - } - - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - if len(test.expectedBuf) > 0 { - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType { - t.Fatalf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType) - } - sll := next.(NDPSourceLinkLayerAddressOption) - if got, want := []byte(sll), test.expectedBuf[2:]; !bytes.Equal(got, want) { - t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want) - } - - if got, want := sll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want { - t.Errorf("got sll.EthernetAddress = %s, want = %s", got, want) - } - } - - // Iterator should not return anything else. - next, done, err := it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } - }) - } -} - // TestNDPTargetLinkLayerAddressOptionEthernetAddress tests getting the // Ethernet address from an NDPTargetLinkLayerAddressOption. func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) { @@ -311,32 +228,309 @@ func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) { } } -// TestNDPTargetLinkLayerAddressOptionSerialize tests serializing a -// NDPTargetLinkLayerAddressOption. -func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) { +func TestOpts(t *testing.T) { + const optionHeaderLen = 2 + + checkNonce := func(expectedNonce []byte) func(*testing.T, NDPOption) { + return func(t *testing.T, opt NDPOption) { + if got := opt.kind(); got != ndpNonceOptionType { + t.Errorf("got kind() = %d, want = %d", got, ndpNonceOptionType) + } + nonce, ok := opt.(NDPNonceOption) + if !ok { + t.Fatalf("got nonce = %T, want = NDPNonceOption", opt) + } + if diff := cmp.Diff(expectedNonce, nonce.Nonce()); diff != "" { + t.Errorf("nonce mismatch (-want +got):\n%s", diff) + } + } + } + + checkTLL := func(expectedAddr tcpip.LinkAddress) func(*testing.T, NDPOption) { + return func(t *testing.T, opt NDPOption) { + if got := opt.kind(); got != ndpTargetLinkLayerAddressOptionType { + t.Errorf("got kind() = %d, want = %d", got, ndpTargetLinkLayerAddressOptionType) + } + tll, ok := opt.(NDPTargetLinkLayerAddressOption) + if !ok { + t.Fatalf("got tll = %T, want = NDPTargetLinkLayerAddressOption", opt) + } + if got, want := tll.EthernetAddress(), expectedAddr; got != want { + t.Errorf("got tll.EthernetAddress = %s, want = %s", got, want) + } + } + } + + checkSLL := func(expectedAddr tcpip.LinkAddress) func(*testing.T, NDPOption) { + return func(t *testing.T, opt NDPOption) { + if got := opt.kind(); got != ndpSourceLinkLayerAddressOptionType { + t.Errorf("got kind() = %d, want = %d", got, ndpSourceLinkLayerAddressOptionType) + } + sll, ok := opt.(NDPSourceLinkLayerAddressOption) + if !ok { + t.Fatalf("got sll = %T, want = NDPSourceLinkLayerAddressOption", opt) + } + if got, want := sll.EthernetAddress(), expectedAddr; got != want { + t.Errorf("got sll.EthernetAddress = %s, want = %s", got, want) + } + } + } + + const validLifetimeSeconds = 16909060 + const address = tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18") + + expectedRDNSSBytes := [...]byte{ + // Type, Length + 25, 3, + + // Reserved + 0, 0, + + // Lifetime + 1, 2, 4, 8, + + // Address + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + } + binary.BigEndian.PutUint32(expectedRDNSSBytes[4:], validLifetimeSeconds) + if n := copy(expectedRDNSSBytes[8:], address); n != IPv6AddressSize { + t.Fatalf("got copy(...) = %d, want = %d", n, IPv6AddressSize) + } + // Update reserved fields to non zero values to make sure serializing sets + // them to zero. + rdnssBytes := expectedRDNSSBytes + rdnssBytes[1] = 1 + rdnssBytes[2] = 2 + + const searchListPaddingBytes = 3 + const domainName = "abc.abcd.e" + expectedSearchListBytes := [...]byte{ + // Type, Length + 31, 3, + + // Reserved + 0, 0, + + // Lifetime + 1, 0, 0, 0, + + // Domain names + 3, 'a', 'b', 'c', + 4, 'a', 'b', 'c', 'd', + 1, 'e', + 0, + 0, 0, 0, 0, + } + binary.BigEndian.PutUint32(expectedSearchListBytes[4:], validLifetimeSeconds) + // Update reserved fields to non zero values to make sure serializing sets + // them to zero. + searchListBytes := expectedSearchListBytes + searchListBytes[2] = 1 + searchListBytes[3] = 2 + + const prefixLength = 43 + const onLinkFlag = false + const slaacFlag = true + const preferredLifetimeSeconds = 84281096 + const onLinkFlagBit = 7 + const slaacFlagBit = 6 + boolToByte := func(v bool) byte { + if v { + return 1 + } + return 0 + } + flags := boolToByte(onLinkFlag)<<onLinkFlagBit | boolToByte(slaacFlag)<<slaacFlagBit + expectedPrefixInformationBytes := [...]byte{ + // Type, Length + 3, 4, + + prefixLength, flags, + + // Valid Lifetime + 1, 2, 3, 4, + + // Preferred Lifetime + 5, 6, 7, 8, + + // Reserved2 + 0, 0, 0, 0, + + // Address + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + } + binary.BigEndian.PutUint32(expectedPrefixInformationBytes[4:], validLifetimeSeconds) + binary.BigEndian.PutUint32(expectedPrefixInformationBytes[8:], preferredLifetimeSeconds) + if n := copy(expectedPrefixInformationBytes[16:], address); n != IPv6AddressSize { + t.Fatalf("got copy(...) = %d, want = %d", n, IPv6AddressSize) + } + // Update reserved fields to non zero values to make sure serializing sets + // them to zero. + prefixInformationBytes := expectedPrefixInformationBytes + prefixInformationBytes[3] |= (1 << slaacFlagBit) - 1 + binary.BigEndian.PutUint32(prefixInformationBytes[12:], validLifetimeSeconds+1) tests := []struct { name string buf []byte + opt NDPOption expectedBuf []byte - addr tcpip.LinkAddress + check func(*testing.T, NDPOption) }{ { - "Ethernet", - make([]byte, 8), - []byte{2, 1, 1, 2, 3, 4, 5, 6}, - "\x01\x02\x03\x04\x05\x06", + name: "Nonce", + buf: make([]byte, 8), + opt: NDPNonceOption([]byte{1, 2, 3, 4, 5, 6}), + expectedBuf: []byte{14, 1, 1, 2, 3, 4, 5, 6}, + check: checkNonce([]byte{1, 2, 3, 4, 5, 6}), + }, + { + name: "Nonce with padding", + buf: []byte{1, 1, 1, 1, 1, 1, 1, 1}, + opt: NDPNonceOption([]byte{1, 2, 3, 4, 5}), + expectedBuf: []byte{14, 1, 1, 2, 3, 4, 5, 0}, + check: checkNonce([]byte{1, 2, 3, 4, 5, 0}), + }, + + { + name: "TLL Ethernet", + buf: make([]byte, 8), + opt: NDPTargetLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06"), + expectedBuf: []byte{2, 1, 1, 2, 3, 4, 5, 6}, + check: checkTLL("\x01\x02\x03\x04\x05\x06"), + }, + { + name: "TLL Padding", + buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + opt: NDPTargetLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06\x07\x08"), + expectedBuf: []byte{2, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0}, + check: checkTLL("\x01\x02\x03\x04\x05\x06"), + }, + { + name: "TLL Empty", + buf: nil, + opt: NDPTargetLinkLayerAddressOption(""), + expectedBuf: nil, }, + { - "Padding", - []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, - []byte{2, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0}, - "\x01\x02\x03\x04\x05\x06\x07\x08", + name: "SLL Ethernet", + buf: make([]byte, 8), + opt: NDPSourceLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06"), + expectedBuf: []byte{1, 1, 1, 2, 3, 4, 5, 6}, + check: checkSLL("\x01\x02\x03\x04\x05\x06"), }, { - "Empty", - nil, - nil, - "", + name: "SLL Padding", + buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + opt: NDPSourceLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06\x07\x08"), + expectedBuf: []byte{1, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0}, + check: checkSLL("\x01\x02\x03\x04\x05\x06"), + }, + { + name: "SLL Empty", + buf: nil, + opt: NDPSourceLinkLayerAddressOption(""), + expectedBuf: nil, + }, + + { + name: "RDNSS", + buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + // NDPRecursiveDNSServer holds the option after the header bytes. + opt: NDPRecursiveDNSServer(rdnssBytes[optionHeaderLen:]), + expectedBuf: expectedRDNSSBytes[:], + check: func(t *testing.T, opt NDPOption) { + if got := opt.kind(); got != ndpRecursiveDNSServerOptionType { + t.Errorf("got kind() = %d, want = %d", got, ndpRecursiveDNSServerOptionType) + } + rdnss, ok := opt.(NDPRecursiveDNSServer) + if !ok { + t.Fatalf("got opt = %T, want = NDPRecursiveDNSServer", opt) + } + if got, want := rdnss.length(), len(expectedRDNSSBytes[optionHeaderLen:]); got != want { + t.Errorf("got length() = %d, want = %d", got, want) + } + if got, want := rdnss.Lifetime(), validLifetimeSeconds*time.Second; got != want { + t.Errorf("got Lifetime() = %s, want = %s", got, want) + } + if addrs, err := rdnss.Addresses(); err != nil { + t.Errorf("Addresses(): %s", err) + } else if diff := cmp.Diff([]tcpip.Address{address}, addrs); diff != "" { + t.Errorf("mismatched addresses (-want +got):\n%s", diff) + } + }, + }, + + { + name: "Search list", + buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + opt: NDPDNSSearchList(searchListBytes[optionHeaderLen:]), + expectedBuf: expectedSearchListBytes[:], + check: func(t *testing.T, opt NDPOption) { + if got := opt.kind(); got != ndpDNSSearchListOptionType { + t.Errorf("got kind() = %d, want = %d", got, ndpDNSSearchListOptionType) + } + + dnssl, ok := opt.(NDPDNSSearchList) + if !ok { + t.Fatalf("got opt = %T, want = NDPDNSSearchList", opt) + } + if got, want := dnssl.length(), len(expectedRDNSSBytes[optionHeaderLen:]); got != want { + t.Errorf("got length() = %d, want = %d", got, want) + } + if got, want := dnssl.Lifetime(), validLifetimeSeconds*time.Second; got != want { + t.Errorf("got Lifetime() = %s, want = %s", got, want) + } + + if domainNames, err := dnssl.DomainNames(); err != nil { + t.Errorf("DomainNames(): %s", err) + } else if diff := cmp.Diff([]string{domainName}, domainNames); diff != "" { + t.Errorf("domain names mismatch (-want +got):\n%s", diff) + } + }, + }, + + { + name: "Prefix Information", + buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + // NDPPrefixInformation holds the option after the header bytes. + opt: NDPPrefixInformation(prefixInformationBytes[optionHeaderLen:]), + expectedBuf: expectedPrefixInformationBytes[:], + check: func(t *testing.T, opt NDPOption) { + if got := opt.kind(); got != ndpPrefixInformationType { + t.Errorf("got kind() = %d, want = %d", got, ndpPrefixInformationType) + } + + pi, ok := opt.(NDPPrefixInformation) + if !ok { + t.Fatalf("got opt = %T, want = NDPPrefixInformation", opt) + } + + if got, want := pi.length(), len(expectedPrefixInformationBytes[optionHeaderLen:]); got != want { + t.Errorf("got length() = %d, want = %d", got, want) + } + if got := pi.PrefixLength(); got != prefixLength { + t.Errorf("got PrefixLength() = %d, want = %d", got, prefixLength) + } + if got := pi.OnLinkFlag(); got != onLinkFlag { + t.Errorf("got OnLinkFlag() = %t, want = %t", got, onLinkFlag) + } + if got := pi.AutonomousAddressConfigurationFlag(); got != slaacFlag { + t.Errorf("got AutonomousAddressConfigurationFlag() = %t, want = %t", got, slaacFlag) + } + if got, want := pi.ValidLifetime(), validLifetimeSeconds*time.Second; got != want { + t.Errorf("got ValidLifetime() = %s, want = %s", got, want) + } + if got, want := pi.PreferredLifetime(), preferredLifetimeSeconds*time.Second; got != want { + t.Errorf("got PreferredLifetime() = %s, want = %s", got, want) + } + if got := pi.Prefix(); got != address { + t.Errorf("got Prefix() = %s, want = %s", got, address) + } + }, }, } @@ -344,230 +538,47 @@ func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) { t.Run(test.name, func(t *testing.T) { opts := NDPOptions(test.buf) serializer := NDPOptionsSerializer{ - NDPTargetLinkLayerAddressOption(test.addr), + test.opt, } if got, want := int(serializer.Length()), len(test.expectedBuf); got != want { - t.Fatalf("got Length = %d, want = %d", got, want) + t.Fatalf("got Length() = %d, want = %d", got, want) } opts.Serialize(serializer) - if !bytes.Equal(test.buf, test.expectedBuf) { - t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf) + if diff := cmp.Diff(test.expectedBuf, test.buf); diff != "" { + t.Fatalf("serialized buffer mismatch (-want +got):\n%s", diff) } it, err := opts.Iter(true) if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + t.Fatalf("got Iter(true) = (_, %s), want = (_, nil)", err) } if len(test.expectedBuf) > 0 { next, done, err := it.Next() if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + t.Fatalf("got Next() = (_, _, %s), want = (_, _, nil)", err) } if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType { - t.Fatalf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType) - } - tll := next.(NDPTargetLinkLayerAddressOption) - if got, want := []byte(tll), test.expectedBuf[2:]; !bytes.Equal(got, want) { - t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want) - } - - if got, want := tll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want { - t.Errorf("got tll.EthernetAddress = %s, want = %s", got, want) + t.Fatal("got Next() = (_, true, _), want = (_, false, _)") } + test.check(t, next) } // Iterator should not return anything else. next, done, err := it.Next() if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + t.Errorf("got Next() = (_, _, %s), want = (_, _, nil)", err) } if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") + t.Error("got Next() = (_, false, _), want = (_, true, _)") } if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + t.Errorf("got Next() = (%x, _, _), want = (nil, _, _)", next) } }) } } -// TestNDPPrefixInformationOption tests the field getters and serialization of a -// NDPPrefixInformation. -func TestNDPPrefixInformationOption(t *testing.T) { - b := []byte{ - 43, 127, - 1, 2, 3, 4, - 5, 6, 7, 8, - 5, 5, 5, 5, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - } - - targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} - opts := NDPOptions(targetBuf) - serializer := NDPOptionsSerializer{ - NDPPrefixInformation(b), - } - opts.Serialize(serializer) - expectedBuf := []byte{ - 3, 4, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - } - if !bytes.Equal(targetBuf, expectedBuf) { - t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expectedBuf) - } - - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPPrefixInformationType { - t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType) - } - - pi := next.(NDPPrefixInformation) - - if got := pi.Type(); got != 3 { - t.Errorf("got Type = %d, want = 3", got) - } - - if got := pi.Length(); got != 30 { - t.Errorf("got Length = %d, want = 30", got) - } - - if got := pi.PrefixLength(); got != 43 { - t.Errorf("got PrefixLength = %d, want = 43", got) - } - - if pi.OnLinkFlag() { - t.Error("got OnLinkFlag = true, want = false") - } - - if !pi.AutonomousAddressConfigurationFlag() { - t.Error("got AutonomousAddressConfigurationFlag = false, want = true") - } - - if got, want := pi.ValidLifetime(), 16909060*time.Second; got != want { - t.Errorf("got ValidLifetime = %d, want = %d", got, want) - } - - if got, want := pi.PreferredLifetime(), 84281096*time.Second; got != want { - t.Errorf("got PreferredLifetime = %d, want = %d", got, want) - } - - if got, want := pi.Prefix(), tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18"); got != want { - t.Errorf("got Prefix = %s, want = %s", got, want) - } - - // Iterator should not return anything else. - next, done, err = it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } -} - -func TestNDPRecursiveDNSServerOptionSerialize(t *testing.T) { - b := []byte{ - 9, 8, - 1, 2, 4, 8, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - } - targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} - expected := []byte{ - 25, 3, 0, 0, - 1, 2, 4, 8, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - } - opts := NDPOptions(targetBuf) - serializer := NDPOptionsSerializer{ - NDPRecursiveDNSServer(b), - } - if got, want := opts.Serialize(serializer), len(expected); got != want { - t.Errorf("got Serialize = %d, want = %d", got, want) - } - if !bytes.Equal(targetBuf, expected) { - t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected) - } - - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPRecursiveDNSServerOptionType { - t.Errorf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType) - } - - opt, ok := next.(NDPRecursiveDNSServer) - if !ok { - t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next) - } - if got := opt.Type(); got != 25 { - t.Errorf("got Type = %d, want = 31", got) - } - if got := opt.Length(); got != 22 { - t.Errorf("got Length = %d, want = 22", got) - } - if got, want := opt.Lifetime(), 16909320*time.Second; got != want { - t.Errorf("got Lifetime = %s, want = %s", got, want) - } - want := []tcpip.Address{ - "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", - } - addrs, err := opt.Addresses() - if err != nil { - t.Errorf("opt.Addresses() = %s", err) - } - if diff := cmp.Diff(addrs, want); diff != "" { - t.Errorf("mismatched addresses (-want +got):\n%s", diff) - } - - // Iterator should not return anything else. - next, done, err = it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } -} - func TestNDPRecursiveDNSServerOption(t *testing.T) { tests := []struct { name string @@ -635,8 +646,8 @@ func TestNDPRecursiveDNSServerOption(t *testing.T) { if done { t.Fatal("got Next = (_, true, _), want = (_, false, _)") } - if got := next.Type(); got != NDPRecursiveDNSServerOptionType { - t.Fatalf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType) + if got := next.kind(); got != ndpRecursiveDNSServerOptionType { + t.Fatalf("got Type = %d, want = %d", got, ndpRecursiveDNSServerOptionType) } opt, ok := next.(NDPRecursiveDNSServer) @@ -1060,86 +1071,6 @@ func TestNDPSearchListOptionDomainNameLabelInvalidSymbols(t *testing.T) { } } -func TestNDPDNSSearchListOptionSerialize(t *testing.T) { - b := []byte{ - 9, 8, - 1, 0, 0, 0, - 3, 'a', 'b', 'c', - 4, 'a', 'b', 'c', 'd', - 1, 'e', - 0, - } - targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} - expected := []byte{ - 31, 3, 0, 0, - 1, 0, 0, 0, - 3, 'a', 'b', 'c', - 4, 'a', 'b', 'c', 'd', - 1, 'e', - 0, - 0, 0, 0, 0, - } - opts := NDPOptions(targetBuf) - serializer := NDPOptionsSerializer{ - NDPDNSSearchList(b), - } - if got, want := opts.Serialize(serializer), len(expected); got != want { - t.Errorf("got Serialize = %d, want = %d", got, want) - } - if !bytes.Equal(targetBuf, expected) { - t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected) - } - - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPDNSSearchListOptionType { - t.Errorf("got Type = %d, want = %d", got, NDPDNSSearchListOptionType) - } - - opt, ok := next.(NDPDNSSearchList) - if !ok { - t.Fatalf("next (type = %T) cannot be casted to an NDPDNSSearchList", next) - } - if got := opt.Type(); got != 31 { - t.Errorf("got Type = %d, want = 31", got) - } - if got := opt.Length(); got != 22 { - t.Errorf("got Length = %d, want = 22", got) - } - if got, want := opt.Lifetime(), 16777216*time.Second; got != want { - t.Errorf("got Lifetime = %s, want = %s", got, want) - } - domainNames, err := opt.DomainNames() - if err != nil { - t.Errorf("opt.DomainNames() = %s", err) - } - if diff := cmp.Diff(domainNames, []string{"abc.abcd.e"}); diff != "" { - t.Errorf("domain names mismatch (-want +got):\n%s", diff) - } - - // Iterator should not return anything else. - next, done, err = it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } -} - // TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions // the iterator was returned for is malformed. func TestNDPOptionsIterCheck(t *testing.T) { @@ -1472,8 +1403,8 @@ func TestNDPOptionsIter(t *testing.T) { if got, want := []byte(next.(NDPSourceLinkLayerAddressOption)), buf[2:][:6]; !bytes.Equal(got, want) { t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) } - if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType { - t.Errorf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType) + if got := next.kind(); got != ndpSourceLinkLayerAddressOptionType { + t.Errorf("got Type = %d, want = %d", got, ndpSourceLinkLayerAddressOptionType) } // Test the next (Target Link-Layer) option. @@ -1487,8 +1418,8 @@ func TestNDPOptionsIter(t *testing.T) { if got, want := []byte(next.(NDPTargetLinkLayerAddressOption)), buf[10:][:6]; !bytes.Equal(got, want) { t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) } - if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType { - t.Errorf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType) + if got := next.kind(); got != ndpTargetLinkLayerAddressOptionType { + t.Errorf("got Type = %d, want = %d", got, ndpTargetLinkLayerAddressOptionType) } // Test the next (Prefix Information) option. @@ -1503,8 +1434,8 @@ func TestNDPOptionsIter(t *testing.T) { if got, want := next.(NDPPrefixInformation), buf[34:][:30]; !bytes.Equal(got, want) { t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) } - if got := next.Type(); got != NDPPrefixInformationType { - t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType) + if got := next.kind(); got != ndpPrefixInformationType { + t.Errorf("got Type = %d, want = %d", got, ndpPrefixInformationType) } // Iterator should not return anything else. diff --git a/pkg/tcpip/header/ndpoptionidentifier_string.go b/pkg/tcpip/header/ndpoptionidentifier_string.go index 6fe9a336b..55ab1d7cf 100644 --- a/pkg/tcpip/header/ndpoptionidentifier_string.go +++ b/pkg/tcpip/header/ndpoptionidentifier_string.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Code generated by "stringer -type NDPOptionIdentifier ."; DO NOT EDIT. +// Code generated by "stringer -type ndpOptionIdentifier"; DO NOT EDIT. package header @@ -22,29 +22,37 @@ func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. var x [1]struct{} - _ = x[NDPSourceLinkLayerAddressOptionType-1] - _ = x[NDPTargetLinkLayerAddressOptionType-2] - _ = x[NDPPrefixInformationType-3] - _ = x[NDPRecursiveDNSServerOptionType-25] + _ = x[ndpSourceLinkLayerAddressOptionType-1] + _ = x[ndpTargetLinkLayerAddressOptionType-2] + _ = x[ndpPrefixInformationType-3] + _ = x[ndpNonceOptionType-14] + _ = x[ndpRecursiveDNSServerOptionType-25] + _ = x[ndpDNSSearchListOptionType-31] } const ( - _NDPOptionIdentifier_name_0 = "NDPSourceLinkLayerAddressOptionTypeNDPTargetLinkLayerAddressOptionTypeNDPPrefixInformationType" - _NDPOptionIdentifier_name_1 = "NDPRecursiveDNSServerOptionType" + _ndpOptionIdentifier_name_0 = "ndpSourceLinkLayerAddressOptionTypendpTargetLinkLayerAddressOptionTypendpPrefixInformationType" + _ndpOptionIdentifier_name_1 = "ndpNonceOptionType" + _ndpOptionIdentifier_name_2 = "ndpRecursiveDNSServerOptionType" + _ndpOptionIdentifier_name_3 = "ndpDNSSearchListOptionType" ) var ( - _NDPOptionIdentifier_index_0 = [...]uint8{0, 35, 70, 94} + _ndpOptionIdentifier_index_0 = [...]uint8{0, 35, 70, 94} ) -func (i NDPOptionIdentifier) String() string { +func (i ndpOptionIdentifier) String() string { switch { case 1 <= i && i <= 3: i -= 1 - return _NDPOptionIdentifier_name_0[_NDPOptionIdentifier_index_0[i]:_NDPOptionIdentifier_index_0[i+1]] + return _ndpOptionIdentifier_name_0[_ndpOptionIdentifier_index_0[i]:_ndpOptionIdentifier_index_0[i+1]] + case i == 14: + return _ndpOptionIdentifier_name_1 case i == 25: - return _NDPOptionIdentifier_name_1 + return _ndpOptionIdentifier_name_2 + case i == 31: + return _ndpOptionIdentifier_name_3 default: - return "NDPOptionIdentifier(" + strconv.FormatInt(int64(i), 10) + ")" + return "ndpOptionIdentifier(" + strconv.FormatInt(int64(i), 10) + ")" } } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index ae0461a6d..7ae38d684 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -38,6 +38,7 @@ const ( var _ stack.DuplicateAddressDetector = (*endpoint)(nil) var _ stack.LinkAddressResolver = (*endpoint)(nil) +var _ ip.DADProtocol = (*endpoint)(nil) // ARP endpoints need to implement stack.NetworkEndpoint because the stack // considers the layer above the link-layer a network layer; the only @@ -82,7 +83,8 @@ func (*endpoint) DuplicateAddressProtocol() tcpip.NetworkProtocolNumber { return header.IPv4ProtocolNumber } -func (e *endpoint) SendDADMessage(addr tcpip.Address) tcpip.Error { +// SendDADMessage implements ip.DADProtocol. +func (e *endpoint) SendDADMessage(addr tcpip.Address, _ []byte) tcpip.Error { return e.sendARPRequest(header.IPv4Any, addr, header.EthernetBroadcastAddress) } @@ -284,9 +286,12 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran e.mu.Lock() e.mu.dad.Init(&e.mu, p.options.DADConfigs, ip.DADOptions{ - Clock: p.stack.Clock(), - Protocol: e, - NICID: nic.ID(), + Clock: p.stack.Clock(), + SecureRNG: p.stack.SecureRNG(), + // ARP does not support sending nonce values. + NonceSize: 0, + Protocol: e, + NICID: nic.ID(), }) e.mu.Unlock() @@ -305,8 +310,6 @@ func (*endpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber { // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { - nicID := e.nic.ID() - stats := e.stats.arp if len(remoteLinkAddr) == 0 { @@ -314,9 +317,9 @@ func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot } if len(localAddr) == 0 { - addr, ok := e.protocol.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) - if !ok { - return &tcpip.ErrUnknownNICID{} + addr, err := e.nic.PrimaryAddress(header.IPv4ProtocolNumber) + if err != nil { + return err } if len(addr.Address) == 0 { diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go index 0053646ee..eed49f5d2 100644 --- a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go +++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go @@ -16,14 +16,27 @@ package ip import ( + "bytes" "fmt" + "io" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) +type extendRequest int + +const ( + notRequested extendRequest = iota + requested + extended +) + type dadState struct { + nonce []byte + extendRequest extendRequest + done *bool timer tcpip.Timer @@ -33,14 +46,17 @@ type dadState struct { // DADProtocol is a protocol whose core state machine can be represented by DAD. type DADProtocol interface { // SendDADMessage attempts to send a DAD probe message. - SendDADMessage(tcpip.Address) tcpip.Error + SendDADMessage(tcpip.Address, []byte) tcpip.Error } // DADOptions holds options for DAD. type DADOptions struct { - Clock tcpip.Clock - Protocol DADProtocol - NICID tcpip.NICID + Clock tcpip.Clock + SecureRNG io.Reader + NonceSize uint8 + ExtendDADTransmits uint8 + Protocol DADProtocol + NICID tcpip.NICID } // DAD performs duplicate address detection for addresses. @@ -63,6 +79,10 @@ func (d *DAD) Init(protocolMU sync.Locker, configs stack.DADConfigurations, opts panic("attempted to initialize DAD state twice") } + if opts.NonceSize != 0 && opts.ExtendDADTransmits == 0 { + panic(fmt.Sprintf("given a non-zero value for NonceSize (%d) but zero for ExtendDADTransmits", opts.NonceSize)) + } + *d = DAD{ opts: opts, configs: configs, @@ -96,10 +116,55 @@ func (d *DAD) CheckDuplicateAddressLocked(addr tcpip.Address, h stack.DADComplet s = dadState{ done: &done, timer: d.opts.Clock.AfterFunc(0, func() { - var err tcpip.Error dadDone := remaining == 0 + + nonce, earlyReturn := func() ([]byte, bool) { + d.protocolMU.Lock() + defer d.protocolMU.Unlock() + + if done { + return nil, true + } + + s, ok := d.addresses[addr] + if !ok { + panic(fmt.Sprintf("dad: timer fired but missing state for %s on NIC(%d)", addr, d.opts.NICID)) + } + + // As per RFC 7527 section 4 + // + // If any probe is looped back within RetransTimer milliseconds + // after having sent DupAddrDetectTransmits NS(DAD) messages, the + // interface continues with another MAX_MULTICAST_SOLICIT number of + // NS(DAD) messages transmitted RetransTimer milliseconds apart. + if dadDone && s.extendRequest == requested { + dadDone = false + remaining = d.opts.ExtendDADTransmits + s.extendRequest = extended + } + + if !dadDone && d.opts.NonceSize != 0 { + if s.nonce == nil { + s.nonce = make([]byte, d.opts.NonceSize) + } + + if n, err := io.ReadFull(d.opts.SecureRNG, s.nonce); err != nil { + panic(fmt.Sprintf("SecureRNG.Read(...): %s", err)) + } else if n != len(s.nonce) { + panic(fmt.Sprintf("expected to read %d bytes from secure RNG, only read %d bytes", len(s.nonce), n)) + } + } + + d.addresses[addr] = s + return s.nonce, false + }() + if earlyReturn { + return + } + + var err tcpip.Error if !dadDone { - err = d.opts.Protocol.SendDADMessage(addr) + err = d.opts.Protocol.SendDADMessage(addr, nonce) } d.protocolMU.Lock() @@ -142,6 +207,68 @@ func (d *DAD) CheckDuplicateAddressLocked(addr tcpip.Address, h stack.DADComplet return ret } +// ExtendIfNonceEqualLockedDisposition enumerates the possible results from +// ExtendIfNonceEqualLocked. +type ExtendIfNonceEqualLockedDisposition int + +const ( + // Extended indicates that the DAD process was extended. + Extended ExtendIfNonceEqualLockedDisposition = iota + + // AlreadyExtended indicates that the DAD process was already extended. + AlreadyExtended + + // NoDADStateFound indicates that DAD state was not found for the address. + NoDADStateFound + + // NonceDisabled indicates that nonce values are not sent with DAD messages. + NonceDisabled + + // NonceNotEqual indicates that the nonce value passed and the nonce in the + // last send DAD message are not equal. + NonceNotEqual +) + +// ExtendIfNonceEqualLocked extends the DAD process if the provided nonce is the +// same as the nonce sent in the last DAD message. +// +// Precondition: d.protocolMU must be locked. +func (d *DAD) ExtendIfNonceEqualLocked(addr tcpip.Address, nonce []byte) ExtendIfNonceEqualLockedDisposition { + s, ok := d.addresses[addr] + if !ok { + return NoDADStateFound + } + + if d.opts.NonceSize == 0 { + return NonceDisabled + } + + if s.extendRequest != notRequested { + return AlreadyExtended + } + + // As per RFC 7527 section 4 + // + // If any probe is looped back within RetransTimer milliseconds after having + // sent DupAddrDetectTransmits NS(DAD) messages, the interface continues + // with another MAX_MULTICAST_SOLICIT number of NS(DAD) messages transmitted + // RetransTimer milliseconds apart. + // + // If a DAD message has already been sent and the nonce value we observed is + // the same as the nonce value we last sent, then we assume our probe was + // looped back and request an extension to the DAD process. + // + // Note, the first DAD message is sent asynchronously so we need to make sure + // that we sent a DAD message by checking if we have a nonce value set. + if s.nonce != nil && bytes.Equal(s.nonce, nonce) { + s.extendRequest = requested + d.addresses[addr] = s + return Extended + } + + return NonceNotEqual +} + // StopLocked stops a currently running DAD process. // // Precondition: d.protocolMU must be locked. diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go index e00aa4678..a22b712c6 100644 --- a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go +++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go @@ -15,6 +15,7 @@ package ip_test import ( + "bytes" "testing" "time" @@ -32,8 +33,8 @@ type mockDADProtocol struct { mu struct { sync.Mutex - dad ip.DAD - sendCount map[tcpip.Address]int + dad ip.DAD + sentNonces map[tcpip.Address][][]byte } } @@ -48,26 +49,30 @@ func (m *mockDADProtocol) init(t *testing.T, c stack.DADConfigurations, opts ip. } func (m *mockDADProtocol) initLocked() { - m.mu.sendCount = make(map[tcpip.Address]int) + m.mu.sentNonces = make(map[tcpip.Address][][]byte) } -func (m *mockDADProtocol) SendDADMessage(addr tcpip.Address) tcpip.Error { +func (m *mockDADProtocol) SendDADMessage(addr tcpip.Address, nonce []byte) tcpip.Error { m.mu.Lock() defer m.mu.Unlock() - m.mu.sendCount[addr]++ + m.mu.sentNonces[addr] = append(m.mu.sentNonces[addr], nonce) return nil } func (m *mockDADProtocol) check(addrs []tcpip.Address) string { - m.mu.Lock() - defer m.mu.Unlock() - - sendCount := make(map[tcpip.Address]int) + sentNonces := make(map[tcpip.Address][][]byte) for _, a := range addrs { - sendCount[a]++ + sentNonces[a] = append(sentNonces[a], nil) } - diff := cmp.Diff(sendCount, m.mu.sendCount) + return m.checkWithNonce(sentNonces) +} + +func (m *mockDADProtocol) checkWithNonce(expectedSentNonces map[tcpip.Address][][]byte) string { + m.mu.Lock() + defer m.mu.Unlock() + + diff := cmp.Diff(expectedSentNonces, m.mu.sentNonces) m.initLocked() return diff } @@ -84,6 +89,12 @@ func (m *mockDADProtocol) stop(addr tcpip.Address, reason stack.DADResult) { m.mu.dad.StopLocked(addr, reason) } +func (m *mockDADProtocol) extendIfNonceEqual(addr tcpip.Address, nonce []byte) ip.ExtendIfNonceEqualLockedDisposition { + m.mu.Lock() + defer m.mu.Unlock() + return m.mu.dad.ExtendIfNonceEqualLocked(addr, nonce) +} + func (m *mockDADProtocol) setConfigs(c stack.DADConfigurations) { m.mu.Lock() defer m.mu.Unlock() @@ -277,3 +288,94 @@ func TestDADStop(t *testing.T) { default: } } + +func TestNonce(t *testing.T) { + const ( + nonceSize = 2 + + extendRequestAttempts = 2 + + dupAddrDetectTransmits = 2 + extendTransmits = 5 + ) + + var secureRNGBytes [nonceSize * (dupAddrDetectTransmits + extendTransmits)]byte + for i := range secureRNGBytes { + secureRNGBytes[i] = byte(i) + } + + tests := []struct { + name string + mockedReceivedNonce []byte + expectedResults [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition + expectedTransmits int + }{ + { + name: "not matching", + mockedReceivedNonce: []byte{0, 0}, + expectedResults: [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition{ip.NonceNotEqual, ip.NonceNotEqual}, + expectedTransmits: dupAddrDetectTransmits, + }, + { + name: "matching nonce", + mockedReceivedNonce: secureRNGBytes[:nonceSize], + expectedResults: [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition{ip.Extended, ip.AlreadyExtended}, + expectedTransmits: dupAddrDetectTransmits + extendTransmits, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var dad mockDADProtocol + clock := faketime.NewManualClock() + dadConfigs := stack.DADConfigurations{ + DupAddrDetectTransmits: dupAddrDetectTransmits, + RetransmitTimer: time.Second, + } + + var secureRNG bytes.Reader + secureRNG.Reset(secureRNGBytes[:]) + dad.init(t, dadConfigs, ip.DADOptions{ + Clock: clock, + SecureRNG: &secureRNG, + NonceSize: nonceSize, + ExtendDADTransmits: extendTransmits, + }) + + ch := make(chan dadResult, 1) + if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { + t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) + } + + clock.Advance(0) + for i, want := range test.expectedResults { + if got := dad.extendIfNonceEqual(addr1, test.mockedReceivedNonce); got != want { + t.Errorf("(i=%d) got dad.extendIfNonceEqual(%s, _) = %d, want = %d", i, addr1, got, want) + } + } + + for i := 0; i < test.expectedTransmits; i++ { + if diff := dad.checkWithNonce(map[tcpip.Address][][]byte{ + addr1: { + secureRNGBytes[nonceSize*i:][:nonceSize], + }, + }); diff != "" { + t.Errorf("(i=%d) dad check mismatch (-want +got):\n%s", i, diff) + } + + clock.Advance(dadConfigs.RetransmitTimer) + } + + if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" { + t.Errorf("dad result mismatch (-want +got):\n%s", diff) + } + + // Should not have anymore updates. + select { + case r := <-ch: + t.Fatalf("unexpectedly got an extra DAD result; r = %#v", r) + default: + } + }) + } +} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index aee1652fa..a4edc69c7 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -335,6 +335,10 @@ func (*testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tc return nil } +func (*testInterface) PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) { + return tcpip.AddressWithPrefix{}, nil +} + func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool { return false } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 8a2140ebe..a1660e9a3 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -593,7 +593,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // Check if the destination is owned by the stack. if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { - ep.handlePacket(pkt) + ep.handleValidatedPacket(h, pkt) return nil } @@ -634,12 +634,25 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - if !e.protocol.parse(pkt) { + h, ok := e.protocol.parseAndValidate(pkt) + if !ok { stats.MalformedPacketsReceived.Increment() return } if !e.nic.IsLoopback() { + if !e.protocol.options.AllowExternalLoopbackTraffic { + if header.IsV4LoopbackAddress(h.SourceAddress()) { + stats.InvalidSourceAddressesReceived.Increment() + return + } + + if header.IsV4LoopbackAddress(h.DestinationAddress()) { + stats.InvalidDestinationAddressesReceived.Increment() + return + } + } + if e.protocol.stack.HandleLocal() { addressEndpoint := e.AcquireAssignedAddress(header.IPv4(pkt.NetworkHeader().View()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) if addressEndpoint != nil { @@ -662,62 +675,32 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } } - e.handlePacket(pkt) + e.handleValidatedPacket(h, pkt) } +// handleLocalPacket is like HandlePacket except it does not perform the +// prerouting iptables hook or check for loopback traffic that originated from +// outside of the netstack (i.e. martian loopback packets). func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum bool) { stats := e.stats.ip - stats.PacketsReceived.Increment() pkt = pkt.CloneToInbound() - if e.protocol.parse(pkt) { - pkt.RXTransportChecksumValidated = canSkipRXChecksum - e.handlePacket(pkt) + pkt.RXTransportChecksumValidated = canSkipRXChecksum + + h, ok := e.protocol.parseAndValidate(pkt) + if !ok { + stats.MalformedPacketsReceived.Increment() return } - stats.MalformedPacketsReceived.Increment() + e.handleValidatedPacket(h, pkt) } -// handlePacket is like HandlePacket except it does not perform the prerouting -// iptables hook. -func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { +func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() stats := e.stats - h := header.IPv4(pkt.NetworkHeader().View()) - if !h.IsValid(pkt.Data().Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { - stats.ip.MalformedPacketsReceived.Increment() - return - } - - // There has been some confusion regarding verifying checksums. We need - // just look for negative 0 (0xffff) as the checksum, as it's not possible to - // get positive 0 (0) for the checksum. Some bad implementations could get it - // when doing entry replacement in the early days of the Internet, - // however the lore that one needs to check for both persists. - // - // RFC 1624 section 1 describes the source of this confusion as: - // [the partial recalculation method described in RFC 1071] computes a - // result for certain cases that differs from the one obtained from - // scratch (one's complement of one's complement sum of the original - // fields). - // - // However RFC 1624 section 5 clarifies that if using the verification method - // "recommended by RFC 1071, it does not matter if an intermediate system - // generated a -0 instead of +0". - // - // RFC1071 page 1 specifies the verification method as: - // (3) To check a checksum, the 1's complement sum is computed over the - // same set of octets, including the checksum field. If the result - // is all 1 bits (-0 in 1's complement arithmetic), the check - // succeeds. - if h.CalculateChecksum() != 0xffff { - stats.ip.MalformedPacketsReceived.Increment() - return - } - srcAddr := h.SourceAddress() dstAddr := h.DestinationAddress() @@ -1114,13 +1097,46 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} -// parse is like Parse but also attempts to parse the transport layer. +// parseAndValidate parses the packet (including its transport layer header) and +// returns the parsed IP header. // -// Returns true if the network header was successfully parsed. -func (p *protocol) parse(pkt *stack.PacketBuffer) bool { +// Returns true if the IP header was successfully parsed. +func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv4, bool) { transProtoNum, hasTransportHdr, ok := p.Parse(pkt) if !ok { - return false + return nil, false + } + + h := header.IPv4(pkt.NetworkHeader().View()) + // Do not include the link header's size when calculating the size of the IP + // packet. + if !h.IsValid(pkt.Size() - pkt.LinkHeader().View().Size()) { + return nil, false + } + + // There has been some confusion regarding verifying checksums. We need + // just look for negative 0 (0xffff) as the checksum, as it's not possible to + // get positive 0 (0) for the checksum. Some bad implementations could get it + // when doing entry replacement in the early days of the Internet, + // however the lore that one needs to check for both persists. + // + // RFC 1624 section 1 describes the source of this confusion as: + // [the partial recalculation method described in RFC 1071] computes a + // result for certain cases that differs from the one obtained from + // scratch (one's complement of one's complement sum of the original + // fields). + // + // However RFC 1624 section 5 clarifies that if using the verification method + // "recommended by RFC 1071, it does not matter if an intermediate system + // generated a -0 instead of +0". + // + // RFC1071 page 1 specifies the verification method as: + // (3) To check a checksum, the 1's complement sum is computed over the + // same set of octets, including the checksum field. If the result + // is all 1 bits (-0 in 1's complement arithmetic), the check + // succeeds. + if h.CalculateChecksum() != 0xffff { + return nil, false } if hasTransportHdr { @@ -1134,7 +1150,7 @@ func (p *protocol) parse(pkt *stack.PacketBuffer) bool { } } - return true + return h, true } // Parse implements stack.NetworkProtocol.Parse. @@ -1213,6 +1229,10 @@ func hashRoute(srcAddr, dstAddr tcpip.Address, protocol tcpip.TransportProtocolN type Options struct { // IGMP holds options for IGMP. IGMP IGMPOptions + + // AllowExternalLoopbackTraffic indicates that inbound loopback packets (i.e. + // martian loopback packets) should be accepted. + AllowExternalLoopbackTraffic bool } // NewProtocolWithOptions returns an IPv4 network protocol. @@ -1599,9 +1619,8 @@ func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Opt // TODO(https://gvisor.dev/issue/4586): This will need tweaking when we start // really forwarding packets as we may need to get two addresses, for rx and // tx interfaces. We will also have to take usage into account. - prefixedAddress, ok := e.protocol.stack.GetMainNICAddress(e.nic.ID(), ProtocolNumber) - localAddress := prefixedAddress.Address - if !ok { + localAddress := e.MainAddress().Address + if len(localAddress) == 0 { h := header.IPv4(pkt.NetworkHeader().View()) dstAddr := h.DestinationAddress() if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(dstAddr) { diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 6344a3e09..2afa856dc 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -369,6 +369,18 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r return } + var it header.NDPOptionIterator + { + var err error + it, err = ns.Options().Iter(false /* check */) + if err != nil { + // Options are not valid as per the wire format, silently drop the + // packet. + received.invalid.Increment() + return + } + } + if e.hasTentativeAddr(targetAddr) { // If the target address is tentative and the source of the packet is a // unicast (specified) address, then the source of the packet is @@ -382,6 +394,22 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // stack know so it can handle such a scenario and do nothing further with // the NS. if srcAddr == header.IPv6Any { + var nonce []byte + for { + opt, done, err := it.Next() + if err != nil { + received.invalid.Increment() + return + } + if done { + break + } + if n, ok := opt.(header.NDPNonceOption); ok { + nonce = n.Nonce() + break + } + } + // Since this is a DAD message we know the sender does not actually hold // the target address so there is no "holder". var holderLinkAddress tcpip.LinkAddress @@ -397,7 +425,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate // address is detected for an assigned address. - switch err := e.dupTentativeAddrDetected(targetAddr, holderLinkAddress); err.(type) { + switch err := e.dupTentativeAddrDetected(targetAddr, holderLinkAddress, nonce); err.(type) { case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState: default: panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err)) @@ -418,21 +446,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r return } - var sourceLinkAddr tcpip.LinkAddress - { - it, err := ns.Options().Iter(false /* check */) - if err != nil { - // Options are not valid as per the wire format, silently drop the - // packet. - received.invalid.Increment() - return - } - - sourceLinkAddr, ok = getSourceLinkAddr(it) - if !ok { - received.invalid.Increment() - return - } + sourceLinkAddr, ok := getSourceLinkAddr(it) + if !ok { + received.invalid.Increment() + return } // As per RFC 4861 section 4.3, the Source Link-Layer Address Option MUST @@ -586,6 +603,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r e.dad.mu.Unlock() if e.hasTentativeAddr(targetAddr) { + // We only send a nonce value in DAD messages to check for loopedback + // messages so we use the empty nonce value here. + var nonce []byte + // We just got an NA from a node that owns an address we are performing // DAD on, implying the address is not unique. In this case we let the // stack know so it can handle such a scenario and do nothing furthur with @@ -602,7 +623,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate // address is detected for an assigned address. - switch err := e.dupTentativeAddrDetected(targetAddr, targetLinkAddr); err.(type) { + switch err := e.dupTentativeAddrDetected(targetAddr, targetLinkAddr, nonce); err.(type) { case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState: return default: @@ -899,13 +920,16 @@ func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot } if len(localAddr) == 0 { + // Find an address that we can use as our source address. addressEndpoint := e.AcquireOutgoingPrimaryAddress(remoteAddr, false /* allowExpired */) if addressEndpoint == nil { return &tcpip.ErrNetworkUnreachable{} } localAddr = addressEndpoint.AddressWithPrefix().Address - } else if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, localAddr) == 0 { + addressEndpoint.DecRef() + } else if !e.checkLocalAddress(localAddr) { + // The provided local address is not assigned to us. return &tcpip.ErrBadLocalAddress{} } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index d4e63710c..47d713f88 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -155,6 +155,10 @@ func (t *testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, return nil } +func (*testInterface) PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) { + return tcpip.AddressWithPrefix{}, nil +} + func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool { return false } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 46b6cc41a..83e98bab9 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -348,7 +348,7 @@ func (e *endpoint) hasTentativeAddr(addr tcpip.Address) bool { // dupTentativeAddrDetected removes the tentative address if it exists. If the // address was generated via SLAAC, an attempt is made to generate a new // address. -func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr tcpip.LinkAddress) tcpip.Error { +func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr tcpip.LinkAddress, nonce []byte) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -361,27 +361,48 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr t return &tcpip.ErrInvalidEndpointState{} } - // If the address is a SLAAC address, do not invalidate its SLAAC prefix as an - // attempt will be made to generate a new address for it. - if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */, &stack.DADDupAddrDetected{HolderLinkAddress: holderLinkAddr}); err != nil { - return err - } + switch result := e.mu.ndp.dad.ExtendIfNonceEqualLocked(addr, nonce); result { + case ip.Extended: + // The nonce we got back was the same we sent so we know the message + // indicating a duplicate address was likely ours so do not consider + // the address duplicate here. + return nil + case ip.AlreadyExtended: + // See Extended. + // + // Our DAD message was looped back already. + return nil + case ip.NoDADStateFound: + panic(fmt.Sprintf("expected DAD state for tentative address %s", addr)) + case ip.NonceDisabled: + // If nonce is disabled then we have no way to know if the packet was + // looped-back so we have to assume it indicates a duplicate address. + fallthrough + case ip.NonceNotEqual: + // If the address is a SLAAC address, do not invalidate its SLAAC prefix as an + // attempt will be made to generate a new address for it. + if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */, &stack.DADDupAddrDetected{HolderLinkAddress: holderLinkAddr}); err != nil { + return err + } - prefix := addressEndpoint.Subnet() + prefix := addressEndpoint.Subnet() - switch t := addressEndpoint.ConfigType(); t { - case stack.AddressConfigStatic: - case stack.AddressConfigSlaac: - e.mu.ndp.regenerateSLAACAddr(prefix) - case stack.AddressConfigSlaacTemp: - // Do not reset the generation attempts counter for the prefix as the - // temporary address is being regenerated in response to a DAD conflict. - e.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */) + switch t := addressEndpoint.ConfigType(); t { + case stack.AddressConfigStatic: + case stack.AddressConfigSlaac: + e.mu.ndp.regenerateSLAACAddr(prefix) + case stack.AddressConfigSlaacTemp: + // Do not reset the generation attempts counter for the prefix as the + // temporary address is being regenerated in response to a DAD conflict. + e.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */) + default: + panic(fmt.Sprintf("unrecognized address config type = %d", t)) + } + + return nil default: - panic(fmt.Sprintf("unrecognized address config type = %d", t)) + panic(fmt.Sprintf("unhandled result = %d", result)) } - - return nil } // transitionForwarding transitions the endpoint's forwarding status to @@ -863,9 +884,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { dstAddr := h.DestinationAddress() // Check if the destination is owned by the stack. - if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { - ep.handlePacket(pkt) + ep.handleValidatedPacket(h, pkt) return nil } @@ -904,12 +924,25 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - if !e.protocol.parse(pkt) { + h, ok := e.protocol.parseAndValidate(pkt) + if !ok { stats.MalformedPacketsReceived.Increment() return } if !e.nic.IsLoopback() { + if !e.protocol.options.AllowExternalLoopbackTraffic { + if header.IsV6LoopbackAddress(h.SourceAddress()) { + stats.InvalidSourceAddressesReceived.Increment() + return + } + + if header.IsV6LoopbackAddress(h.DestinationAddress()) { + stats.InvalidDestinationAddressesReceived.Increment() + return + } + } + if e.protocol.stack.HandleLocal() { addressEndpoint := e.AcquireAssignedAddress(header.IPv6(pkt.NetworkHeader().View()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) if addressEndpoint != nil { @@ -932,35 +965,31 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } } - e.handlePacket(pkt) + e.handleValidatedPacket(h, pkt) } +// handleLocalPacket is like HandlePacket except it does not perform the +// prerouting iptables hook or check for loopback traffic that originated from +// outside of the netstack (i.e. martian loopback packets). func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum bool) { stats := e.stats.ip - stats.PacketsReceived.Increment() pkt = pkt.CloneToInbound() - if e.protocol.parse(pkt) { - pkt.RXTransportChecksumValidated = canSkipRXChecksum - e.handlePacket(pkt) + pkt.RXTransportChecksumValidated = canSkipRXChecksum + + h, ok := e.protocol.parseAndValidate(pkt) + if !ok { + stats.MalformedPacketsReceived.Increment() return } - stats.MalformedPacketsReceived.Increment() + e.handleValidatedPacket(h, pkt) } -// handlePacket is like HandlePacket except it does not perform the prerouting -// iptables hook. -func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { +func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() stats := e.stats.ip - - h := header.IPv6(pkt.NetworkHeader().View()) - if !h.IsValid(pkt.Data().Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { - stats.MalformedPacketsReceived.Increment() - return - } srcAddr := h.SourceAddress() dstAddr := h.DestinationAddress() @@ -1797,16 +1826,36 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran dispatcher: dispatcher, protocol: p, } + + // NDP options must be 8 octet aligned and the first 2 bytes are used for + // the type and length fields leaving 6 octets as the minimum size for a + // nonce option without padding. + const nonceSize = 6 + + // As per RFC 7527 section 4.1, + // + // If any probe is looped back within RetransTimer milliseconds after + // having sent DupAddrDetectTransmits NS(DAD) messages, the interface + // continues with another MAX_MULTICAST_SOLICIT number of NS(DAD) + // messages transmitted RetransTimer milliseconds apart. + // + // Value taken from RFC 4861 section 10. + const maxMulticastSolicit = 3 + dadOptions := ip.DADOptions{ + Clock: p.stack.Clock(), + SecureRNG: p.stack.SecureRNG(), + NonceSize: nonceSize, + ExtendDADTransmits: maxMulticastSolicit, + Protocol: &e.mu.ndp, + NICID: nic.ID(), + } + e.mu.Lock() e.mu.addressableEndpointState.Init(e) - e.mu.ndp.init(e) + e.mu.ndp.init(e, dadOptions) e.mu.mld.init(e) e.dad.mu.Lock() - e.dad.mu.dad.Init(&e.dad.mu, p.options.DADConfigs, ip.DADOptions{ - Clock: p.stack.Clock(), - Protocol: &e.mu.ndp, - NICID: nic.ID(), - }) + e.dad.mu.dad.Init(&e.dad.mu, p.options.DADConfigs, dadOptions) e.dad.mu.Unlock() e.mu.Unlock() @@ -1879,13 +1928,21 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} -// parse is like Parse but also attempts to parse the transport layer. +// parseAndValidate parses the packet (including its transport layer header) and +// returns the parsed IP header. // -// Returns true if the network header was successfully parsed. -func (p *protocol) parse(pkt *stack.PacketBuffer) bool { +// Returns true if the IP header was successfully parsed. +func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv6, bool) { transProtoNum, hasTransportHdr, ok := p.Parse(pkt) if !ok { - return false + return nil, false + } + + h := header.IPv6(pkt.NetworkHeader().View()) + // Do not include the link header's size when calculating the size of the IP + // packet. + if !h.IsValid(pkt.Size() - pkt.LinkHeader().View().Size()) { + return nil, false } if hasTransportHdr { @@ -1899,7 +1956,7 @@ func (p *protocol) parse(pkt *stack.PacketBuffer) bool { } } - return true + return h, true } // Parse implements stack.NetworkProtocol.Parse. @@ -2013,6 +2070,10 @@ type Options struct { // DADConfigs holds the default DAD configurations used by IPv6 endpoints. DADConfigs stack.DADConfigurations + + // AllowExternalLoopbackTraffic indicates that inbound loopback packets (i.e. + // martian loopback packets) should be accepted. + AllowExternalLoopbackTraffic bool } // NewProtocolWithOptions returns an IPv6 network protocol. diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 266a53e3b..81f5f23c3 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -343,6 +343,8 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) { // TestAddIpv6Address tests adding IPv6 addresses. func TestAddIpv6Address(t *testing.T) { + const nicID = 1 + tests := []struct { name string addr tcpip.Address @@ -367,18 +369,18 @@ func TestAddIpv6Address(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) - if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(1, ProtocolNumber, test.addr); err != nil { - t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err) + if err := s.AddAddress(nicID, ProtocolNumber, test.addr); err != nil { + t.Fatalf("AddAddress(%d, %d, nil) = %s", nicID, ProtocolNumber, err) } - if addr, ok := s.GetMainNICAddress(1, header.IPv6ProtocolNumber); !ok { - t.Fatalf("got stack.GetMainNICAddress(1, %d) = (_, false), want = (_, true)", header.IPv6ProtocolNumber) + if addr, err := s.GetMainNICAddress(nicID, ProtocolNumber); err != nil { + t.Fatalf("stack.GetMainNICAddress(%d, %d): %s", nicID, ProtocolNumber, err) } else if addr.Address != test.addr { - t.Fatalf("got stack.GetMainNICAddress(1_, %d) = (%s, true), want = (%s, true)", header.IPv6ProtocolNumber, addr.Address, test.addr) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, ProtocolNumber, addr.Address, test.addr) } }) } diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go index 9a425e50a..85a8f9944 100644 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -15,6 +15,7 @@ package ipv6_test import ( + "bytes" "testing" "time" @@ -119,11 +120,26 @@ func TestSendQueuedMLDReports(t *testing.T) { }, } + nonce := [...]byte{ + 1, 2, 3, 4, 5, 6, + } + + const maxNSMessages = 2 + secureRNGBytes := make([]byte, len(nonce)*maxNSMessages) + for b := secureRNGBytes[:]; len(b) > 0; b = b[len(nonce):] { + if n := copy(b, nonce[:]); n != len(nonce) { + t.Fatalf("got copy(...) = %d, want = %d", n, len(nonce)) + } + } + for _, test := range tests { t.Run(test.name, func(t *testing.T) { dadResolutionTime := test.retransmitTimer * time.Duration(test.dadTransmits) clock := faketime.NewManualClock() + var secureRNG bytes.Reader + secureRNG.Reset(secureRNGBytes[:]) s := stack.New(stack.Options{ + SecureRNG: &secureRNG, NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ DADConfigs: stack.DADConfigurations{ DupAddrDetectTransmits: test.dadTransmits, @@ -154,7 +170,7 @@ func TestSendQueuedMLDReports(t *testing.T) { checker.TTL(header.NDPHopLimit), checker.NDPNS( checker.NDPNSTargetAddress(addr), - checker.NDPNSOptions(nil), + checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonce[:])}), )) } } diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index d9b728878..536493f87 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -1789,18 +1789,14 @@ func (ndp *ndpState) stopSolicitingRouters() { ndp.rtrSolicitTimer = timer{} } -func (ndp *ndpState) init(ep *endpoint) { +func (ndp *ndpState) init(ep *endpoint, dadOptions ip.DADOptions) { if ndp.defaultRouters != nil { panic("attempted to initialize NDP state twice") } ndp.ep = ep ndp.configs = ep.protocol.options.NDPConfigs - ndp.dad.Init(&ndp.ep.mu, ep.protocol.options.DADConfigs, ip.DADOptions{ - Clock: ep.protocol.stack.Clock(), - Protocol: ndp, - NICID: ep.nic.ID(), - }) + ndp.dad.Init(&ndp.ep.mu, ep.protocol.options.DADConfigs, dadOptions) ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState) ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState) ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState) @@ -1811,9 +1807,11 @@ func (ndp *ndpState) init(ep *endpoint) { } } -func (ndp *ndpState) SendDADMessage(addr tcpip.Address) tcpip.Error { +func (ndp *ndpState) SendDADMessage(addr tcpip.Address, nonce []byte) tcpip.Error { snmc := header.SolicitedNodeAddr(addr) - return ndp.ep.sendNDPNS(header.IPv6Any, snmc, addr, header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* opts */) + return ndp.ep.sendNDPNS(header.IPv6Any, snmc, addr, header.EthernetAddressFromMulticastIPv6Address(snmc), header.NDPOptionsSerializer{ + header.NDPNonceOption(nonce), + }) } func (e *endpoint) sendNDPNS(srcAddr, dstAddr, targetAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, opts header.NDPOptionsSerializer) tcpip.Error { diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 6e850fd46..52b9a200c 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -15,6 +15,7 @@ package ipv6 import ( + "bytes" "context" "strings" "testing" @@ -1264,8 +1265,21 @@ func TestCheckDuplicateAddress(t *testing.T) { DupAddrDetectTransmits: 1, RetransmitTimer: time.Second, } + + nonces := [...][]byte{ + {1, 2, 3, 4, 5, 6}, + {7, 8, 9, 10, 11, 12}, + } + + var secureRNGBytes []byte + for _, n := range nonces { + secureRNGBytes = append(secureRNGBytes, n...) + } + var secureRNG bytes.Reader + secureRNG.Reset(secureRNGBytes[:]) s := stack.New(stack.Options{ - Clock: clock, + SecureRNG: &secureRNG, + Clock: clock, NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{ DADConfigs: dadConfigs, })}, @@ -1278,10 +1292,36 @@ func TestCheckDuplicateAddress(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - dadPacketsSent := 1 + dadPacketsSent := 0 + snmc := header.SolicitedNodeAddr(lladdr0) + remoteLinkAddr := header.EthernetAddressFromMulticastIPv6Address(snmc) + checkDADMsg := func() { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("expected %d-th DAD message", dadPacketsSent) + } + + if p.Proto != header.IPv6ProtocolNumber { + t.Errorf("(i=%d) got p.Proto = %d, want = %d", dadPacketsSent, p.Proto, header.IPv6ProtocolNumber) + } + + if p.Route.RemoteLinkAddress != remoteLinkAddr { + t.Errorf("(i=%d) got p.Route.RemoteLinkAddress = %s, want = %s", dadPacketsSent, p.Route.RemoteLinkAddress, remoteLinkAddr) + } + + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(snmc), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(lladdr0), + checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonces[dadPacketsSent])}), + )) + } if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) } + checkDADMsg() // Start DAD for the address we just added. // @@ -1297,6 +1337,7 @@ func TestCheckDuplicateAddress(t *testing.T) { } else if res != stack.DADStarting { t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, ProtocolNumber, lladdr0, res, stack.DADStarting) } + checkDADMsg() // Remove the address and make sure our DAD request was not stopped. if err := s.RemoveAddress(nicID, lladdr0); err != nil { @@ -1328,33 +1369,6 @@ func TestCheckDuplicateAddress(t *testing.T) { default: } - snmc := header.SolicitedNodeAddr(lladdr0) - remoteLinkAddr := header.EthernetAddressFromMulticastIPv6Address(snmc) - - for i := 0; i < dadPacketsSent; i++ { - p, ok := e.Read() - if !ok { - t.Fatalf("expected %d-th DAD message", i) - } - - if p.Proto != header.IPv6ProtocolNumber { - t.Errorf("(i=%d) got p.Proto = %d, want = %d", i, p.Proto, header.IPv6ProtocolNumber) - } - - if p.Route.RemoteLinkAddress != remoteLinkAddr { - t.Errorf("(i=%d) got p.Route.RemoteLinkAddress = %s, want = %s", i, p.Route.RemoteLinkAddress, remoteLinkAddr) - } - - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(snmc), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(lladdr0), - checker.NDPNSOptions(nil), - )) - } - // Should have no more packets. if p, ok := e.Read(); ok { t.Errorf("got unexpected packet = %#v", p) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 47796a6ba..43e6d102c 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -15,6 +15,7 @@ package stack_test import ( + "bytes" "context" "encoding/binary" "fmt" @@ -29,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" @@ -358,6 +360,66 @@ func TestDADDisabled(t *testing.T) { } } +func TestDADResolveLoopback(t *testing.T) { + const nicID = 1 + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + } + + dadConfigs := stack.DADConfigurations{ + RetransmitTimer: time.Second, + DupAddrDetectTransmits: 1, + } + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + Clock: clock, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPDisp: &ndpDisp, + DADConfigs: dadConfigs, + })}, + }) + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: addr1, + PrefixLen: defaultPrefixLen, + } + if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) + } + + // Address should not be considered bound to the NIC yet (DAD ongoing). + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) + } + + // DAD should not resolve after the normal resolution time since our DAD + // message was looped back - we should extend our DAD process. + dadResolutionTime := time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer + clock.Advance(dadResolutionTime) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Error(err) + } + + // Make sure the address does not resolve before the extended resolution time + // has passed. + const delta = time.Nanosecond + // DAD will send extra NS probes if an NS message is looped back. + const extraTransmits = 3 + clock.Advance(dadResolutionTime*extraTransmits - delta) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Error(err) + } + + // DAD should now resolve. + clock.Advance(delta) + if diff := checkDADEvent(<-ndpDisp.dadC, nicID, addr1, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) + } +} + // TestDADResolve tests that an address successfully resolves after performing // DAD for various values of DupAddrDetectTransmits and RetransmitTimer. // Included in the subtests is a test to make sure that an invalid @@ -404,6 +466,16 @@ func TestDADResolve(t *testing.T) { }, } + nonces := [][]byte{ + {1, 2, 3, 4, 5, 6}, + {7, 8, 9, 10, 11, 12}, + } + + var secureRNGBytes []byte + for _, n := range nonces { + secureRNGBytes = append(secureRNGBytes, n...) + } + for _, test := range tests { test := test @@ -419,7 +491,12 @@ func TestDADResolve(t *testing.T) { headerLength: test.linkHeaderLen, } e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + var secureRNG bytes.Reader + secureRNG.Reset(secureRNGBytes) + s := stack.New(stack.Options{ + SecureRNG: &secureRNG, NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, DADConfigs: stack.DADConfigurations{ @@ -553,7 +630,7 @@ func TestDADResolve(t *testing.T) { checker.TTL(header.NDPHopLimit), checker.NDPNS( checker.NDPNSTargetAddress(addr1), - checker.NDPNSOptions(nil), + checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonces[i])}), )) if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 62f7c880e..ca15c0691 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -568,23 +568,19 @@ func (n *nic) primaryAddresses() []tcpip.ProtocolAddress { return addrs } -// primaryAddress returns the primary address associated with this NIC. -// -// primaryAddress will return the first non-deprecated address if such an -// address exists. If no non-deprecated address exists, the first deprecated -// address will be returned. -func (n *nic) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix { +// PrimaryAddress implements NetworkInterface. +func (n *nic) PrimaryAddress(proto tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) { ep, ok := n.networkEndpoints[proto] if !ok { - return tcpip.AddressWithPrefix{} + return tcpip.AddressWithPrefix{}, &tcpip.ErrUnknownProtocol{} } addressableEndpoint, ok := ep.(AddressableEndpoint) if !ok { - return tcpip.AddressWithPrefix{} + return tcpip.AddressWithPrefix{}, &tcpip.ErrNotSupported{} } - return addressableEndpoint.MainAddress() + return addressableEndpoint.MainAddress(), nil } // removeAddress removes an address from n. diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 85f0f471a..ff3a385e1 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -525,6 +525,14 @@ type NetworkInterface interface { // assigned to it. Spoofing() bool + // PrimaryAddress returns the primary address associated with the interface. + // + // PrimaryAddress will return the first non-deprecated address if such an + // address exists. If no non-deprecated addresses exist, the first deprecated + // address will be returned. If no deprecated addresses exist, the zero value + // will be returned. + PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) + // CheckLocalAddress returns true if the address exists on the interface. CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 53370c354..931a97ddc 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -23,6 +23,7 @@ import ( "bytes" "encoding/binary" "fmt" + "io" mathrand "math/rand" "sync/atomic" "time" @@ -445,6 +446,9 @@ type Stack struct { // used when a random number is required. randomGenerator *mathrand.Rand + // secureRNG is a cryptographically secure random number generator. + secureRNG io.Reader + // sendBufferSize holds the min/default/max send buffer sizes for // endpoints other than TCP. sendBufferSize tcpip.SendBufferSizeOption @@ -528,6 +532,9 @@ type Options struct { // IPTables are the initial iptables rules. If nil, iptables will allow // all traffic. IPTables *IPTables + + // SecureRNG is a cryptographically secure random number generator. + SecureRNG io.Reader } // TransportEndpointInfo holds useful information about a transport endpoint @@ -636,6 +643,10 @@ func New(opts Options) *Stack { opts.NUDConfigs.resetInvalidFields() + if opts.SecureRNG == nil { + opts.SecureRNG = rand.Reader + } + s := &Stack{ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), @@ -652,6 +663,7 @@ func New(opts Options) *Stack { uniqueIDGenerator: opts.UniqueID, nudDisp: opts.NUDDisp, randomGenerator: mathrand.New(randSrc), + secureRNG: opts.SecureRNG, sendBufferSize: tcpip.SendBufferSizeOption{ Min: MinBufferSize, Default: DefaultBufferSize, @@ -1211,20 +1223,19 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress { } // GetMainNICAddress returns the first non-deprecated primary address and prefix -// for the given NIC and protocol. If no non-deprecated primary address exists, -// a deprecated primary address and prefix will be returned. Returns false if -// the NIC doesn't exist and an empty value if the NIC doesn't have a primary -// address for the given protocol. -func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, bool) { +// for the given NIC and protocol. If no non-deprecated primary addresses exist, +// a deprecated address will be returned. If no deprecated addresses exist, the +// zero value will be returned. +func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() nic, ok := s.nics[id] if !ok { - return tcpip.AddressWithPrefix{}, false + return tcpip.AddressWithPrefix{}, &tcpip.ErrUnknownNICID{} } - return nic.primaryAddress(protocol), true + return nic.PrimaryAddress(protocol) } func (s *Stack) getAddressEP(nic *nic, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) AssignableAddressEndpoint { @@ -2047,6 +2058,12 @@ func (s *Stack) Rand() *mathrand.Rand { return s.randomGenerator } +// SecureRNG returns the stack's cryptographically secure random number +// generator. +func (s *Stack) SecureRNG() io.Reader { + return s.secureRNG +} + func generateRandUint32() uint32 { b := make([]byte, 4) if _, err := rand.Read(b); err != nil { diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 880219007..7ddf7a083 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -62,10 +62,10 @@ const ( ) func checkGetMainNICAddress(s *stack.Stack, nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber, want tcpip.AddressWithPrefix) error { - if addr, ok := s.GetMainNICAddress(nicID, proto); !ok { - return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, false), want = (_, true)", nicID, proto) + if addr, err := s.GetMainNICAddress(nicID, proto); err != nil { + return fmt.Errorf("stack.GetMainNICAddress(%d, %d): %s", nicID, proto, err) } else if addr != want { - return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, true), want = (%s, true)", nicID, proto, addr, want) + return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, proto, addr, want) } return nil } @@ -1854,6 +1854,8 @@ func TestNetworkOption(t *testing.T) { } func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { + const nicID = 1 + for _, addrLen := range []int{4, 16} { t.Run(fmt.Sprintf("addrLen=%d", addrLen), func(t *testing.T) { for canBe := 0; canBe < 3; canBe++ { @@ -1864,8 +1866,8 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) + if err := s.CreateNIC(nicID, ep); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } // Insert <canBe> primary and <never> never-primary addresses. // Each one will add a network endpoint to the NIC. @@ -1888,34 +1890,34 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { PrefixLen: addrLen * 8, }, } - if err := s.AddProtocolAddressWithOptions(1, protocolAddress, behavior); err != nil { - t.Fatal("AddProtocolAddressWithOptions failed:", err) + if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil { + t.Fatalf("AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, protocolAddress, behavior, err) } // Remember the address/prefix. primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{} } else { - if err := s.AddAddressWithOptions(1, fakeNetNumber, address, behavior); err != nil { - t.Fatal("AddAddressWithOptions failed:", err) + if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil { + t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s:", nicID, fakeNetNumber, address, behavior, err) } } } // Check that GetMainNICAddress returns an address if at least // one primary address was added. In that case make sure the // address/prefixLen matches what we added. - gotAddr, ok := s.GetMainNICAddress(1, fakeNetNumber) - if !ok { - t.Fatalf("got GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber) + gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber) + if err != nil { + t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) } if len(primaryAddrAdded) == 0 { // No primary addresses present. if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { - t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, gotAddr, wantAddr) + t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, wantAddr) } } else { // At least one primary address was added, verify the returned // address is in the list of primary addresses we added. if _, ok := primaryAddrAdded[gotAddr]; !ok { - t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, gotAddr, primaryAddrAdded) + t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, primaryAddrAdded) } } }) @@ -1926,6 +1928,45 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { } } +func TestGetMainNICAddressErrors(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, + }) + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + // Sanity check with a successful call. + if addr, err := s.GetMainNICAddress(nicID, ipv4.ProtocolNumber); err != nil { + t.Errorf("s.GetMainNICAddress(%d, %d): %s", nicID, ipv4.ProtocolNumber, err) + } else if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, ipv4.ProtocolNumber, addr, want) + } + + const unknownNICID = nicID + 1 + switch addr, err := s.GetMainNICAddress(unknownNICID, ipv4.ProtocolNumber); err.(type) { + case *tcpip.ErrUnknownNICID: + default: + t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, %T), want = (_, tcpip.ErrUnknownNICID)", unknownNICID, ipv4.ProtocolNumber, addr, err) + } + + // ARP is not an addressable network endpoint. + switch addr, err := s.GetMainNICAddress(nicID, arp.ProtocolNumber); err.(type) { + case *tcpip.ErrNotSupported: + default: + t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, %T), want = (_, tcpip.ErrNotSupported)", nicID, arp.ProtocolNumber, addr, err) + } + + const unknownProtocolNumber = 1234 + switch addr, err := s.GetMainNICAddress(nicID, unknownProtocolNumber); err.(type) { + case *tcpip.ErrUnknownProtocol: + default: + t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, %T), want = (_, tcpip.ErrUnknownProtocol)", nicID, unknownProtocolNumber, addr, err) + } +} + func TestGetMainNICAddressAddRemove(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, @@ -2507,11 +2548,15 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { } } - // Check that we get no address after removal. - if err := checkGetMainNICAddress(s, 1, fakeNetNumber, tcpip.AddressWithPrefix{}); err != nil { + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, expectedMainAddr); err != nil { t.Fatal(err) } - if err := checkGetMainNICAddress(s, 1, header.IPv6ProtocolNumber, expectedMainAddr); err != nil { + + // Disabling the NIC should remove the auto-generated address. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { t.Fatal(err) } }) @@ -2617,6 +2662,8 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { // TestNewPEB tests that a new PrimaryEndpointBehavior value (peb) is respected // when an address's kind gets "promoted" to permanent from permanentExpired. func TestNewPEBOnPromotionToPermanent(t *testing.T) { + const nicID = 1 + pebs := []stack.PrimaryEndpointBehavior{ stack.NeverPrimaryEndpoint, stack.CanBePrimaryEndpoint, @@ -2630,8 +2677,8 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed:", err) + if err := s.CreateNIC(nicID, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } // Add a permanent address with initial @@ -2639,20 +2686,21 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // NeverPrimaryEndpoint, the address should not // be returned by a call to GetMainNICAddress; // else, it should. - if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", pi); err != nil { - t.Fatal("AddAddressWithOptions failed:", err) + const address1 = tcpip.Address("\x01") + if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, pi); err != nil { + t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err) } - addr, ok := s.GetMainNICAddress(1, fakeNetNumber) - if !ok { - t.Fatalf("GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber) + addr, err := s.GetMainNICAddress(nicID, fakeNetNumber) + if err != nil { + t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) } if pi == stack.NeverPrimaryEndpoint { if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, addr, want) + t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr, want) } - } else if addr.Address != "\x01" { - t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (1, true)", fakeNetNumber, addr.Address) + } else if addr.Address != address1 { + t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr.Address, address1) } { @@ -2670,13 +2718,14 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // new peb is respected when an address gets // "promoted" to permanent from a // permanentExpired kind. - r, err := s.FindRoute(1, "\x01", "\x02", fakeNetNumber, false) + const address2 = tcpip.Address("\x02") + r, err := s.FindRoute(nicID, address1, address2, fakeNetNumber, false) if err != nil { - t.Fatalf("FindRoute failed: %v", err) + t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, address1, address2, fakeNetNumber, err) } defer r.Release() - if err := s.RemoveAddress(1, "\x01"); err != nil { - t.Fatalf("RemoveAddress failed: %v", err) + if err := s.RemoveAddress(nicID, address1); err != nil { + t.Fatalf("RemoveAddress(%d, %s): %s", nicID, address1, err) } // @@ -2687,19 +2736,20 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // Add some other address with peb set to // FirstPrimaryEndpoint. - if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x03", stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddAddressWithOptions failed: %v", err) + const address3 = tcpip.Address("\x03") + if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint); err != nil { + t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint, err) } // Add back the address we removed earlier and // make sure the new peb was respected. // (The address should just be promoted now). - if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", ps); err != nil { - t.Fatalf("AddAddressWithOptions failed: %v", err) + if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, ps); err != nil { + t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err) } var primaryAddrs []tcpip.Address - for _, pa := range s.NICInfo()[1].ProtocolAddresses { + for _, pa := range s.NICInfo()[nicID].ProtocolAddresses { primaryAddrs = append(primaryAddrs, pa.AddressWithPrefix.Address) } var expectedList []tcpip.Address @@ -2728,20 +2778,20 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // should be returned by a call to // GetMainNICAddress; else, our original address // should be returned. - if err := s.RemoveAddress(1, "\x03"); err != nil { - t.Fatalf("RemoveAddress failed: %v", err) + if err := s.RemoveAddress(nicID, address3); err != nil { + t.Fatalf("RemoveAddress(%d, %s): %s", nicID, address3, err) } - addr, ok = s.GetMainNICAddress(1, fakeNetNumber) - if !ok { - t.Fatalf("got GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber) + addr, err = s.GetMainNICAddress(nicID, fakeNetNumber) + if err != nil { + t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) } if ps == stack.NeverPrimaryEndpoint { if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, addr, want) + t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr, want) } } else { - if addr.Address != "\x01" { - t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (1, true)", fakeNetNumber, addr.Address) + if addr.Address != address1 { + t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr.Address, address1) } } }) diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 10cbbe589..c1c6cbccd 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -33,8 +33,8 @@ import ( ) const ( - testSrcAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testDstAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + testSrcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + testDstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" testSrcAddrV4 = "\x0a\x00\x00\x01" testDstAddrV4 = "\x0a\x00\x00\x02" diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 58aabe547..3cc8c36f1 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -72,11 +72,13 @@ go_test( "//pkg/tcpip/buffer", "//pkg/tcpip/checker", "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", + "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 80afc2825..6462e9d42 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -24,11 +24,13 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -502,3 +504,262 @@ func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) { }) } } + +func TestExternalLoopbackTraffic(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + + ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01") + + numPackets = 1 + ) + + loopbackSourcedICMPv4 := func(e *channel.Endpoint) { + utils.RxICMPv4EchoRequest(e, ipv4Loopback, utils.Ipv4Addr.Address) + } + + loopbackSourcedICMPv6 := func(e *channel.Endpoint) { + utils.RxICMPv6EchoRequest(e, header.IPv6Loopback, utils.Ipv6Addr.Address) + } + + loopbackDestinedICMPv4 := func(e *channel.Endpoint) { + utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, ipv4Loopback) + } + + loopbackDestinedICMPv6 := func(e *channel.Endpoint) { + utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, header.IPv6Loopback) + } + + invalidSrcAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter { + return s.InvalidSourceAddressesReceived + } + + invalidDestAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter { + return s.InvalidDestinationAddressesReceived + } + + tests := []struct { + name string + allowExternalLoopback bool + forwarding bool + rxICMP func(*channel.Endpoint) + invalidAddressStat func(tcpip.IPStats) *tcpip.StatCounter + shouldAccept bool + }{ + { + name: "IPv4 external loopback sourced traffic without forwarding and drop external loopback disabled", + allowExternalLoopback: true, + forwarding: false, + rxICMP: loopbackSourcedICMPv4, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: true, + }, + { + name: "IPv4 external loopback sourced traffic without forwarding and drop external loopback enabled", + allowExternalLoopback: false, + forwarding: false, + rxICMP: loopbackSourcedICMPv4, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: false, + }, + { + name: "IPv4 external loopback sourced traffic with forwarding and drop external loopback disabled", + allowExternalLoopback: true, + forwarding: true, + rxICMP: loopbackSourcedICMPv4, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: true, + }, + { + name: "IPv4 external loopback sourced traffic with forwarding and drop external loopback enabled", + allowExternalLoopback: false, + forwarding: true, + rxICMP: loopbackSourcedICMPv4, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: false, + }, + { + name: "IPv4 external loopback destined traffic without forwarding and drop external loopback disabled", + allowExternalLoopback: true, + forwarding: false, + rxICMP: loopbackDestinedICMPv4, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: false, + }, + { + name: "IPv4 external loopback destined traffic without forwarding and drop external loopback enabled", + allowExternalLoopback: false, + forwarding: false, + rxICMP: loopbackDestinedICMPv4, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: false, + }, + { + name: "IPv4 external loopback destined traffic with forwarding and drop external loopback disabled", + allowExternalLoopback: true, + forwarding: true, + rxICMP: loopbackDestinedICMPv4, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: true, + }, + { + name: "IPv4 external loopback destined traffic with forwarding and drop external loopback enabled", + allowExternalLoopback: false, + forwarding: true, + rxICMP: loopbackDestinedICMPv4, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: false, + }, + + { + name: "IPv6 external loopback sourced traffic without forwarding and drop external loopback disabled", + allowExternalLoopback: true, + forwarding: false, + rxICMP: loopbackSourcedICMPv6, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: true, + }, + { + name: "IPv6 external loopback sourced traffic without forwarding and drop external loopback enabled", + allowExternalLoopback: false, + forwarding: false, + rxICMP: loopbackSourcedICMPv6, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: false, + }, + { + name: "IPv6 external loopback sourced traffic with forwarding and drop external loopback disabled", + allowExternalLoopback: true, + forwarding: true, + rxICMP: loopbackSourcedICMPv6, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: true, + }, + { + name: "IPv6 external loopback sourced traffic with forwarding and drop external loopback enabled", + allowExternalLoopback: false, + forwarding: true, + rxICMP: loopbackSourcedICMPv6, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: false, + }, + { + name: "IPv6 external loopback destined traffic without forwarding and drop external loopback disabled", + allowExternalLoopback: true, + forwarding: false, + rxICMP: loopbackDestinedICMPv6, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: false, + }, + { + name: "IPv6 external loopback destined traffic without forwarding and drop external loopback enabled", + allowExternalLoopback: false, + forwarding: false, + rxICMP: loopbackDestinedICMPv6, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: false, + }, + { + name: "IPv6 external loopback destined traffic with forwarding and drop external loopback disabled", + allowExternalLoopback: true, + forwarding: true, + rxICMP: loopbackDestinedICMPv6, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: true, + }, + { + name: "IPv6 external loopback destined traffic with forwarding and drop external loopback enabled", + allowExternalLoopback: false, + forwarding: true, + rxICMP: loopbackDestinedICMPv6, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocolWithOptions(ipv4.Options{ + AllowExternalLoopbackTraffic: test.allowExternalLoopback, + }), + ipv6.NewProtocolWithOptions(ipv6.Options{ + AllowExternalLoopbackTraffic: test.allowExternalLoopback, + }), + }, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, + }) + e := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID1, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + if err := s.AddAddressWithPrefix(nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr, err) + } + if err := s.AddAddressWithPrefix(nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr, err) + } + + if err := s.CreateNIC(nicID2, loopback.New()); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, ipv4Loopback); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, ipv4Loopback, err) + } + if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, header.IPv6Loopback); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, header.IPv6Loopback, err) + } + + if test.forwarding { + if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err) + } + if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + } + } + + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: nicID1, + }, + tcpip.Route{ + Destination: header.IPv6EmptySubnet, + NIC: nicID1, + }, + tcpip.Route{ + Destination: ipv4Loopback.WithPrefix().Subnet(), + NIC: nicID2, + }, + tcpip.Route{ + Destination: header.IPv6Loopback.WithPrefix().Subnet(), + NIC: nicID2, + }, + }) + + stats := s.Stats().IP + invalidAddressStat := test.invalidAddressStat(stats) + deliveredPacketsStat := stats.PacketsDelivered + if got := invalidAddressStat.Value(); got != 0 { + t.Fatalf("got invalidAddressStat.Value() = %d, want = 0", got) + } + if got := deliveredPacketsStat.Value(); got != 0 { + t.Fatalf("got deliveredPacketsStat.Value() = %d, want = 0", got) + } + test.rxICMP(e) + var expectedInvalidPackets uint64 + if !test.shouldAccept { + expectedInvalidPackets = numPackets + } + if got := invalidAddressStat.Value(); got != expectedInvalidPackets { + t.Fatalf("got invalidAddressStat.Value() = %d, want = %d", got, expectedInvalidPackets) + } + if got, want := deliveredPacketsStat.Value(), numPackets-expectedInvalidPackets; got != want { + t.Fatalf("got deliveredPacketsStat.Value() = %d, want = %d", got, want) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 29266a4fc..77f4a88ec 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -45,82 +45,61 @@ const ( func TestPingMulticastBroadcast(t *testing.T) { const nicID = 1 - rxIPv4ICMP := func(e *channel.Endpoint, dst tcpip.Address) { - totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize - hdr := buffer.NewPrependable(totalLen) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - pkt.SetType(header.ICMPv4Echo) - pkt.SetCode(0) - pkt.SetChecksum(0) - pkt.SetChecksum(^header.Checksum(pkt, 0)) - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - Protocol: uint8(icmp.ProtocolNumber4), - TTL: ttl, - SrcAddr: utils.RemoteIPv4Addr, - DstAddr: dst, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - - rxIPv6ICMP := func(e *channel.Endpoint, dst tcpip.Address) { - totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize - hdr := buffer.NewPrependable(totalLen) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) - pkt.SetType(header.ICMPv6EchoRequest) - pkt.SetCode(0) - pkt.SetChecksum(0) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: utils.RemoteIPv6Addr, - Dst: dst, - })) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: ttl, - SrcAddr: utils.RemoteIPv6Addr, - DstAddr: dst, - }) - - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - tests := []struct { - name string - dstAddr tcpip.Address + name string + protoNum tcpip.NetworkProtocolNumber + rxICMP func(*channel.Endpoint, tcpip.Address, tcpip.Address) + srcAddr tcpip.Address + dstAddr tcpip.Address + expectedSrc tcpip.Address }{ { - name: "IPv4 unicast", - dstAddr: utils.Ipv4Addr.Address, + name: "IPv4 unicast", + protoNum: header.IPv4ProtocolNumber, + dstAddr: utils.Ipv4Addr.Address, + srcAddr: utils.RemoteIPv4Addr, + rxICMP: utils.RxICMPv4EchoRequest, + expectedSrc: utils.Ipv4Addr.Address, }, { - name: "IPv4 directed broadcast", - dstAddr: utils.Ipv4SubnetBcast, + name: "IPv4 directed broadcast", + protoNum: header.IPv4ProtocolNumber, + rxICMP: utils.RxICMPv4EchoRequest, + srcAddr: utils.RemoteIPv4Addr, + dstAddr: utils.Ipv4SubnetBcast, + expectedSrc: utils.Ipv4Addr.Address, }, { - name: "IPv4 broadcast", - dstAddr: header.IPv4Broadcast, + name: "IPv4 broadcast", + protoNum: header.IPv4ProtocolNumber, + rxICMP: utils.RxICMPv4EchoRequest, + srcAddr: utils.RemoteIPv4Addr, + dstAddr: header.IPv4Broadcast, + expectedSrc: utils.Ipv4Addr.Address, }, { - name: "IPv4 all-systems multicast", - dstAddr: header.IPv4AllSystems, + name: "IPv4 all-systems multicast", + protoNum: header.IPv4ProtocolNumber, + rxICMP: utils.RxICMPv4EchoRequest, + srcAddr: utils.RemoteIPv4Addr, + dstAddr: header.IPv4AllSystems, + expectedSrc: utils.Ipv4Addr.Address, }, { - name: "IPv6 unicast", - dstAddr: utils.Ipv6Addr.Address, + name: "IPv6 unicast", + protoNum: header.IPv6ProtocolNumber, + rxICMP: utils.RxICMPv6EchoRequest, + srcAddr: utils.RemoteIPv6Addr, + dstAddr: utils.Ipv6Addr.Address, + expectedSrc: utils.Ipv6Addr.Address, }, { - name: "IPv6 all-nodes multicast", - dstAddr: header.IPv6AllNodesMulticastAddress, + name: "IPv6 all-nodes multicast", + protoNum: header.IPv6ProtocolNumber, + rxICMP: utils.RxICMPv6EchoRequest, + srcAddr: utils.RemoteIPv6Addr, + dstAddr: header.IPv6AllNodesMulticastAddress, + expectedSrc: utils.Ipv6Addr.Address, }, } @@ -157,44 +136,29 @@ func TestPingMulticastBroadcast(t *testing.T) { }, }) - var rxICMP func(*channel.Endpoint, tcpip.Address) - var expectedSrc tcpip.Address - var expectedDst tcpip.Address - var protoNum tcpip.NetworkProtocolNumber - switch l := len(test.dstAddr); l { - case header.IPv4AddressSize: - rxICMP = rxIPv4ICMP - expectedSrc = utils.Ipv4Addr.Address - expectedDst = utils.RemoteIPv4Addr - protoNum = header.IPv4ProtocolNumber - case header.IPv6AddressSize: - rxICMP = rxIPv6ICMP - expectedSrc = utils.Ipv6Addr.Address - expectedDst = utils.RemoteIPv6Addr - protoNum = header.IPv6ProtocolNumber - default: - t.Fatalf("got unexpected address length = %d bytes", l) - } - - rxICMP(e, test.dstAddr) + test.rxICMP(e, test.srcAddr, test.dstAddr) pkt, ok := e.Read() if !ok { t.Fatal("expected ICMP response") } - if pkt.Route.LocalAddress != expectedSrc { - t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, expectedSrc) + if pkt.Route.LocalAddress != test.expectedSrc { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.expectedSrc) } - if pkt.Route.RemoteAddress != expectedDst { - t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst) + // The destination of the response packet should be the source of the + // original packet. + if pkt.Route.RemoteAddress != test.srcAddr { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.srcAddr) } - src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader())) - if src != expectedSrc { - t.Errorf("got pkt source = %s, want = %s", src, expectedSrc) + src, dst := s.NetworkProtocolInstance(test.protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader())) + if src != test.expectedSrc { + t.Errorf("got pkt source = %s, want = %s", src, test.expectedSrc) } - if dst != expectedDst { - t.Errorf("got pkt destination = %s, want = %s", dst, expectedDst) + // The destination of the response packet should be the source of the + // original packet. + if dst != test.srcAddr { + t.Errorf("got pkt destination = %s, want = %s", dst, test.srcAddr) } }) } diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go index 4455f6dd7..ed499179f 100644 --- a/pkg/tcpip/tests/integration/route_test.go +++ b/pkg/tcpip/tests/integration/route_test.go @@ -16,6 +16,7 @@ package route_test import ( "bytes" + "fmt" "testing" "github.com/google/go-cmp/cmp" @@ -161,78 +162,79 @@ func TestLocalPing(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, - HandleLocal: true, - }) - e := test.linkEndpoint() - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } + for _, allowExternalLoopback := range []bool{true, false} { + t.Run(fmt.Sprintf("AllowExternalLoopback=%t", allowExternalLoopback), func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocolWithOptions(ipv4.Options{ + AllowExternalLoopbackTraffic: allowExternalLoopback, + }), + ipv6.NewProtocolWithOptions(ipv6.Options{ + AllowExternalLoopbackTraffic: allowExternalLoopback, + }), + }, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, + HandleLocal: true, + }) + e := test.linkEndpoint() + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } - if len(test.localAddr) != 0 { - if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err) - } - } + if len(test.localAddr) != 0 { + if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err) + } + } - var wq waiter.Queue - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err) - } - defer ep.Close() - - connAddr := tcpip.FullAddress{Addr: test.localAddr} - { - err := ep.Connect(connAddr) - if diff := cmp.Diff(test.expectedConnectErr, err); diff != "" { - t.Fatalf("unexpected error from ep.Connect(%#v), (-want, +got):\n%s", connAddr, diff) - } - } + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err) + } + defer ep.Close() - if test.expectedConnectErr != nil { - return - } + connAddr := tcpip.FullAddress{Addr: test.localAddr} + if err := ep.Connect(connAddr); err != test.expectedConnectErr { + t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr) + } - payload := test.icmpBuf(t) - var r bytes.Reader - r.Reset(payload) - var wOpts tcpip.WriteOptions - if n, err := ep.Write(&r, wOpts); err != nil { - t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err) - } else if n != int64(len(payload)) { - t.Fatalf("got ep.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", payload, wOpts, n, len(payload)) - } + if test.expectedConnectErr != nil { + return + } - // Wait for the endpoint to become readable. - <-ch + var r bytes.Reader + payload := test.icmpBuf(t) + r.Reset(payload) + var wOpts tcpip.WriteOptions + if n, err := ep.Write(&r, wOpts); err != nil { + t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err) + } else if n != int64(len(payload)) { + t.Fatalf("got ep.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", payload, wOpts, n, len(payload)) + } - var buf bytes.Buffer - opts := tcpip.ReadOptions{NeedRemoteAddr: true} - res, err := ep.Read(&buf, opts) - if err != nil { - t.Fatalf("ep.Read(_, %#v): %s", opts, err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - RemoteAddr: tcpip.FullAddress{Addr: test.localAddr}, - }, res, checker.IgnoreCmpPath( - "ControlMessages", - "RemoteAddr.NIC", - "RemoteAddr.Port", - )); diff != "" { - t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(buf.Bytes()[icmpDataOffset:], []byte(payload[icmpDataOffset:])); diff != "" { - t.Errorf("received data mismatch (-want +got):\n%s", diff) - } + // Wait for the endpoint to become readable. + <-ch - test.checkLinkEndpoint(t, e) + var w bytes.Buffer + rr, err := ep.Read(&w, tcpip.ReadOptions{ + NeedRemoteAddr: true, + }) + if err != nil { + t.Fatalf("ep.Read(...): %s", err) + } + if diff := cmp.Diff(buffer.View(w.Bytes()[icmpDataOffset:]), payload[icmpDataOffset:]); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) + } + if rr.RemoteAddr.Addr != test.localAddr { + t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr) + } + + test.checkLinkEndpoint(t, e) + }) + } }) } } diff --git a/pkg/tcpip/tests/utils/BUILD b/pkg/tcpip/tests/utils/BUILD index 433004148..a9699a367 100644 --- a/pkg/tcpip/tests/utils/BUILD +++ b/pkg/tcpip/tests/utils/BUILD @@ -8,12 +8,15 @@ go_library( visibility = ["//pkg/tcpip/tests:__subpackages__"], deps = [ "//pkg/tcpip", + "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", "//pkg/tcpip/link/ethernet", "//pkg/tcpip/link/nested", "//pkg/tcpip/link/pipe", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/icmp", ], ) diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go index f414a2234..d1c9f3a94 100644 --- a/pkg/tcpip/tests/utils/utils.go +++ b/pkg/tcpip/tests/utils/utils.go @@ -20,13 +20,16 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/ethernet" "gvisor.dev/gvisor/pkg/tcpip/link/nested" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" ) // Common NIC IDs used by tests. @@ -45,6 +48,10 @@ const ( LinkAddr4 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") ) +const ( + ttl = 255 +) + // Common IP addresses used by tests. var ( Ipv4Addr = tcpip.AddressWithPrefix{ @@ -312,3 +319,56 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. }, }) } + +// RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on +// the provided endpoint. +func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { + totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + pkt.SetType(header.ICMPv4Echo) + pkt.SetCode(header.ICMPv4UnusedCode) + pkt.SetChecksum(0) + pkt.SetChecksum(^header.Checksum(pkt, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLen), + Protocol: uint8(icmp.ProtocolNumber4), + TTL: ttl, + SrcAddr: src, + DstAddr: dst, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) +} + +// RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on +// the provided endpoint. +func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { + totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + pkt.SetType(header.ICMPv6EchoRequest) + pkt.SetCode(header.ICMPv6UnusedCode) + pkt.SetChecksum(0) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: src, + Dst: dst, + })) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: ttl, + SrcAddr: src, + DstAddr: dst, + }) + + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) +} diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 3b574837c..0a2f3291c 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -20,6 +20,7 @@ import ( "fmt" "hash" "io" + "sync/atomic" "time" "gvisor.dev/gvisor/pkg/rand" @@ -390,7 +391,7 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) { // deliverAccepted delivers the newly-accepted endpoint to the listener. If the // endpoint has transitioned out of the listen state (acceptedChan is nil), // the new endpoint is closed instead. -func (e *endpoint) deliverAccepted(n *endpoint) { +func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) { e.mu.Lock() e.pendingAccepted.Add(1) e.mu.Unlock() @@ -405,6 +406,9 @@ func (e *endpoint) deliverAccepted(n *endpoint) { } select { case e.acceptedChan <- n: + if !withSynCookie { + atomic.AddInt32(&e.synRcvdCount, -1) + } e.acceptMu.Unlock() e.waiterQueue.Notify(waiter.EventIn) return @@ -476,7 +480,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() - e.synRcvdCount-- + atomic.AddInt32(&e.synRcvdCount, -1) return err } @@ -486,18 +490,13 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() ctx.cleanupFailedHandshake(h) - e.mu.Lock() - e.synRcvdCount-- - e.mu.Unlock() + atomic.AddInt32(&e.synRcvdCount, -1) return } ctx.cleanupCompletedHandshake(h) - e.mu.Lock() - e.synRcvdCount-- - e.mu.Unlock() h.ep.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - e.deliverAccepted(h.ep) + e.deliverAccepted(h.ep, false /*withSynCookie*/) }() // S/R-SAFE: synRcvdCount is the barrier. return nil @@ -505,17 +504,17 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header func (e *endpoint) incSynRcvdCount() bool { e.acceptMu.Lock() - canInc := e.synRcvdCount < cap(e.acceptedChan) + canInc := int(atomic.LoadInt32(&e.synRcvdCount)) < cap(e.acceptedChan) e.acceptMu.Unlock() if canInc { - e.synRcvdCount++ + atomic.AddInt32(&e.synRcvdCount, 1) } return canInc } func (e *endpoint) acceptQueueIsFull() bool { e.acceptMu.Lock() - full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan) + full := len(e.acceptedChan)+int(atomic.LoadInt32(&e.synRcvdCount)) >= cap(e.acceptedChan) e.acceptMu.Unlock() return full } @@ -737,7 +736,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err // Start the protocol goroutine. n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - go e.deliverAccepted(n) + go e.deliverAccepted(n, true /*withSynCookie*/) return nil default: diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 129f36d11..43d344350 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -532,8 +532,8 @@ type endpoint struct { segmentQueue segmentQueue `state:"wait"` // synRcvdCount is the number of connections for this endpoint that are - // in SYN-RCVD state. - synRcvdCount int + // in SYN-RCVD state; this is only accessed atomically. + synRcvdCount int32 // userMSS if non-zero is the MSS value explicitly set by the user // for this endpoint using the TCP_MAXSEG setsockopt. diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index a5c82b8fa..bc6793fc6 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -260,7 +260,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum case StateEstablished: r.ep.setEndpointState(StateCloseWait) case StateFinWait1: - if s.flagIsSet(header.TCPFlagAck) { + if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { // FIN-ACK, transition to TIME-WAIT. r.ep.setEndpointState(StateTimeWait) } else { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index fd499a47b..6c86ae1ae 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -165,7 +165,7 @@ func TestConnectIncrementActiveConnection(t *testing.T) { stats := c.Stack().Stats() want := stats.TCP.ActiveConnectionOpenings.Value() + 1 - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want { t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %d, want = %d", got, want) } @@ -178,7 +178,7 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) { stats := c.Stack().Stats() want := stats.TCP.FailedConnectionAttempts.Value() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want) } @@ -239,7 +239,7 @@ func TestTCPSegmentsSentIncrement(t *testing.T) { stats := c.Stack().Stats() // SYN and ACK want := stats.TCP.SegmentsSent.Value() + 2 - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) if got := stats.TCP.SegmentsSent.Value(); got != want { t.Errorf("got stats.TCP.SegmentsSent.Value() = %d, want = %d", got, want) @@ -269,7 +269,7 @@ func TestTCPResetsSentIncrement(t *testing.T) { } // Send a SYN request. - iss := seqnum.Value(789) + iss := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -318,7 +318,7 @@ func TestTCPResetsSentNoICMP(t *testing.T) { // Send a SYN request for a closed port. This should elicit an RST // but NOT an ICMPv4 DstUnreachable packet. - iss := seqnum.Value(789) + iss := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -362,7 +362,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { } // Send a SYN request. - iss := seqnum.Value(789) + iss := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -459,7 +459,7 @@ func TestTCPResetsReceivedIncrement(t *testing.T) { stats := c.Stack().Stats() want := stats.TCP.ResetsReceived.Value() + 1 - iss := seqnum.Value(789) + iss := seqnum.Value(context.TestInitialSequenceNumber) rcvWnd := seqnum.Size(30000) c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) @@ -483,7 +483,7 @@ func TestTCPResetsDoNotGenerateResets(t *testing.T) { stats := c.Stack().Stats() want := stats.TCP.ResetsReceived.Value() + 1 - iss := seqnum.Value(789) + iss := seqnum.Value(context.TestInitialSequenceNumber) rcvWnd := seqnum.Size(30000) c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) @@ -506,14 +506,14 @@ func TestActiveHandshake(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) } func TestNonBlockingClose(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil @@ -537,18 +537,19 @@ func TestConnectResetAfterClose(t *testing.T) { t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil // Close the endpoint, make sure we get a FIN segment, then acknowledge // to complete closure of sender, but don't send our own FIN. ep.Close() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -556,7 +557,7 @@ func TestConnectResetAfterClose(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(2), RcvWnd: 30000, }) @@ -570,7 +571,7 @@ func TestConnectResetAfterClose(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(2), RcvWnd: 30000, }) @@ -612,7 +613,7 @@ func TestCurrentConnectedIncrement(t *testing.T) { t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil @@ -625,12 +626,12 @@ func TestCurrentConnectedIncrement(t *testing.T) { } ep.Close() - + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -638,7 +639,7 @@ func TestCurrentConnectedIncrement(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(2), RcvWnd: 30000, }) @@ -655,7 +656,7 @@ func TestCurrentConnectedIncrement(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(2), RcvWnd: 30000, }) @@ -666,7 +667,7 @@ func TestCurrentConnectedIncrement(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(791), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -690,7 +691,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil @@ -699,11 +700,12 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { } // Send a FIN for ESTABLISHED --> CLOSED-WAIT + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagFin | header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -713,7 +715,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(791), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -734,7 +736,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(791), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -753,7 +755,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 791, + SeqNum: iss.Add(1), AckNum: c.IRS.Add(2), RcvWnd: 30000, }) @@ -764,7 +766,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 792, + SeqNum: iss.Add(2), AckNum: c.IRS.Add(2), RcvWnd: 30000, }) @@ -804,7 +806,7 @@ func TestSimpleReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -813,11 +815,12 @@ func TestSimpleReceive(t *testing.T) { ept := endpointTester{c.EP} data := []byte{1, 2, 3} + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -840,7 +843,7 @@ func TestSimpleReceive(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+len(data))), + checker.TCPAckNum(uint32(iss)+uint32(len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1366,12 +1369,13 @@ func TestTOSV4(t *testing.T) { // Check that data is received. b := c.GetPacket() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, b, checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), // Acknum is initial sequence number + 1 + checker.TCPAckNum(uint32(iss)), // Acknum is initial sequence number + 1 checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), checker.TOS(tos, 0), @@ -1414,12 +1418,13 @@ func TestTrafficClassV6(t *testing.T) { // Check that data is received. b := c.GetV6Packet() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv6(t, b, checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), checker.TOS(tos, 0), @@ -1472,7 +1477,7 @@ func TestConnectBindToDevice(t *testing.T) { tcpHdr := header.TCP(header.IPv4(b).Payload()) c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - iss := seqnum.Value(789) + iss := seqnum.Value(context.TestInitialSequenceNumber) rcvWnd := seqnum.Size(30000) c.SendPacket(nil, &context.Headers{ SrcPort: tcpHdr.DestinationPort(), @@ -1537,7 +1542,7 @@ func TestSynSent(t *testing.T) { if test.reset { // Send a packet with a proper ACK and a RST flag to cause the socket // to error and close out. - iss := seqnum.Value(789) + iss := seqnum.Value(context.TestInitialSequenceNumber) rcvWnd := seqnum.Size(30000) c.SendPacket(nil, &context.Headers{ SrcPort: tcpHdr.DestinationPort(), @@ -1582,7 +1587,7 @@ func TestOutOfOrderReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -1593,11 +1598,12 @@ func TestOutOfOrderReceive(t *testing.T) { // Send second half of data first, with seqnum 3 ahead of expected. data := []byte{1, 2, 3, 4, 5, 6} + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) c.SendPacket(data[3:], &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 793, + SeqNum: iss.Add(3), AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -1607,7 +1613,7 @@ func TestOutOfOrderReceive(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1621,7 +1627,7 @@ func TestOutOfOrderReceive(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -1639,7 +1645,7 @@ func TestOutOfOrderReceive(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+len(data))), + checker.TCPAckNum(uint32(iss)+uint32(len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1650,19 +1656,20 @@ func TestOutOfOrderFlood(t *testing.T) { defer c.Cleanup() rcvBufSz := math.MaxUint16 - c.CreateConnected(789, 30000, rcvBufSz) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBufSz) ept := endpointTester{c.EP} ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Send 100 packets before the actual one that is expected. data := []byte{1, 2, 3, 4, 5, 6} + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) for i := 0; i < 100; i++ { c.SendPacket(data[3:], &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 796, + SeqNum: iss.Add(6), AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -1671,19 +1678,19 @@ func TestOutOfOrderFlood(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), ), ) } - // Send packet with seqnum 793. It must be discarded because the + // Send packet with seqnum as initial + 3. It must be discarded because the // out-of-order buffer was filled by the previous packets. c.SendPacket(data[3:], &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 793, + SeqNum: iss.Add(3), AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -1692,27 +1699,27 @@ func TestOutOfOrderFlood(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), ), ) - // Now send the expected packet, seqnum 790. + // Now send the expected packet with initial sequence number. c.SendPacket(data[:3], &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - // Check that only packet 790 is acknowledged. + // Check that only packet with initial sequence number is acknowledged. checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(793), + checker.TCPAckNum(uint32(iss)+3), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1722,7 +1729,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -1732,11 +1739,12 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) data := []byte{1, 2, 3} + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -1753,7 +1761,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+len(data))), + checker.TCPAckNum(uint32(iss)+uint32(len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1780,7 +1788,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + len(data)), + SeqNum: iss.Add(seqnum.Size(len(data))), AckNum: c.IRS.Add(seqnum.Size(2)), RcvWnd: 30000, }) @@ -1790,7 +1798,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -1800,11 +1808,12 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) data := []byte{1, 2, 3} + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -1821,7 +1830,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+len(data))), + checker.TCPAckNum(uint32(iss)+uint32(len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1866,7 +1875,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + len(data)), + SeqNum: iss.Add(seqnum.Size(len(data))), AckNum: c.IRS.Add(seqnum.Size(2)), RcvWnd: 30000, }) @@ -1876,7 +1885,7 @@ func TestShutdownRead(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) ept := endpointTester{c.EP} ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) @@ -1897,7 +1906,7 @@ func TestFullWindowReceive(t *testing.T) { defer c.Cleanup() const rcvBufSz = 10 - c.CreateConnected(789, 30000, rcvBufSz) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBufSz) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -1913,11 +1922,12 @@ func TestFullWindowReceive(t *testing.T) { for i := range data { data[i] = byte(i % 255) } + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -1934,7 +1944,7 @@ func TestFullWindowReceive(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+len(data))), + checker.TCPAckNum(uint32(iss)+uint32(len(data))), checker.TCPFlags(header.TCPFlagAck), checker.TCPWindow(0), ), @@ -1956,7 +1966,7 @@ func TestFullWindowReceive(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+len(data))), + checker.TCPAckNum(uint32(iss)+uint32(len(data))), checker.TCPFlags(header.TCPFlagAck), checker.TCPWindow(10), ), @@ -1996,8 +2006,7 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { } payload := generateRandomPayload(t, payloadSize) payloadLen := seqnum.Size(len(payload)) - iss := seqnum.Value(789) - seqNum := iss.Add(1) + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) // Send payload to the endpoint and return the advertised receive window // from the endpoint. @@ -2005,12 +2014,12 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { c.SendPacket(payload, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, - SeqNum: seqNum, + SeqNum: iss, AckNum: c.IRS.Add(1), Flags: header.TCPFlagAck, RcvWnd: 30000, }) - seqNum = seqNum.Add(payloadLen) + iss = iss.Add(payloadLen) pkt := c.GetPacket() return uint32(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.RcvdWindowScale @@ -2054,9 +2063,8 @@ func TestNoWindowShrinking(t *testing.T) { // the right edge of the window does not shrink. // NOTE: Netstack doubles the value specified here. rcvBufSize := 65536 - iss := seqnum.Value(789) // Enable window scaling with a scale of zero from our end. - c.CreateConnectedWithRawOptions(iss, 30000, rcvBufSize, []byte{ + c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, rcvBufSize, []byte{ header.TCPOptionWS, 3, 0, header.TCPOptionNOP, }) @@ -2069,13 +2077,13 @@ func TestNoWindowShrinking(t *testing.T) { // Send a 1 byte payload so that we can record the current receive window. // Send a payload of half the size of rcvBufSize. - seqNum := iss.Add(1) + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) payload := []byte{1} c.SendPacket(payload, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqNum, + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -2092,20 +2100,20 @@ func TestNoWindowShrinking(t *testing.T) { t.Fatalf("got data: %v, want: %v", got, want) } - seqNum = seqNum.Add(1) // Verify that the ACK does not shrink the window. pkt := c.GetPacket() + iss = iss.Add(1) checker.IPv4(t, pkt, checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(seqNum)), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), ), ) // Stash the initial window. initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale - initialLastAcceptableSeq := seqNum.Add(seqnum.Size(initialWnd)) + initialLastAcceptableSeq := iss.Add(seqnum.Size(initialWnd)) // Now shrink the receive buffer to half its original size. if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufSize/2); err != nil { t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err) @@ -2117,11 +2125,11 @@ func TestNoWindowShrinking(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqNum, + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2)) + iss = iss.Add(seqnum.Size(rcvBufSize / 2)) // Verify that the ACK does not shrink the window. pkt = c.GetPacket() @@ -2129,12 +2137,12 @@ func TestNoWindowShrinking(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(seqNum)), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), ), ) newWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale - newLastAcceptableSeq := seqNum.Add(seqnum.Size(newWnd)) + newLastAcceptableSeq := iss.Add(seqnum.Size(newWnd)) if newLastAcceptableSeq.LessThan(initialLastAcceptableSeq) { t.Fatalf("receive window shrunk unexpectedly got: %d, want >= %d", newLastAcceptableSeq, initialLastAcceptableSeq) } @@ -2145,17 +2153,17 @@ func TestNoWindowShrinking(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqNum, + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2)) + iss = iss.Add(seqnum.Size(rcvBufSize / 2)) checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(seqNum)), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), checker.TCPWindow(0), ), @@ -2173,7 +2181,7 @@ func TestNoWindowShrinking(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(seqNum)), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), checker.TCPWindow(uint16(rcvBufSize/2)>>c.RcvdWindowScale), ), @@ -2184,7 +2192,7 @@ func TestSimpleSend(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) data := []byte{1, 2, 3} var r bytes.Reader @@ -2195,12 +2203,13 @@ func TestSimpleSend(t *testing.T) { // Check that data is received. b := c.GetPacket() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, b, checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2214,7 +2223,7 @@ func TestSimpleSend(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), RcvWnd: 30000, }) @@ -2224,7 +2233,7 @@ func TestZeroWindowSend(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789 /* iss */, 0 /* rcvWnd */, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 0 /* rcvWnd */, -1 /* epRcvBuf */) data := []byte{1, 2, 3} var r bytes.Reader @@ -2235,12 +2244,13 @@ func TestZeroWindowSend(t *testing.T) { // Check if we got a zero-window probe. b := c.GetPacket() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, b, checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2250,7 +2260,7 @@ func TestZeroWindowSend(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -2262,7 +2272,7 @@ func TestZeroWindowSend(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2276,7 +2286,7 @@ func TestZeroWindowSend(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), RcvWnd: 30000, }) @@ -2289,7 +2299,7 @@ func TestScaledWindowConnect(t *testing.T) { defer c.Cleanup() // Set the window size greater than the maximum non-scaled window. - c.CreateConnectedWithRawOptions(789, 30000, 65535*3, []byte{ + c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, 65535*3, []byte{ header.TCPOptionWS, 3, 0, header.TCPOptionNOP, }) @@ -2303,12 +2313,13 @@ func TestScaledWindowConnect(t *testing.T) { // Check that data is received, and that advertised window is 0x5fff, // that is, that it is scaled. b := c.GetPacket() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, b, checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPWindow(0x5fff), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), @@ -2322,7 +2333,7 @@ func TestNonScaledWindowConnect(t *testing.T) { defer c.Cleanup() // Set the window size greater than the maximum non-scaled window. - c.CreateConnected(789, 30000, 65535*3) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, 65535*3) data := []byte{1, 2, 3} var r bytes.Reader @@ -2334,12 +2345,13 @@ func TestNonScaledWindowConnect(t *testing.T) { // Check that data is received, and that advertised window is 0xffff, // that is, that it's not scaled. b := c.GetPacket() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, b, checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPWindow(0xffff), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), @@ -2407,12 +2419,13 @@ func TestScaledWindowAccept(t *testing.T) { // Check that data is received, and that advertised window is 0x5fff, // that is, that it is scaled. b := c.GetPacket() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, b, checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPWindow(0x5fff), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), @@ -2480,12 +2493,13 @@ func TestNonScaledWindowAccept(t *testing.T) { // Check that data is received, and that advertised window is 0xffff, // that is, that it's not scaled. b := c.GetPacket() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, b, checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPWindow(0xffff), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), @@ -2502,7 +2516,7 @@ func TestZeroScaledWindowReceive(t *testing.T) { // Set the buffer size such that a window scale of 5 will be used. const bufSz = 65535 * 10 const ws = uint32(5) - c.CreateConnectedWithRawOptions(789, 30000, bufSz, []byte{ + c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, bufSz, []byte{ header.TCPOptionWS, 3, 0, header.TCPOptionNOP, }) @@ -2510,13 +2524,14 @@ func TestZeroScaledWindowReceive(t *testing.T) { remain := 0 sent := 0 data := make([]byte, 50000) + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) // Keep writing till the window drops below len(data). for { c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + sent), + SeqNum: iss.Add(seqnum.Size(sent)), AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -2527,7 +2542,7 @@ func TestZeroScaledWindowReceive(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+sent)), + checker.TCPAckNum(uint32(iss)+uint32(sent)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -2545,7 +2560,7 @@ func TestZeroScaledWindowReceive(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + sent), + SeqNum: iss.Add(seqnum.Size(sent)), AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -2556,7 +2571,7 @@ func TestZeroScaledWindowReceive(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+sent)), + checker.TCPAckNum(uint32(iss)+uint32(sent)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -2594,7 +2609,7 @@ func TestZeroScaledWindowReceive(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+sent)), + checker.TCPAckNum(uint32(iss)+uint32(sent)), checker.TCPWindowGreaterThanEq(uint16(defaultMTU>>ws)), checker.TCPFlags(header.TCPFlagAck), ), @@ -2632,7 +2647,7 @@ func TestSegmentMerging(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) // Send tcp.InitialCwnd number of segments to fill up // InitialWindow but don't ACK. That should prevent @@ -2657,6 +2672,7 @@ func TestSegmentMerging(t *testing.T) { } // Check that we get tcp.InitialCwnd packets. + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) for i := 0; i < tcp.InitialCwnd; i++ { b := c.GetPacket() checker.IPv4(t, b, @@ -2664,7 +2680,7 @@ func TestSegmentMerging(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2675,7 +2691,7 @@ func TestSegmentMerging(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1 + 10), // 10 for the 10 bytes of payload. RcvWnd: 30000, }) @@ -2687,7 +2703,7 @@ func TestSegmentMerging(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+11), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2701,7 +2717,7 @@ func TestSegmentMerging(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(11 + seqnum.Size(len(allData))), RcvWnd: 30000, }) @@ -2713,7 +2729,7 @@ func TestDelay(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) c.EP.SocketOptions().SetDelayOption(true) @@ -2728,6 +2744,7 @@ func TestDelay(t *testing.T) { } seq := c.IRS.Add(1) + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) for _, want := range [][]byte{allData[:1], allData[1:]} { // Check that data is received. b := c.GetPacket() @@ -2736,7 +2753,7 @@ func TestDelay(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(seq)), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2751,7 +2768,7 @@ func TestDelay(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: seq, RcvWnd: 30000, }) @@ -2762,7 +2779,7 @@ func TestUndelay(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) c.EP.SocketOptions().SetDelayOption(true) @@ -2776,7 +2793,7 @@ func TestUndelay(t *testing.T) { } seq := c.IRS.Add(1) - + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) // Check that data is received. first := c.GetPacket() checker.IPv4(t, first, @@ -2784,7 +2801,7 @@ func TestUndelay(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(seq)), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2807,7 +2824,7 @@ func TestUndelay(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(seq)), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2823,7 +2840,7 @@ func TestUndelay(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: seq, RcvWnd: 30000, }) @@ -2845,7 +2862,7 @@ func TestMSSNotDelayed(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ + c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), }) @@ -2861,7 +2878,7 @@ func TestMSSNotDelayed(t *testing.T) { } seq := c.IRS.Add(1) - + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) for i, data := range allData { // Check that data is received. packet := c.GetPacket() @@ -2870,7 +2887,7 @@ func TestMSSNotDelayed(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(seq)), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2887,7 +2904,7 @@ func TestMSSNotDelayed(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: seq, RcvWnd: 30000, }) @@ -2912,6 +2929,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { // Check that data is received in chunks. bytesReceived := 0 numPackets := 0 + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) for bytesReceived != dataLen { b := c.GetPacket() numPackets++ @@ -2921,7 +2939,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1+uint32(bytesReceived)), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2945,7 +2963,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)), RcvWnd: 30000, TCPOpts: options, @@ -2961,7 +2979,7 @@ func TestSendGreaterThanMTU(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) testBrokenUpWrite(t, c, maxPayload) } @@ -3001,7 +3019,7 @@ func TestActiveSendMSSLessThanMTU(t *testing.T) { c := context.New(t, 65535) defer c.Cleanup() - c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ + c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), }) testBrokenUpWrite(t, c, maxPayload) @@ -3210,7 +3228,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { ) // Send SYN-ACK. - iss := seqnum.Value(789) + iss := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: tcpHdr.DestinationPort(), DstPort: tcpHdr.SourcePort(), @@ -3272,14 +3290,15 @@ func TestReceiveOnResetConnection(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) // Send RST segment. c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagRst, - SeqNum: 790, + SeqNum: iss, RcvWnd: 30000, }) @@ -3328,14 +3347,15 @@ func TestSendOnResetConnection(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) // Send RST segment. + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagRst, - SeqNum: 790, + SeqNum: iss, RcvWnd: 30000, }) @@ -3363,7 +3383,7 @@ func TestMaxRetransmitsTimeout(t *testing.T) { t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } - c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) waitEntry, notifyCh := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&waitEntry, waiter.EventHUp) @@ -3426,7 +3446,7 @@ func TestMaxRTO(t *testing.T) { t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } - c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) var r bytes.Reader r.Reset(make([]byte, 1)) @@ -3469,7 +3489,7 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) // Disabling PMTU discovery causes all packets sent from this socket to // have DF=0. This needs to be done because the IPv4 ID uniqueness @@ -3518,19 +3538,20 @@ func TestFinImmediately(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) // Shutdown immediately, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { t.Fatalf("Shutdown failed: %s", err) } + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3540,7 +3561,7 @@ func TestFinImmediately(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(2), RcvWnd: 30000, }) @@ -3551,7 +3572,7 @@ func TestFinImmediately(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(791), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3561,19 +3582,20 @@ func TestFinRetransmit(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) // Shutdown immediately, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { t.Fatalf("Shutdown failed: %s", err) } + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3584,7 +3606,7 @@ func TestFinRetransmit(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3594,7 +3616,7 @@ func TestFinRetransmit(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(2), RcvWnd: 30000, }) @@ -3605,7 +3627,7 @@ func TestFinRetransmit(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(791), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3615,7 +3637,7 @@ func TestFinWithNoPendingData(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) // Write something out, and have it acknowledged. view := make([]byte, 10) @@ -3626,12 +3648,13 @@ func TestFinWithNoPendingData(t *testing.T) { } next := uint32(c.IRS) + 1 + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, c.GetPacket(), checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3641,7 +3664,7 @@ func TestFinWithNoPendingData(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: seqnum.Value(next), RcvWnd: 30000, }) @@ -3656,7 +3679,7 @@ func TestFinWithNoPendingData(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3667,7 +3690,7 @@ func TestFinWithNoPendingData(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, + SeqNum: iss, AckNum: seqnum.Value(next), RcvWnd: 30000, }) @@ -3678,7 +3701,7 @@ func TestFinWithNoPendingData(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(791), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3688,7 +3711,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) // Write enough segments to fill the congestion window before ACK'ing // any of them. @@ -3702,13 +3725,14 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { } next := uint32(c.IRS) + 1 + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) for i := tcp.InitialCwnd; i > 0; i-- { checker.IPv4(t, c.GetPacket(), checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3727,7 +3751,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3737,7 +3761,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: seqnum.Value(next), RcvWnd: 30000, }) @@ -3747,7 +3771,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3758,7 +3782,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, + SeqNum: iss, AckNum: seqnum.Value(next), RcvWnd: 30000, }) @@ -3768,7 +3792,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(791), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3778,7 +3802,7 @@ func TestFinWithPendingData(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) // Write something out, and acknowledge it to get cwnd to 2. view := make([]byte, 10) @@ -3789,12 +3813,13 @@ func TestFinWithPendingData(t *testing.T) { } next := uint32(c.IRS) + 1 + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, c.GetPacket(), checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3804,7 +3829,7 @@ func TestFinWithPendingData(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: seqnum.Value(next), RcvWnd: 30000, }) @@ -3820,7 +3845,7 @@ func TestFinWithPendingData(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3836,7 +3861,7 @@ func TestFinWithPendingData(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3847,7 +3872,7 @@ func TestFinWithPendingData(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, + SeqNum: iss, AckNum: seqnum.Value(next), RcvWnd: 30000, }) @@ -3857,7 +3882,7 @@ func TestFinWithPendingData(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(791), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3867,7 +3892,7 @@ func TestFinWithPartialAck(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) // Write something out, and acknowledge it to get cwnd to 2. Also send // FIN from the test side. @@ -3879,12 +3904,13 @@ func TestFinWithPartialAck(t *testing.T) { } next := uint32(c.IRS) + 1 + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, c.GetPacket(), checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3894,7 +3920,7 @@ func TestFinWithPartialAck(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, + SeqNum: iss, AckNum: seqnum.Value(next), RcvWnd: 30000, }) @@ -3905,7 +3931,7 @@ func TestFinWithPartialAck(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(791), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3921,7 +3947,7 @@ func TestFinWithPartialAck(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(791), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3937,7 +3963,7 @@ func TestFinWithPartialAck(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(791), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3948,7 +3974,7 @@ func TestFinWithPartialAck(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 791, + SeqNum: iss.Add(1), AckNum: seqnum.Value(next - 1), RcvWnd: 30000, }) @@ -3961,7 +3987,7 @@ func TestFinWithPartialAck(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 791, + SeqNum: iss.Add(1), AckNum: seqnum.Value(next), RcvWnd: 30000, }) @@ -4002,17 +4028,18 @@ func scaledSendWindow(t *testing.T, scale uint8) { defer c.Cleanup() maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize - c.CreateConnectedWithRawOptions(789, 0, -1 /* epRcvBuf */, []byte{ + c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 0, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), header.TCPOptionWS, 3, scale, header.TCPOptionNOP, }) // Open up the window with a scaled value. + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 1, }) @@ -4031,7 +4058,7 @@ func scaledSendWindow(t *testing.T, scale uint8) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -4041,7 +4068,7 @@ func scaledSendWindow(t *testing.T, scale uint8) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagRst, - SeqNum: 790, + SeqNum: iss, }) } @@ -4054,15 +4081,16 @@ func TestScaledSendWindow(t *testing.T) { func TestReceivedValidSegmentCountIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) stats := c.Stack().Stats() want := stats.TCP.ValidSegmentsReceived.Value() + 1 + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790), + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -4083,14 +4111,15 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) { func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) stats := c.Stack().Stats() want := stats.TCP.InvalidSegmentsReceived.Value() + 1 + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) vv := c.BuildSegment(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790), + SeqNum: seqnum.Value(iss), AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -4110,14 +4139,15 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { func TestReceivedIncorrectChecksumIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) stats := c.Stack().Stats() want := stats.TCP.ChecksumErrors.Value() + 1 + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790), + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -4144,23 +4174,24 @@ func TestReceivedSegmentQueuing(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) // Send 200 segments. + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) data := []byte{1, 2, 3} for i := 0; i < 200; i++ { c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + i*len(data)), + SeqNum: iss.Add(seqnum.Size(i * len(data))), AckNum: c.IRS.Add(1), RcvWnd: 30000, }) } // Receive ACKs for all segments. - last := seqnum.Value(790 + 200*len(data)) + last := iss.Add(seqnum.Size(200 * len(data))) for { b := c.GetPacket() checker.IPv4(t, b, @@ -4198,7 +4229,7 @@ func TestReadAfterClosedState(t *testing.T) { t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -4212,12 +4243,13 @@ func TestReadAfterClosedState(t *testing.T) { t.Fatalf("Shutdown failed: %s", err) } + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -4232,7 +4264,7 @@ func TestReadAfterClosedState(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(2), RcvWnd: 30000, }) @@ -4242,7 +4274,7 @@ func TestReadAfterClosedState(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(uint32(791+len(data))), + checker.TCPAckNum(uint32(iss)+uint32(len(data))+1), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -4815,7 +4847,7 @@ func TestPathMTUDiscovery(t *testing.T) { // Create new connection with MSS of 1460. const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize - c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ + c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), }) @@ -4833,6 +4865,7 @@ func TestPathMTUDiscovery(t *testing.T) { receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte { var ret []byte + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) for i, size := range sizes { p := c.GetPacket() if i == which { @@ -4843,7 +4876,7 @@ func TestPathMTUDiscovery(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(seqNum), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -4893,14 +4926,15 @@ func TestTCPEndpointProbe(t *testing.T) { invoked <- struct{}{} }) - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) data := []byte{1, 2, 3} + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -5027,7 +5061,7 @@ func TestEndpointSetCongestionControl(t *testing.T) { } if connected { - c.Connect(789 /* iss */, 32768 /* rcvWnd */, nil) + c.Connect(context.TestInitialSequenceNumber, 32768 /* rcvWnd */, nil) } if err := c.EP.SetSockOpt(&tc.cc); err != tc.err { @@ -5067,7 +5101,7 @@ func TestKeepalive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) const keepAliveIdle = 100 * time.Millisecond const keepAliveInterval = 3 * time.Second @@ -5087,13 +5121,14 @@ func TestKeepalive(t *testing.T) { // 5 unacked keepalives are sent. ACK each one, and check that the // connection stays alive after 5. + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) for i := 0; i < 10; i++ { b := c.GetPacket() checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)), - checker.TCPAckNum(uint32(790)), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -5103,7 +5138,7 @@ func TestKeepalive(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: c.IRS, RcvWnd: 30000, }) @@ -5128,7 +5163,7 @@ func TestKeepalive(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -5140,7 +5175,7 @@ func TestKeepalive(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), ), ) @@ -5153,7 +5188,7 @@ func TestKeepalive(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: seqnum.Value(next), RcvWnd: 30000, }) @@ -5166,7 +5201,7 @@ func TestKeepalive(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(next-1)), - checker.TCPAckNum(uint32(790)), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -5184,7 +5219,7 @@ func TestKeepalive(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: seqnum.Value(next), RcvWnd: 30000, }) @@ -5215,7 +5250,7 @@ func TestKeepalive(t *testing.T) { func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { t.Helper() // Send a SYN request. - irs = seqnum.Value(789) + irs = seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: srcPort, DstPort: context.StackPort, @@ -5260,7 +5295,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { t.Helper() // Send a SYN request. - irs = seqnum.Value(789) + irs = seqnum.Value(context.TestInitialSequenceNumber) c.SendV6Packet(nil, &context.Headers{ SrcPort: srcPort, DstPort: context.StackPort, @@ -5340,7 +5375,7 @@ func TestListenBacklogFull(t *testing.T) { SrcPort: context.TestPort + uint16(lastPortOffset), DstPort: context.StackPort, Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(789), + SeqNum: seqnum.Value(context.TestInitialSequenceNumber), RcvWnd: 30000, }) c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) @@ -5491,7 +5526,7 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) { t.Fatalf("Listen failed: %s", err) } - irs := seqnum.Value(789) + irs := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacketWithAddrs(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -5591,7 +5626,7 @@ func TestListenNoAcceptNonUnicastV6(t *testing.T) { t.Fatalf("Listen failed: %s", err) } - irs := seqnum.Value(789) + irs := seqnum.Value(context.TestInitialSequenceNumber) c.SendV6PacketWithAddrs(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -5645,7 +5680,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { // Send two SYN's the first one should get a SYN-ACK, the // second one should not get any response and is dropped as // the synRcvd count will be equal to backlog. - irs := seqnum.Value(789) + irs := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -5758,7 +5793,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { time.Sleep(50 * time.Millisecond) // Send a SYN request. - irs := seqnum.Value(789) + irs := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ // pick a different src port for new SYN. SrcPort: context.TestPort + 1, @@ -5824,7 +5859,7 @@ func TestSYNRetransmit(t *testing.T) { // Send the same SYN packet multiple times. We should still get a valid SYN-ACK // reply. - irs := seqnum.Value(789) + irs := seqnum.Value(context.TestInitialSequenceNumber) for i := 0; i < 5; i++ { c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, @@ -5867,7 +5902,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { } // Send a SYN to get a SYN-ACK. This should put the ep into SYN-RCVD state - irs := seqnum.Value(789) + irs := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -6051,7 +6086,7 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { SrcPort: srcPort + 2, DstPort: context.StackPort, Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(789), + SeqNum: seqnum.Value(context.TestInitialSequenceNumber), RcvWnd: 30000, }) @@ -6501,7 +6536,7 @@ func TestTCPLingerTimeout(t *testing.T) { c := context.New(t, 1500 /* mtu */) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) testCases := []struct { name string @@ -6552,7 +6587,7 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { } // Send a SYN request. - iss := seqnum.Value(789) + iss := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -6671,7 +6706,7 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { } // Send a SYN request. - iss := seqnum.Value(789) + iss := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -6778,7 +6813,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { } // Send a SYN request. - iss := seqnum.Value(789) + iss := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -6852,7 +6887,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { // Send a SYN request w/ sequence number lower than // the highest sequence number sent. We just reuse // the same number. - iss = seqnum.Value(789) + iss = seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -6872,7 +6907,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { // Send a SYN request w/ sequence number higher than // the highest sequence number sent. - iss = seqnum.Value(792) + iss = iss.Add(3) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -6942,7 +6977,7 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { } // Send a SYN request. - iss := seqnum.Value(789) + iss := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -7091,7 +7126,7 @@ func TestTCPCloseWithData(t *testing.T) { } // Send a SYN request. - iss := seqnum.Value(789) + iss := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -7242,7 +7277,7 @@ func TestTCPUserTimeout(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) waitEntry, notifyCh := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&waitEntry, waiter.EventHUp) @@ -7268,12 +7303,13 @@ func TestTCPUserTimeout(t *testing.T) { } next := uint32(c.IRS) + 1 + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, c.GetPacket(), checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(790), + checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -7299,7 +7335,7 @@ func TestTCPUserTimeout(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: seqnum.Value(next), RcvWnd: 30000, }) @@ -7328,7 +7364,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() @@ -7362,11 +7398,12 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { // Now receive 1 keepalives, but don't ACK it. b := c.GetPacket() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)), - checker.TCPAckNum(uint32(790)), + checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -7383,7 +7420,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: iss, AckNum: seqnum.Value(c.IRS + 1), RcvWnd: 30000, }) @@ -7413,20 +7450,20 @@ func TestIncreaseWindowOnRead(t *testing.T) { defer c.Cleanup() const rcvBuf = 65535 * 10 - c.CreateConnected(789, 30000, rcvBuf) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBuf) // Write chunks of ~30000 bytes. It's important that two // payloads make it equal or longer than MSS. remain := rcvBuf * 2 sent := 0 data := make([]byte, defaultMTU/2) - + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) for remain > len(data) { c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + sent), + SeqNum: iss.Add(seqnum.Size(sent)), AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -7438,7 +7475,7 @@ func TestIncreaseWindowOnRead(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+sent)), + checker.TCPAckNum(uint32(iss)+uint32(sent)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -7469,7 +7506,7 @@ func TestIncreaseWindowOnRead(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+sent)), + checker.TCPAckNum(uint32(iss)+uint32(sent)), checker.TCPWindow(uint16(0xffff)), checker.TCPFlags(header.TCPFlagAck), ), @@ -7483,20 +7520,20 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { defer c.Cleanup() const rcvBuf = 65535 * 10 - c.CreateConnected(789, 30000, rcvBuf) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBuf) // Write chunks of ~30000 bytes. It's important that two // payloads make it equal or longer than MSS. remain := rcvBuf sent := 0 data := make([]byte, defaultMTU/2) - + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) for remain > len(data) { c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + sent), + SeqNum: iss.Add(seqnum.Size(sent)), AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -7507,7 +7544,7 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+sent)), + checker.TCPAckNum(uint32(iss)+uint32(sent)), checker.TCPWindowLessThanEq(0xffff), checker.TCPFlags(header.TCPFlagAck), ), @@ -7523,7 +7560,7 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(790+sent)), + checker.TCPAckNum(uint32(iss)+uint32(sent)), checker.TCPWindow(uint16(0xffff)), checker.TCPFlags(header.TCPFlagAck), ), @@ -7664,16 +7701,16 @@ func TestResetDuringClose(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - iss := seqnum.Value(789) - c.CreateConnected(iss, 30000, -1 /* epRecvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRecvBuf */) // Send some data to make sure there is some unread // data to trigger a reset on c.Close. irs := c.IRS + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: iss.Add(1), + SeqNum: iss, AckNum: irs.Add(1), RcvWnd: 30000, }) @@ -7683,7 +7720,7 @@ func TestResetDuringClose(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), checker.TCPSeqNum(uint32(irs.Add(1))), - checker.TCPAckNum(uint32(iss.Add(5))))) + checker.TCPAckNum(uint32(iss)+4))) // Close in a separate goroutine so that we can trigger // a race with the RST we send below. This should not @@ -7698,7 +7735,7 @@ func TestResetDuringClose(t *testing.T) { c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, - SeqNum: iss.Add(5), + SeqNum: iss.Add(4), AckNum: c.IRS.Add(5), RcvWnd: 30000, Flags: header.TCPFlagRst, diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 5d81dbb94..c8126b51b 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -45,8 +45,8 @@ import ( // represents the remote endpoint. const ( v4MappedAddrPrefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" - stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + stackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + testV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" stackV4MappedAddr = v4MappedAddrPrefix + stackAddr testV4MappedAddr = v4MappedAddrPrefix + testAddr multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr diff --git a/pkg/test/dockerutil/network.go b/pkg/test/dockerutil/network.go index 047091e75..dbe17fa5e 100644 --- a/pkg/test/dockerutil/network.go +++ b/pkg/test/dockerutil/network.go @@ -102,11 +102,8 @@ func (n *Network) Inspect(ctx context.Context) (types.NetworkResource, error) { return n.client.NetworkInspect(ctx, n.id, types.NetworkInspectOptions{Verbose: true}) } -// Cleanup cleans up the docker network and all the containers attached to it. +// Cleanup cleans up the docker network. func (n *Network) Cleanup(ctx context.Context) error { - for _, c := range n.containers { - c.CleanUp(ctx) - } n.containers = nil return n.client.NetworkRemove(ctx, n.id) |