summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/p9/p9.go11
-rw-r--r--pkg/ring0/BUILD1
-rw-r--r--pkg/ring0/defs.go3
-rw-r--r--pkg/ring0/gen_offsets/BUILD1
-rw-r--r--pkg/ring0/kernel_amd64.go12
-rw-r--r--pkg/ring0/kernel_arm64.go4
-rw-r--r--pkg/sentry/arch/BUILD3
-rw-r--r--pkg/sentry/arch/arch.go34
-rw-r--r--pkg/sentry/arch/arch_aarch64.go71
-rw-r--r--pkg/sentry/arch/arch_amd64.go13
-rw-r--r--pkg/sentry/arch/arch_arm64.go13
-rw-r--r--pkg/sentry/arch/arch_state_x86.go54
-rw-r--r--pkg/sentry/arch/arch_x86.go217
-rw-r--r--pkg/sentry/arch/arch_x86_impl.go3
-rw-r--r--pkg/sentry/arch/fpu/BUILD21
-rw-r--r--pkg/sentry/arch/fpu/fpu.go54
-rw-r--r--pkg/sentry/arch/fpu/fpu_amd64.go280
-rw-r--r--pkg/sentry/arch/fpu/fpu_amd64.s (renamed from pkg/sentry/arch/arch_amd64.s)0
-rw-r--r--pkg/sentry/arch/fpu/fpu_arm64.go63
-rw-r--r--pkg/sentry/arch/signal_amd64.go9
-rw-r--r--pkg/sentry/arch/signal_arm64.go7
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached.go21
-rw-r--r--pkg/sentry/fs/gofer/file.go27
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go20
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go20
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go14
-rw-r--r--pkg/sentry/fsimpl/overlay/filesystem.go21
-rw-r--r--pkg/sentry/fsimpl/overlay/overlay.go21
-rw-r--r--pkg/sentry/fsimpl/overlay/regular_file.go43
-rw-r--r--pkg/sentry/kernel/ptrace_amd64.go10
-rw-r--r--pkg/sentry/platform/kvm/BUILD2
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64.go6
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.go10
-rw-r--r--pkg/sentry/platform/kvm/kvm_amd64_test.go2
-rw-r--r--pkg/sentry/platform/kvm/kvm_test.go33
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go13
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64.go3
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go3
-rw-r--r--pkg/sentry/platform/ptrace/BUILD1
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_unsafe.go9
-rw-r--r--pkg/sentry/syscalls/linux/error.go26
-rw-r--r--pkg/syserror/syserror.go10
-rw-r--r--pkg/tcpip/checker/checker.go7
-rw-r--r--pkg/tcpip/header/ndp_options.go156
-rw-r--r--pkg/tcpip/header/ndp_test.go695
-rw-r--r--pkg/tcpip/header/ndpoptionidentifier_string.go32
-rw-r--r--pkg/tcpip/network/arp/arp.go21
-rw-r--r--pkg/tcpip/network/internal/ip/duplicate_address_detection.go139
-rw-r--r--pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go124
-rw-r--r--pkg/tcpip/network/ip_test.go4
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go121
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go60
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go4
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go155
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go16
-rw-r--r--pkg/tcpip/network/ipv6/mld_test.go18
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go14
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go72
-rw-r--r--pkg/tcpip/stack/ndp_test.go79
-rw-r--r--pkg/tcpip/stack/nic.go14
-rw-r--r--pkg/tcpip/stack/registration.go8
-rw-r--r--pkg/tcpip/stack/stack.go31
-rw-r--r--pkg/tcpip/stack/stack_test.go138
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go4
-rw-r--r--pkg/tcpip/tests/integration/BUILD2
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go261
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go148
-rw-r--r--pkg/tcpip/tests/integration/route_test.go132
-rw-r--r--pkg/tcpip/tests/utils/BUILD3
-rw-r--r--pkg/tcpip/tests/utils/utils.go60
-rw-r--r--pkg/tcpip/transport/tcp/accept.go25
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go4
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go2
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go507
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go4
-rw-r--r--pkg/test/dockerutil/network.go5
76 files changed, 2654 insertions, 1600 deletions
diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go
index 2235f8968..648cf4b49 100644
--- a/pkg/p9/p9.go
+++ b/pkg/p9/p9.go
@@ -151,9 +151,16 @@ const (
// Sticky is a mode bit indicating sticky directories.
Sticky FileMode = 01000
+ // SetGID is the set group ID bit.
+ SetGID FileMode = 02000
+
+ // SetUID is the set user ID bit.
+ SetUID FileMode = 04000
+
// permissionsMask is the mask to apply to FileModes for permissions. It
- // includes rwx bits for user, group and others, and sticky bit.
- permissionsMask FileMode = 01777
+ // includes rwx bits for user, group, and others, as well as the sticky
+ // bit, setuid bit, and setgid bit.
+ permissionsMask FileMode = 07777
)
// QIDType is the most significant byte of the FileMode word, to be used as the
diff --git a/pkg/ring0/BUILD b/pkg/ring0/BUILD
index d1b14efdb..885958456 100644
--- a/pkg/ring0/BUILD
+++ b/pkg/ring0/BUILD
@@ -80,6 +80,7 @@ go_library(
"//pkg/ring0/pagetables",
"//pkg/safecopy",
"//pkg/sentry/arch",
+ "//pkg/sentry/arch/fpu",
"//pkg/usermem",
],
)
diff --git a/pkg/ring0/defs.go b/pkg/ring0/defs.go
index e8ce608ba..b6e2012e8 100644
--- a/pkg/ring0/defs.go
+++ b/pkg/ring0/defs.go
@@ -17,6 +17,7 @@ package ring0
import (
"gvisor.dev/gvisor/pkg/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
)
// Kernel is a global kernel object.
@@ -96,7 +97,7 @@ type SwitchOpts struct {
// FloatingPointState is a byte pointer where floating point state is
// saved and restored.
- FloatingPointState arch.FloatingPointData
+ FloatingPointState *fpu.State
// PageTables are the application page tables.
PageTables *pagetables.PageTables
diff --git a/pkg/ring0/gen_offsets/BUILD b/pkg/ring0/gen_offsets/BUILD
index 15b93d61c..f421e1687 100644
--- a/pkg/ring0/gen_offsets/BUILD
+++ b/pkg/ring0/gen_offsets/BUILD
@@ -35,6 +35,7 @@ go_binary(
"//pkg/cpuid",
"//pkg/ring0/pagetables",
"//pkg/sentry/arch",
+ "//pkg/sentry/arch/fpu",
"//pkg/usermem",
],
)
diff --git a/pkg/ring0/kernel_amd64.go b/pkg/ring0/kernel_amd64.go
index e9e706716..33c259757 100644
--- a/pkg/ring0/kernel_amd64.go
+++ b/pkg/ring0/kernel_amd64.go
@@ -239,17 +239,17 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
regs.Ss = uint64(Udata) // Ditto.
// Perform the switch.
- swapgs() // GS will be swapped on return.
- WriteFS(uintptr(regs.Fs_base)) // escapes: no. Set application FS.
- WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS.
- LoadFloatingPoint(&switchOpts.FloatingPointState[0]) // escapes: no. Copy in floating point.
+ swapgs() // GS will be swapped on return.
+ WriteFS(uintptr(regs.Fs_base)) // escapes: no. Set application FS.
+ WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS.
+ LoadFloatingPoint(switchOpts.FloatingPointState.BytePointer()) // escapes: no. Copy in floating point.
if switchOpts.FullRestore {
vector = iret(c, regs, uintptr(userCR3))
} else {
vector = sysret(c, regs, uintptr(userCR3))
}
- SaveFloatingPoint(&switchOpts.FloatingPointState[0]) // escapes: no. Copy out floating point.
- WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS.
+ SaveFloatingPoint(switchOpts.FloatingPointState.BytePointer()) // escapes: no. Copy out floating point.
+ WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS.
return
}
diff --git a/pkg/ring0/kernel_arm64.go b/pkg/ring0/kernel_arm64.go
index c9a120952..7975e5f92 100644
--- a/pkg/ring0/kernel_arm64.go
+++ b/pkg/ring0/kernel_arm64.go
@@ -62,7 +62,7 @@ func IsCanonical(addr uint64) bool {
//go:nosplit
func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
storeAppASID(uintptr(switchOpts.UserASID))
- storeEl0Fpstate(&switchOpts.FloatingPointState[0])
+ storeEl0Fpstate(switchOpts.FloatingPointState.BytePointer())
if switchOpts.Flush {
FlushTlbByASID(uintptr(switchOpts.UserASID))
@@ -82,7 +82,7 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
fpDisableTrap = CPACREL1()
if fpDisableTrap != 0 {
- SaveFloatingPoint(&switchOpts.FloatingPointState[0])
+ SaveFloatingPoint(switchOpts.FloatingPointState.BytePointer())
}
vector = c.vecCode
diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD
index 85278b389..f660f1614 100644
--- a/pkg/sentry/arch/BUILD
+++ b/pkg/sentry/arch/BUILD
@@ -9,7 +9,6 @@ go_library(
"arch.go",
"arch_aarch64.go",
"arch_amd64.go",
- "arch_amd64.s",
"arch_arm64.go",
"arch_state_x86.go",
"arch_x86.go",
@@ -36,8 +35,8 @@ go_library(
"//pkg/log",
"//pkg/marshal",
"//pkg/marshal/primitive",
+ "//pkg/sentry/arch/fpu",
"//pkg/sentry/limits",
- "//pkg/sync",
"//pkg/syserror",
"//pkg/usermem",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go
index 3443b9e1b..921151137 100644
--- a/pkg/sentry/arch/arch.go
+++ b/pkg/sentry/arch/arch.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -50,12 +51,6 @@ func (a Arch) String() string {
}
}
-// FloatingPointData is a generic type, and will always be passed as a pointer.
-// We rely on the individual arch implementations to meet all the necessary
-// requirements. For example, on x86 the region must be 16-byte aligned and 512
-// bytes in size.
-type FloatingPointData []byte
-
// Context provides architecture-dependent information for a specific thread.
//
// NOTE(b/34169503): Currently we use uintptr here to refer to a generic native
@@ -187,7 +182,7 @@ type Context interface {
ClearSingleStep()
// FloatingPointData will be passed to underlying save routines.
- FloatingPointData() FloatingPointData
+ FloatingPointData() *fpu.State
// NewMmapLayout returns a layout for a new MM, where MinAddr for the
// returned layout must be no lower than min, and MaxAddr for the returned
@@ -221,16 +216,6 @@ type Context interface {
// number of bytes read.
PtraceSetRegs(src io.Reader) (int, error)
- // PtraceGetFPRegs implements ptrace(PTRACE_GETFPREGS) by writing the
- // floating-point registers represented by this Context to addr in dst and
- // returning the number of bytes written.
- PtraceGetFPRegs(dst io.Writer) (int, error)
-
- // PtraceSetFPRegs implements ptrace(PTRACE_SETFPREGS) by reading
- // floating-point registers from src into this Context and returning the
- // number of bytes read.
- PtraceSetFPRegs(src io.Reader) (int, error)
-
// PtraceGetRegSet implements ptrace(PTRACE_GETREGSET) by writing the
// register set given by architecture-defined value regset from this
// Context to dst and returning the number of bytes written, which must be
@@ -365,18 +350,3 @@ func (a SyscallArgument) SizeT() uint {
func (a SyscallArgument) ModeT() uint {
return uint(uint16(a.Value))
}
-
-// ErrFloatingPoint indicates a failed restore due to unusable floating point
-// state.
-type ErrFloatingPoint struct {
- // supported is the supported floating point state.
- supported uint64
-
- // saved is the saved floating point state.
- saved uint64
-}
-
-// Error returns a sensible description of the restore error.
-func (e ErrFloatingPoint) Error() string {
- return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supported, e.saved)
-}
diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go
index 6b81e9708..08789f517 100644
--- a/pkg/sentry/arch/arch_aarch64.go
+++ b/pkg/sentry/arch/arch_aarch64.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -40,65 +41,11 @@ type Registers struct {
const (
// SyscallWidth is the width of insturctions.
SyscallWidth = 4
-
- // fpsimdMagic is the magic number which is used in fpsimd_context.
- fpsimdMagic = 0x46508001
-
- // fpsimdContextSize is the size of fpsimd_context.
- fpsimdContextSize = 0x210
)
// ARMTrapFlag is the mask for the trap flag.
const ARMTrapFlag = uint64(1) << 21
-// aarch64FPState is aarch64 floating point state.
-type aarch64FPState []byte
-
-// initAarch64FPState sets up initial state.
-//
-// Related code in Linux kernel: fpsimd_flush_thread().
-// FPCR = FPCR_RM_RN (0x0 << 22).
-//
-// Currently, aarch64FPState is only a space of 0x210 length for fpstate.
-// The fp head is useless in sentry/ptrace/kvm.
-//
-func initAarch64FPState(data aarch64FPState) {
-}
-
-func newAarch64FPStateSlice() []byte {
- return alignedBytes(4096, 16)[:fpsimdContextSize]
-}
-
-// newAarch64FPState returns an initialized floating point state.
-//
-// The returned state is large enough to store all floating point state
-// supported by host, even if the app won't use much of it due to a restricted
-// FeatureSet.
-func newAarch64FPState() aarch64FPState {
- f := aarch64FPState(newAarch64FPStateSlice())
- initAarch64FPState(f)
- return f
-}
-
-// fork creates and returns an identical copy of the aarch64 floating point state.
-func (f aarch64FPState) fork() aarch64FPState {
- n := aarch64FPState(newAarch64FPStateSlice())
- copy(n, f)
- return n
-}
-
-// FloatingPointData returns the raw data pointer.
-func (f aarch64FPState) FloatingPointData() FloatingPointData {
- return ([]byte)(f)
-}
-
-// NewFloatingPointData returns a new floating point data blob.
-//
-// This is primarily for use in tests.
-func NewFloatingPointData() FloatingPointData {
- return ([]byte)(newAarch64FPState())
-}
-
// State contains the common architecture bits for aarch64 (the build tag of this
// file ensures it's only built on aarch64).
//
@@ -108,7 +55,7 @@ type State struct {
Regs Registers
// Our floating point state.
- aarch64FPState `state:"wait"`
+ fpState fpu.State `state:"wait"`
// FeatureSet is a pointer to the currently active feature set.
FeatureSet *cpuid.FeatureSet
@@ -162,10 +109,10 @@ func (s State) Proto() *rpb.Registers {
// Fork creates and returns an identical copy of the state.
func (s *State) Fork() State {
return State{
- Regs: s.Regs,
- aarch64FPState: s.aarch64FPState.fork(),
- FeatureSet: s.FeatureSet,
- OrigR0: s.OrigR0,
+ Regs: s.Regs,
+ fpState: s.fpState.Fork(),
+ FeatureSet: s.FeatureSet,
+ OrigR0: s.OrigR0,
}
}
@@ -318,10 +265,10 @@ func New(arch Arch, fs *cpuid.FeatureSet) Context {
case ARM64:
return &context64{
State{
- aarch64FPState: newAarch64FPState(),
- FeatureSet: fs,
+ fpState: fpu.NewState(),
+ FeatureSet: fs,
},
- []aarch64FPState(nil),
+ []fpu.State(nil),
}
}
panic(fmt.Sprintf("unknown architecture %v", arch))
diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go
index 15d8ddb40..2571be60f 100644
--- a/pkg/sentry/arch/arch_amd64.go
+++ b/pkg/sentry/arch/arch_amd64.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -105,7 +106,7 @@ const (
// +stateify savable
type context64 struct {
State
- sigFPState []x86FPState // fpstate to be restored on sigreturn.
+ sigFPState []fpu.State // fpstate to be restored on sigreturn.
}
// Arch implements Context.Arch.
@@ -113,14 +114,18 @@ func (c *context64) Arch() Arch {
return AMD64
}
-func (c *context64) copySigFPState() []x86FPState {
- var sigfps []x86FPState
+func (c *context64) copySigFPState() []fpu.State {
+ var sigfps []fpu.State
for _, s := range c.sigFPState {
- sigfps = append(sigfps, s.fork())
+ sigfps = append(sigfps, s.Fork())
}
return sigfps
}
+func (c *context64) FloatingPointData() *fpu.State {
+ return &c.State.fpState
+}
+
// Fork returns an exact copy of this context.
func (c *context64) Fork() Context {
return &context64{
diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go
index 0c61a3ff7..14ad9483b 100644
--- a/pkg/sentry/arch/arch_arm64.go
+++ b/pkg/sentry/arch/arch_arm64.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -79,7 +80,7 @@ const (
// +stateify savable
type context64 struct {
State
- sigFPState []aarch64FPState // fpstate to be restored on sigreturn.
+ sigFPState []fpu.State // fpstate to be restored on sigreturn.
}
// Arch implements Context.Arch.
@@ -87,10 +88,10 @@ func (c *context64) Arch() Arch {
return ARM64
}
-func (c *context64) copySigFPState() []aarch64FPState {
- var sigfps []aarch64FPState
+func (c *context64) copySigFPState() []fpu.State {
+ var sigfps []fpu.State
for _, s := range c.sigFPState {
- sigfps = append(sigfps, s.fork())
+ sigfps = append(sigfps, s.Fork())
}
return sigfps
}
@@ -286,3 +287,7 @@ func (c *context64) PtracePokeUser(addr, data uintptr) error {
// TODO(gvisor.dev/issue/1239): Full ptrace supporting for Arm64.
return nil
}
+
+func (c *context64) FloatingPointData() *fpu.State {
+ return &c.State.fpState
+}
diff --git a/pkg/sentry/arch/arch_state_x86.go b/pkg/sentry/arch/arch_state_x86.go
index 840e53d33..b2b94c304 100644
--- a/pkg/sentry/arch/arch_state_x86.go
+++ b/pkg/sentry/arch/arch_state_x86.go
@@ -16,59 +16,7 @@
package arch
-import (
- "gvisor.dev/gvisor/pkg/cpuid"
- "gvisor.dev/gvisor/pkg/usermem"
-)
-
-// XSTATE_BV does not exist if FXSAVE is used, but FXSAVE implicitly saves x87
-// and SSE state, so this is the equivalent XSTATE_BV value.
-const fxsaveBV uint64 = cpuid.XSAVEFeatureX87 | cpuid.XSAVEFeatureSSE
-
// afterLoadFPState is invoked by afterLoad.
func (s *State) afterLoadFPState() {
- old := s.x86FPState
-
- // Recreate the slice. This is done to ensure that it is aligned
- // appropriately in memory, and large enough to accommodate any new
- // state that may be saved by the new CPU. Even if extraneous new state
- // is saved, the state we care about is guaranteed to be a subset of
- // new state. Later optimizations can use less space when using a
- // smaller state component bitmap. Intel SDM Volume 1 Chapter 13 has
- // more info.
- s.x86FPState = newX86FPState()
-
- // x86FPState always contains all the FP state supported by the host.
- // We may have come from a newer machine that supports additional state
- // which we cannot restore.
- //
- // The x86 FP state areas are backwards compatible, so we can simply
- // truncate the additional floating point state.
- //
- // Applications should not depend on the truncated state because it
- // should relate only to features that were not exposed in the app
- // FeatureSet. However, because we do not *prevent* them from using
- // this state, we must verify here that there is no in-use state
- // (according to XSTATE_BV) which we do not support.
- if len(s.x86FPState) < len(old) {
- // What do we support?
- supportedBV := fxsaveBV
- if fs := cpuid.HostFeatureSet(); fs.UseXsave() {
- supportedBV = fs.ValidXCR0Mask()
- }
-
- // What was in use?
- savedBV := fxsaveBV
- if len(old) >= xstateBVOffset+8 {
- savedBV = usermem.ByteOrder.Uint64(old[xstateBVOffset:])
- }
-
- // Supported features must be a superset of saved features.
- if savedBV&^supportedBV != 0 {
- panic(ErrFloatingPoint{supported: supportedBV, saved: savedBV})
- }
- }
-
- // Copy to the new, aligned location.
- copy(s.x86FPState, old)
+ s.fpState.AfterLoad()
}
diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go
index 91edf0703..e8e52d3a8 100644
--- a/pkg/sentry/arch/arch_x86.go
+++ b/pkg/sentry/arch/arch_x86.go
@@ -24,10 +24,9 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto"
- "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// Registers represents the CPU registers for this architecture.
@@ -111,57 +110,6 @@ var (
X86TrapFlag uint64 = (1 << 8)
)
-// x86FPState is x86 floating point state.
-type x86FPState []byte
-
-// initX86FPState (defined in asm files) sets up initial state.
-func initX86FPState(data *byte, useXsave bool)
-
-func newX86FPStateSlice() []byte {
- size, align := cpuid.HostFeatureSet().ExtendedStateSize()
- capacity := size
- // Always use at least 4096 bytes.
- //
- // For the KVM platform, this state is a fixed 4096 bytes, so make sure
- // that the underlying array is at _least_ that size otherwise we will
- // corrupt random memory. This is not a pleasant thing to debug.
- if capacity < 4096 {
- capacity = 4096
- }
- return alignedBytes(capacity, align)[:size]
-}
-
-// newX86FPState returns an initialized floating point state.
-//
-// The returned state is large enough to store all floating point state
-// supported by host, even if the app won't use much of it due to a restricted
-// FeatureSet. Since they may still be able to see state not advertised by
-// CPUID we must ensure it does not contain any sentry state.
-func newX86FPState() x86FPState {
- f := x86FPState(newX86FPStateSlice())
- initX86FPState(&f.FloatingPointData()[0], cpuid.HostFeatureSet().UseXsave())
- return f
-}
-
-// fork creates and returns an identical copy of the x86 floating point state.
-func (f x86FPState) fork() x86FPState {
- n := x86FPState(newX86FPStateSlice())
- copy(n, f)
- return n
-}
-
-// FloatingPointData returns the raw data pointer.
-func (f x86FPState) FloatingPointData() FloatingPointData {
- return []byte(f)
-}
-
-// NewFloatingPointData returns a new floating point data blob.
-//
-// This is primarily for use in tests.
-func NewFloatingPointData() FloatingPointData {
- return (FloatingPointData)(newX86FPState())
-}
-
// Proto returns a protobuf representation of the system registers in State.
func (s State) Proto() *rpb.Registers {
regs := &rpb.AMD64Registers{
@@ -200,7 +148,7 @@ func (s State) Proto() *rpb.Registers {
func (s *State) Fork() State {
return State{
Regs: s.Regs,
- x86FPState: s.x86FPState.fork(),
+ fpState: s.fpState.Fork(),
FeatureSet: s.FeatureSet,
}
}
@@ -393,149 +341,6 @@ func isValidSegmentBase(reg uint64) bool {
return reg < uint64(maxAddr64)
}
-// ptraceFPRegsSize is the size in bytes of Linux's user_i387_struct, the type
-// manipulated by PTRACE_GETFPREGS and PTRACE_SETFPREGS on x86. Equivalently,
-// ptraceFPRegsSize is the size in bytes of the x86 FXSAVE area.
-const ptraceFPRegsSize = 512
-
-// PtraceGetFPRegs implements Context.PtraceGetFPRegs.
-func (s *State) PtraceGetFPRegs(dst io.Writer) (int, error) {
- return dst.Write(s.x86FPState[:ptraceFPRegsSize])
-}
-
-// PtraceSetFPRegs implements Context.PtraceSetFPRegs.
-func (s *State) PtraceSetFPRegs(src io.Reader) (int, error) {
- var f [ptraceFPRegsSize]byte
- n, err := io.ReadFull(src, f[:])
- if err != nil {
- return 0, err
- }
- // Force reserved bits in MXCSR to 0. This is consistent with Linux.
- sanitizeMXCSR(x86FPState(f[:]))
- // N.B. this only copies the beginning of the FP state, which
- // corresponds to the FXSAVE area.
- copy(s.x86FPState, f[:])
- return n, nil
-}
-
-const (
- // mxcsrOffset is the offset in bytes of the MXCSR field from the start of
- // the FXSAVE area. (Intel SDM Vol. 1, Table 10-2 "Format of an FXSAVE
- // Area")
- mxcsrOffset = 24
-
- // mxcsrMaskOffset is the offset in bytes of the MXCSR_MASK field from the
- // start of the FXSAVE area.
- mxcsrMaskOffset = 28
-)
-
-var (
- mxcsrMask uint32
- initMXCSRMask sync.Once
-)
-
-// sanitizeMXCSR coerces reserved bits in the MXCSR field of f to 0. ("FXRSTOR
-// generates a general-protection fault (#GP) in response to an attempt to set
-// any of the reserved bits of the MXCSR register." - Intel SDM Vol. 1, Section
-// 10.5.1.2 "SSE State")
-func sanitizeMXCSR(f x86FPState) {
- mxcsr := usermem.ByteOrder.Uint32(f[mxcsrOffset:])
- initMXCSRMask.Do(func() {
- temp := x86FPState(alignedBytes(uint(ptraceFPRegsSize), 16))
- initX86FPState(&temp.FloatingPointData()[0], false /* useXsave */)
- mxcsrMask = usermem.ByteOrder.Uint32(temp[mxcsrMaskOffset:])
- if mxcsrMask == 0 {
- // "If the value of the MXCSR_MASK field is 00000000H, then the
- // MXCSR_MASK value is the default value of 0000FFBFH." - Intel SDM
- // Vol. 1, Section 11.6.6 "Guidelines for Writing to the MXCSR
- // Register"
- mxcsrMask = 0xffbf
- }
- })
- mxcsr &= mxcsrMask
- usermem.ByteOrder.PutUint32(f[mxcsrOffset:], mxcsr)
-}
-
-const (
- // minXstateBytes is the minimum size in bytes of an x86 XSAVE area, equal
- // to the size of the XSAVE legacy area (512 bytes) plus the size of the
- // XSAVE header (64 bytes). Equivalently, minXstateBytes is GDB's
- // X86_XSTATE_SSE_SIZE.
- minXstateBytes = 512 + 64
-
- // userXstateXCR0Offset is the offset in bytes of the USER_XSTATE_XCR0_WORD
- // field in Linux's struct user_xstateregs, which is the type manipulated
- // by ptrace(PTRACE_GET/SETREGSET, NT_X86_XSTATE). Equivalently,
- // userXstateXCR0Offset is GDB's I386_LINUX_XSAVE_XCR0_OFFSET.
- userXstateXCR0Offset = 464
-
- // xstateBVOffset is the offset in bytes of the XSTATE_BV field in an x86
- // XSAVE area.
- xstateBVOffset = 512
-
- // xsaveHeaderZeroedOffset and xsaveHeaderZeroedBytes indicate parts of the
- // XSAVE header that we coerce to zero: "Bytes 15:8 of the XSAVE header is
- // a state-component bitmap called XCOMP_BV. ... Bytes 63:16 of the XSAVE
- // header are reserved." - Intel SDM Vol. 1, Section 13.4.2 "XSAVE Header".
- // Linux ignores XCOMP_BV, but it's able to recover from XRSTOR #GP
- // exceptions resulting from invalid values; we aren't. Linux also never
- // uses the compacted format when doing XSAVE and doesn't even define the
- // compaction extensions to XSAVE as a CPU feature, so for simplicity we
- // assume no one is using them.
- xsaveHeaderZeroedOffset = 512 + 8
- xsaveHeaderZeroedBytes = 64 - 8
-)
-
-func (s *State) ptraceGetXstateRegs(dst io.Writer, maxlen int) (int, error) {
- // N.B. s.x86FPState may contain more state than the application
- // expects. We only copy the subset that would be in their XSAVE area.
- ess, _ := s.FeatureSet.ExtendedStateSize()
- f := make([]byte, ess)
- copy(f, s.x86FPState)
- // "The XSAVE feature set does not use bytes 511:416; bytes 463:416 are
- // reserved." - Intel SDM Vol 1., Section 13.4.1 "Legacy Region of an XSAVE
- // Area". Linux uses the first 8 bytes of this area to store the OS XSTATE
- // mask. GDB relies on this: see
- // gdb/x86-linux-nat.c:x86_linux_read_description().
- usermem.ByteOrder.PutUint64(f[userXstateXCR0Offset:], s.FeatureSet.ValidXCR0Mask())
- if len(f) > maxlen {
- f = f[:maxlen]
- }
- return dst.Write(f)
-}
-
-func (s *State) ptraceSetXstateRegs(src io.Reader, maxlen int) (int, error) {
- // Allow users to pass an xstate register set smaller than ours (they can
- // mask bits out of XSTATE_BV), as long as it's at least minXstateBytes.
- // Also allow users to pass a register set larger than ours; anything after
- // their ExtendedStateSize will be ignored. (I think Linux technically
- // permits setting a register set smaller than minXstateBytes, but it has
- // the same silent truncation behavior in kernel/ptrace.c:ptrace_regset().)
- if maxlen < minXstateBytes {
- return 0, unix.EFAULT
- }
- ess, _ := s.FeatureSet.ExtendedStateSize()
- if maxlen > int(ess) {
- maxlen = int(ess)
- }
- f := make([]byte, maxlen)
- if _, err := io.ReadFull(src, f); err != nil {
- return 0, err
- }
- // Force reserved bits in MXCSR to 0. This is consistent with Linux.
- sanitizeMXCSR(x86FPState(f))
- // Users can't enable *more* XCR0 bits than what we, and the CPU, support.
- xstateBV := usermem.ByteOrder.Uint64(f[xstateBVOffset:])
- xstateBV &= s.FeatureSet.ValidXCR0Mask()
- usermem.ByteOrder.PutUint64(f[xstateBVOffset:], xstateBV)
- // Force XCOMP_BV and reserved bytes in the XSAVE header to 0.
- reserved := f[xsaveHeaderZeroedOffset : xsaveHeaderZeroedOffset+xsaveHeaderZeroedBytes]
- for i := range reserved {
- reserved[i] = 0
- }
- return copy(s.x86FPState, f), nil
-}
-
// Register sets defined in include/uapi/linux/elf.h.
const (
_NT_PRSTATUS = 1
@@ -552,12 +357,9 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int,
}
return s.PtraceGetRegs(dst)
case _NT_PRFPREG:
- if maxlen < ptraceFPRegsSize {
- return 0, syserror.EFAULT
- }
- return s.PtraceGetFPRegs(dst)
+ return s.fpState.PtraceGetFPRegs(dst, maxlen)
case _NT_X86_XSTATE:
- return s.ptraceGetXstateRegs(dst, maxlen)
+ return s.fpState.PtraceGetXstateRegs(dst, maxlen, s.FeatureSet)
default:
return 0, syserror.EINVAL
}
@@ -572,12 +374,9 @@ func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int,
}
return s.PtraceSetRegs(src)
case _NT_PRFPREG:
- if maxlen < ptraceFPRegsSize {
- return 0, syserror.EFAULT
- }
- return s.PtraceSetFPRegs(src)
+ return s.fpState.PtraceSetFPRegs(src, maxlen)
case _NT_X86_XSTATE:
- return s.ptraceSetXstateRegs(src, maxlen)
+ return s.fpState.PtraceSetXstateRegs(src, maxlen, s.FeatureSet)
default:
return 0, syserror.EINVAL
}
@@ -609,10 +408,10 @@ func New(arch Arch, fs *cpuid.FeatureSet) Context {
case AMD64:
return &context64{
State{
- x86FPState: newX86FPState(),
+ fpState: fpu.NewState(),
FeatureSet: fs,
},
- []x86FPState(nil),
+ []fpu.State(nil),
}
}
panic(fmt.Sprintf("unknown architecture %v", arch))
diff --git a/pkg/sentry/arch/arch_x86_impl.go b/pkg/sentry/arch/arch_x86_impl.go
index 0c73fcbfb..5d7b99bd9 100644
--- a/pkg/sentry/arch/arch_x86_impl.go
+++ b/pkg/sentry/arch/arch_x86_impl.go
@@ -18,6 +18,7 @@ package arch
import (
"gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
)
// State contains the common architecture bits for X86 (the build tag of this
@@ -29,7 +30,7 @@ type State struct {
Regs Registers
// Our floating point state.
- x86FPState `state:"wait"`
+ fpState fpu.State `state:"wait"`
// FeatureSet is a pointer to the currently active feature set.
FeatureSet *cpuid.FeatureSet
diff --git a/pkg/sentry/arch/fpu/BUILD b/pkg/sentry/arch/fpu/BUILD
new file mode 100644
index 000000000..0a5395267
--- /dev/null
+++ b/pkg/sentry/arch/fpu/BUILD
@@ -0,0 +1,21 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "fpu",
+ srcs = [
+ "fpu.go",
+ "fpu_amd64.go",
+ "fpu_amd64.s",
+ "fpu_arm64.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/cpuid",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/arch/fpu/fpu.go b/pkg/sentry/arch/fpu/fpu.go
new file mode 100644
index 000000000..867d309a3
--- /dev/null
+++ b/pkg/sentry/arch/fpu/fpu.go
@@ -0,0 +1,54 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fpu provides basic floating point helpers.
+package fpu
+
+import (
+ "fmt"
+ "reflect"
+)
+
+// State represents floating point state.
+//
+// This is a simple byte slice, but may have architecture-specific methods
+// attached to it.
+type State []byte
+
+// ErrLoadingState indicates a failed restore due to unusable floating point
+// state.
+type ErrLoadingState struct {
+ // supported is the supported floating point state.
+ supportedFeatures uint64
+
+ // saved is the saved floating point state.
+ savedFeatures uint64
+}
+
+// Error returns a sensible description of the restore error.
+func (e ErrLoadingState) Error() string {
+ return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supportedFeatures, e.savedFeatures)
+}
+
+// alignedBytes returns a slice of size bytes, aligned in memory to the given
+// alignment. This is used because we require certain structures to be aligned
+// in a specific way (for example, the X86 floating point data).
+func alignedBytes(size, alignment uint) []byte {
+ data := make([]byte, size+alignment-1)
+ offset := uint(reflect.ValueOf(data).Index(0).Addr().Pointer() % uintptr(alignment))
+ if offset == 0 {
+ return data[:size:size]
+ }
+ return data[alignment-offset:][:size:size]
+}
diff --git a/pkg/sentry/arch/fpu/fpu_amd64.go b/pkg/sentry/arch/fpu/fpu_amd64.go
new file mode 100644
index 000000000..3a62f51be
--- /dev/null
+++ b/pkg/sentry/arch/fpu/fpu_amd64.go
@@ -0,0 +1,280 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64 i386
+
+package fpu
+
+import (
+ "io"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// initX86FPState (defined in asm files) sets up initial state.
+func initX86FPState(data *byte, useXsave bool)
+
+func newX86FPStateSlice() State {
+ size, align := cpuid.HostFeatureSet().ExtendedStateSize()
+ capacity := size
+ // Always use at least 4096 bytes.
+ //
+ // For the KVM platform, this state is a fixed 4096 bytes, so make sure
+ // that the underlying array is at _least_ that size otherwise we will
+ // corrupt random memory. This is not a pleasant thing to debug.
+ if capacity < 4096 {
+ capacity = 4096
+ }
+ return alignedBytes(capacity, align)[:size]
+}
+
+// NewState returns an initialized floating point state.
+//
+// The returned state is large enough to store all floating point state
+// supported by host, even if the app won't use much of it due to a restricted
+// FeatureSet. Since they may still be able to see state not advertised by
+// CPUID we must ensure it does not contain any sentry state.
+func NewState() State {
+ f := newX86FPStateSlice()
+ initX86FPState(&f[0], cpuid.HostFeatureSet().UseXsave())
+ return f
+}
+
+// Fork creates and returns an identical copy of the x86 floating point state.
+func (s *State) Fork() State {
+ n := newX86FPStateSlice()
+ copy(n, *s)
+ return n
+}
+
+// ptraceFPRegsSize is the size in bytes of Linux's user_i387_struct, the type
+// manipulated by PTRACE_GETFPREGS and PTRACE_SETFPREGS on x86. Equivalently,
+// ptraceFPRegsSize is the size in bytes of the x86 FXSAVE area.
+const ptraceFPRegsSize = 512
+
+// PtraceGetFPRegs implements Context.PtraceGetFPRegs.
+func (s *State) PtraceGetFPRegs(dst io.Writer, maxlen int) (int, error) {
+ if maxlen < ptraceFPRegsSize {
+ return 0, syserror.EFAULT
+ }
+
+ return dst.Write((*s)[:ptraceFPRegsSize])
+}
+
+// PtraceSetFPRegs implements Context.PtraceSetFPRegs.
+func (s *State) PtraceSetFPRegs(src io.Reader, maxlen int) (int, error) {
+ if maxlen < ptraceFPRegsSize {
+ return 0, syserror.EFAULT
+ }
+
+ var f [ptraceFPRegsSize]byte
+ n, err := io.ReadFull(src, f[:])
+ if err != nil {
+ return 0, err
+ }
+ // Force reserved bits in MXCSR to 0. This is consistent with Linux.
+ sanitizeMXCSR(State(f[:]))
+ // N.B. this only copies the beginning of the FP state, which
+ // corresponds to the FXSAVE area.
+ copy(*s, f[:])
+ return n, nil
+}
+
+const (
+ // mxcsrOffset is the offset in bytes of the MXCSR field from the start of
+ // the FXSAVE area. (Intel SDM Vol. 1, Table 10-2 "Format of an FXSAVE
+ // Area")
+ mxcsrOffset = 24
+
+ // mxcsrMaskOffset is the offset in bytes of the MXCSR_MASK field from the
+ // start of the FXSAVE area.
+ mxcsrMaskOffset = 28
+)
+
+var (
+ mxcsrMask uint32
+ initMXCSRMask sync.Once
+)
+
+const (
+ // minXstateBytes is the minimum size in bytes of an x86 XSAVE area, equal
+ // to the size of the XSAVE legacy area (512 bytes) plus the size of the
+ // XSAVE header (64 bytes). Equivalently, minXstateBytes is GDB's
+ // X86_XSTATE_SSE_SIZE.
+ minXstateBytes = 512 + 64
+
+ // userXstateXCR0Offset is the offset in bytes of the USER_XSTATE_XCR0_WORD
+ // field in Linux's struct user_xstateregs, which is the type manipulated
+ // by ptrace(PTRACE_GET/SETREGSET, NT_X86_XSTATE). Equivalently,
+ // userXstateXCR0Offset is GDB's I386_LINUX_XSAVE_XCR0_OFFSET.
+ userXstateXCR0Offset = 464
+
+ // xstateBVOffset is the offset in bytes of the XSTATE_BV field in an x86
+ // XSAVE area.
+ xstateBVOffset = 512
+
+ // xsaveHeaderZeroedOffset and xsaveHeaderZeroedBytes indicate parts of the
+ // XSAVE header that we coerce to zero: "Bytes 15:8 of the XSAVE header is
+ // a state-component bitmap called XCOMP_BV. ... Bytes 63:16 of the XSAVE
+ // header are reserved." - Intel SDM Vol. 1, Section 13.4.2 "XSAVE Header".
+ // Linux ignores XCOMP_BV, but it's able to recover from XRSTOR #GP
+ // exceptions resulting from invalid values; we aren't. Linux also never
+ // uses the compacted format when doing XSAVE and doesn't even define the
+ // compaction extensions to XSAVE as a CPU feature, so for simplicity we
+ // assume no one is using them.
+ xsaveHeaderZeroedOffset = 512 + 8
+ xsaveHeaderZeroedBytes = 64 - 8
+)
+
+// sanitizeMXCSR coerces reserved bits in the MXCSR field of f to 0. ("FXRSTOR
+// generates a general-protection fault (#GP) in response to an attempt to set
+// any of the reserved bits of the MXCSR register." - Intel SDM Vol. 1, Section
+// 10.5.1.2 "SSE State")
+func sanitizeMXCSR(f State) {
+ mxcsr := usermem.ByteOrder.Uint32(f[mxcsrOffset:])
+ initMXCSRMask.Do(func() {
+ temp := State(alignedBytes(uint(ptraceFPRegsSize), 16))
+ initX86FPState(&temp[0], false /* useXsave */)
+ mxcsrMask = usermem.ByteOrder.Uint32(temp[mxcsrMaskOffset:])
+ if mxcsrMask == 0 {
+ // "If the value of the MXCSR_MASK field is 00000000H, then the
+ // MXCSR_MASK value is the default value of 0000FFBFH." - Intel SDM
+ // Vol. 1, Section 11.6.6 "Guidelines for Writing to the MXCSR
+ // Register"
+ mxcsrMask = 0xffbf
+ }
+ })
+ mxcsr &= mxcsrMask
+ usermem.ByteOrder.PutUint32(f[mxcsrOffset:], mxcsr)
+}
+
+// PtraceGetXstateRegs implements ptrace(PTRACE_GETREGS, NT_X86_XSTATE) by
+// writing the floating point registers from this state to dst and returning the
+// number of bytes written, which must be less than or equal to maxlen.
+func (s *State) PtraceGetXstateRegs(dst io.Writer, maxlen int, featureSet *cpuid.FeatureSet) (int, error) {
+ // N.B. s.x86FPState may contain more state than the application
+ // expects. We only copy the subset that would be in their XSAVE area.
+ ess, _ := featureSet.ExtendedStateSize()
+ f := make([]byte, ess)
+ copy(f, *s)
+ // "The XSAVE feature set does not use bytes 511:416; bytes 463:416 are
+ // reserved." - Intel SDM Vol 1., Section 13.4.1 "Legacy Region of an XSAVE
+ // Area". Linux uses the first 8 bytes of this area to store the OS XSTATE
+ // mask. GDB relies on this: see
+ // gdb/x86-linux-nat.c:x86_linux_read_description().
+ usermem.ByteOrder.PutUint64(f[userXstateXCR0Offset:], featureSet.ValidXCR0Mask())
+ if len(f) > maxlen {
+ f = f[:maxlen]
+ }
+ return dst.Write(f)
+}
+
+// PtraceSetXstateRegs implements ptrace(PTRACE_SETREGS, NT_X86_XSTATE) by
+// reading floating point registers from src and returning the number of bytes
+// read, which must be less than or equal to maxlen.
+func (s *State) PtraceSetXstateRegs(src io.Reader, maxlen int, featureSet *cpuid.FeatureSet) (int, error) {
+ // Allow users to pass an xstate register set smaller than ours (they can
+ // mask bits out of XSTATE_BV), as long as it's at least minXstateBytes.
+ // Also allow users to pass a register set larger than ours; anything after
+ // their ExtendedStateSize will be ignored. (I think Linux technically
+ // permits setting a register set smaller than minXstateBytes, but it has
+ // the same silent truncation behavior in kernel/ptrace.c:ptrace_regset().)
+ if maxlen < minXstateBytes {
+ return 0, unix.EFAULT
+ }
+ ess, _ := featureSet.ExtendedStateSize()
+ if maxlen > int(ess) {
+ maxlen = int(ess)
+ }
+ f := make([]byte, maxlen)
+ if _, err := io.ReadFull(src, f); err != nil {
+ return 0, err
+ }
+ // Force reserved bits in MXCSR to 0. This is consistent with Linux.
+ sanitizeMXCSR(State(f))
+ // Users can't enable *more* XCR0 bits than what we, and the CPU, support.
+ xstateBV := usermem.ByteOrder.Uint64(f[xstateBVOffset:])
+ xstateBV &= featureSet.ValidXCR0Mask()
+ usermem.ByteOrder.PutUint64(f[xstateBVOffset:], xstateBV)
+ // Force XCOMP_BV and reserved bytes in the XSAVE header to 0.
+ reserved := f[xsaveHeaderZeroedOffset : xsaveHeaderZeroedOffset+xsaveHeaderZeroedBytes]
+ for i := range reserved {
+ reserved[i] = 0
+ }
+ return copy(*s, f), nil
+}
+
+// BytePointer returns a pointer to the first byte of the state.
+//
+//go:nosplit
+func (s *State) BytePointer() *byte {
+ return &(*s)[0]
+}
+
+// XSTATE_BV does not exist if FXSAVE is used, but FXSAVE implicitly saves x87
+// and SSE state, so this is the equivalent XSTATE_BV value.
+const fxsaveBV uint64 = cpuid.XSAVEFeatureX87 | cpuid.XSAVEFeatureSSE
+
+// AfterLoad converts the loaded state to the format that compatible with the
+// current processor.
+func (s *State) AfterLoad() {
+ old := *s
+
+ // Recreate the slice. This is done to ensure that it is aligned
+ // appropriately in memory, and large enough to accommodate any new
+ // state that may be saved by the new CPU. Even if extraneous new state
+ // is saved, the state we care about is guaranteed to be a subset of
+ // new state. Later optimizations can use less space when using a
+ // smaller state component bitmap. Intel SDM Volume 1 Chapter 13 has
+ // more info.
+ *s = NewState()
+
+ // x86FPState always contains all the FP state supported by the host.
+ // We may have come from a newer machine that supports additional state
+ // which we cannot restore.
+ //
+ // The x86 FP state areas are backwards compatible, so we can simply
+ // truncate the additional floating point state.
+ //
+ // Applications should not depend on the truncated state because it
+ // should relate only to features that were not exposed in the app
+ // FeatureSet. However, because we do not *prevent* them from using
+ // this state, we must verify here that there is no in-use state
+ // (according to XSTATE_BV) which we do not support.
+ if len(*s) < len(old) {
+ // What do we support?
+ supportedBV := fxsaveBV
+ if fs := cpuid.HostFeatureSet(); fs.UseXsave() {
+ supportedBV = fs.ValidXCR0Mask()
+ }
+
+ // What was in use?
+ savedBV := fxsaveBV
+ if len(old) >= xstateBVOffset+8 {
+ savedBV = usermem.ByteOrder.Uint64(old[xstateBVOffset:])
+ }
+
+ // Supported features must be a superset of saved features.
+ if savedBV&^supportedBV != 0 {
+ panic(ErrLoadingState{supportedFeatures: supportedBV, savedFeatures: savedBV})
+ }
+ }
+
+ // Copy to the new, aligned location.
+ copy(*s, old)
+}
diff --git a/pkg/sentry/arch/arch_amd64.s b/pkg/sentry/arch/fpu/fpu_amd64.s
index 6c10336e7..6c10336e7 100644
--- a/pkg/sentry/arch/arch_amd64.s
+++ b/pkg/sentry/arch/fpu/fpu_amd64.s
diff --git a/pkg/sentry/arch/fpu/fpu_arm64.go b/pkg/sentry/arch/fpu/fpu_arm64.go
new file mode 100644
index 000000000..d2f62631d
--- /dev/null
+++ b/pkg/sentry/arch/fpu/fpu_arm64.go
@@ -0,0 +1,63 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package fpu
+
+const (
+ // fpsimdMagic is the magic number which is used in fpsimd_context.
+ fpsimdMagic = 0x46508001
+
+ // fpsimdContextSize is the size of fpsimd_context.
+ fpsimdContextSize = 0x210
+)
+
+// initAarch64FPState sets up initial state.
+//
+// Related code in Linux kernel: fpsimd_flush_thread().
+// FPCR = FPCR_RM_RN (0x0 << 22).
+//
+// Currently, aarch64FPState is only a space of 0x210 length for fpstate.
+// The fp head is useless in sentry/ptrace/kvm.
+//
+func initAarch64FPState(data *State) {
+}
+
+func newAarch64FPStateSlice() []byte {
+ return alignedBytes(4096, 16)[:fpsimdContextSize]
+}
+
+// NewState returns an initialized floating point state.
+//
+// The returned state is large enough to store all floating point state
+// supported by host, even if the app won't use much of it due to a restricted
+// FeatureSet.
+func NewState() State {
+ f := State(newAarch64FPStateSlice())
+ initAarch64FPState(&f)
+ return f
+}
+
+// Fork creates and returns an identical copy of the aarch64 floating point state.
+func (s *State) Fork() State {
+ n := State(newAarch64FPStateSlice())
+ copy(n, *s)
+ return n
+}
+
+// BytePointer returns a pointer to the first byte of the state.
+func (s *State) BytePointer() *byte {
+ return &(*s)[0]
+}
diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go
index e6557cab6..ee3743483 100644
--- a/pkg/sentry/arch/signal_amd64.go
+++ b/pkg/sentry/arch/signal_amd64.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/marshal/primitive"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -98,7 +99,7 @@ func (c *context64) NewSignalStack() NativeSignalStack {
const _FP_XSTATE_MAGIC2_SIZE = 4
func (c *context64) fpuFrameSize() (size int, useXsave bool) {
- size = len(c.x86FPState)
+ size = len(c.fpState)
if size > 512 {
// Make room for the magic cookie at the end of the xsave frame.
size += _FP_XSTATE_MAGIC2_SIZE
@@ -226,10 +227,10 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
c.Regs.Ss = userDS
// Save the thread's floating point state.
- c.sigFPState = append(c.sigFPState, c.x86FPState)
+ c.sigFPState = append(c.sigFPState, c.fpState)
// Signal handler gets a clean floating point state.
- c.x86FPState = newX86FPState()
+ c.fpState = fpu.NewState()
return nil
}
@@ -273,7 +274,7 @@ func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalSt
// Restore floating point state.
l := len(c.sigFPState)
if l > 0 {
- c.x86FPState = c.sigFPState[l-1]
+ c.fpState = c.sigFPState[l-1]
// NOTE(cl/133042258): State save requires that any slice
// elements from '[len:cap]' to be zero value.
c.sigFPState[l-1] = nil
diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go
index 4491008c2..53281dcba 100644
--- a/pkg/sentry/arch/signal_arm64.go
+++ b/pkg/sentry/arch/signal_arm64.go
@@ -20,6 +20,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -139,9 +140,9 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
c.Regs.Regs[30] = uint64(act.Restorer)
// Save the thread's floating point state.
- c.sigFPState = append(c.sigFPState, c.aarch64FPState)
+ c.sigFPState = append(c.sigFPState, c.fpState)
// Signal handler gets a clean floating point state.
- c.aarch64FPState = newAarch64FPState()
+ c.fpState = fpu.NewState()
return nil
}
@@ -166,7 +167,7 @@ func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalSt
// Restore floating point state.
l := len(c.sigFPState)
if l > 0 {
- c.aarch64FPState = c.sigFPState[l-1]
+ c.fpState = c.sigFPState[l-1]
// NOTE(cl/133042258): State save requires that any slice
// elements from '[len:cap]' to be zero value.
c.sigFPState[l-1] = nil
diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go
index 82eda3e43..0ed7aafa5 100644
--- a/pkg/sentry/fs/fsutil/inode_cached.go
+++ b/pkg/sentry/fs/fsutil/inode_cached.go
@@ -380,16 +380,17 @@ func (c *CachingInodeOperations) Allocate(ctx context.Context, offset, length in
return nil
}
-// WriteOut implements fs.InodeOperations.WriteOut.
-func (c *CachingInodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) error {
+// WriteDirtyPagesAndAttrs will write the dirty pages and attributes to the
+// gofer without calling Fsync on the remote file.
+func (c *CachingInodeOperations) WriteDirtyPagesAndAttrs(ctx context.Context, inode *fs.Inode) error {
c.attrMu.Lock()
+ defer c.attrMu.Unlock()
+ c.dataMu.Lock()
+ defer c.dataMu.Unlock()
// Write dirty pages back.
- c.dataMu.Lock()
err := SyncDirtyAll(ctx, &c.cache, &c.dirty, uint64(c.attr.Size), c.mfp.MemoryFile(), c.backingFile.WriteFromBlocksAt)
- c.dataMu.Unlock()
if err != nil {
- c.attrMu.Unlock()
return err
}
@@ -399,12 +400,18 @@ func (c *CachingInodeOperations) WriteOut(ctx context.Context, inode *fs.Inode)
// Write out cached attributes.
if err := c.backingFile.SetMaskedAttributes(ctx, c.dirtyAttr, c.attr, false); err != nil {
- c.attrMu.Unlock()
return err
}
c.dirtyAttr = fs.AttrMask{}
- c.attrMu.Unlock()
+ return nil
+}
+
+// WriteOut implements fs.InodeOperations.WriteOut.
+func (c *CachingInodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) error {
+ if err := c.WriteDirtyPagesAndAttrs(ctx, inode); err != nil {
+ return err
+ }
// Fsync the remote file.
return c.backingFile.Sync(ctx)
diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go
index 06d450ba6..8f5a87120 100644
--- a/pkg/sentry/fs/gofer/file.go
+++ b/pkg/sentry/fs/gofer/file.go
@@ -204,20 +204,8 @@ func (f *fileOperations) readdirAll(ctx context.Context) (map[string]fs.DentAttr
return entries, nil
}
-// maybeSync will call FSync on the file if either the cache policy or file
-// flags require it.
+// maybeSync will call FSync on the file if the file flags require it.
func (f *fileOperations) maybeSync(ctx context.Context, file *fs.File, offset, n int64) error {
- if n == 0 {
- // Nothing to sync.
- return nil
- }
-
- if f.inodeOperations.session().cachePolicy.writeThrough(file.Dirent.Inode) {
- // Call WriteOut directly, as some "writethrough" filesystems
- // do not support sync.
- return f.inodeOperations.cachingInodeOps.WriteOut(ctx, file.Dirent.Inode)
- }
-
flags := file.Flags()
var syncType fs.SyncType
switch {
@@ -254,6 +242,19 @@ func (f *fileOperations) Write(ctx context.Context, file *fs.File, src usermem.I
n, err = src.CopyInTo(ctx, f.handles.readWriterAt(ctx, offset))
}
+ if n == 0 {
+ // Nothing written. We are done.
+ return 0, err
+ }
+
+ // Write the dirty pages and attributes if cache policy tells us to.
+ if f.inodeOperations.session().cachePolicy.writeThrough(file.Dirent.Inode) {
+ if werr := f.inodeOperations.cachingInodeOps.WriteDirtyPagesAndAttrs(ctx, file.Dirent.Inode); werr != nil {
+ // Report no bytes written since the write faild.
+ return 0, werr
+ }
+ }
+
// We may need to sync the written bytes.
if syncErr := f.maybeSync(ctx, file, offset, n); syncErr != nil {
// Sync failed. Report 0 bytes written, since none of them are
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index c34451269..43c3c5a2d 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -783,7 +783,15 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
creds := rp.Credentials()
return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string, _ **[]*dentry) error {
- if _, err := parent.file.mkdir(ctx, name, (p9.FileMode)(opts.Mode), (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)); err != nil {
+ // If the parent is a setgid directory, use the parent's GID
+ // rather than the caller's and enable setgid.
+ kgid := creds.EffectiveKGID
+ mode := opts.Mode
+ if atomic.LoadUint32(&parent.mode)&linux.S_ISGID != 0 {
+ kgid = auth.KGID(atomic.LoadUint32(&parent.gid))
+ mode |= linux.S_ISGID
+ }
+ if _, err := parent.file.mkdir(ctx, name, p9.FileMode(mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid)); err != nil {
if !opts.ForSyntheticMountpoint || err == syserror.EEXIST {
return err
}
@@ -1145,7 +1153,15 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
name := rp.Component()
// We only want the access mode for creating the file.
createFlags := p9.OpenFlags(opts.Flags) & p9.OpenFlagsModeMask
- fdobj, openFile, createQID, _, err := dirfile.create(ctx, name, createFlags, (p9.FileMode)(opts.Mode), (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
+
+ // If the parent is a setgid directory, use the parent's GID rather
+ // than the caller's.
+ kgid := creds.EffectiveKGID
+ if atomic.LoadUint32(&d.mode)&linux.S_ISGID != 0 {
+ kgid = auth.KGID(atomic.LoadUint32(&d.gid))
+ }
+
+ fdobj, openFile, createQID, _, err := dirfile.create(ctx, name, createFlags, p9.FileMode(opts.Mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid))
if err != nil {
dirfile.close(ctx)
return nil, err
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 71569dc65..692da02c1 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -1102,10 +1102,26 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
d.metadataMu.Lock()
defer d.metadataMu.Unlock()
+
+ // As with Linux, if the UID, GID, or file size is changing, we have to
+ // clear permission bits. Note that when set, clearSGID causes
+ // permissions to be updated, but does not modify stat.Mask, as
+ // modification would cause an extra inotify flag to be set.
+ clearSGID := stat.Mask&linux.STATX_UID != 0 && stat.UID != atomic.LoadUint32(&d.uid) ||
+ stat.Mask&linux.STATX_GID != 0 && stat.GID != atomic.LoadUint32(&d.gid) ||
+ stat.Mask&linux.STATX_SIZE != 0
+ if clearSGID {
+ if stat.Mask&linux.STATX_MODE != 0 {
+ stat.Mode = uint16(vfs.ClearSUIDAndSGID(uint32(stat.Mode)))
+ } else {
+ stat.Mode = uint16(vfs.ClearSUIDAndSGID(atomic.LoadUint32(&d.mode)))
+ }
+ }
+
if !d.isSynthetic() {
if stat.Mask != 0 {
if err := d.file.setAttr(ctx, p9.SetAttrMask{
- Permissions: stat.Mask&linux.STATX_MODE != 0,
+ Permissions: stat.Mask&linux.STATX_MODE != 0 || clearSGID,
UID: stat.Mask&linux.STATX_UID != 0,
GID: stat.Mask&linux.STATX_GID != 0,
Size: stat.Mask&linux.STATX_SIZE != 0,
@@ -1140,7 +1156,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
return nil
}
}
- if stat.Mask&linux.STATX_MODE != 0 {
+ if stat.Mask&linux.STATX_MODE != 0 || clearSGID {
atomic.StoreUint32(&d.mode, d.fileType()|uint32(stat.Mode))
}
if stat.Mask&linux.STATX_UID != 0 {
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index 283b220bb..4f1ad0c88 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -266,6 +266,20 @@ func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off
return 0, offset, err
}
}
+
+ // As with Linux, writing clears the setuid and setgid bits.
+ if n > 0 {
+ oldMode := atomic.LoadUint32(&d.mode)
+ // If setuid or setgid were set, update d.mode and propagate
+ // changes to the host.
+ if newMode := vfs.ClearSUIDAndSGID(oldMode); newMode != oldMode {
+ atomic.StoreUint32(&d.mode, newMode)
+ if err := d.file.setAttr(ctx, p9.SetAttrMask{Permissions: true}, p9.SetAttr{Permissions: p9.FileMode(newMode)}); err != nil {
+ return 0, offset, err
+ }
+ }
+ }
+
return n, offset + n, nil
}
diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go
index 84e37f793..46c500427 100644
--- a/pkg/sentry/fsimpl/overlay/filesystem.go
+++ b/pkg/sentry/fsimpl/overlay/filesystem.go
@@ -689,13 +689,9 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
}
return err
}
- creds := rp.Credentials()
+
if err := vfsObj.SetStatAt(ctx, fs.creds, &pop, &vfs.SetStatOptions{
- Stat: linux.Statx{
- Mask: linux.STATX_UID | linux.STATX_GID,
- UID: uint32(creds.EffectiveKUID),
- GID: uint32(creds.EffectiveKGID),
- },
+ Stat: parent.newChildOwnerStat(opts.Mode, rp.Credentials()),
}); err != nil {
if cleanupErr := vfsObj.RmdirAt(ctx, fs.creds, &pop); cleanupErr != nil {
panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer directory after MkdirAt metadata update failure: %v", cleanupErr))
@@ -750,11 +746,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
}
creds := rp.Credentials()
if err := vfsObj.SetStatAt(ctx, fs.creds, &pop, &vfs.SetStatOptions{
- Stat: linux.Statx{
- Mask: linux.STATX_UID | linux.STATX_GID,
- UID: uint32(creds.EffectiveKUID),
- GID: uint32(creds.EffectiveKGID),
- },
+ Stat: parent.newChildOwnerStat(opts.Mode, creds),
}); err != nil {
if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil {
panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after MknodAt metadata update failure: %v", cleanupErr))
@@ -963,14 +955,11 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving
}
return nil, err
}
+
// Change the file's owner to the caller. We can't use upperFD.SetStat()
// because it will pick up creds from ctx.
if err := vfsObj.SetStatAt(ctx, fs.creds, &pop, &vfs.SetStatOptions{
- Stat: linux.Statx{
- Mask: linux.STATX_UID | linux.STATX_GID,
- UID: uint32(creds.EffectiveKUID),
- GID: uint32(creds.EffectiveKGID),
- },
+ Stat: parent.newChildOwnerStat(opts.Mode, creds),
}); err != nil {
if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil {
panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) metadata update failure: %v", cleanupErr))
diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go
index 58680bc80..454c20d4f 100644
--- a/pkg/sentry/fsimpl/overlay/overlay.go
+++ b/pkg/sentry/fsimpl/overlay/overlay.go
@@ -749,6 +749,27 @@ func (d *dentry) mayDelete(creds *auth.Credentials, child *dentry) error {
)
}
+// newChildOwnerStat returns a Statx for configuring the UID, GID, and mode of
+// children.
+func (d *dentry) newChildOwnerStat(mode linux.FileMode, creds *auth.Credentials) linux.Statx {
+ stat := linux.Statx{
+ Mask: uint32(linux.STATX_UID | linux.STATX_GID),
+ UID: uint32(creds.EffectiveKUID),
+ GID: uint32(creds.EffectiveKGID),
+ }
+ // Set GID and possibly the SGID bit if the parent is an SGID directory.
+ d.copyMu.RLock()
+ defer d.copyMu.RUnlock()
+ if atomic.LoadUint32(&d.mode)&linux.ModeSetGID == linux.ModeSetGID {
+ stat.GID = atomic.LoadUint32(&d.gid)
+ if stat.Mode&linux.ModeDirectory == linux.ModeDirectory {
+ stat.Mode = uint16(mode) | linux.ModeSetGID
+ stat.Mask |= linux.STATX_MODE
+ }
+ }
+ return stat
+}
+
// fileDescription is embedded by overlay implementations of
// vfs.FileDescriptionImpl.
//
diff --git a/pkg/sentry/fsimpl/overlay/regular_file.go b/pkg/sentry/fsimpl/overlay/regular_file.go
index 25c785fd4..d791c06db 100644
--- a/pkg/sentry/fsimpl/overlay/regular_file.go
+++ b/pkg/sentry/fsimpl/overlay/regular_file.go
@@ -205,6 +205,20 @@ func (fd *regularFileFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) e
if err := wrappedFD.SetStat(ctx, opts); err != nil {
return err
}
+
+ // Changing owners may clear one or both of the setuid and setgid bits,
+ // so we may have to update opts before setting d.mode.
+ if opts.Stat.Mask&(linux.STATX_UID|linux.STATX_GID) != 0 {
+ stat, err := wrappedFD.Stat(ctx, vfs.StatOptions{
+ Mask: linux.STATX_MODE,
+ })
+ if err != nil {
+ return err
+ }
+ opts.Stat.Mode = stat.Mode
+ opts.Stat.Mask |= linux.STATX_MODE
+ }
+
d.updateAfterSetStatLocked(&opts)
if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent)
@@ -295,7 +309,11 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
return 0, err
}
defer wrappedFD.DecRef(ctx)
- return wrappedFD.PWrite(ctx, src, offset, opts)
+ n, err := wrappedFD.PWrite(ctx, src, offset, opts)
+ if err != nil {
+ return n, err
+ }
+ return fd.updateSetUserGroupIDs(ctx, wrappedFD, n)
}
// Write implements vfs.FileDescriptionImpl.Write.
@@ -307,7 +325,28 @@ func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts
if err != nil {
return 0, err
}
- return wrappedFD.Write(ctx, src, opts)
+ n, err := wrappedFD.Write(ctx, src, opts)
+ if err != nil {
+ return n, err
+ }
+ return fd.updateSetUserGroupIDs(ctx, wrappedFD, n)
+}
+
+func (fd *regularFileFD) updateSetUserGroupIDs(ctx context.Context, wrappedFD *vfs.FileDescription, written int64) (int64, error) {
+ // Writing can clear the setuid and/or setgid bits. We only have to
+ // check this if something was written and one of those bits was set.
+ dentry := fd.dentry()
+ if written == 0 || atomic.LoadUint32(&dentry.mode)&(linux.S_ISUID|linux.S_ISGID) == 0 {
+ return written, nil
+ }
+ stat, err := wrappedFD.Stat(ctx, vfs.StatOptions{Mask: linux.STATX_MODE})
+ if err != nil {
+ return written, err
+ }
+ dentry.copyMu.Lock()
+ defer dentry.copyMu.Unlock()
+ atomic.StoreUint32(&dentry.mode, uint32(stat.Mode))
+ return written, nil
}
// Seek implements vfs.FileDescriptionImpl.Seek.
diff --git a/pkg/sentry/kernel/ptrace_amd64.go b/pkg/sentry/kernel/ptrace_amd64.go
index 609ad3941..7aea3dcd8 100644
--- a/pkg/sentry/kernel/ptrace_amd64.go
+++ b/pkg/sentry/kernel/ptrace_amd64.go
@@ -51,14 +51,15 @@ func (t *Task) ptraceArch(target *Task, req int64, addr, data usermem.Addr) erro
return err
case linux.PTRACE_GETFPREGS:
- _, err := target.Arch().PtraceGetFPRegs(&usermem.IOReadWriter{
+ s := target.Arch().FloatingPointData()
+ _, err := target.Arch().FloatingPointData().PtraceGetFPRegs(&usermem.IOReadWriter{
Ctx: t,
IO: t.MemoryManager(),
Addr: data,
Opts: usermem.IOOpts{
AddressSpaceActive: true,
},
- })
+ }, len(*s))
return err
case linux.PTRACE_SETREGS:
@@ -73,14 +74,15 @@ func (t *Task) ptraceArch(target *Task, req int64, addr, data usermem.Addr) erro
return err
case linux.PTRACE_SETFPREGS:
- _, err := target.Arch().PtraceSetFPRegs(&usermem.IOReadWriter{
+ s := target.Arch().FloatingPointData()
+ _, err := s.PtraceSetFPRegs(&usermem.IOReadWriter{
Ctx: t,
IO: t.MemoryManager(),
Addr: data,
Opts: usermem.IOOpts{
AddressSpaceActive: true,
},
- })
+ }, len(*s))
return err
default:
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index 4f9e781af..03a76eb9b 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -50,6 +50,7 @@ go_library(
"//pkg/safecopy",
"//pkg/seccomp",
"//pkg/sentry/arch",
+ "//pkg/sentry/arch/fpu",
"//pkg/sentry/memmap",
"//pkg/sentry/platform",
"//pkg/sentry/platform/interrupt",
@@ -78,6 +79,7 @@ go_test(
"//pkg/ring0",
"//pkg/ring0/pagetables",
"//pkg/sentry/arch",
+ "//pkg/sentry/arch/fpu",
"//pkg/sentry/platform",
"//pkg/sentry/platform/kvm/testutil",
"//pkg/sentry/time",
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.go b/pkg/sentry/platform/kvm/bluepill_amd64.go
index 308696efe..d761bbdee 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64.go
@@ -73,7 +73,7 @@ func (c *vCPU) KernelSyscall() {
// We only trigger a bluepill entry in the bluepill function, and can
// therefore be guaranteed that there is no floating point state to be
// loaded on resuming from halt. We only worry about saving on exit.
- ring0.SaveFloatingPoint(&c.floatingPointState[0]) // escapes: no.
+ ring0.SaveFloatingPoint(c.floatingPointState.BytePointer()) // escapes: no.
ring0.Halt()
ring0.WriteFS(uintptr(regs.Fs_base)) // escapes: no, reload host segment.
}
@@ -92,7 +92,7 @@ func (c *vCPU) KernelException(vector ring0.Vector) {
regs.Rip = 0
}
// See above.
- ring0.SaveFloatingPoint(&c.floatingPointState[0]) // escapes: no.
+ ring0.SaveFloatingPoint(c.floatingPointState.BytePointer()) // escapes: no.
ring0.Halt()
ring0.WriteFS(uintptr(regs.Fs_base)) // escapes: no; reload host segment.
}
@@ -124,5 +124,5 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) {
// Set the context pointer to the saved floating point state. This is
// where the guest data has been serialized, the kernel will restore
// from this new pointer value.
- context.Fpstate = uint64(uintptrValue(&c.floatingPointState[0]))
+ context.Fpstate = uint64(uintptrValue(c.floatingPointState.BytePointer()))
}
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go
index c317f1e99..578852c3f 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.go
@@ -92,7 +92,7 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) {
lazyVfp := c.GetLazyVFP()
if lazyVfp != 0 {
- fpsimd := fpsimdPtr(&c.floatingPointState[0])
+ fpsimd := fpsimdPtr(c.floatingPointState.BytePointer())
context.Fpsimd64.Fpsr = fpsimd.Fpsr
context.Fpsimd64.Fpcr = fpsimd.Fpcr
context.Fpsimd64.Vregs = fpsimd.Vregs
@@ -112,12 +112,12 @@ func (c *vCPU) KernelSyscall() {
fpDisableTrap := ring0.CPACREL1()
if fpDisableTrap != 0 {
- fpsimd := fpsimdPtr(&c.floatingPointState[0])
+ fpsimd := fpsimdPtr(c.floatingPointState.BytePointer())
fpcr := ring0.GetFPCR()
fpsr := ring0.GetFPSR()
fpsimd.Fpcr = uint32(fpcr)
fpsimd.Fpsr = uint32(fpsr)
- ring0.SaveVRegs(&c.floatingPointState[0])
+ ring0.SaveVRegs(c.floatingPointState.BytePointer())
}
ring0.Halt()
@@ -136,12 +136,12 @@ func (c *vCPU) KernelException(vector ring0.Vector) {
fpDisableTrap := ring0.CPACREL1()
if fpDisableTrap != 0 {
- fpsimd := fpsimdPtr(&c.floatingPointState[0])
+ fpsimd := fpsimdPtr(c.floatingPointState.BytePointer())
fpcr := ring0.GetFPCR()
fpsr := ring0.GetFPSR()
fpsimd.Fpcr = uint32(fpcr)
fpsimd.Fpsr = uint32(fpsr)
- ring0.SaveVRegs(&c.floatingPointState[0])
+ ring0.SaveVRegs(c.floatingPointState.BytePointer())
}
ring0.Halt()
diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.go b/pkg/sentry/platform/kvm/kvm_amd64_test.go
index 76fc594a0..e44e995a0 100644
--- a/pkg/sentry/platform/kvm/kvm_amd64_test.go
+++ b/pkg/sentry/platform/kvm/kvm_amd64_test.go
@@ -33,7 +33,7 @@ func TestSegments(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
FullRestore: true,
}, &si); err == platform.ErrContextInterrupt {
diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go
index 6243b9a04..5bce16dde 100644
--- a/pkg/sentry/platform/kvm/kvm_test.go
+++ b/pkg/sentry/platform/kvm/kvm_test.go
@@ -25,13 +25,14 @@ import (
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil"
ktime "gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/usermem"
)
-var dummyFPState = (*byte)(arch.NewFloatingPointData())
+var dummyFPState = fpu.NewState()
type testHarness interface {
Errorf(format string, args ...interface{})
@@ -159,7 +160,7 @@ func TestApplicationSyscall(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
FullRestore: true,
}, &si); err == platform.ErrContextInterrupt {
@@ -173,7 +174,7 @@ func TestApplicationSyscall(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
}, &si); err == platform.ErrContextInterrupt {
return true // Retry.
@@ -190,7 +191,7 @@ func TestApplicationFault(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
FullRestore: true,
}, &si); err == platform.ErrContextInterrupt {
@@ -205,7 +206,7 @@ func TestApplicationFault(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
}, &si); err == platform.ErrContextInterrupt {
return true // Retry.
@@ -223,7 +224,7 @@ func TestRegistersSyscall(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
}, &si); err == platform.ErrContextInterrupt {
continue // Retry.
@@ -246,7 +247,7 @@ func TestRegistersFault(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
FullRestore: true,
}, &si); err == platform.ErrContextInterrupt {
@@ -272,7 +273,7 @@ func TestBounce(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
}, &si); err != platform.ErrContextInterrupt {
t.Errorf("application partial restore: got %v, wanted %v", err, platform.ErrContextInterrupt)
@@ -287,7 +288,7 @@ func TestBounce(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
FullRestore: true,
}, &si); err != platform.ErrContextInterrupt {
@@ -319,7 +320,7 @@ func TestBounceStress(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
}, &si); err != platform.ErrContextInterrupt {
t.Errorf("application partial restore: got %v, wanted %v", err, platform.ErrContextInterrupt)
@@ -340,7 +341,7 @@ func TestInvalidate(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
}, &si); err == platform.ErrContextInterrupt {
continue // Retry.
@@ -355,7 +356,7 @@ func TestInvalidate(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
Flush: true,
}, &si); err == platform.ErrContextInterrupt {
@@ -379,7 +380,7 @@ func TestEmptyAddressSpace(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
}, &si); err == platform.ErrContextInterrupt {
return true // Retry.
@@ -393,7 +394,7 @@ func TestEmptyAddressSpace(t *testing.T) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
FullRestore: true,
}, &si); err == platform.ErrContextInterrupt {
@@ -469,7 +470,7 @@ func BenchmarkApplicationSyscall(b *testing.B) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
}, &si); err == platform.ErrContextInterrupt {
a++
@@ -506,7 +507,7 @@ func BenchmarkWorldSwitchToUserRoundtrip(b *testing.B) {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
- FloatingPointState: dummyFPState,
+ FloatingPointState: &dummyFPState,
PageTables: pt,
}, &si); err == platform.ErrContextInterrupt {
a++
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index 916903881..3af96c7e5 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -27,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
ktime "gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/usermem"
@@ -70,7 +71,7 @@ type vCPUArchState struct {
// floatingPointState is the floating point state buffer used in guest
// to host transitions. See usage in bluepill_amd64.go.
- floatingPointState arch.FloatingPointData
+ floatingPointState fpu.State
}
const (
@@ -151,7 +152,7 @@ func (c *vCPU) initArchState() error {
// This will be saved prior to leaving the guest, and we restore from
// this always. We cannot use the pointer in the context alone because
// we don't know how large the area there is in reality.
- c.floatingPointState = arch.NewFloatingPointData()
+ c.floatingPointState = fpu.NewState()
// Set the time offset to the host native time.
return c.setSystemTime()
@@ -307,12 +308,12 @@ func loadByte(ptr *byte) byte {
// emulate instructions like xsave and xrstor.
//
//go:nosplit
-func prefaultFloatingPointState(data arch.FloatingPointData) {
- size := len(data)
+func prefaultFloatingPointState(data *fpu.State) {
+ size := len(*data)
for i := 0; i < size; i += usermem.PageSize {
- loadByte(&(data)[i])
+ loadByte(&(*data)[i])
}
- loadByte(&(data)[size-1])
+ loadByte(&(*data)[size-1])
}
// SwitchToUser unpacks architectural-details.
diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go
index 3d715e570..2edc9d1b2 100644
--- a/pkg/sentry/platform/kvm/machine_arm64.go
+++ b/pkg/sentry/platform/kvm/machine_arm64.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -32,7 +33,7 @@ type vCPUArchState struct {
// floatingPointState is the floating point state buffer used in guest
// to host transitions. See usage in bluepill_arm64.go.
- floatingPointState arch.FloatingPointData
+ floatingPointState fpu.State
}
const (
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index 059aa43d0..e7d5f3193 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -150,7 +151,7 @@ func (c *vCPU) initArchState() error {
c.PCIDs = pagetables.NewPCIDs(fixedKernelPCID+1, poolPCIDs)
}
- c.floatingPointState = arch.NewFloatingPointData()
+ c.floatingPointState = fpu.NewState()
return c.setSystemTime()
}
diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD
index fc43cc3c0..47efde6a2 100644
--- a/pkg/sentry/platform/ptrace/BUILD
+++ b/pkg/sentry/platform/ptrace/BUILD
@@ -30,6 +30,7 @@ go_library(
"//pkg/safecopy",
"//pkg/seccomp",
"//pkg/sentry/arch",
+ "//pkg/sentry/arch/fpu",
"//pkg/sentry/memmap",
"//pkg/sentry/platform",
"//pkg/sentry/platform/interrupt",
diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
index 6259350ec..01e73b019 100644
--- a/pkg/sentry/platform/ptrace/ptrace_unsafe.go
+++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
@@ -20,6 +20,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -62,9 +63,9 @@ func (t *thread) setRegs(regs *arch.Registers) error {
}
// getFPRegs gets the floating-point data via the GETREGSET ptrace unix.
-func (t *thread) getFPRegs(fpState arch.FloatingPointData, fpLen uint64, useXsave bool) error {
+func (t *thread) getFPRegs(fpState *fpu.State, fpLen uint64, useXsave bool) error {
iovec := unix.Iovec{
- Base: (*byte)(&fpState[0]),
+ Base: fpState.BytePointer(),
Len: fpLen,
}
_, _, errno := unix.RawSyscall6(
@@ -81,9 +82,9 @@ func (t *thread) getFPRegs(fpState arch.FloatingPointData, fpLen uint64, useXsav
}
// setFPRegs sets the floating-point data via the SETREGSET ptrace unix.
-func (t *thread) setFPRegs(fpState arch.FloatingPointData, fpLen uint64, useXsave bool) error {
+func (t *thread) setFPRegs(fpState *fpu.State, fpLen uint64, useXsave bool) error {
iovec := unix.Iovec{
- Base: (*byte)(&fpState[0]),
+ Base: fpState.BytePointer(),
Len: fpLen,
}
_, _, errno := unix.RawSyscall6(
diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go
index 5bd526b73..efec93f73 100644
--- a/pkg/sentry/syscalls/linux/error.go
+++ b/pkg/sentry/syscalls/linux/error.go
@@ -75,17 +75,25 @@ func handleIOError(ctx context.Context, partialResult bool, ioerr, intr error, o
// errors, we may consume the error and return only the partial read/write.
//
// Returns false if error is unknown.
-func handleIOErrorImpl(ctx context.Context, partialResult bool, err, intr error, op string) (bool, error) {
- switch err {
- case nil:
+func handleIOErrorImpl(ctx context.Context, partialResult bool, errOrig, intr error, op string) (bool, error) {
+ if errOrig == nil {
// Typical successful syscall.
return true, nil
+ }
+
+ // Translate error, if possible, to consolidate errors from other packages
+ // into a smaller set of errors from syserror package.
+ translatedErr := errOrig
+ if errno, ok := syserror.TranslateError(errOrig); ok {
+ translatedErr = errno
+ }
+ switch translatedErr {
case io.EOF:
// EOF is always consumed. If this is a partial read/write
// (result != 0), the application will see that, otherwise
// they will see 0.
return true, nil
- case syserror.ErrExceedsFileSizeLimit:
+ case syserror.EFBIG:
t := kernel.TaskFromContext(ctx)
if t == nil {
panic("I/O error should only occur from a context associated with a Task")
@@ -98,7 +106,7 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, err, intr error,
// Simultaneously send a SIGXFSZ per setrlimit(2).
t.SendSignal(kernel.SignalInfoNoInfo(linux.SIGXFSZ, t, t))
return true, syserror.EFBIG
- case syserror.ErrInterrupted:
+ case syserror.EINTR:
// The syscall was interrupted. Return nil if it completed
// partially, otherwise return the error code that the syscall
// needs (to indicate to the kernel what it should do).
@@ -110,10 +118,10 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, err, intr error,
if !partialResult {
// Typical syscall error.
- return true, err
+ return true, errOrig
}
- switch err {
+ switch translatedErr {
case syserror.EINTR:
// Syscall interrupted, but completed a partial
// read/write. Like ErrWouldBlock, since we have a
@@ -143,7 +151,7 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, err, intr error,
// For TCP sendfile connections, we may have a reset or timeout. But we
// should just return n as the result.
return true, nil
- case syserror.ErrWouldBlock:
+ case syserror.EWOULDBLOCK:
// Syscall would block, but completed a partial read/write.
// This case should only be returned by IssueIO for nonblocking
// files. Since we have a partial read/write, we consume
@@ -151,7 +159,7 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, err, intr error,
return true, nil
}
- switch err.(type) {
+ switch errOrig.(type) {
case syserror.SyscallRestartErrno:
// Identical to the EINTR case.
return true, nil
diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go
index 97de17afe..56b621357 100644
--- a/pkg/syserror/syserror.go
+++ b/pkg/syserror/syserror.go
@@ -130,17 +130,15 @@ func AddErrorUnwrapper(unwrap func(e error) (unix.Errno, bool)) {
// TranslateError translates errors to errnos, it will return false if
// the error was not registered.
func TranslateError(from error) (unix.Errno, bool) {
- err, ok := errorMap[from]
- if ok {
- return err, ok
+ if err, ok := errorMap[from]; ok {
+ return err, true
}
// Try to unwrap the error if we couldn't match an error
// exactly. This might mean that a package has its own
// error type.
for _, unwrap := range errorUnwrappers {
- err, ok := unwrap(from)
- if ok {
- return err, ok
+ if err, ok := unwrap(from); ok {
+ return err, true
}
}
return 0, false
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index fc622b246..fef065b05 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -1287,6 +1287,13 @@ func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption
} else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
}
+ case header.NDPNonceOption:
+ gotOpt, ok := opt.(header.NDPNonceOption)
+ if !ok {
+ t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
+ } else if diff := cmp.Diff(wantOpt.Nonce(), gotOpt.Nonce()); diff != "" {
+ t.Errorf("nonce mismatch (-want +got):\n%s", diff)
+ }
default:
t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt)
}
diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go
index 554242f0c..3d1bccd15 100644
--- a/pkg/tcpip/header/ndp_options.go
+++ b/pkg/tcpip/header/ndp_options.go
@@ -26,29 +26,33 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
-// NDPOptionIdentifier is an NDP option type identifier.
-type NDPOptionIdentifier uint8
+// ndpOptionIdentifier is an NDP option type identifier.
+type ndpOptionIdentifier uint8
const (
- // NDPSourceLinkLayerAddressOptionType is the type of the Source Link Layer
+ // ndpSourceLinkLayerAddressOptionType is the type of the Source Link Layer
// Address option, as per RFC 4861 section 4.6.1.
- NDPSourceLinkLayerAddressOptionType NDPOptionIdentifier = 1
+ ndpSourceLinkLayerAddressOptionType ndpOptionIdentifier = 1
- // NDPTargetLinkLayerAddressOptionType is the type of the Target Link Layer
+ // ndpTargetLinkLayerAddressOptionType is the type of the Target Link Layer
// Address option, as per RFC 4861 section 4.6.1.
- NDPTargetLinkLayerAddressOptionType NDPOptionIdentifier = 2
+ ndpTargetLinkLayerAddressOptionType ndpOptionIdentifier = 2
- // NDPPrefixInformationType is the type of the Prefix Information
+ // ndpPrefixInformationType is the type of the Prefix Information
// option, as per RFC 4861 section 4.6.2.
- NDPPrefixInformationType NDPOptionIdentifier = 3
+ ndpPrefixInformationType ndpOptionIdentifier = 3
- // NDPRecursiveDNSServerOptionType is the type of the Recursive DNS
+ // ndpNonceOptionType is the type of the Nonce option, as per
+ // RFC 3971 section 5.3.2.
+ ndpNonceOptionType ndpOptionIdentifier = 14
+
+ // ndpRecursiveDNSServerOptionType is the type of the Recursive DNS
// Server option, as per RFC 8106 section 5.1.
- NDPRecursiveDNSServerOptionType NDPOptionIdentifier = 25
+ ndpRecursiveDNSServerOptionType ndpOptionIdentifier = 25
- // NDPDNSSearchListOptionType is the type of the DNS Search List option,
+ // ndpDNSSearchListOptionType is the type of the DNS Search List option,
// as per RFC 8106 section 5.2.
- NDPDNSSearchListOptionType = 31
+ ndpDNSSearchListOptionType ndpOptionIdentifier = 31
)
const (
@@ -198,7 +202,7 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) {
// bytes for the whole option.
return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Type field: %w", io.ErrUnexpectedEOF)
}
- kind := NDPOptionIdentifier(temp)
+ kind := ndpOptionIdentifier(temp)
// Get the Length field.
length, err := i.opts.ReadByte()
@@ -225,13 +229,16 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) {
}
switch kind {
- case NDPSourceLinkLayerAddressOptionType:
+ case ndpSourceLinkLayerAddressOptionType:
return NDPSourceLinkLayerAddressOption(body), false, nil
- case NDPTargetLinkLayerAddressOptionType:
+ case ndpTargetLinkLayerAddressOptionType:
return NDPTargetLinkLayerAddressOption(body), false, nil
- case NDPPrefixInformationType:
+ case ndpNonceOptionType:
+ return NDPNonceOption(body), false, nil
+
+ case ndpPrefixInformationType:
// Make sure the length of a Prefix Information option
// body is ndpPrefixInformationLength, as per RFC 4861
// section 4.6.2.
@@ -241,7 +248,7 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) {
return NDPPrefixInformation(body), false, nil
- case NDPRecursiveDNSServerOptionType:
+ case ndpRecursiveDNSServerOptionType:
opt := NDPRecursiveDNSServer(body)
if err := opt.checkAddresses(); err != nil {
return nil, true, err
@@ -249,7 +256,7 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) {
return opt, false, nil
- case NDPDNSSearchListOptionType:
+ case ndpDNSSearchListOptionType:
opt := NDPDNSSearchList(body)
if err := opt.checkDomainNames(); err != nil {
return nil, true, err
@@ -316,7 +323,7 @@ func (b NDPOptions) Serialize(s NDPOptionsSerializer) int {
continue
}
- b[0] = byte(o.Type())
+ b[0] = byte(o.kind())
// We know this safe because paddedLength would have returned
// 0 if o had an invalid length (> 255 * lengthByteUnits).
@@ -341,11 +348,11 @@ func (b NDPOptions) Serialize(s NDPOptionsSerializer) int {
type NDPOption interface {
fmt.Stringer
- // Type returns the type of the receiver.
- Type() NDPOptionIdentifier
+ // kind returns the type of the receiver.
+ kind() ndpOptionIdentifier
- // Length returns the length of the body of the receiver, in bytes.
- Length() int
+ // length returns the length of the body of the receiver, in bytes.
+ length() int
// serializeInto serializes the receiver into the provided byte
// buffer.
@@ -365,7 +372,7 @@ type NDPOption interface {
// paddedLength returns the length of o, in bytes, with any padding bytes, if
// required.
func paddedLength(o NDPOption) int {
- l := o.Length()
+ l := o.length()
if l == 0 {
return 0
@@ -416,6 +423,37 @@ func (b NDPOptionsSerializer) Length() int {
return l
}
+// NDPNonceOption is the NDP Nonce Option as defined by RFC 3971 section 5.3.2.
+//
+// It is the first X bytes following the NDP option's Type and Length field
+// where X is the value in Length multiplied by lengthByteUnits - 2 bytes.
+type NDPNonceOption []byte
+
+// kind implements NDPOption.
+func (o NDPNonceOption) kind() ndpOptionIdentifier {
+ return ndpNonceOptionType
+}
+
+// length implements NDPOption.
+func (o NDPNonceOption) length() int {
+ return len(o)
+}
+
+// serializeInto implements NDPOption.
+func (o NDPNonceOption) serializeInto(b []byte) int {
+ return copy(b, o)
+}
+
+// String implements fmt.Stringer.
+func (o NDPNonceOption) String() string {
+ return fmt.Sprintf("%T(%x)", o, []byte(o))
+}
+
+// Nonce returns the nonce value this option holds.
+func (o NDPNonceOption) Nonce() []byte {
+ return []byte(o)
+}
+
// NDPSourceLinkLayerAddressOption is the NDP Source Link Layer Option
// as defined by RFC 4861 section 4.6.1.
//
@@ -423,22 +461,22 @@ func (b NDPOptionsSerializer) Length() int {
// where X is the value in Length multiplied by lengthByteUnits - 2 bytes.
type NDPSourceLinkLayerAddressOption tcpip.LinkAddress
-// Type implements NDPOption.Type.
-func (o NDPSourceLinkLayerAddressOption) Type() NDPOptionIdentifier {
- return NDPSourceLinkLayerAddressOptionType
+// kind implements NDPOption.
+func (o NDPSourceLinkLayerAddressOption) kind() ndpOptionIdentifier {
+ return ndpSourceLinkLayerAddressOptionType
}
-// Length implements NDPOption.Length.
-func (o NDPSourceLinkLayerAddressOption) Length() int {
+// length implements NDPOption.
+func (o NDPSourceLinkLayerAddressOption) length() int {
return len(o)
}
-// serializeInto implements NDPOption.serializeInto.
+// serializeInto implements NDPOption.
func (o NDPSourceLinkLayerAddressOption) serializeInto(b []byte) int {
return copy(b, o)
}
-// String implements fmt.Stringer.String.
+// String implements fmt.Stringer.
func (o NDPSourceLinkLayerAddressOption) String() string {
return fmt.Sprintf("%T(%s)", o, tcpip.LinkAddress(o))
}
@@ -463,22 +501,22 @@ func (o NDPSourceLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress {
// where X is the value in Length multiplied by lengthByteUnits - 2 bytes.
type NDPTargetLinkLayerAddressOption tcpip.LinkAddress
-// Type implements NDPOption.Type.
-func (o NDPTargetLinkLayerAddressOption) Type() NDPOptionIdentifier {
- return NDPTargetLinkLayerAddressOptionType
+// kind implements NDPOption.
+func (o NDPTargetLinkLayerAddressOption) kind() ndpOptionIdentifier {
+ return ndpTargetLinkLayerAddressOptionType
}
-// Length implements NDPOption.Length.
-func (o NDPTargetLinkLayerAddressOption) Length() int {
+// length implements NDPOption.
+func (o NDPTargetLinkLayerAddressOption) length() int {
return len(o)
}
-// serializeInto implements NDPOption.serializeInto.
+// serializeInto implements NDPOption.
func (o NDPTargetLinkLayerAddressOption) serializeInto(b []byte) int {
return copy(b, o)
}
-// String implements fmt.Stringer.String.
+// String implements fmt.Stringer.
func (o NDPTargetLinkLayerAddressOption) String() string {
return fmt.Sprintf("%T(%s)", o, tcpip.LinkAddress(o))
}
@@ -503,17 +541,17 @@ func (o NDPTargetLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress {
// ndpPrefixInformationLength bytes.
type NDPPrefixInformation []byte
-// Type implements NDPOption.Type.
-func (o NDPPrefixInformation) Type() NDPOptionIdentifier {
- return NDPPrefixInformationType
+// kind implements NDPOption.
+func (o NDPPrefixInformation) kind() ndpOptionIdentifier {
+ return ndpPrefixInformationType
}
-// Length implements NDPOption.Length.
-func (o NDPPrefixInformation) Length() int {
+// length implements NDPOption.
+func (o NDPPrefixInformation) length() int {
return ndpPrefixInformationLength
}
-// serializeInto implements NDPOption.serializeInto.
+// serializeInto implements NDPOption.
func (o NDPPrefixInformation) serializeInto(b []byte) int {
used := copy(b, o)
@@ -529,7 +567,7 @@ func (o NDPPrefixInformation) serializeInto(b []byte) int {
return used
}
-// String implements fmt.Stringer.String.
+// String implements fmt.Stringer.
func (o NDPPrefixInformation) String() string {
return fmt.Sprintf("%T(O=%t, A=%t, PL=%s, VL=%s, Prefix=%s)",
o,
@@ -627,17 +665,17 @@ type NDPRecursiveDNSServer []byte
// Type returns the type of an NDP Recursive DNS Server option.
//
-// Type implements NDPOption.Type.
-func (NDPRecursiveDNSServer) Type() NDPOptionIdentifier {
- return NDPRecursiveDNSServerOptionType
+// kind implements NDPOption.
+func (NDPRecursiveDNSServer) kind() ndpOptionIdentifier {
+ return ndpRecursiveDNSServerOptionType
}
-// Length implements NDPOption.Length.
-func (o NDPRecursiveDNSServer) Length() int {
+// length implements NDPOption.
+func (o NDPRecursiveDNSServer) length() int {
return len(o)
}
-// serializeInto implements NDPOption.serializeInto.
+// serializeInto implements NDPOption.
func (o NDPRecursiveDNSServer) serializeInto(b []byte) int {
used := copy(b, o)
@@ -649,7 +687,7 @@ func (o NDPRecursiveDNSServer) serializeInto(b []byte) int {
return used
}
-// String implements fmt.Stringer.String.
+// String implements fmt.Stringer.
func (o NDPRecursiveDNSServer) String() string {
lt := o.Lifetime()
addrs, err := o.Addresses()
@@ -722,17 +760,17 @@ func (o NDPRecursiveDNSServer) iterAddresses(fn func(tcpip.Address)) error {
// RFC 8106 section 5.2.
type NDPDNSSearchList []byte
-// Type implements NDPOption.Type.
-func (o NDPDNSSearchList) Type() NDPOptionIdentifier {
- return NDPDNSSearchListOptionType
+// kind implements NDPOption.
+func (o NDPDNSSearchList) kind() ndpOptionIdentifier {
+ return ndpDNSSearchListOptionType
}
-// Length implements NDPOption.Length.
-func (o NDPDNSSearchList) Length() int {
+// length implements NDPOption.
+func (o NDPDNSSearchList) length() int {
return len(o)
}
-// serializeInto implements NDPOption.serializeInto.
+// serializeInto implements NDPOption.
func (o NDPDNSSearchList) serializeInto(b []byte) int {
used := copy(b, o)
@@ -744,7 +782,7 @@ func (o NDPDNSSearchList) serializeInto(b []byte) int {
return used
}
-// String implements fmt.Stringer.String.
+// String implements fmt.Stringer.
func (o NDPDNSSearchList) String() string {
lt := o.Lifetime()
domainNames, err := o.DomainNames()
diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go
index dc4591253..d0a1a2492 100644
--- a/pkg/tcpip/header/ndp_test.go
+++ b/pkg/tcpip/header/ndp_test.go
@@ -16,6 +16,7 @@ package header
import (
"bytes"
+ "encoding/binary"
"errors"
"fmt"
"io"
@@ -192,90 +193,6 @@ func TestNDPSourceLinkLayerAddressOptionEthernetAddress(t *testing.T) {
}
}
-// TestNDPSourceLinkLayerAddressOptionSerialize tests serializing a
-// NDPSourceLinkLayerAddressOption.
-func TestNDPSourceLinkLayerAddressOptionSerialize(t *testing.T) {
- tests := []struct {
- name string
- buf []byte
- expectedBuf []byte
- addr tcpip.LinkAddress
- }{
- {
- "Ethernet",
- make([]byte, 8),
- []byte{1, 1, 1, 2, 3, 4, 5, 6},
- "\x01\x02\x03\x04\x05\x06",
- },
- {
- "Padding",
- []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
- []byte{1, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0},
- "\x01\x02\x03\x04\x05\x06\x07\x08",
- },
- {
- "Empty",
- nil,
- nil,
- "",
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- opts := NDPOptions(test.buf)
- serializer := NDPOptionsSerializer{
- NDPSourceLinkLayerAddressOption(test.addr),
- }
- if got, want := int(serializer.Length()), len(test.expectedBuf); got != want {
- t.Fatalf("got Length = %d, want = %d", got, want)
- }
- opts.Serialize(serializer)
- if !bytes.Equal(test.buf, test.expectedBuf) {
- t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf)
- }
-
- it, err := opts.Iter(true)
- if err != nil {
- t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
- }
-
- if len(test.expectedBuf) > 0 {
- next, done, err := it.Next()
- if err != nil {
- t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if done {
- t.Fatal("got Next = (_, true, _), want = (_, false, _)")
- }
- if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType {
- t.Fatalf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType)
- }
- sll := next.(NDPSourceLinkLayerAddressOption)
- if got, want := []byte(sll), test.expectedBuf[2:]; !bytes.Equal(got, want) {
- t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
- }
-
- if got, want := sll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want {
- t.Errorf("got sll.EthernetAddress = %s, want = %s", got, want)
- }
- }
-
- // Iterator should not return anything else.
- next, done, err := it.Next()
- if err != nil {
- t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if !done {
- t.Error("got Next = (_, false, _), want = (_, true, _)")
- }
- if next != nil {
- t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
- }
- })
- }
-}
-
// TestNDPTargetLinkLayerAddressOptionEthernetAddress tests getting the
// Ethernet address from an NDPTargetLinkLayerAddressOption.
func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) {
@@ -311,32 +228,309 @@ func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) {
}
}
-// TestNDPTargetLinkLayerAddressOptionSerialize tests serializing a
-// NDPTargetLinkLayerAddressOption.
-func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) {
+func TestOpts(t *testing.T) {
+ const optionHeaderLen = 2
+
+ checkNonce := func(expectedNonce []byte) func(*testing.T, NDPOption) {
+ return func(t *testing.T, opt NDPOption) {
+ if got := opt.kind(); got != ndpNonceOptionType {
+ t.Errorf("got kind() = %d, want = %d", got, ndpNonceOptionType)
+ }
+ nonce, ok := opt.(NDPNonceOption)
+ if !ok {
+ t.Fatalf("got nonce = %T, want = NDPNonceOption", opt)
+ }
+ if diff := cmp.Diff(expectedNonce, nonce.Nonce()); diff != "" {
+ t.Errorf("nonce mismatch (-want +got):\n%s", diff)
+ }
+ }
+ }
+
+ checkTLL := func(expectedAddr tcpip.LinkAddress) func(*testing.T, NDPOption) {
+ return func(t *testing.T, opt NDPOption) {
+ if got := opt.kind(); got != ndpTargetLinkLayerAddressOptionType {
+ t.Errorf("got kind() = %d, want = %d", got, ndpTargetLinkLayerAddressOptionType)
+ }
+ tll, ok := opt.(NDPTargetLinkLayerAddressOption)
+ if !ok {
+ t.Fatalf("got tll = %T, want = NDPTargetLinkLayerAddressOption", opt)
+ }
+ if got, want := tll.EthernetAddress(), expectedAddr; got != want {
+ t.Errorf("got tll.EthernetAddress = %s, want = %s", got, want)
+ }
+ }
+ }
+
+ checkSLL := func(expectedAddr tcpip.LinkAddress) func(*testing.T, NDPOption) {
+ return func(t *testing.T, opt NDPOption) {
+ if got := opt.kind(); got != ndpSourceLinkLayerAddressOptionType {
+ t.Errorf("got kind() = %d, want = %d", got, ndpSourceLinkLayerAddressOptionType)
+ }
+ sll, ok := opt.(NDPSourceLinkLayerAddressOption)
+ if !ok {
+ t.Fatalf("got sll = %T, want = NDPSourceLinkLayerAddressOption", opt)
+ }
+ if got, want := sll.EthernetAddress(), expectedAddr; got != want {
+ t.Errorf("got sll.EthernetAddress = %s, want = %s", got, want)
+ }
+ }
+ }
+
+ const validLifetimeSeconds = 16909060
+ const address = tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18")
+
+ expectedRDNSSBytes := [...]byte{
+ // Type, Length
+ 25, 3,
+
+ // Reserved
+ 0, 0,
+
+ // Lifetime
+ 1, 2, 4, 8,
+
+ // Address
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ }
+ binary.BigEndian.PutUint32(expectedRDNSSBytes[4:], validLifetimeSeconds)
+ if n := copy(expectedRDNSSBytes[8:], address); n != IPv6AddressSize {
+ t.Fatalf("got copy(...) = %d, want = %d", n, IPv6AddressSize)
+ }
+ // Update reserved fields to non zero values to make sure serializing sets
+ // them to zero.
+ rdnssBytes := expectedRDNSSBytes
+ rdnssBytes[1] = 1
+ rdnssBytes[2] = 2
+
+ const searchListPaddingBytes = 3
+ const domainName = "abc.abcd.e"
+ expectedSearchListBytes := [...]byte{
+ // Type, Length
+ 31, 3,
+
+ // Reserved
+ 0, 0,
+
+ // Lifetime
+ 1, 0, 0, 0,
+
+ // Domain names
+ 3, 'a', 'b', 'c',
+ 4, 'a', 'b', 'c', 'd',
+ 1, 'e',
+ 0,
+ 0, 0, 0, 0,
+ }
+ binary.BigEndian.PutUint32(expectedSearchListBytes[4:], validLifetimeSeconds)
+ // Update reserved fields to non zero values to make sure serializing sets
+ // them to zero.
+ searchListBytes := expectedSearchListBytes
+ searchListBytes[2] = 1
+ searchListBytes[3] = 2
+
+ const prefixLength = 43
+ const onLinkFlag = false
+ const slaacFlag = true
+ const preferredLifetimeSeconds = 84281096
+ const onLinkFlagBit = 7
+ const slaacFlagBit = 6
+ boolToByte := func(v bool) byte {
+ if v {
+ return 1
+ }
+ return 0
+ }
+ flags := boolToByte(onLinkFlag)<<onLinkFlagBit | boolToByte(slaacFlag)<<slaacFlagBit
+ expectedPrefixInformationBytes := [...]byte{
+ // Type, Length
+ 3, 4,
+
+ prefixLength, flags,
+
+ // Valid Lifetime
+ 1, 2, 3, 4,
+
+ // Preferred Lifetime
+ 5, 6, 7, 8,
+
+ // Reserved2
+ 0, 0, 0, 0,
+
+ // Address
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 21, 22, 23, 24,
+ }
+ binary.BigEndian.PutUint32(expectedPrefixInformationBytes[4:], validLifetimeSeconds)
+ binary.BigEndian.PutUint32(expectedPrefixInformationBytes[8:], preferredLifetimeSeconds)
+ if n := copy(expectedPrefixInformationBytes[16:], address); n != IPv6AddressSize {
+ t.Fatalf("got copy(...) = %d, want = %d", n, IPv6AddressSize)
+ }
+ // Update reserved fields to non zero values to make sure serializing sets
+ // them to zero.
+ prefixInformationBytes := expectedPrefixInformationBytes
+ prefixInformationBytes[3] |= (1 << slaacFlagBit) - 1
+ binary.BigEndian.PutUint32(prefixInformationBytes[12:], validLifetimeSeconds+1)
tests := []struct {
name string
buf []byte
+ opt NDPOption
expectedBuf []byte
- addr tcpip.LinkAddress
+ check func(*testing.T, NDPOption)
}{
{
- "Ethernet",
- make([]byte, 8),
- []byte{2, 1, 1, 2, 3, 4, 5, 6},
- "\x01\x02\x03\x04\x05\x06",
+ name: "Nonce",
+ buf: make([]byte, 8),
+ opt: NDPNonceOption([]byte{1, 2, 3, 4, 5, 6}),
+ expectedBuf: []byte{14, 1, 1, 2, 3, 4, 5, 6},
+ check: checkNonce([]byte{1, 2, 3, 4, 5, 6}),
+ },
+ {
+ name: "Nonce with padding",
+ buf: []byte{1, 1, 1, 1, 1, 1, 1, 1},
+ opt: NDPNonceOption([]byte{1, 2, 3, 4, 5}),
+ expectedBuf: []byte{14, 1, 1, 2, 3, 4, 5, 0},
+ check: checkNonce([]byte{1, 2, 3, 4, 5, 0}),
+ },
+
+ {
+ name: "TLL Ethernet",
+ buf: make([]byte, 8),
+ opt: NDPTargetLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06"),
+ expectedBuf: []byte{2, 1, 1, 2, 3, 4, 5, 6},
+ check: checkTLL("\x01\x02\x03\x04\x05\x06"),
+ },
+ {
+ name: "TLL Padding",
+ buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
+ opt: NDPTargetLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06\x07\x08"),
+ expectedBuf: []byte{2, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0},
+ check: checkTLL("\x01\x02\x03\x04\x05\x06"),
+ },
+ {
+ name: "TLL Empty",
+ buf: nil,
+ opt: NDPTargetLinkLayerAddressOption(""),
+ expectedBuf: nil,
},
+
{
- "Padding",
- []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
- []byte{2, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0},
- "\x01\x02\x03\x04\x05\x06\x07\x08",
+ name: "SLL Ethernet",
+ buf: make([]byte, 8),
+ opt: NDPSourceLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06"),
+ expectedBuf: []byte{1, 1, 1, 2, 3, 4, 5, 6},
+ check: checkSLL("\x01\x02\x03\x04\x05\x06"),
},
{
- "Empty",
- nil,
- nil,
- "",
+ name: "SLL Padding",
+ buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
+ opt: NDPSourceLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06\x07\x08"),
+ expectedBuf: []byte{1, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0},
+ check: checkSLL("\x01\x02\x03\x04\x05\x06"),
+ },
+ {
+ name: "SLL Empty",
+ buf: nil,
+ opt: NDPSourceLinkLayerAddressOption(""),
+ expectedBuf: nil,
+ },
+
+ {
+ name: "RDNSS",
+ buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
+ // NDPRecursiveDNSServer holds the option after the header bytes.
+ opt: NDPRecursiveDNSServer(rdnssBytes[optionHeaderLen:]),
+ expectedBuf: expectedRDNSSBytes[:],
+ check: func(t *testing.T, opt NDPOption) {
+ if got := opt.kind(); got != ndpRecursiveDNSServerOptionType {
+ t.Errorf("got kind() = %d, want = %d", got, ndpRecursiveDNSServerOptionType)
+ }
+ rdnss, ok := opt.(NDPRecursiveDNSServer)
+ if !ok {
+ t.Fatalf("got opt = %T, want = NDPRecursiveDNSServer", opt)
+ }
+ if got, want := rdnss.length(), len(expectedRDNSSBytes[optionHeaderLen:]); got != want {
+ t.Errorf("got length() = %d, want = %d", got, want)
+ }
+ if got, want := rdnss.Lifetime(), validLifetimeSeconds*time.Second; got != want {
+ t.Errorf("got Lifetime() = %s, want = %s", got, want)
+ }
+ if addrs, err := rdnss.Addresses(); err != nil {
+ t.Errorf("Addresses(): %s", err)
+ } else if diff := cmp.Diff([]tcpip.Address{address}, addrs); diff != "" {
+ t.Errorf("mismatched addresses (-want +got):\n%s", diff)
+ }
+ },
+ },
+
+ {
+ name: "Search list",
+ buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
+ opt: NDPDNSSearchList(searchListBytes[optionHeaderLen:]),
+ expectedBuf: expectedSearchListBytes[:],
+ check: func(t *testing.T, opt NDPOption) {
+ if got := opt.kind(); got != ndpDNSSearchListOptionType {
+ t.Errorf("got kind() = %d, want = %d", got, ndpDNSSearchListOptionType)
+ }
+
+ dnssl, ok := opt.(NDPDNSSearchList)
+ if !ok {
+ t.Fatalf("got opt = %T, want = NDPDNSSearchList", opt)
+ }
+ if got, want := dnssl.length(), len(expectedRDNSSBytes[optionHeaderLen:]); got != want {
+ t.Errorf("got length() = %d, want = %d", got, want)
+ }
+ if got, want := dnssl.Lifetime(), validLifetimeSeconds*time.Second; got != want {
+ t.Errorf("got Lifetime() = %s, want = %s", got, want)
+ }
+
+ if domainNames, err := dnssl.DomainNames(); err != nil {
+ t.Errorf("DomainNames(): %s", err)
+ } else if diff := cmp.Diff([]string{domainName}, domainNames); diff != "" {
+ t.Errorf("domain names mismatch (-want +got):\n%s", diff)
+ }
+ },
+ },
+
+ {
+ name: "Prefix Information",
+ buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
+ // NDPPrefixInformation holds the option after the header bytes.
+ opt: NDPPrefixInformation(prefixInformationBytes[optionHeaderLen:]),
+ expectedBuf: expectedPrefixInformationBytes[:],
+ check: func(t *testing.T, opt NDPOption) {
+ if got := opt.kind(); got != ndpPrefixInformationType {
+ t.Errorf("got kind() = %d, want = %d", got, ndpPrefixInformationType)
+ }
+
+ pi, ok := opt.(NDPPrefixInformation)
+ if !ok {
+ t.Fatalf("got opt = %T, want = NDPPrefixInformation", opt)
+ }
+
+ if got, want := pi.length(), len(expectedPrefixInformationBytes[optionHeaderLen:]); got != want {
+ t.Errorf("got length() = %d, want = %d", got, want)
+ }
+ if got := pi.PrefixLength(); got != prefixLength {
+ t.Errorf("got PrefixLength() = %d, want = %d", got, prefixLength)
+ }
+ if got := pi.OnLinkFlag(); got != onLinkFlag {
+ t.Errorf("got OnLinkFlag() = %t, want = %t", got, onLinkFlag)
+ }
+ if got := pi.AutonomousAddressConfigurationFlag(); got != slaacFlag {
+ t.Errorf("got AutonomousAddressConfigurationFlag() = %t, want = %t", got, slaacFlag)
+ }
+ if got, want := pi.ValidLifetime(), validLifetimeSeconds*time.Second; got != want {
+ t.Errorf("got ValidLifetime() = %s, want = %s", got, want)
+ }
+ if got, want := pi.PreferredLifetime(), preferredLifetimeSeconds*time.Second; got != want {
+ t.Errorf("got PreferredLifetime() = %s, want = %s", got, want)
+ }
+ if got := pi.Prefix(); got != address {
+ t.Errorf("got Prefix() = %s, want = %s", got, address)
+ }
+ },
},
}
@@ -344,230 +538,47 @@ func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
opts := NDPOptions(test.buf)
serializer := NDPOptionsSerializer{
- NDPTargetLinkLayerAddressOption(test.addr),
+ test.opt,
}
if got, want := int(serializer.Length()), len(test.expectedBuf); got != want {
- t.Fatalf("got Length = %d, want = %d", got, want)
+ t.Fatalf("got Length() = %d, want = %d", got, want)
}
opts.Serialize(serializer)
- if !bytes.Equal(test.buf, test.expectedBuf) {
- t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf)
+ if diff := cmp.Diff(test.expectedBuf, test.buf); diff != "" {
+ t.Fatalf("serialized buffer mismatch (-want +got):\n%s", diff)
}
it, err := opts.Iter(true)
if err != nil {
- t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ t.Fatalf("got Iter(true) = (_, %s), want = (_, nil)", err)
}
if len(test.expectedBuf) > 0 {
next, done, err := it.Next()
if err != nil {
- t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ t.Fatalf("got Next() = (_, _, %s), want = (_, _, nil)", err)
}
if done {
- t.Fatal("got Next = (_, true, _), want = (_, false, _)")
- }
- if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType {
- t.Fatalf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType)
- }
- tll := next.(NDPTargetLinkLayerAddressOption)
- if got, want := []byte(tll), test.expectedBuf[2:]; !bytes.Equal(got, want) {
- t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
- }
-
- if got, want := tll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want {
- t.Errorf("got tll.EthernetAddress = %s, want = %s", got, want)
+ t.Fatal("got Next() = (_, true, _), want = (_, false, _)")
}
+ test.check(t, next)
}
// Iterator should not return anything else.
next, done, err := it.Next()
if err != nil {
- t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ t.Errorf("got Next() = (_, _, %s), want = (_, _, nil)", err)
}
if !done {
- t.Error("got Next = (_, false, _), want = (_, true, _)")
+ t.Error("got Next() = (_, false, _), want = (_, true, _)")
}
if next != nil {
- t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
+ t.Errorf("got Next() = (%x, _, _), want = (nil, _, _)", next)
}
})
}
}
-// TestNDPPrefixInformationOption tests the field getters and serialization of a
-// NDPPrefixInformation.
-func TestNDPPrefixInformationOption(t *testing.T) {
- b := []byte{
- 43, 127,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 5, 5, 5, 5,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 17, 18, 19, 20,
- 21, 22, 23, 24,
- }
-
- targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
- opts := NDPOptions(targetBuf)
- serializer := NDPOptionsSerializer{
- NDPPrefixInformation(b),
- }
- opts.Serialize(serializer)
- expectedBuf := []byte{
- 3, 4, 43, 64,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 0, 0, 0, 0,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 17, 18, 19, 20,
- 21, 22, 23, 24,
- }
- if !bytes.Equal(targetBuf, expectedBuf) {
- t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expectedBuf)
- }
-
- it, err := opts.Iter(true)
- if err != nil {
- t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
- }
-
- next, done, err := it.Next()
- if err != nil {
- t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if done {
- t.Fatal("got Next = (_, true, _), want = (_, false, _)")
- }
- if got := next.Type(); got != NDPPrefixInformationType {
- t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType)
- }
-
- pi := next.(NDPPrefixInformation)
-
- if got := pi.Type(); got != 3 {
- t.Errorf("got Type = %d, want = 3", got)
- }
-
- if got := pi.Length(); got != 30 {
- t.Errorf("got Length = %d, want = 30", got)
- }
-
- if got := pi.PrefixLength(); got != 43 {
- t.Errorf("got PrefixLength = %d, want = 43", got)
- }
-
- if pi.OnLinkFlag() {
- t.Error("got OnLinkFlag = true, want = false")
- }
-
- if !pi.AutonomousAddressConfigurationFlag() {
- t.Error("got AutonomousAddressConfigurationFlag = false, want = true")
- }
-
- if got, want := pi.ValidLifetime(), 16909060*time.Second; got != want {
- t.Errorf("got ValidLifetime = %d, want = %d", got, want)
- }
-
- if got, want := pi.PreferredLifetime(), 84281096*time.Second; got != want {
- t.Errorf("got PreferredLifetime = %d, want = %d", got, want)
- }
-
- if got, want := pi.Prefix(), tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18"); got != want {
- t.Errorf("got Prefix = %s, want = %s", got, want)
- }
-
- // Iterator should not return anything else.
- next, done, err = it.Next()
- if err != nil {
- t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if !done {
- t.Error("got Next = (_, false, _), want = (_, true, _)")
- }
- if next != nil {
- t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
- }
-}
-
-func TestNDPRecursiveDNSServerOptionSerialize(t *testing.T) {
- b := []byte{
- 9, 8,
- 1, 2, 4, 8,
- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
- }
- targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
- expected := []byte{
- 25, 3, 0, 0,
- 1, 2, 4, 8,
- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
- }
- opts := NDPOptions(targetBuf)
- serializer := NDPOptionsSerializer{
- NDPRecursiveDNSServer(b),
- }
- if got, want := opts.Serialize(serializer), len(expected); got != want {
- t.Errorf("got Serialize = %d, want = %d", got, want)
- }
- if !bytes.Equal(targetBuf, expected) {
- t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected)
- }
-
- it, err := opts.Iter(true)
- if err != nil {
- t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
- }
-
- next, done, err := it.Next()
- if err != nil {
- t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if done {
- t.Fatal("got Next = (_, true, _), want = (_, false, _)")
- }
- if got := next.Type(); got != NDPRecursiveDNSServerOptionType {
- t.Errorf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType)
- }
-
- opt, ok := next.(NDPRecursiveDNSServer)
- if !ok {
- t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next)
- }
- if got := opt.Type(); got != 25 {
- t.Errorf("got Type = %d, want = 31", got)
- }
- if got := opt.Length(); got != 22 {
- t.Errorf("got Length = %d, want = 22", got)
- }
- if got, want := opt.Lifetime(), 16909320*time.Second; got != want {
- t.Errorf("got Lifetime = %s, want = %s", got, want)
- }
- want := []tcpip.Address{
- "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
- }
- addrs, err := opt.Addresses()
- if err != nil {
- t.Errorf("opt.Addresses() = %s", err)
- }
- if diff := cmp.Diff(addrs, want); diff != "" {
- t.Errorf("mismatched addresses (-want +got):\n%s", diff)
- }
-
- // Iterator should not return anything else.
- next, done, err = it.Next()
- if err != nil {
- t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if !done {
- t.Error("got Next = (_, false, _), want = (_, true, _)")
- }
- if next != nil {
- t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
- }
-}
-
func TestNDPRecursiveDNSServerOption(t *testing.T) {
tests := []struct {
name string
@@ -635,8 +646,8 @@ func TestNDPRecursiveDNSServerOption(t *testing.T) {
if done {
t.Fatal("got Next = (_, true, _), want = (_, false, _)")
}
- if got := next.Type(); got != NDPRecursiveDNSServerOptionType {
- t.Fatalf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType)
+ if got := next.kind(); got != ndpRecursiveDNSServerOptionType {
+ t.Fatalf("got Type = %d, want = %d", got, ndpRecursiveDNSServerOptionType)
}
opt, ok := next.(NDPRecursiveDNSServer)
@@ -1060,86 +1071,6 @@ func TestNDPSearchListOptionDomainNameLabelInvalidSymbols(t *testing.T) {
}
}
-func TestNDPDNSSearchListOptionSerialize(t *testing.T) {
- b := []byte{
- 9, 8,
- 1, 0, 0, 0,
- 3, 'a', 'b', 'c',
- 4, 'a', 'b', 'c', 'd',
- 1, 'e',
- 0,
- }
- targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
- expected := []byte{
- 31, 3, 0, 0,
- 1, 0, 0, 0,
- 3, 'a', 'b', 'c',
- 4, 'a', 'b', 'c', 'd',
- 1, 'e',
- 0,
- 0, 0, 0, 0,
- }
- opts := NDPOptions(targetBuf)
- serializer := NDPOptionsSerializer{
- NDPDNSSearchList(b),
- }
- if got, want := opts.Serialize(serializer), len(expected); got != want {
- t.Errorf("got Serialize = %d, want = %d", got, want)
- }
- if !bytes.Equal(targetBuf, expected) {
- t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected)
- }
-
- it, err := opts.Iter(true)
- if err != nil {
- t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
- }
-
- next, done, err := it.Next()
- if err != nil {
- t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if done {
- t.Fatal("got Next = (_, true, _), want = (_, false, _)")
- }
- if got := next.Type(); got != NDPDNSSearchListOptionType {
- t.Errorf("got Type = %d, want = %d", got, NDPDNSSearchListOptionType)
- }
-
- opt, ok := next.(NDPDNSSearchList)
- if !ok {
- t.Fatalf("next (type = %T) cannot be casted to an NDPDNSSearchList", next)
- }
- if got := opt.Type(); got != 31 {
- t.Errorf("got Type = %d, want = 31", got)
- }
- if got := opt.Length(); got != 22 {
- t.Errorf("got Length = %d, want = 22", got)
- }
- if got, want := opt.Lifetime(), 16777216*time.Second; got != want {
- t.Errorf("got Lifetime = %s, want = %s", got, want)
- }
- domainNames, err := opt.DomainNames()
- if err != nil {
- t.Errorf("opt.DomainNames() = %s", err)
- }
- if diff := cmp.Diff(domainNames, []string{"abc.abcd.e"}); diff != "" {
- t.Errorf("domain names mismatch (-want +got):\n%s", diff)
- }
-
- // Iterator should not return anything else.
- next, done, err = it.Next()
- if err != nil {
- t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if !done {
- t.Error("got Next = (_, false, _), want = (_, true, _)")
- }
- if next != nil {
- t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
- }
-}
-
// TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions
// the iterator was returned for is malformed.
func TestNDPOptionsIterCheck(t *testing.T) {
@@ -1472,8 +1403,8 @@ func TestNDPOptionsIter(t *testing.T) {
if got, want := []byte(next.(NDPSourceLinkLayerAddressOption)), buf[2:][:6]; !bytes.Equal(got, want) {
t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
}
- if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType {
- t.Errorf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType)
+ if got := next.kind(); got != ndpSourceLinkLayerAddressOptionType {
+ t.Errorf("got Type = %d, want = %d", got, ndpSourceLinkLayerAddressOptionType)
}
// Test the next (Target Link-Layer) option.
@@ -1487,8 +1418,8 @@ func TestNDPOptionsIter(t *testing.T) {
if got, want := []byte(next.(NDPTargetLinkLayerAddressOption)), buf[10:][:6]; !bytes.Equal(got, want) {
t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
}
- if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType {
- t.Errorf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType)
+ if got := next.kind(); got != ndpTargetLinkLayerAddressOptionType {
+ t.Errorf("got Type = %d, want = %d", got, ndpTargetLinkLayerAddressOptionType)
}
// Test the next (Prefix Information) option.
@@ -1503,8 +1434,8 @@ func TestNDPOptionsIter(t *testing.T) {
if got, want := next.(NDPPrefixInformation), buf[34:][:30]; !bytes.Equal(got, want) {
t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
}
- if got := next.Type(); got != NDPPrefixInformationType {
- t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType)
+ if got := next.kind(); got != ndpPrefixInformationType {
+ t.Errorf("got Type = %d, want = %d", got, ndpPrefixInformationType)
}
// Iterator should not return anything else.
diff --git a/pkg/tcpip/header/ndpoptionidentifier_string.go b/pkg/tcpip/header/ndpoptionidentifier_string.go
index 6fe9a336b..55ab1d7cf 100644
--- a/pkg/tcpip/header/ndpoptionidentifier_string.go
+++ b/pkg/tcpip/header/ndpoptionidentifier_string.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Code generated by "stringer -type NDPOptionIdentifier ."; DO NOT EDIT.
+// Code generated by "stringer -type ndpOptionIdentifier"; DO NOT EDIT.
package header
@@ -22,29 +22,37 @@ func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
- _ = x[NDPSourceLinkLayerAddressOptionType-1]
- _ = x[NDPTargetLinkLayerAddressOptionType-2]
- _ = x[NDPPrefixInformationType-3]
- _ = x[NDPRecursiveDNSServerOptionType-25]
+ _ = x[ndpSourceLinkLayerAddressOptionType-1]
+ _ = x[ndpTargetLinkLayerAddressOptionType-2]
+ _ = x[ndpPrefixInformationType-3]
+ _ = x[ndpNonceOptionType-14]
+ _ = x[ndpRecursiveDNSServerOptionType-25]
+ _ = x[ndpDNSSearchListOptionType-31]
}
const (
- _NDPOptionIdentifier_name_0 = "NDPSourceLinkLayerAddressOptionTypeNDPTargetLinkLayerAddressOptionTypeNDPPrefixInformationType"
- _NDPOptionIdentifier_name_1 = "NDPRecursiveDNSServerOptionType"
+ _ndpOptionIdentifier_name_0 = "ndpSourceLinkLayerAddressOptionTypendpTargetLinkLayerAddressOptionTypendpPrefixInformationType"
+ _ndpOptionIdentifier_name_1 = "ndpNonceOptionType"
+ _ndpOptionIdentifier_name_2 = "ndpRecursiveDNSServerOptionType"
+ _ndpOptionIdentifier_name_3 = "ndpDNSSearchListOptionType"
)
var (
- _NDPOptionIdentifier_index_0 = [...]uint8{0, 35, 70, 94}
+ _ndpOptionIdentifier_index_0 = [...]uint8{0, 35, 70, 94}
)
-func (i NDPOptionIdentifier) String() string {
+func (i ndpOptionIdentifier) String() string {
switch {
case 1 <= i && i <= 3:
i -= 1
- return _NDPOptionIdentifier_name_0[_NDPOptionIdentifier_index_0[i]:_NDPOptionIdentifier_index_0[i+1]]
+ return _ndpOptionIdentifier_name_0[_ndpOptionIdentifier_index_0[i]:_ndpOptionIdentifier_index_0[i+1]]
+ case i == 14:
+ return _ndpOptionIdentifier_name_1
case i == 25:
- return _NDPOptionIdentifier_name_1
+ return _ndpOptionIdentifier_name_2
+ case i == 31:
+ return _ndpOptionIdentifier_name_3
default:
- return "NDPOptionIdentifier(" + strconv.FormatInt(int64(i), 10) + ")"
+ return "ndpOptionIdentifier(" + strconv.FormatInt(int64(i), 10) + ")"
}
}
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index ae0461a6d..7ae38d684 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -38,6 +38,7 @@ const (
var _ stack.DuplicateAddressDetector = (*endpoint)(nil)
var _ stack.LinkAddressResolver = (*endpoint)(nil)
+var _ ip.DADProtocol = (*endpoint)(nil)
// ARP endpoints need to implement stack.NetworkEndpoint because the stack
// considers the layer above the link-layer a network layer; the only
@@ -82,7 +83,8 @@ func (*endpoint) DuplicateAddressProtocol() tcpip.NetworkProtocolNumber {
return header.IPv4ProtocolNumber
}
-func (e *endpoint) SendDADMessage(addr tcpip.Address) tcpip.Error {
+// SendDADMessage implements ip.DADProtocol.
+func (e *endpoint) SendDADMessage(addr tcpip.Address, _ []byte) tcpip.Error {
return e.sendARPRequest(header.IPv4Any, addr, header.EthernetBroadcastAddress)
}
@@ -284,9 +286,12 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran
e.mu.Lock()
e.mu.dad.Init(&e.mu, p.options.DADConfigs, ip.DADOptions{
- Clock: p.stack.Clock(),
- Protocol: e,
- NICID: nic.ID(),
+ Clock: p.stack.Clock(),
+ SecureRNG: p.stack.SecureRNG(),
+ // ARP does not support sending nonce values.
+ NonceSize: 0,
+ Protocol: e,
+ NICID: nic.ID(),
})
e.mu.Unlock()
@@ -305,8 +310,6 @@ func (*endpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error {
- nicID := e.nic.ID()
-
stats := e.stats.arp
if len(remoteLinkAddr) == 0 {
@@ -314,9 +317,9 @@ func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
}
if len(localAddr) == 0 {
- addr, ok := e.protocol.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber)
- if !ok {
- return &tcpip.ErrUnknownNICID{}
+ addr, err := e.nic.PrimaryAddress(header.IPv4ProtocolNumber)
+ if err != nil {
+ return err
}
if len(addr.Address) == 0 {
diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go
index 0053646ee..eed49f5d2 100644
--- a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go
+++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go
@@ -16,14 +16,27 @@
package ip
import (
+ "bytes"
"fmt"
+ "io"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+type extendRequest int
+
+const (
+ notRequested extendRequest = iota
+ requested
+ extended
+)
+
type dadState struct {
+ nonce []byte
+ extendRequest extendRequest
+
done *bool
timer tcpip.Timer
@@ -33,14 +46,17 @@ type dadState struct {
// DADProtocol is a protocol whose core state machine can be represented by DAD.
type DADProtocol interface {
// SendDADMessage attempts to send a DAD probe message.
- SendDADMessage(tcpip.Address) tcpip.Error
+ SendDADMessage(tcpip.Address, []byte) tcpip.Error
}
// DADOptions holds options for DAD.
type DADOptions struct {
- Clock tcpip.Clock
- Protocol DADProtocol
- NICID tcpip.NICID
+ Clock tcpip.Clock
+ SecureRNG io.Reader
+ NonceSize uint8
+ ExtendDADTransmits uint8
+ Protocol DADProtocol
+ NICID tcpip.NICID
}
// DAD performs duplicate address detection for addresses.
@@ -63,6 +79,10 @@ func (d *DAD) Init(protocolMU sync.Locker, configs stack.DADConfigurations, opts
panic("attempted to initialize DAD state twice")
}
+ if opts.NonceSize != 0 && opts.ExtendDADTransmits == 0 {
+ panic(fmt.Sprintf("given a non-zero value for NonceSize (%d) but zero for ExtendDADTransmits", opts.NonceSize))
+ }
+
*d = DAD{
opts: opts,
configs: configs,
@@ -96,10 +116,55 @@ func (d *DAD) CheckDuplicateAddressLocked(addr tcpip.Address, h stack.DADComplet
s = dadState{
done: &done,
timer: d.opts.Clock.AfterFunc(0, func() {
- var err tcpip.Error
dadDone := remaining == 0
+
+ nonce, earlyReturn := func() ([]byte, bool) {
+ d.protocolMU.Lock()
+ defer d.protocolMU.Unlock()
+
+ if done {
+ return nil, true
+ }
+
+ s, ok := d.addresses[addr]
+ if !ok {
+ panic(fmt.Sprintf("dad: timer fired but missing state for %s on NIC(%d)", addr, d.opts.NICID))
+ }
+
+ // As per RFC 7527 section 4
+ //
+ // If any probe is looped back within RetransTimer milliseconds
+ // after having sent DupAddrDetectTransmits NS(DAD) messages, the
+ // interface continues with another MAX_MULTICAST_SOLICIT number of
+ // NS(DAD) messages transmitted RetransTimer milliseconds apart.
+ if dadDone && s.extendRequest == requested {
+ dadDone = false
+ remaining = d.opts.ExtendDADTransmits
+ s.extendRequest = extended
+ }
+
+ if !dadDone && d.opts.NonceSize != 0 {
+ if s.nonce == nil {
+ s.nonce = make([]byte, d.opts.NonceSize)
+ }
+
+ if n, err := io.ReadFull(d.opts.SecureRNG, s.nonce); err != nil {
+ panic(fmt.Sprintf("SecureRNG.Read(...): %s", err))
+ } else if n != len(s.nonce) {
+ panic(fmt.Sprintf("expected to read %d bytes from secure RNG, only read %d bytes", len(s.nonce), n))
+ }
+ }
+
+ d.addresses[addr] = s
+ return s.nonce, false
+ }()
+ if earlyReturn {
+ return
+ }
+
+ var err tcpip.Error
if !dadDone {
- err = d.opts.Protocol.SendDADMessage(addr)
+ err = d.opts.Protocol.SendDADMessage(addr, nonce)
}
d.protocolMU.Lock()
@@ -142,6 +207,68 @@ func (d *DAD) CheckDuplicateAddressLocked(addr tcpip.Address, h stack.DADComplet
return ret
}
+// ExtendIfNonceEqualLockedDisposition enumerates the possible results from
+// ExtendIfNonceEqualLocked.
+type ExtendIfNonceEqualLockedDisposition int
+
+const (
+ // Extended indicates that the DAD process was extended.
+ Extended ExtendIfNonceEqualLockedDisposition = iota
+
+ // AlreadyExtended indicates that the DAD process was already extended.
+ AlreadyExtended
+
+ // NoDADStateFound indicates that DAD state was not found for the address.
+ NoDADStateFound
+
+ // NonceDisabled indicates that nonce values are not sent with DAD messages.
+ NonceDisabled
+
+ // NonceNotEqual indicates that the nonce value passed and the nonce in the
+ // last send DAD message are not equal.
+ NonceNotEqual
+)
+
+// ExtendIfNonceEqualLocked extends the DAD process if the provided nonce is the
+// same as the nonce sent in the last DAD message.
+//
+// Precondition: d.protocolMU must be locked.
+func (d *DAD) ExtendIfNonceEqualLocked(addr tcpip.Address, nonce []byte) ExtendIfNonceEqualLockedDisposition {
+ s, ok := d.addresses[addr]
+ if !ok {
+ return NoDADStateFound
+ }
+
+ if d.opts.NonceSize == 0 {
+ return NonceDisabled
+ }
+
+ if s.extendRequest != notRequested {
+ return AlreadyExtended
+ }
+
+ // As per RFC 7527 section 4
+ //
+ // If any probe is looped back within RetransTimer milliseconds after having
+ // sent DupAddrDetectTransmits NS(DAD) messages, the interface continues
+ // with another MAX_MULTICAST_SOLICIT number of NS(DAD) messages transmitted
+ // RetransTimer milliseconds apart.
+ //
+ // If a DAD message has already been sent and the nonce value we observed is
+ // the same as the nonce value we last sent, then we assume our probe was
+ // looped back and request an extension to the DAD process.
+ //
+ // Note, the first DAD message is sent asynchronously so we need to make sure
+ // that we sent a DAD message by checking if we have a nonce value set.
+ if s.nonce != nil && bytes.Equal(s.nonce, nonce) {
+ s.extendRequest = requested
+ d.addresses[addr] = s
+ return Extended
+ }
+
+ return NonceNotEqual
+}
+
// StopLocked stops a currently running DAD process.
//
// Precondition: d.protocolMU must be locked.
diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go
index e00aa4678..a22b712c6 100644
--- a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go
+++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go
@@ -15,6 +15,7 @@
package ip_test
import (
+ "bytes"
"testing"
"time"
@@ -32,8 +33,8 @@ type mockDADProtocol struct {
mu struct {
sync.Mutex
- dad ip.DAD
- sendCount map[tcpip.Address]int
+ dad ip.DAD
+ sentNonces map[tcpip.Address][][]byte
}
}
@@ -48,26 +49,30 @@ func (m *mockDADProtocol) init(t *testing.T, c stack.DADConfigurations, opts ip.
}
func (m *mockDADProtocol) initLocked() {
- m.mu.sendCount = make(map[tcpip.Address]int)
+ m.mu.sentNonces = make(map[tcpip.Address][][]byte)
}
-func (m *mockDADProtocol) SendDADMessage(addr tcpip.Address) tcpip.Error {
+func (m *mockDADProtocol) SendDADMessage(addr tcpip.Address, nonce []byte) tcpip.Error {
m.mu.Lock()
defer m.mu.Unlock()
- m.mu.sendCount[addr]++
+ m.mu.sentNonces[addr] = append(m.mu.sentNonces[addr], nonce)
return nil
}
func (m *mockDADProtocol) check(addrs []tcpip.Address) string {
- m.mu.Lock()
- defer m.mu.Unlock()
-
- sendCount := make(map[tcpip.Address]int)
+ sentNonces := make(map[tcpip.Address][][]byte)
for _, a := range addrs {
- sendCount[a]++
+ sentNonces[a] = append(sentNonces[a], nil)
}
- diff := cmp.Diff(sendCount, m.mu.sendCount)
+ return m.checkWithNonce(sentNonces)
+}
+
+func (m *mockDADProtocol) checkWithNonce(expectedSentNonces map[tcpip.Address][][]byte) string {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ diff := cmp.Diff(expectedSentNonces, m.mu.sentNonces)
m.initLocked()
return diff
}
@@ -84,6 +89,12 @@ func (m *mockDADProtocol) stop(addr tcpip.Address, reason stack.DADResult) {
m.mu.dad.StopLocked(addr, reason)
}
+func (m *mockDADProtocol) extendIfNonceEqual(addr tcpip.Address, nonce []byte) ip.ExtendIfNonceEqualLockedDisposition {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ return m.mu.dad.ExtendIfNonceEqualLocked(addr, nonce)
+}
+
func (m *mockDADProtocol) setConfigs(c stack.DADConfigurations) {
m.mu.Lock()
defer m.mu.Unlock()
@@ -277,3 +288,94 @@ func TestDADStop(t *testing.T) {
default:
}
}
+
+func TestNonce(t *testing.T) {
+ const (
+ nonceSize = 2
+
+ extendRequestAttempts = 2
+
+ dupAddrDetectTransmits = 2
+ extendTransmits = 5
+ )
+
+ var secureRNGBytes [nonceSize * (dupAddrDetectTransmits + extendTransmits)]byte
+ for i := range secureRNGBytes {
+ secureRNGBytes[i] = byte(i)
+ }
+
+ tests := []struct {
+ name string
+ mockedReceivedNonce []byte
+ expectedResults [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition
+ expectedTransmits int
+ }{
+ {
+ name: "not matching",
+ mockedReceivedNonce: []byte{0, 0},
+ expectedResults: [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition{ip.NonceNotEqual, ip.NonceNotEqual},
+ expectedTransmits: dupAddrDetectTransmits,
+ },
+ {
+ name: "matching nonce",
+ mockedReceivedNonce: secureRNGBytes[:nonceSize],
+ expectedResults: [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition{ip.Extended, ip.AlreadyExtended},
+ expectedTransmits: dupAddrDetectTransmits + extendTransmits,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ var dad mockDADProtocol
+ clock := faketime.NewManualClock()
+ dadConfigs := stack.DADConfigurations{
+ DupAddrDetectTransmits: dupAddrDetectTransmits,
+ RetransmitTimer: time.Second,
+ }
+
+ var secureRNG bytes.Reader
+ secureRNG.Reset(secureRNGBytes[:])
+ dad.init(t, dadConfigs, ip.DADOptions{
+ Clock: clock,
+ SecureRNG: &secureRNG,
+ NonceSize: nonceSize,
+ ExtendDADTransmits: extendTransmits,
+ })
+
+ ch := make(chan dadResult, 1)
+ if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting {
+ t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting)
+ }
+
+ clock.Advance(0)
+ for i, want := range test.expectedResults {
+ if got := dad.extendIfNonceEqual(addr1, test.mockedReceivedNonce); got != want {
+ t.Errorf("(i=%d) got dad.extendIfNonceEqual(%s, _) = %d, want = %d", i, addr1, got, want)
+ }
+ }
+
+ for i := 0; i < test.expectedTransmits; i++ {
+ if diff := dad.checkWithNonce(map[tcpip.Address][][]byte{
+ addr1: {
+ secureRNGBytes[nonceSize*i:][:nonceSize],
+ },
+ }); diff != "" {
+ t.Errorf("(i=%d) dad check mismatch (-want +got):\n%s", i, diff)
+ }
+
+ clock.Advance(dadConfigs.RetransmitTimer)
+ }
+
+ if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" {
+ t.Errorf("dad result mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have anymore updates.
+ select {
+ case r := <-ch:
+ t.Fatalf("unexpectedly got an extra DAD result; r = %#v", r)
+ default:
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index aee1652fa..a4edc69c7 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -335,6 +335,10 @@ func (*testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tc
return nil
}
+func (*testInterface) PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) {
+ return tcpip.AddressWithPrefix{}, nil
+}
+
func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool {
return false
}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 8a2140ebe..a1660e9a3 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -593,7 +593,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
// Check if the destination is owned by the stack.
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
- ep.handlePacket(pkt)
+ ep.handleValidatedPacket(h, pkt)
return nil
}
@@ -634,12 +634,25 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
- if !e.protocol.parse(pkt) {
+ h, ok := e.protocol.parseAndValidate(pkt)
+ if !ok {
stats.MalformedPacketsReceived.Increment()
return
}
if !e.nic.IsLoopback() {
+ if !e.protocol.options.AllowExternalLoopbackTraffic {
+ if header.IsV4LoopbackAddress(h.SourceAddress()) {
+ stats.InvalidSourceAddressesReceived.Increment()
+ return
+ }
+
+ if header.IsV4LoopbackAddress(h.DestinationAddress()) {
+ stats.InvalidDestinationAddressesReceived.Increment()
+ return
+ }
+ }
+
if e.protocol.stack.HandleLocal() {
addressEndpoint := e.AcquireAssignedAddress(header.IPv4(pkt.NetworkHeader().View()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
if addressEndpoint != nil {
@@ -662,62 +675,32 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
}
}
- e.handlePacket(pkt)
+ e.handleValidatedPacket(h, pkt)
}
+// handleLocalPacket is like HandlePacket except it does not perform the
+// prerouting iptables hook or check for loopback traffic that originated from
+// outside of the netstack (i.e. martian loopback packets).
func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum bool) {
stats := e.stats.ip
-
stats.PacketsReceived.Increment()
pkt = pkt.CloneToInbound()
- if e.protocol.parse(pkt) {
- pkt.RXTransportChecksumValidated = canSkipRXChecksum
- e.handlePacket(pkt)
+ pkt.RXTransportChecksumValidated = canSkipRXChecksum
+
+ h, ok := e.protocol.parseAndValidate(pkt)
+ if !ok {
+ stats.MalformedPacketsReceived.Increment()
return
}
- stats.MalformedPacketsReceived.Increment()
+ e.handleValidatedPacket(h, pkt)
}
-// handlePacket is like HandlePacket except it does not perform the prerouting
-// iptables hook.
-func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
+func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) {
pkt.NICID = e.nic.ID()
stats := e.stats
- h := header.IPv4(pkt.NetworkHeader().View())
- if !h.IsValid(pkt.Data().Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
- stats.ip.MalformedPacketsReceived.Increment()
- return
- }
-
- // There has been some confusion regarding verifying checksums. We need
- // just look for negative 0 (0xffff) as the checksum, as it's not possible to
- // get positive 0 (0) for the checksum. Some bad implementations could get it
- // when doing entry replacement in the early days of the Internet,
- // however the lore that one needs to check for both persists.
- //
- // RFC 1624 section 1 describes the source of this confusion as:
- // [the partial recalculation method described in RFC 1071] computes a
- // result for certain cases that differs from the one obtained from
- // scratch (one's complement of one's complement sum of the original
- // fields).
- //
- // However RFC 1624 section 5 clarifies that if using the verification method
- // "recommended by RFC 1071, it does not matter if an intermediate system
- // generated a -0 instead of +0".
- //
- // RFC1071 page 1 specifies the verification method as:
- // (3) To check a checksum, the 1's complement sum is computed over the
- // same set of octets, including the checksum field. If the result
- // is all 1 bits (-0 in 1's complement arithmetic), the check
- // succeeds.
- if h.CalculateChecksum() != 0xffff {
- stats.ip.MalformedPacketsReceived.Increment()
- return
- }
-
srcAddr := h.SourceAddress()
dstAddr := h.DestinationAddress()
@@ -1114,13 +1097,46 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
-// parse is like Parse but also attempts to parse the transport layer.
+// parseAndValidate parses the packet (including its transport layer header) and
+// returns the parsed IP header.
//
-// Returns true if the network header was successfully parsed.
-func (p *protocol) parse(pkt *stack.PacketBuffer) bool {
+// Returns true if the IP header was successfully parsed.
+func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv4, bool) {
transProtoNum, hasTransportHdr, ok := p.Parse(pkt)
if !ok {
- return false
+ return nil, false
+ }
+
+ h := header.IPv4(pkt.NetworkHeader().View())
+ // Do not include the link header's size when calculating the size of the IP
+ // packet.
+ if !h.IsValid(pkt.Size() - pkt.LinkHeader().View().Size()) {
+ return nil, false
+ }
+
+ // There has been some confusion regarding verifying checksums. We need
+ // just look for negative 0 (0xffff) as the checksum, as it's not possible to
+ // get positive 0 (0) for the checksum. Some bad implementations could get it
+ // when doing entry replacement in the early days of the Internet,
+ // however the lore that one needs to check for both persists.
+ //
+ // RFC 1624 section 1 describes the source of this confusion as:
+ // [the partial recalculation method described in RFC 1071] computes a
+ // result for certain cases that differs from the one obtained from
+ // scratch (one's complement of one's complement sum of the original
+ // fields).
+ //
+ // However RFC 1624 section 5 clarifies that if using the verification method
+ // "recommended by RFC 1071, it does not matter if an intermediate system
+ // generated a -0 instead of +0".
+ //
+ // RFC1071 page 1 specifies the verification method as:
+ // (3) To check a checksum, the 1's complement sum is computed over the
+ // same set of octets, including the checksum field. If the result
+ // is all 1 bits (-0 in 1's complement arithmetic), the check
+ // succeeds.
+ if h.CalculateChecksum() != 0xffff {
+ return nil, false
}
if hasTransportHdr {
@@ -1134,7 +1150,7 @@ func (p *protocol) parse(pkt *stack.PacketBuffer) bool {
}
}
- return true
+ return h, true
}
// Parse implements stack.NetworkProtocol.Parse.
@@ -1213,6 +1229,10 @@ func hashRoute(srcAddr, dstAddr tcpip.Address, protocol tcpip.TransportProtocolN
type Options struct {
// IGMP holds options for IGMP.
IGMP IGMPOptions
+
+ // AllowExternalLoopbackTraffic indicates that inbound loopback packets (i.e.
+ // martian loopback packets) should be accepted.
+ AllowExternalLoopbackTraffic bool
}
// NewProtocolWithOptions returns an IPv4 network protocol.
@@ -1599,9 +1619,8 @@ func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Opt
// TODO(https://gvisor.dev/issue/4586): This will need tweaking when we start
// really forwarding packets as we may need to get two addresses, for rx and
// tx interfaces. We will also have to take usage into account.
- prefixedAddress, ok := e.protocol.stack.GetMainNICAddress(e.nic.ID(), ProtocolNumber)
- localAddress := prefixedAddress.Address
- if !ok {
+ localAddress := e.MainAddress().Address
+ if len(localAddress) == 0 {
h := header.IPv4(pkt.NetworkHeader().View())
dstAddr := h.DestinationAddress()
if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(dstAddr) {
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 6344a3e09..2afa856dc 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -369,6 +369,18 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
return
}
+ var it header.NDPOptionIterator
+ {
+ var err error
+ it, err = ns.Options().Iter(false /* check */)
+ if err != nil {
+ // Options are not valid as per the wire format, silently drop the
+ // packet.
+ received.invalid.Increment()
+ return
+ }
+ }
+
if e.hasTentativeAddr(targetAddr) {
// If the target address is tentative and the source of the packet is a
// unicast (specified) address, then the source of the packet is
@@ -382,6 +394,22 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
// stack know so it can handle such a scenario and do nothing further with
// the NS.
if srcAddr == header.IPv6Any {
+ var nonce []byte
+ for {
+ opt, done, err := it.Next()
+ if err != nil {
+ received.invalid.Increment()
+ return
+ }
+ if done {
+ break
+ }
+ if n, ok := opt.(header.NDPNonceOption); ok {
+ nonce = n.Nonce()
+ break
+ }
+ }
+
// Since this is a DAD message we know the sender does not actually hold
// the target address so there is no "holder".
var holderLinkAddress tcpip.LinkAddress
@@ -397,7 +425,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
//
// TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate
// address is detected for an assigned address.
- switch err := e.dupTentativeAddrDetected(targetAddr, holderLinkAddress); err.(type) {
+ switch err := e.dupTentativeAddrDetected(targetAddr, holderLinkAddress, nonce); err.(type) {
case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState:
default:
panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err))
@@ -418,21 +446,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
return
}
- var sourceLinkAddr tcpip.LinkAddress
- {
- it, err := ns.Options().Iter(false /* check */)
- if err != nil {
- // Options are not valid as per the wire format, silently drop the
- // packet.
- received.invalid.Increment()
- return
- }
-
- sourceLinkAddr, ok = getSourceLinkAddr(it)
- if !ok {
- received.invalid.Increment()
- return
- }
+ sourceLinkAddr, ok := getSourceLinkAddr(it)
+ if !ok {
+ received.invalid.Increment()
+ return
}
// As per RFC 4861 section 4.3, the Source Link-Layer Address Option MUST
@@ -586,6 +603,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
e.dad.mu.Unlock()
if e.hasTentativeAddr(targetAddr) {
+ // We only send a nonce value in DAD messages to check for loopedback
+ // messages so we use the empty nonce value here.
+ var nonce []byte
+
// We just got an NA from a node that owns an address we are performing
// DAD on, implying the address is not unique. In this case we let the
// stack know so it can handle such a scenario and do nothing furthur with
@@ -602,7 +623,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
//
// TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate
// address is detected for an assigned address.
- switch err := e.dupTentativeAddrDetected(targetAddr, targetLinkAddr); err.(type) {
+ switch err := e.dupTentativeAddrDetected(targetAddr, targetLinkAddr, nonce); err.(type) {
case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState:
return
default:
@@ -899,13 +920,16 @@ func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
}
if len(localAddr) == 0 {
+ // Find an address that we can use as our source address.
addressEndpoint := e.AcquireOutgoingPrimaryAddress(remoteAddr, false /* allowExpired */)
if addressEndpoint == nil {
return &tcpip.ErrNetworkUnreachable{}
}
localAddr = addressEndpoint.AddressWithPrefix().Address
- } else if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, localAddr) == 0 {
+ addressEndpoint.DecRef()
+ } else if !e.checkLocalAddress(localAddr) {
+ // The provided local address is not assigned to us.
return &tcpip.ErrBadLocalAddress{}
}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index d4e63710c..47d713f88 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -155,6 +155,10 @@ func (t *testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber,
return nil
}
+func (*testInterface) PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) {
+ return tcpip.AddressWithPrefix{}, nil
+}
+
func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool {
return false
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 46b6cc41a..83e98bab9 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -348,7 +348,7 @@ func (e *endpoint) hasTentativeAddr(addr tcpip.Address) bool {
// dupTentativeAddrDetected removes the tentative address if it exists. If the
// address was generated via SLAAC, an attempt is made to generate a new
// address.
-func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr tcpip.LinkAddress) tcpip.Error {
+func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr tcpip.LinkAddress, nonce []byte) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -361,27 +361,48 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr t
return &tcpip.ErrInvalidEndpointState{}
}
- // If the address is a SLAAC address, do not invalidate its SLAAC prefix as an
- // attempt will be made to generate a new address for it.
- if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */, &stack.DADDupAddrDetected{HolderLinkAddress: holderLinkAddr}); err != nil {
- return err
- }
+ switch result := e.mu.ndp.dad.ExtendIfNonceEqualLocked(addr, nonce); result {
+ case ip.Extended:
+ // The nonce we got back was the same we sent so we know the message
+ // indicating a duplicate address was likely ours so do not consider
+ // the address duplicate here.
+ return nil
+ case ip.AlreadyExtended:
+ // See Extended.
+ //
+ // Our DAD message was looped back already.
+ return nil
+ case ip.NoDADStateFound:
+ panic(fmt.Sprintf("expected DAD state for tentative address %s", addr))
+ case ip.NonceDisabled:
+ // If nonce is disabled then we have no way to know if the packet was
+ // looped-back so we have to assume it indicates a duplicate address.
+ fallthrough
+ case ip.NonceNotEqual:
+ // If the address is a SLAAC address, do not invalidate its SLAAC prefix as an
+ // attempt will be made to generate a new address for it.
+ if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */, &stack.DADDupAddrDetected{HolderLinkAddress: holderLinkAddr}); err != nil {
+ return err
+ }
- prefix := addressEndpoint.Subnet()
+ prefix := addressEndpoint.Subnet()
- switch t := addressEndpoint.ConfigType(); t {
- case stack.AddressConfigStatic:
- case stack.AddressConfigSlaac:
- e.mu.ndp.regenerateSLAACAddr(prefix)
- case stack.AddressConfigSlaacTemp:
- // Do not reset the generation attempts counter for the prefix as the
- // temporary address is being regenerated in response to a DAD conflict.
- e.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */)
+ switch t := addressEndpoint.ConfigType(); t {
+ case stack.AddressConfigStatic:
+ case stack.AddressConfigSlaac:
+ e.mu.ndp.regenerateSLAACAddr(prefix)
+ case stack.AddressConfigSlaacTemp:
+ // Do not reset the generation attempts counter for the prefix as the
+ // temporary address is being regenerated in response to a DAD conflict.
+ e.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */)
+ default:
+ panic(fmt.Sprintf("unrecognized address config type = %d", t))
+ }
+
+ return nil
default:
- panic(fmt.Sprintf("unrecognized address config type = %d", t))
+ panic(fmt.Sprintf("unhandled result = %d", result))
}
-
- return nil
}
// transitionForwarding transitions the endpoint's forwarding status to
@@ -863,9 +884,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
dstAddr := h.DestinationAddress()
// Check if the destination is owned by the stack.
-
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
- ep.handlePacket(pkt)
+ ep.handleValidatedPacket(h, pkt)
return nil
}
@@ -904,12 +924,25 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
- if !e.protocol.parse(pkt) {
+ h, ok := e.protocol.parseAndValidate(pkt)
+ if !ok {
stats.MalformedPacketsReceived.Increment()
return
}
if !e.nic.IsLoopback() {
+ if !e.protocol.options.AllowExternalLoopbackTraffic {
+ if header.IsV6LoopbackAddress(h.SourceAddress()) {
+ stats.InvalidSourceAddressesReceived.Increment()
+ return
+ }
+
+ if header.IsV6LoopbackAddress(h.DestinationAddress()) {
+ stats.InvalidDestinationAddressesReceived.Increment()
+ return
+ }
+ }
+
if e.protocol.stack.HandleLocal() {
addressEndpoint := e.AcquireAssignedAddress(header.IPv6(pkt.NetworkHeader().View()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
if addressEndpoint != nil {
@@ -932,35 +965,31 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
}
}
- e.handlePacket(pkt)
+ e.handleValidatedPacket(h, pkt)
}
+// handleLocalPacket is like HandlePacket except it does not perform the
+// prerouting iptables hook or check for loopback traffic that originated from
+// outside of the netstack (i.e. martian loopback packets).
func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum bool) {
stats := e.stats.ip
-
stats.PacketsReceived.Increment()
pkt = pkt.CloneToInbound()
- if e.protocol.parse(pkt) {
- pkt.RXTransportChecksumValidated = canSkipRXChecksum
- e.handlePacket(pkt)
+ pkt.RXTransportChecksumValidated = canSkipRXChecksum
+
+ h, ok := e.protocol.parseAndValidate(pkt)
+ if !ok {
+ stats.MalformedPacketsReceived.Increment()
return
}
- stats.MalformedPacketsReceived.Increment()
+ e.handleValidatedPacket(h, pkt)
}
-// handlePacket is like HandlePacket except it does not perform the prerouting
-// iptables hook.
-func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
+func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) {
pkt.NICID = e.nic.ID()
stats := e.stats.ip
-
- h := header.IPv6(pkt.NetworkHeader().View())
- if !h.IsValid(pkt.Data().Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
- stats.MalformedPacketsReceived.Increment()
- return
- }
srcAddr := h.SourceAddress()
dstAddr := h.DestinationAddress()
@@ -1797,16 +1826,36 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran
dispatcher: dispatcher,
protocol: p,
}
+
+ // NDP options must be 8 octet aligned and the first 2 bytes are used for
+ // the type and length fields leaving 6 octets as the minimum size for a
+ // nonce option without padding.
+ const nonceSize = 6
+
+ // As per RFC 7527 section 4.1,
+ //
+ // If any probe is looped back within RetransTimer milliseconds after
+ // having sent DupAddrDetectTransmits NS(DAD) messages, the interface
+ // continues with another MAX_MULTICAST_SOLICIT number of NS(DAD)
+ // messages transmitted RetransTimer milliseconds apart.
+ //
+ // Value taken from RFC 4861 section 10.
+ const maxMulticastSolicit = 3
+ dadOptions := ip.DADOptions{
+ Clock: p.stack.Clock(),
+ SecureRNG: p.stack.SecureRNG(),
+ NonceSize: nonceSize,
+ ExtendDADTransmits: maxMulticastSolicit,
+ Protocol: &e.mu.ndp,
+ NICID: nic.ID(),
+ }
+
e.mu.Lock()
e.mu.addressableEndpointState.Init(e)
- e.mu.ndp.init(e)
+ e.mu.ndp.init(e, dadOptions)
e.mu.mld.init(e)
e.dad.mu.Lock()
- e.dad.mu.dad.Init(&e.dad.mu, p.options.DADConfigs, ip.DADOptions{
- Clock: p.stack.Clock(),
- Protocol: &e.mu.ndp,
- NICID: nic.ID(),
- })
+ e.dad.mu.dad.Init(&e.dad.mu, p.options.DADConfigs, dadOptions)
e.dad.mu.Unlock()
e.mu.Unlock()
@@ -1879,13 +1928,21 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
-// parse is like Parse but also attempts to parse the transport layer.
+// parseAndValidate parses the packet (including its transport layer header) and
+// returns the parsed IP header.
//
-// Returns true if the network header was successfully parsed.
-func (p *protocol) parse(pkt *stack.PacketBuffer) bool {
+// Returns true if the IP header was successfully parsed.
+func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv6, bool) {
transProtoNum, hasTransportHdr, ok := p.Parse(pkt)
if !ok {
- return false
+ return nil, false
+ }
+
+ h := header.IPv6(pkt.NetworkHeader().View())
+ // Do not include the link header's size when calculating the size of the IP
+ // packet.
+ if !h.IsValid(pkt.Size() - pkt.LinkHeader().View().Size()) {
+ return nil, false
}
if hasTransportHdr {
@@ -1899,7 +1956,7 @@ func (p *protocol) parse(pkt *stack.PacketBuffer) bool {
}
}
- return true
+ return h, true
}
// Parse implements stack.NetworkProtocol.Parse.
@@ -2013,6 +2070,10 @@ type Options struct {
// DADConfigs holds the default DAD configurations used by IPv6 endpoints.
DADConfigs stack.DADConfigurations
+
+ // AllowExternalLoopbackTraffic indicates that inbound loopback packets (i.e.
+ // martian loopback packets) should be accepted.
+ AllowExternalLoopbackTraffic bool
}
// NewProtocolWithOptions returns an IPv6 network protocol.
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 266a53e3b..81f5f23c3 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -343,6 +343,8 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
// TestAddIpv6Address tests adding IPv6 addresses.
func TestAddIpv6Address(t *testing.T) {
+ const nicID = 1
+
tests := []struct {
name string
addr tcpip.Address
@@ -367,18 +369,18 @@ func TestAddIpv6Address(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
- if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(1, ProtocolNumber, test.addr); err != nil {
- t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err)
+ if err := s.AddAddress(nicID, ProtocolNumber, test.addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, nil) = %s", nicID, ProtocolNumber, err)
}
- if addr, ok := s.GetMainNICAddress(1, header.IPv6ProtocolNumber); !ok {
- t.Fatalf("got stack.GetMainNICAddress(1, %d) = (_, false), want = (_, true)", header.IPv6ProtocolNumber)
+ if addr, err := s.GetMainNICAddress(nicID, ProtocolNumber); err != nil {
+ t.Fatalf("stack.GetMainNICAddress(%d, %d): %s", nicID, ProtocolNumber, err)
} else if addr.Address != test.addr {
- t.Fatalf("got stack.GetMainNICAddress(1_, %d) = (%s, true), want = (%s, true)", header.IPv6ProtocolNumber, addr.Address, test.addr)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, ProtocolNumber, addr.Address, test.addr)
}
})
}
diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go
index 9a425e50a..85a8f9944 100644
--- a/pkg/tcpip/network/ipv6/mld_test.go
+++ b/pkg/tcpip/network/ipv6/mld_test.go
@@ -15,6 +15,7 @@
package ipv6_test
import (
+ "bytes"
"testing"
"time"
@@ -119,11 +120,26 @@ func TestSendQueuedMLDReports(t *testing.T) {
},
}
+ nonce := [...]byte{
+ 1, 2, 3, 4, 5, 6,
+ }
+
+ const maxNSMessages = 2
+ secureRNGBytes := make([]byte, len(nonce)*maxNSMessages)
+ for b := secureRNGBytes[:]; len(b) > 0; b = b[len(nonce):] {
+ if n := copy(b, nonce[:]); n != len(nonce) {
+ t.Fatalf("got copy(...) = %d, want = %d", n, len(nonce))
+ }
+ }
+
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
dadResolutionTime := test.retransmitTimer * time.Duration(test.dadTransmits)
clock := faketime.NewManualClock()
+ var secureRNG bytes.Reader
+ secureRNG.Reset(secureRNGBytes[:])
s := stack.New(stack.Options{
+ SecureRNG: &secureRNG,
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
DADConfigs: stack.DADConfigurations{
DupAddrDetectTransmits: test.dadTransmits,
@@ -154,7 +170,7 @@ func TestSendQueuedMLDReports(t *testing.T) {
checker.TTL(header.NDPHopLimit),
checker.NDPNS(
checker.NDPNSTargetAddress(addr),
- checker.NDPNSOptions(nil),
+ checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonce[:])}),
))
}
}
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index d9b728878..536493f87 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -1789,18 +1789,14 @@ func (ndp *ndpState) stopSolicitingRouters() {
ndp.rtrSolicitTimer = timer{}
}
-func (ndp *ndpState) init(ep *endpoint) {
+func (ndp *ndpState) init(ep *endpoint, dadOptions ip.DADOptions) {
if ndp.defaultRouters != nil {
panic("attempted to initialize NDP state twice")
}
ndp.ep = ep
ndp.configs = ep.protocol.options.NDPConfigs
- ndp.dad.Init(&ndp.ep.mu, ep.protocol.options.DADConfigs, ip.DADOptions{
- Clock: ep.protocol.stack.Clock(),
- Protocol: ndp,
- NICID: ep.nic.ID(),
- })
+ ndp.dad.Init(&ndp.ep.mu, ep.protocol.options.DADConfigs, dadOptions)
ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState)
ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState)
ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState)
@@ -1811,9 +1807,11 @@ func (ndp *ndpState) init(ep *endpoint) {
}
}
-func (ndp *ndpState) SendDADMessage(addr tcpip.Address) tcpip.Error {
+func (ndp *ndpState) SendDADMessage(addr tcpip.Address, nonce []byte) tcpip.Error {
snmc := header.SolicitedNodeAddr(addr)
- return ndp.ep.sendNDPNS(header.IPv6Any, snmc, addr, header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* opts */)
+ return ndp.ep.sendNDPNS(header.IPv6Any, snmc, addr, header.EthernetAddressFromMulticastIPv6Address(snmc), header.NDPOptionsSerializer{
+ header.NDPNonceOption(nonce),
+ })
}
func (e *endpoint) sendNDPNS(srcAddr, dstAddr, targetAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, opts header.NDPOptionsSerializer) tcpip.Error {
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 6e850fd46..52b9a200c 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -15,6 +15,7 @@
package ipv6
import (
+ "bytes"
"context"
"strings"
"testing"
@@ -1264,8 +1265,21 @@ func TestCheckDuplicateAddress(t *testing.T) {
DupAddrDetectTransmits: 1,
RetransmitTimer: time.Second,
}
+
+ nonces := [...][]byte{
+ {1, 2, 3, 4, 5, 6},
+ {7, 8, 9, 10, 11, 12},
+ }
+
+ var secureRNGBytes []byte
+ for _, n := range nonces {
+ secureRNGBytes = append(secureRNGBytes, n...)
+ }
+ var secureRNG bytes.Reader
+ secureRNG.Reset(secureRNGBytes[:])
s := stack.New(stack.Options{
- Clock: clock,
+ SecureRNG: &secureRNG,
+ Clock: clock,
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{
DADConfigs: dadConfigs,
})},
@@ -1278,10 +1292,36 @@ func TestCheckDuplicateAddress(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- dadPacketsSent := 1
+ dadPacketsSent := 0
+ snmc := header.SolicitedNodeAddr(lladdr0)
+ remoteLinkAddr := header.EthernetAddressFromMulticastIPv6Address(snmc)
+ checkDADMsg := func() {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("expected %d-th DAD message", dadPacketsSent)
+ }
+
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Errorf("(i=%d) got p.Proto = %d, want = %d", dadPacketsSent, p.Proto, header.IPv6ProtocolNumber)
+ }
+
+ if p.Route.RemoteLinkAddress != remoteLinkAddr {
+ t.Errorf("(i=%d) got p.Route.RemoteLinkAddress = %s, want = %s", dadPacketsSent, p.Route.RemoteLinkAddress, remoteLinkAddr)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(snmc),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(lladdr0),
+ checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonces[dadPacketsSent])}),
+ ))
+ }
if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
}
+ checkDADMsg()
// Start DAD for the address we just added.
//
@@ -1297,6 +1337,7 @@ func TestCheckDuplicateAddress(t *testing.T) {
} else if res != stack.DADStarting {
t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, ProtocolNumber, lladdr0, res, stack.DADStarting)
}
+ checkDADMsg()
// Remove the address and make sure our DAD request was not stopped.
if err := s.RemoveAddress(nicID, lladdr0); err != nil {
@@ -1328,33 +1369,6 @@ func TestCheckDuplicateAddress(t *testing.T) {
default:
}
- snmc := header.SolicitedNodeAddr(lladdr0)
- remoteLinkAddr := header.EthernetAddressFromMulticastIPv6Address(snmc)
-
- for i := 0; i < dadPacketsSent; i++ {
- p, ok := e.Read()
- if !ok {
- t.Fatalf("expected %d-th DAD message", i)
- }
-
- if p.Proto != header.IPv6ProtocolNumber {
- t.Errorf("(i=%d) got p.Proto = %d, want = %d", i, p.Proto, header.IPv6ProtocolNumber)
- }
-
- if p.Route.RemoteLinkAddress != remoteLinkAddr {
- t.Errorf("(i=%d) got p.Route.RemoteLinkAddress = %s, want = %s", i, p.Route.RemoteLinkAddress, remoteLinkAddr)
- }
-
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(snmc),
- checker.TTL(header.NDPHopLimit),
- checker.NDPNS(
- checker.NDPNSTargetAddress(lladdr0),
- checker.NDPNSOptions(nil),
- ))
- }
-
// Should have no more packets.
if p, ok := e.Read(); ok {
t.Errorf("got unexpected packet = %#v", p)
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 47796a6ba..43e6d102c 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -15,6 +15,7 @@
package stack_test
import (
+ "bytes"
"context"
"encoding/binary"
"fmt"
@@ -29,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
@@ -358,6 +360,66 @@ func TestDADDisabled(t *testing.T) {
}
}
+func TestDADResolveLoopback(t *testing.T) {
+ const nicID = 1
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ }
+
+ dadConfigs := stack.DADConfigurations{
+ RetransmitTimer: time.Second,
+ DupAddrDetectTransmits: 1,
+ }
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ Clock: clock,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPDisp: &ndpDisp,
+ DADConfigs: dadConfigs,
+ })},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: addr1,
+ PrefixLen: defaultPrefixLen,
+ }
+ if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
+ t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
+ }
+
+ // Address should not be considered bound to the NIC yet (DAD ongoing).
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
+ }
+
+ // DAD should not resolve after the normal resolution time since our DAD
+ // message was looped back - we should extend our DAD process.
+ dadResolutionTime := time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer
+ clock.Advance(dadResolutionTime)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Error(err)
+ }
+
+ // Make sure the address does not resolve before the extended resolution time
+ // has passed.
+ const delta = time.Nanosecond
+ // DAD will send extra NS probes if an NS message is looped back.
+ const extraTransmits = 3
+ clock.Advance(dadResolutionTime*extraTransmits - delta)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Error(err)
+ }
+
+ // DAD should now resolve.
+ clock.Advance(delta)
+ if diff := checkDADEvent(<-ndpDisp.dadC, nicID, addr1, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
+ }
+}
+
// TestDADResolve tests that an address successfully resolves after performing
// DAD for various values of DupAddrDetectTransmits and RetransmitTimer.
// Included in the subtests is a test to make sure that an invalid
@@ -404,6 +466,16 @@ func TestDADResolve(t *testing.T) {
},
}
+ nonces := [][]byte{
+ {1, 2, 3, 4, 5, 6},
+ {7, 8, 9, 10, 11, 12},
+ }
+
+ var secureRNGBytes []byte
+ for _, n := range nonces {
+ secureRNGBytes = append(secureRNGBytes, n...)
+ }
+
for _, test := range tests {
test := test
@@ -419,7 +491,12 @@ func TestDADResolve(t *testing.T) {
headerLength: test.linkHeaderLen,
}
e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ var secureRNG bytes.Reader
+ secureRNG.Reset(secureRNGBytes)
+
s := stack.New(stack.Options{
+ SecureRNG: &secureRNG,
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPDisp: &ndpDisp,
DADConfigs: stack.DADConfigurations{
@@ -553,7 +630,7 @@ func TestDADResolve(t *testing.T) {
checker.TTL(header.NDPHopLimit),
checker.NDPNS(
checker.NDPNSTargetAddress(addr1),
- checker.NDPNSOptions(nil),
+ checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonces[i])}),
))
if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 62f7c880e..ca15c0691 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -568,23 +568,19 @@ func (n *nic) primaryAddresses() []tcpip.ProtocolAddress {
return addrs
}
-// primaryAddress returns the primary address associated with this NIC.
-//
-// primaryAddress will return the first non-deprecated address if such an
-// address exists. If no non-deprecated address exists, the first deprecated
-// address will be returned.
-func (n *nic) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix {
+// PrimaryAddress implements NetworkInterface.
+func (n *nic) PrimaryAddress(proto tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) {
ep, ok := n.networkEndpoints[proto]
if !ok {
- return tcpip.AddressWithPrefix{}
+ return tcpip.AddressWithPrefix{}, &tcpip.ErrUnknownProtocol{}
}
addressableEndpoint, ok := ep.(AddressableEndpoint)
if !ok {
- return tcpip.AddressWithPrefix{}
+ return tcpip.AddressWithPrefix{}, &tcpip.ErrNotSupported{}
}
- return addressableEndpoint.MainAddress()
+ return addressableEndpoint.MainAddress(), nil
}
// removeAddress removes an address from n.
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 85f0f471a..ff3a385e1 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -525,6 +525,14 @@ type NetworkInterface interface {
// assigned to it.
Spoofing() bool
+ // PrimaryAddress returns the primary address associated with the interface.
+ //
+ // PrimaryAddress will return the first non-deprecated address if such an
+ // address exists. If no non-deprecated addresses exist, the first deprecated
+ // address will be returned. If no deprecated addresses exist, the zero value
+ // will be returned.
+ PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error)
+
// CheckLocalAddress returns true if the address exists on the interface.
CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 53370c354..931a97ddc 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -23,6 +23,7 @@ import (
"bytes"
"encoding/binary"
"fmt"
+ "io"
mathrand "math/rand"
"sync/atomic"
"time"
@@ -445,6 +446,9 @@ type Stack struct {
// used when a random number is required.
randomGenerator *mathrand.Rand
+ // secureRNG is a cryptographically secure random number generator.
+ secureRNG io.Reader
+
// sendBufferSize holds the min/default/max send buffer sizes for
// endpoints other than TCP.
sendBufferSize tcpip.SendBufferSizeOption
@@ -528,6 +532,9 @@ type Options struct {
// IPTables are the initial iptables rules. If nil, iptables will allow
// all traffic.
IPTables *IPTables
+
+ // SecureRNG is a cryptographically secure random number generator.
+ SecureRNG io.Reader
}
// TransportEndpointInfo holds useful information about a transport endpoint
@@ -636,6 +643,10 @@ func New(opts Options) *Stack {
opts.NUDConfigs.resetInvalidFields()
+ if opts.SecureRNG == nil {
+ opts.SecureRNG = rand.Reader
+ }
+
s := &Stack{
transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
@@ -652,6 +663,7 @@ func New(opts Options) *Stack {
uniqueIDGenerator: opts.UniqueID,
nudDisp: opts.NUDDisp,
randomGenerator: mathrand.New(randSrc),
+ secureRNG: opts.SecureRNG,
sendBufferSize: tcpip.SendBufferSizeOption{
Min: MinBufferSize,
Default: DefaultBufferSize,
@@ -1211,20 +1223,19 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress {
}
// GetMainNICAddress returns the first non-deprecated primary address and prefix
-// for the given NIC and protocol. If no non-deprecated primary address exists,
-// a deprecated primary address and prefix will be returned. Returns false if
-// the NIC doesn't exist and an empty value if the NIC doesn't have a primary
-// address for the given protocol.
-func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, bool) {
+// for the given NIC and protocol. If no non-deprecated primary addresses exist,
+// a deprecated address will be returned. If no deprecated addresses exist, the
+// zero value will be returned.
+func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
nic, ok := s.nics[id]
if !ok {
- return tcpip.AddressWithPrefix{}, false
+ return tcpip.AddressWithPrefix{}, &tcpip.ErrUnknownNICID{}
}
- return nic.primaryAddress(protocol), true
+ return nic.PrimaryAddress(protocol)
}
func (s *Stack) getAddressEP(nic *nic, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) AssignableAddressEndpoint {
@@ -2047,6 +2058,12 @@ func (s *Stack) Rand() *mathrand.Rand {
return s.randomGenerator
}
+// SecureRNG returns the stack's cryptographically secure random number
+// generator.
+func (s *Stack) SecureRNG() io.Reader {
+ return s.secureRNG
+}
+
func generateRandUint32() uint32 {
b := make([]byte, 4)
if _, err := rand.Read(b); err != nil {
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 880219007..7ddf7a083 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -62,10 +62,10 @@ const (
)
func checkGetMainNICAddress(s *stack.Stack, nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber, want tcpip.AddressWithPrefix) error {
- if addr, ok := s.GetMainNICAddress(nicID, proto); !ok {
- return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, false), want = (_, true)", nicID, proto)
+ if addr, err := s.GetMainNICAddress(nicID, proto); err != nil {
+ return fmt.Errorf("stack.GetMainNICAddress(%d, %d): %s", nicID, proto, err)
} else if addr != want {
- return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, true), want = (%s, true)", nicID, proto, addr, want)
+ return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, proto, addr, want)
}
return nil
}
@@ -1854,6 +1854,8 @@ func TestNetworkOption(t *testing.T) {
}
func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
+ const nicID = 1
+
for _, addrLen := range []int{4, 16} {
t.Run(fmt.Sprintf("addrLen=%d", addrLen), func(t *testing.T) {
for canBe := 0; canBe < 3; canBe++ {
@@ -1864,8 +1866,8 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
+ if err := s.CreateNIC(nicID, ep); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
// Insert <canBe> primary and <never> never-primary addresses.
// Each one will add a network endpoint to the NIC.
@@ -1888,34 +1890,34 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
PrefixLen: addrLen * 8,
},
}
- if err := s.AddProtocolAddressWithOptions(1, protocolAddress, behavior); err != nil {
- t.Fatal("AddProtocolAddressWithOptions failed:", err)
+ if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil {
+ t.Fatalf("AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, protocolAddress, behavior, err)
}
// Remember the address/prefix.
primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{}
} else {
- if err := s.AddAddressWithOptions(1, fakeNetNumber, address, behavior); err != nil {
- t.Fatal("AddAddressWithOptions failed:", err)
+ if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil {
+ t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s:", nicID, fakeNetNumber, address, behavior, err)
}
}
}
// Check that GetMainNICAddress returns an address if at least
// one primary address was added. In that case make sure the
// address/prefixLen matches what we added.
- gotAddr, ok := s.GetMainNICAddress(1, fakeNetNumber)
- if !ok {
- t.Fatalf("got GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber)
+ gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber)
+ if err != nil {
+ t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
}
if len(primaryAddrAdded) == 0 {
// No primary addresses present.
if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr {
- t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, gotAddr, wantAddr)
+ t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, wantAddr)
}
} else {
// At least one primary address was added, verify the returned
// address is in the list of primary addresses we added.
if _, ok := primaryAddrAdded[gotAddr]; !ok {
- t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, gotAddr, primaryAddrAdded)
+ t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, primaryAddrAdded)
}
}
})
@@ -1926,6 +1928,45 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
}
}
+func TestGetMainNICAddressErrors(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ // Sanity check with a successful call.
+ if addr, err := s.GetMainNICAddress(nicID, ipv4.ProtocolNumber); err != nil {
+ t.Errorf("s.GetMainNICAddress(%d, %d): %s", nicID, ipv4.ProtocolNumber, err)
+ } else if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, ipv4.ProtocolNumber, addr, want)
+ }
+
+ const unknownNICID = nicID + 1
+ switch addr, err := s.GetMainNICAddress(unknownNICID, ipv4.ProtocolNumber); err.(type) {
+ case *tcpip.ErrUnknownNICID:
+ default:
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, %T), want = (_, tcpip.ErrUnknownNICID)", unknownNICID, ipv4.ProtocolNumber, addr, err)
+ }
+
+ // ARP is not an addressable network endpoint.
+ switch addr, err := s.GetMainNICAddress(nicID, arp.ProtocolNumber); err.(type) {
+ case *tcpip.ErrNotSupported:
+ default:
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, %T), want = (_, tcpip.ErrNotSupported)", nicID, arp.ProtocolNumber, addr, err)
+ }
+
+ const unknownProtocolNumber = 1234
+ switch addr, err := s.GetMainNICAddress(nicID, unknownProtocolNumber); err.(type) {
+ case *tcpip.ErrUnknownProtocol:
+ default:
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, %T), want = (_, tcpip.ErrUnknownProtocol)", nicID, unknownProtocolNumber, addr, err)
+ }
+}
+
func TestGetMainNICAddressAddRemove(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
@@ -2507,11 +2548,15 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
}
}
- // Check that we get no address after removal.
- if err := checkGetMainNICAddress(s, 1, fakeNetNumber, tcpip.AddressWithPrefix{}); err != nil {
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, expectedMainAddr); err != nil {
t.Fatal(err)
}
- if err := checkGetMainNICAddress(s, 1, header.IPv6ProtocolNumber, expectedMainAddr); err != nil {
+
+ // Disabling the NIC should remove the auto-generated address.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
t.Fatal(err)
}
})
@@ -2617,6 +2662,8 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
// TestNewPEB tests that a new PrimaryEndpointBehavior value (peb) is respected
// when an address's kind gets "promoted" to permanent from permanentExpired.
func TestNewPEBOnPromotionToPermanent(t *testing.T) {
+ const nicID = 1
+
pebs := []stack.PrimaryEndpointBehavior{
stack.NeverPrimaryEndpoint,
stack.CanBePrimaryEndpoint,
@@ -2630,8 +2677,8 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep1); err != nil {
- t.Fatal("CreateNIC failed:", err)
+ if err := s.CreateNIC(nicID, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
// Add a permanent address with initial
@@ -2639,20 +2686,21 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// NeverPrimaryEndpoint, the address should not
// be returned by a call to GetMainNICAddress;
// else, it should.
- if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", pi); err != nil {
- t.Fatal("AddAddressWithOptions failed:", err)
+ const address1 = tcpip.Address("\x01")
+ if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, pi); err != nil {
+ t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err)
}
- addr, ok := s.GetMainNICAddress(1, fakeNetNumber)
- if !ok {
- t.Fatalf("GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber)
+ addr, err := s.GetMainNICAddress(nicID, fakeNetNumber)
+ if err != nil {
+ t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
}
if pi == stack.NeverPrimaryEndpoint {
if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, addr, want)
+ t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr, want)
}
- } else if addr.Address != "\x01" {
- t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (1, true)", fakeNetNumber, addr.Address)
+ } else if addr.Address != address1 {
+ t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr.Address, address1)
}
{
@@ -2670,13 +2718,14 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// new peb is respected when an address gets
// "promoted" to permanent from a
// permanentExpired kind.
- r, err := s.FindRoute(1, "\x01", "\x02", fakeNetNumber, false)
+ const address2 = tcpip.Address("\x02")
+ r, err := s.FindRoute(nicID, address1, address2, fakeNetNumber, false)
if err != nil {
- t.Fatalf("FindRoute failed: %v", err)
+ t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, address1, address2, fakeNetNumber, err)
}
defer r.Release()
- if err := s.RemoveAddress(1, "\x01"); err != nil {
- t.Fatalf("RemoveAddress failed: %v", err)
+ if err := s.RemoveAddress(nicID, address1); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s): %s", nicID, address1, err)
}
//
@@ -2687,19 +2736,20 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// Add some other address with peb set to
// FirstPrimaryEndpoint.
- if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x03", stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddAddressWithOptions failed: %v", err)
+ const address3 = tcpip.Address("\x03")
+ if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint); err != nil {
+ t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint, err)
}
// Add back the address we removed earlier and
// make sure the new peb was respected.
// (The address should just be promoted now).
- if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", ps); err != nil {
- t.Fatalf("AddAddressWithOptions failed: %v", err)
+ if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, ps); err != nil {
+ t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err)
}
var primaryAddrs []tcpip.Address
- for _, pa := range s.NICInfo()[1].ProtocolAddresses {
+ for _, pa := range s.NICInfo()[nicID].ProtocolAddresses {
primaryAddrs = append(primaryAddrs, pa.AddressWithPrefix.Address)
}
var expectedList []tcpip.Address
@@ -2728,20 +2778,20 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// should be returned by a call to
// GetMainNICAddress; else, our original address
// should be returned.
- if err := s.RemoveAddress(1, "\x03"); err != nil {
- t.Fatalf("RemoveAddress failed: %v", err)
+ if err := s.RemoveAddress(nicID, address3); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s): %s", nicID, address3, err)
}
- addr, ok = s.GetMainNICAddress(1, fakeNetNumber)
- if !ok {
- t.Fatalf("got GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber)
+ addr, err = s.GetMainNICAddress(nicID, fakeNetNumber)
+ if err != nil {
+ t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
}
if ps == stack.NeverPrimaryEndpoint {
if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, addr, want)
+ t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr, want)
}
} else {
- if addr.Address != "\x01" {
- t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (1, true)", fakeNetNumber, addr.Address)
+ if addr.Address != address1 {
+ t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr.Address, address1)
}
}
})
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 10cbbe589..c1c6cbccd 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -33,8 +33,8 @@ import (
)
const (
- testSrcAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- testDstAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ testSrcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ testDstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
testSrcAddrV4 = "\x0a\x00\x00\x01"
testDstAddrV4 = "\x0a\x00\x00\x02"
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 58aabe547..3cc8c36f1 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -72,11 +72,13 @@ go_test(
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/tests/utils",
+ "//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index 80afc2825..6462e9d42 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -24,11 +24,13 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
@@ -502,3 +504,262 @@ func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) {
})
}
}
+
+func TestExternalLoopbackTraffic(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+
+ ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01")
+
+ numPackets = 1
+ )
+
+ loopbackSourcedICMPv4 := func(e *channel.Endpoint) {
+ utils.RxICMPv4EchoRequest(e, ipv4Loopback, utils.Ipv4Addr.Address)
+ }
+
+ loopbackSourcedICMPv6 := func(e *channel.Endpoint) {
+ utils.RxICMPv6EchoRequest(e, header.IPv6Loopback, utils.Ipv6Addr.Address)
+ }
+
+ loopbackDestinedICMPv4 := func(e *channel.Endpoint) {
+ utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, ipv4Loopback)
+ }
+
+ loopbackDestinedICMPv6 := func(e *channel.Endpoint) {
+ utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, header.IPv6Loopback)
+ }
+
+ invalidSrcAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter {
+ return s.InvalidSourceAddressesReceived
+ }
+
+ invalidDestAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter {
+ return s.InvalidDestinationAddressesReceived
+ }
+
+ tests := []struct {
+ name string
+ allowExternalLoopback bool
+ forwarding bool
+ rxICMP func(*channel.Endpoint)
+ invalidAddressStat func(tcpip.IPStats) *tcpip.StatCounter
+ shouldAccept bool
+ }{
+ {
+ name: "IPv4 external loopback sourced traffic without forwarding and drop external loopback disabled",
+ allowExternalLoopback: true,
+ forwarding: false,
+ rxICMP: loopbackSourcedICMPv4,
+ invalidAddressStat: invalidSrcAddrStat,
+ shouldAccept: true,
+ },
+ {
+ name: "IPv4 external loopback sourced traffic without forwarding and drop external loopback enabled",
+ allowExternalLoopback: false,
+ forwarding: false,
+ rxICMP: loopbackSourcedICMPv4,
+ invalidAddressStat: invalidSrcAddrStat,
+ shouldAccept: false,
+ },
+ {
+ name: "IPv4 external loopback sourced traffic with forwarding and drop external loopback disabled",
+ allowExternalLoopback: true,
+ forwarding: true,
+ rxICMP: loopbackSourcedICMPv4,
+ invalidAddressStat: invalidSrcAddrStat,
+ shouldAccept: true,
+ },
+ {
+ name: "IPv4 external loopback sourced traffic with forwarding and drop external loopback enabled",
+ allowExternalLoopback: false,
+ forwarding: true,
+ rxICMP: loopbackSourcedICMPv4,
+ invalidAddressStat: invalidSrcAddrStat,
+ shouldAccept: false,
+ },
+ {
+ name: "IPv4 external loopback destined traffic without forwarding and drop external loopback disabled",
+ allowExternalLoopback: true,
+ forwarding: false,
+ rxICMP: loopbackDestinedICMPv4,
+ invalidAddressStat: invalidDestAddrStat,
+ shouldAccept: false,
+ },
+ {
+ name: "IPv4 external loopback destined traffic without forwarding and drop external loopback enabled",
+ allowExternalLoopback: false,
+ forwarding: false,
+ rxICMP: loopbackDestinedICMPv4,
+ invalidAddressStat: invalidDestAddrStat,
+ shouldAccept: false,
+ },
+ {
+ name: "IPv4 external loopback destined traffic with forwarding and drop external loopback disabled",
+ allowExternalLoopback: true,
+ forwarding: true,
+ rxICMP: loopbackDestinedICMPv4,
+ invalidAddressStat: invalidDestAddrStat,
+ shouldAccept: true,
+ },
+ {
+ name: "IPv4 external loopback destined traffic with forwarding and drop external loopback enabled",
+ allowExternalLoopback: false,
+ forwarding: true,
+ rxICMP: loopbackDestinedICMPv4,
+ invalidAddressStat: invalidDestAddrStat,
+ shouldAccept: false,
+ },
+
+ {
+ name: "IPv6 external loopback sourced traffic without forwarding and drop external loopback disabled",
+ allowExternalLoopback: true,
+ forwarding: false,
+ rxICMP: loopbackSourcedICMPv6,
+ invalidAddressStat: invalidSrcAddrStat,
+ shouldAccept: true,
+ },
+ {
+ name: "IPv6 external loopback sourced traffic without forwarding and drop external loopback enabled",
+ allowExternalLoopback: false,
+ forwarding: false,
+ rxICMP: loopbackSourcedICMPv6,
+ invalidAddressStat: invalidSrcAddrStat,
+ shouldAccept: false,
+ },
+ {
+ name: "IPv6 external loopback sourced traffic with forwarding and drop external loopback disabled",
+ allowExternalLoopback: true,
+ forwarding: true,
+ rxICMP: loopbackSourcedICMPv6,
+ invalidAddressStat: invalidSrcAddrStat,
+ shouldAccept: true,
+ },
+ {
+ name: "IPv6 external loopback sourced traffic with forwarding and drop external loopback enabled",
+ allowExternalLoopback: false,
+ forwarding: true,
+ rxICMP: loopbackSourcedICMPv6,
+ invalidAddressStat: invalidSrcAddrStat,
+ shouldAccept: false,
+ },
+ {
+ name: "IPv6 external loopback destined traffic without forwarding and drop external loopback disabled",
+ allowExternalLoopback: true,
+ forwarding: false,
+ rxICMP: loopbackDestinedICMPv6,
+ invalidAddressStat: invalidDestAddrStat,
+ shouldAccept: false,
+ },
+ {
+ name: "IPv6 external loopback destined traffic without forwarding and drop external loopback enabled",
+ allowExternalLoopback: false,
+ forwarding: false,
+ rxICMP: loopbackDestinedICMPv6,
+ invalidAddressStat: invalidDestAddrStat,
+ shouldAccept: false,
+ },
+ {
+ name: "IPv6 external loopback destined traffic with forwarding and drop external loopback disabled",
+ allowExternalLoopback: true,
+ forwarding: true,
+ rxICMP: loopbackDestinedICMPv6,
+ invalidAddressStat: invalidDestAddrStat,
+ shouldAccept: true,
+ },
+ {
+ name: "IPv6 external loopback destined traffic with forwarding and drop external loopback enabled",
+ allowExternalLoopback: false,
+ forwarding: true,
+ rxICMP: loopbackDestinedICMPv6,
+ invalidAddressStat: invalidDestAddrStat,
+ shouldAccept: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocolWithOptions(ipv4.Options{
+ AllowExternalLoopbackTraffic: test.allowExternalLoopback,
+ }),
+ ipv6.NewProtocolWithOptions(ipv6.Options{
+ AllowExternalLoopbackTraffic: test.allowExternalLoopback,
+ }),
+ },
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
+ })
+ e := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID1, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ if err := s.AddAddressWithPrefix(nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr); err != nil {
+ t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr, err)
+ }
+ if err := s.AddAddressWithPrefix(nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr); err != nil {
+ t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr, err)
+ }
+
+ if err := s.CreateNIC(nicID2, loopback.New()); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+ if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, ipv4Loopback); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, ipv4Loopback, err)
+ }
+ if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, header.IPv6Loopback); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, header.IPv6Loopback, err)
+ }
+
+ if test.forwarding {
+ if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err)
+ }
+ if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err)
+ }
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID1,
+ },
+ tcpip.Route{
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID1,
+ },
+ tcpip.Route{
+ Destination: ipv4Loopback.WithPrefix().Subnet(),
+ NIC: nicID2,
+ },
+ tcpip.Route{
+ Destination: header.IPv6Loopback.WithPrefix().Subnet(),
+ NIC: nicID2,
+ },
+ })
+
+ stats := s.Stats().IP
+ invalidAddressStat := test.invalidAddressStat(stats)
+ deliveredPacketsStat := stats.PacketsDelivered
+ if got := invalidAddressStat.Value(); got != 0 {
+ t.Fatalf("got invalidAddressStat.Value() = %d, want = 0", got)
+ }
+ if got := deliveredPacketsStat.Value(); got != 0 {
+ t.Fatalf("got deliveredPacketsStat.Value() = %d, want = 0", got)
+ }
+ test.rxICMP(e)
+ var expectedInvalidPackets uint64
+ if !test.shouldAccept {
+ expectedInvalidPackets = numPackets
+ }
+ if got := invalidAddressStat.Value(); got != expectedInvalidPackets {
+ t.Fatalf("got invalidAddressStat.Value() = %d, want = %d", got, expectedInvalidPackets)
+ }
+ if got, want := deliveredPacketsStat.Value(), numPackets-expectedInvalidPackets; got != want {
+ t.Fatalf("got deliveredPacketsStat.Value() = %d, want = %d", got, want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 29266a4fc..77f4a88ec 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -45,82 +45,61 @@ const (
func TestPingMulticastBroadcast(t *testing.T) {
const nicID = 1
- rxIPv4ICMP := func(e *channel.Endpoint, dst tcpip.Address) {
- totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
- hdr := buffer.NewPrependable(totalLen)
- pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
- pkt.SetType(header.ICMPv4Echo)
- pkt.SetCode(0)
- pkt.SetChecksum(0)
- pkt.SetChecksum(^header.Checksum(pkt, 0))
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(totalLen),
- Protocol: uint8(icmp.ProtocolNumber4),
- TTL: ttl,
- SrcAddr: utils.RemoteIPv4Addr,
- DstAddr: dst,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- }
-
- rxIPv6ICMP := func(e *channel.Endpoint, dst tcpip.Address) {
- totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
- hdr := buffer.NewPrependable(totalLen)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
- pkt.SetType(header.ICMPv6EchoRequest)
- pkt.SetCode(0)
- pkt.SetChecksum(0)
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: utils.RemoteIPv6Addr,
- Dst: dst,
- }))
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- TransportProtocol: icmp.ProtocolNumber6,
- HopLimit: ttl,
- SrcAddr: utils.RemoteIPv6Addr,
- DstAddr: dst,
- })
-
- e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- }
-
tests := []struct {
- name string
- dstAddr tcpip.Address
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ rxICMP func(*channel.Endpoint, tcpip.Address, tcpip.Address)
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
+ expectedSrc tcpip.Address
}{
{
- name: "IPv4 unicast",
- dstAddr: utils.Ipv4Addr.Address,
+ name: "IPv4 unicast",
+ protoNum: header.IPv4ProtocolNumber,
+ dstAddr: utils.Ipv4Addr.Address,
+ srcAddr: utils.RemoteIPv4Addr,
+ rxICMP: utils.RxICMPv4EchoRequest,
+ expectedSrc: utils.Ipv4Addr.Address,
},
{
- name: "IPv4 directed broadcast",
- dstAddr: utils.Ipv4SubnetBcast,
+ name: "IPv4 directed broadcast",
+ protoNum: header.IPv4ProtocolNumber,
+ rxICMP: utils.RxICMPv4EchoRequest,
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: utils.Ipv4SubnetBcast,
+ expectedSrc: utils.Ipv4Addr.Address,
},
{
- name: "IPv4 broadcast",
- dstAddr: header.IPv4Broadcast,
+ name: "IPv4 broadcast",
+ protoNum: header.IPv4ProtocolNumber,
+ rxICMP: utils.RxICMPv4EchoRequest,
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: header.IPv4Broadcast,
+ expectedSrc: utils.Ipv4Addr.Address,
},
{
- name: "IPv4 all-systems multicast",
- dstAddr: header.IPv4AllSystems,
+ name: "IPv4 all-systems multicast",
+ protoNum: header.IPv4ProtocolNumber,
+ rxICMP: utils.RxICMPv4EchoRequest,
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: header.IPv4AllSystems,
+ expectedSrc: utils.Ipv4Addr.Address,
},
{
- name: "IPv6 unicast",
- dstAddr: utils.Ipv6Addr.Address,
+ name: "IPv6 unicast",
+ protoNum: header.IPv6ProtocolNumber,
+ rxICMP: utils.RxICMPv6EchoRequest,
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: utils.Ipv6Addr.Address,
+ expectedSrc: utils.Ipv6Addr.Address,
},
{
- name: "IPv6 all-nodes multicast",
- dstAddr: header.IPv6AllNodesMulticastAddress,
+ name: "IPv6 all-nodes multicast",
+ protoNum: header.IPv6ProtocolNumber,
+ rxICMP: utils.RxICMPv6EchoRequest,
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: header.IPv6AllNodesMulticastAddress,
+ expectedSrc: utils.Ipv6Addr.Address,
},
}
@@ -157,44 +136,29 @@ func TestPingMulticastBroadcast(t *testing.T) {
},
})
- var rxICMP func(*channel.Endpoint, tcpip.Address)
- var expectedSrc tcpip.Address
- var expectedDst tcpip.Address
- var protoNum tcpip.NetworkProtocolNumber
- switch l := len(test.dstAddr); l {
- case header.IPv4AddressSize:
- rxICMP = rxIPv4ICMP
- expectedSrc = utils.Ipv4Addr.Address
- expectedDst = utils.RemoteIPv4Addr
- protoNum = header.IPv4ProtocolNumber
- case header.IPv6AddressSize:
- rxICMP = rxIPv6ICMP
- expectedSrc = utils.Ipv6Addr.Address
- expectedDst = utils.RemoteIPv6Addr
- protoNum = header.IPv6ProtocolNumber
- default:
- t.Fatalf("got unexpected address length = %d bytes", l)
- }
-
- rxICMP(e, test.dstAddr)
+ test.rxICMP(e, test.srcAddr, test.dstAddr)
pkt, ok := e.Read()
if !ok {
t.Fatal("expected ICMP response")
}
- if pkt.Route.LocalAddress != expectedSrc {
- t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, expectedSrc)
+ if pkt.Route.LocalAddress != test.expectedSrc {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.expectedSrc)
}
- if pkt.Route.RemoteAddress != expectedDst {
- t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst)
+ // The destination of the response packet should be the source of the
+ // original packet.
+ if pkt.Route.RemoteAddress != test.srcAddr {
+ t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.srcAddr)
}
- src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
- if src != expectedSrc {
- t.Errorf("got pkt source = %s, want = %s", src, expectedSrc)
+ src, dst := s.NetworkProtocolInstance(test.protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
+ if src != test.expectedSrc {
+ t.Errorf("got pkt source = %s, want = %s", src, test.expectedSrc)
}
- if dst != expectedDst {
- t.Errorf("got pkt destination = %s, want = %s", dst, expectedDst)
+ // The destination of the response packet should be the source of the
+ // original packet.
+ if dst != test.srcAddr {
+ t.Errorf("got pkt destination = %s, want = %s", dst, test.srcAddr)
}
})
}
diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go
index 4455f6dd7..ed499179f 100644
--- a/pkg/tcpip/tests/integration/route_test.go
+++ b/pkg/tcpip/tests/integration/route_test.go
@@ -16,6 +16,7 @@ package route_test
import (
"bytes"
+ "fmt"
"testing"
"github.com/google/go-cmp/cmp"
@@ -161,78 +162,79 @@ func TestLocalPing(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
- HandleLocal: true,
- })
- e := test.linkEndpoint()
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
+ for _, allowExternalLoopback := range []bool{true, false} {
+ t.Run(fmt.Sprintf("AllowExternalLoopback=%t", allowExternalLoopback), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocolWithOptions(ipv4.Options{
+ AllowExternalLoopbackTraffic: allowExternalLoopback,
+ }),
+ ipv6.NewProtocolWithOptions(ipv6.Options{
+ AllowExternalLoopbackTraffic: allowExternalLoopback,
+ }),
+ },
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
+ HandleLocal: true,
+ })
+ e := test.linkEndpoint()
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
- if len(test.localAddr) != 0 {
- if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err)
- }
- }
+ if len(test.localAddr) != 0 {
+ if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err)
+ }
+ }
- var wq waiter.Queue
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq)
- if err != nil {
- t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err)
- }
- defer ep.Close()
-
- connAddr := tcpip.FullAddress{Addr: test.localAddr}
- {
- err := ep.Connect(connAddr)
- if diff := cmp.Diff(test.expectedConnectErr, err); diff != "" {
- t.Fatalf("unexpected error from ep.Connect(%#v), (-want, +got):\n%s", connAddr, diff)
- }
- }
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err)
+ }
+ defer ep.Close()
- if test.expectedConnectErr != nil {
- return
- }
+ connAddr := tcpip.FullAddress{Addr: test.localAddr}
+ if err := ep.Connect(connAddr); err != test.expectedConnectErr {
+ t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr)
+ }
- payload := test.icmpBuf(t)
- var r bytes.Reader
- r.Reset(payload)
- var wOpts tcpip.WriteOptions
- if n, err := ep.Write(&r, wOpts); err != nil {
- t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err)
- } else if n != int64(len(payload)) {
- t.Fatalf("got ep.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", payload, wOpts, n, len(payload))
- }
+ if test.expectedConnectErr != nil {
+ return
+ }
- // Wait for the endpoint to become readable.
- <-ch
+ var r bytes.Reader
+ payload := test.icmpBuf(t)
+ r.Reset(payload)
+ var wOpts tcpip.WriteOptions
+ if n, err := ep.Write(&r, wOpts); err != nil {
+ t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err)
+ } else if n != int64(len(payload)) {
+ t.Fatalf("got ep.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", payload, wOpts, n, len(payload))
+ }
- var buf bytes.Buffer
- opts := tcpip.ReadOptions{NeedRemoteAddr: true}
- res, err := ep.Read(&buf, opts)
- if err != nil {
- t.Fatalf("ep.Read(_, %#v): %s", opts, err)
- }
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: buf.Len(),
- Total: buf.Len(),
- RemoteAddr: tcpip.FullAddress{Addr: test.localAddr},
- }, res, checker.IgnoreCmpPath(
- "ControlMessages",
- "RemoteAddr.NIC",
- "RemoteAddr.Port",
- )); diff != "" {
- t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(buf.Bytes()[icmpDataOffset:], []byte(payload[icmpDataOffset:])); diff != "" {
- t.Errorf("received data mismatch (-want +got):\n%s", diff)
- }
+ // Wait for the endpoint to become readable.
+ <-ch
- test.checkLinkEndpoint(t, e)
+ var w bytes.Buffer
+ rr, err := ep.Read(&w, tcpip.ReadOptions{
+ NeedRemoteAddr: true,
+ })
+ if err != nil {
+ t.Fatalf("ep.Read(...): %s", err)
+ }
+ if diff := cmp.Diff(buffer.View(w.Bytes()[icmpDataOffset:]), payload[icmpDataOffset:]); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ }
+ if rr.RemoteAddr.Addr != test.localAddr {
+ t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr)
+ }
+
+ test.checkLinkEndpoint(t, e)
+ })
+ }
})
}
}
diff --git a/pkg/tcpip/tests/utils/BUILD b/pkg/tcpip/tests/utils/BUILD
index 433004148..a9699a367 100644
--- a/pkg/tcpip/tests/utils/BUILD
+++ b/pkg/tcpip/tests/utils/BUILD
@@ -8,12 +8,15 @@ go_library(
visibility = ["//pkg/tcpip/tests:__subpackages__"],
deps = [
"//pkg/tcpip",
+ "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
"//pkg/tcpip/link/ethernet",
"//pkg/tcpip/link/nested",
"//pkg/tcpip/link/pipe",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
],
)
diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go
index f414a2234..d1c9f3a94 100644
--- a/pkg/tcpip/tests/utils/utils.go
+++ b/pkg/tcpip/tests/utils/utils.go
@@ -20,13 +20,16 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/ethernet"
"gvisor.dev/gvisor/pkg/tcpip/link/nested"
"gvisor.dev/gvisor/pkg/tcpip/link/pipe"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
)
// Common NIC IDs used by tests.
@@ -45,6 +48,10 @@ const (
LinkAddr4 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
)
+const (
+ ttl = 255
+)
+
// Common IP addresses used by tests.
var (
Ipv4Addr = tcpip.AddressWithPrefix{
@@ -312,3 +319,56 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.
},
})
}
+
+// RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on
+// the provided endpoint.
+func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
+ totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ pkt.SetType(header.ICMPv4Echo)
+ pkt.SetCode(header.ICMPv4UnusedCode)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(^header.Checksum(pkt, 0))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(totalLen),
+ Protocol: uint8(icmp.ProtocolNumber4),
+ TTL: ttl,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+}
+
+// RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on
+// the provided endpoint.
+func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
+ totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ pkt.SetType(header.ICMPv6EchoRequest)
+ pkt.SetCode(header.ICMPv6UnusedCode)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: pkt,
+ Src: src,
+ Dst: dst,
+ }))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: header.ICMPv6MinimumSize,
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: ttl,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+}
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 3b574837c..0a2f3291c 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -20,6 +20,7 @@ import (
"fmt"
"hash"
"io"
+ "sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/rand"
@@ -390,7 +391,7 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
// endpoint has transitioned out of the listen state (acceptedChan is nil),
// the new endpoint is closed instead.
-func (e *endpoint) deliverAccepted(n *endpoint) {
+func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) {
e.mu.Lock()
e.pendingAccepted.Add(1)
e.mu.Unlock()
@@ -405,6 +406,9 @@ func (e *endpoint) deliverAccepted(n *endpoint) {
}
select {
case e.acceptedChan <- n:
+ if !withSynCookie {
+ atomic.AddInt32(&e.synRcvdCount, -1)
+ }
e.acceptMu.Unlock()
e.waiterQueue.Notify(waiter.EventIn)
return
@@ -476,7 +480,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
- e.synRcvdCount--
+ atomic.AddInt32(&e.synRcvdCount, -1)
return err
}
@@ -486,18 +490,13 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
ctx.cleanupFailedHandshake(h)
- e.mu.Lock()
- e.synRcvdCount--
- e.mu.Unlock()
+ atomic.AddInt32(&e.synRcvdCount, -1)
return
}
ctx.cleanupCompletedHandshake(h)
- e.mu.Lock()
- e.synRcvdCount--
- e.mu.Unlock()
h.ep.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
- e.deliverAccepted(h.ep)
+ e.deliverAccepted(h.ep, false /*withSynCookie*/)
}() // S/R-SAFE: synRcvdCount is the barrier.
return nil
@@ -505,17 +504,17 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
func (e *endpoint) incSynRcvdCount() bool {
e.acceptMu.Lock()
- canInc := e.synRcvdCount < cap(e.acceptedChan)
+ canInc := int(atomic.LoadInt32(&e.synRcvdCount)) < cap(e.acceptedChan)
e.acceptMu.Unlock()
if canInc {
- e.synRcvdCount++
+ atomic.AddInt32(&e.synRcvdCount, 1)
}
return canInc
}
func (e *endpoint) acceptQueueIsFull() bool {
e.acceptMu.Lock()
- full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan)
+ full := len(e.acceptedChan)+int(atomic.LoadInt32(&e.synRcvdCount)) >= cap(e.acceptedChan)
e.acceptMu.Unlock()
return full
}
@@ -737,7 +736,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
// Start the protocol goroutine.
n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
- go e.deliverAccepted(n)
+ go e.deliverAccepted(n, true /*withSynCookie*/)
return nil
default:
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 129f36d11..43d344350 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -532,8 +532,8 @@ type endpoint struct {
segmentQueue segmentQueue `state:"wait"`
// synRcvdCount is the number of connections for this endpoint that are
- // in SYN-RCVD state.
- synRcvdCount int
+ // in SYN-RCVD state; this is only accessed atomically.
+ synRcvdCount int32
// userMSS if non-zero is the MSS value explicitly set by the user
// for this endpoint using the TCP_MAXSEG setsockopt.
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index a5c82b8fa..bc6793fc6 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -260,7 +260,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
case StateEstablished:
r.ep.setEndpointState(StateCloseWait)
case StateFinWait1:
- if s.flagIsSet(header.TCPFlagAck) {
+ if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt {
// FIN-ACK, transition to TIME-WAIT.
r.ep.setEndpointState(StateTimeWait)
} else {
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index fd499a47b..6c86ae1ae 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -165,7 +165,7 @@ func TestConnectIncrementActiveConnection(t *testing.T) {
stats := c.Stack().Stats()
want := stats.TCP.ActiveConnectionOpenings.Value() + 1
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want {
t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %d, want = %d", got, want)
}
@@ -178,7 +178,7 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) {
stats := c.Stack().Stats()
want := stats.TCP.FailedConnectionAttempts.Value()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want)
}
@@ -239,7 +239,7 @@ func TestTCPSegmentsSentIncrement(t *testing.T) {
stats := c.Stack().Stats()
// SYN and ACK
want := stats.TCP.SegmentsSent.Value() + 2
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.SegmentsSent.Value(); got != want {
t.Errorf("got stats.TCP.SegmentsSent.Value() = %d, want = %d", got, want)
@@ -269,7 +269,7 @@ func TestTCPResetsSentIncrement(t *testing.T) {
}
// Send a SYN request.
- iss := seqnum.Value(789)
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -318,7 +318,7 @@ func TestTCPResetsSentNoICMP(t *testing.T) {
// Send a SYN request for a closed port. This should elicit an RST
// but NOT an ICMPv4 DstUnreachable packet.
- iss := seqnum.Value(789)
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -362,7 +362,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
}
// Send a SYN request.
- iss := seqnum.Value(789)
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -459,7 +459,7 @@ func TestTCPResetsReceivedIncrement(t *testing.T) {
stats := c.Stack().Stats()
want := stats.TCP.ResetsReceived.Value() + 1
- iss := seqnum.Value(789)
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
rcvWnd := seqnum.Size(30000)
c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
@@ -483,7 +483,7 @@ func TestTCPResetsDoNotGenerateResets(t *testing.T) {
stats := c.Stack().Stats()
want := stats.TCP.ResetsReceived.Value() + 1
- iss := seqnum.Value(789)
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
rcvWnd := seqnum.Size(30000)
c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
@@ -506,14 +506,14 @@ func TestActiveHandshake(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
}
func TestNonBlockingClose(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
ep := c.EP
c.EP = nil
@@ -537,18 +537,19 @@ func TestConnectResetAfterClose(t *testing.T) {
t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
ep := c.EP
c.EP = nil
// Close the endpoint, make sure we get a FIN segment, then acknowledge
// to complete closure of sender, but don't send our own FIN.
ep.Close()
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -556,7 +557,7 @@ func TestConnectResetAfterClose(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(2),
RcvWnd: 30000,
})
@@ -570,7 +571,7 @@ func TestConnectResetAfterClose(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(2),
RcvWnd: 30000,
})
@@ -612,7 +613,7 @@ func TestCurrentConnectedIncrement(t *testing.T) {
t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
ep := c.EP
c.EP = nil
@@ -625,12 +626,12 @@ func TestCurrentConnectedIncrement(t *testing.T) {
}
ep.Close()
-
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -638,7 +639,7 @@ func TestCurrentConnectedIncrement(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(2),
RcvWnd: 30000,
})
@@ -655,7 +656,7 @@ func TestCurrentConnectedIncrement(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(2),
RcvWnd: 30000,
})
@@ -666,7 +667,7 @@ func TestCurrentConnectedIncrement(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+2),
- checker.TCPAckNum(791),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -690,7 +691,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
ep := c.EP
c.EP = nil
@@ -699,11 +700,12 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
}
// Send a FIN for ESTABLISHED --> CLOSED-WAIT
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagFin | header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -713,7 +715,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(791),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -734,7 +736,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(791),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -753,7 +755,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 791,
+ SeqNum: iss.Add(1),
AckNum: c.IRS.Add(2),
RcvWnd: 30000,
})
@@ -764,7 +766,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 792,
+ SeqNum: iss.Add(2),
AckNum: c.IRS.Add(2),
RcvWnd: 30000,
})
@@ -804,7 +806,7 @@ func TestSimpleReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -813,11 +815,12 @@ func TestSimpleReceive(t *testing.T) {
ept := endpointTester{c.EP}
data := []byte{1, 2, 3}
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -840,7 +843,7 @@ func TestSimpleReceive(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(790+len(data))),
+ checker.TCPAckNum(uint32(iss)+uint32(len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1366,12 +1369,13 @@ func TestTOSV4(t *testing.T) {
// Check that data is received.
b := c.GetPacket()
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, b,
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790), // Acknum is initial sequence number + 1
+ checker.TCPAckNum(uint32(iss)), // Acknum is initial sequence number + 1
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
checker.TOS(tos, 0),
@@ -1414,12 +1418,13 @@ func TestTrafficClassV6(t *testing.T) {
// Check that data is received.
b := c.GetV6Packet()
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv6(t, b,
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
checker.TOS(tos, 0),
@@ -1472,7 +1477,7 @@ func TestConnectBindToDevice(t *testing.T) {
tcpHdr := header.TCP(header.IPv4(b).Payload())
c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
- iss := seqnum.Value(789)
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
rcvWnd := seqnum.Size(30000)
c.SendPacket(nil, &context.Headers{
SrcPort: tcpHdr.DestinationPort(),
@@ -1537,7 +1542,7 @@ func TestSynSent(t *testing.T) {
if test.reset {
// Send a packet with a proper ACK and a RST flag to cause the socket
// to error and close out.
- iss := seqnum.Value(789)
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
rcvWnd := seqnum.Size(30000)
c.SendPacket(nil, &context.Headers{
SrcPort: tcpHdr.DestinationPort(),
@@ -1582,7 +1587,7 @@ func TestOutOfOrderReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -1593,11 +1598,12 @@ func TestOutOfOrderReceive(t *testing.T) {
// Send second half of data first, with seqnum 3 ahead of expected.
data := []byte{1, 2, 3, 4, 5, 6}
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
c.SendPacket(data[3:], &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 793,
+ SeqNum: iss.Add(3),
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -1607,7 +1613,7 @@ func TestOutOfOrderReceive(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1621,7 +1627,7 @@ func TestOutOfOrderReceive(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -1639,7 +1645,7 @@ func TestOutOfOrderReceive(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(790+len(data))),
+ checker.TCPAckNum(uint32(iss)+uint32(len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1650,19 +1656,20 @@ func TestOutOfOrderFlood(t *testing.T) {
defer c.Cleanup()
rcvBufSz := math.MaxUint16
- c.CreateConnected(789, 30000, rcvBufSz)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBufSz)
ept := endpointTester{c.EP}
ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
// Send 100 packets before the actual one that is expected.
data := []byte{1, 2, 3, 4, 5, 6}
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
for i := 0; i < 100; i++ {
c.SendPacket(data[3:], &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 796,
+ SeqNum: iss.Add(6),
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -1671,19 +1678,19 @@ func TestOutOfOrderFlood(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck),
),
)
}
- // Send packet with seqnum 793. It must be discarded because the
+ // Send packet with seqnum as initial + 3. It must be discarded because the
// out-of-order buffer was filled by the previous packets.
c.SendPacket(data[3:], &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 793,
+ SeqNum: iss.Add(3),
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -1692,27 +1699,27 @@ func TestOutOfOrderFlood(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck),
),
)
- // Now send the expected packet, seqnum 790.
+ // Now send the expected packet with initial sequence number.
c.SendPacket(data[:3], &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
- // Check that only packet 790 is acknowledged.
+ // Check that only packet with initial sequence number is acknowledged.
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(793),
+ checker.TCPAckNum(uint32(iss)+3),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1722,7 +1729,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -1732,11 +1739,12 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
data := []byte{1, 2, 3}
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -1753,7 +1761,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(790+len(data))),
+ checker.TCPAckNum(uint32(iss)+uint32(len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1780,7 +1788,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790 + len(data)),
+ SeqNum: iss.Add(seqnum.Size(len(data))),
AckNum: c.IRS.Add(seqnum.Size(2)),
RcvWnd: 30000,
})
@@ -1790,7 +1798,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -1800,11 +1808,12 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
data := []byte{1, 2, 3}
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -1821,7 +1830,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(790+len(data))),
+ checker.TCPAckNum(uint32(iss)+uint32(len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1866,7 +1875,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790 + len(data)),
+ SeqNum: iss.Add(seqnum.Size(len(data))),
AckNum: c.IRS.Add(seqnum.Size(2)),
RcvWnd: 30000,
})
@@ -1876,7 +1885,7 @@ func TestShutdownRead(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
ept := endpointTester{c.EP}
ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
@@ -1897,7 +1906,7 @@ func TestFullWindowReceive(t *testing.T) {
defer c.Cleanup()
const rcvBufSz = 10
- c.CreateConnected(789, 30000, rcvBufSz)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBufSz)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -1913,11 +1922,12 @@ func TestFullWindowReceive(t *testing.T) {
for i := range data {
data[i] = byte(i % 255)
}
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -1934,7 +1944,7 @@ func TestFullWindowReceive(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(790+len(data))),
+ checker.TCPAckNum(uint32(iss)+uint32(len(data))),
checker.TCPFlags(header.TCPFlagAck),
checker.TCPWindow(0),
),
@@ -1956,7 +1966,7 @@ func TestFullWindowReceive(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(790+len(data))),
+ checker.TCPAckNum(uint32(iss)+uint32(len(data))),
checker.TCPFlags(header.TCPFlagAck),
checker.TCPWindow(10),
),
@@ -1996,8 +2006,7 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) {
}
payload := generateRandomPayload(t, payloadSize)
payloadLen := seqnum.Size(len(payload))
- iss := seqnum.Value(789)
- seqNum := iss.Add(1)
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
// Send payload to the endpoint and return the advertised receive window
// from the endpoint.
@@ -2005,12 +2014,12 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) {
c.SendPacket(payload, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
- SeqNum: seqNum,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
Flags: header.TCPFlagAck,
RcvWnd: 30000,
})
- seqNum = seqNum.Add(payloadLen)
+ iss = iss.Add(payloadLen)
pkt := c.GetPacket()
return uint32(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.RcvdWindowScale
@@ -2054,9 +2063,8 @@ func TestNoWindowShrinking(t *testing.T) {
// the right edge of the window does not shrink.
// NOTE: Netstack doubles the value specified here.
rcvBufSize := 65536
- iss := seqnum.Value(789)
// Enable window scaling with a scale of zero from our end.
- c.CreateConnectedWithRawOptions(iss, 30000, rcvBufSize, []byte{
+ c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, rcvBufSize, []byte{
header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
})
@@ -2069,13 +2077,13 @@ func TestNoWindowShrinking(t *testing.T) {
// Send a 1 byte payload so that we can record the current receive window.
// Send a payload of half the size of rcvBufSize.
- seqNum := iss.Add(1)
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
payload := []byte{1}
c.SendPacket(payload, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqNum,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -2092,20 +2100,20 @@ func TestNoWindowShrinking(t *testing.T) {
t.Fatalf("got data: %v, want: %v", got, want)
}
- seqNum = seqNum.Add(1)
// Verify that the ACK does not shrink the window.
pkt := c.GetPacket()
+ iss = iss.Add(1)
checker.IPv4(t, pkt,
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(seqNum)),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck),
),
)
// Stash the initial window.
initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale
- initialLastAcceptableSeq := seqNum.Add(seqnum.Size(initialWnd))
+ initialLastAcceptableSeq := iss.Add(seqnum.Size(initialWnd))
// Now shrink the receive buffer to half its original size.
if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufSize/2); err != nil {
t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err)
@@ -2117,11 +2125,11 @@ func TestNoWindowShrinking(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqNum,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
- seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2))
+ iss = iss.Add(seqnum.Size(rcvBufSize / 2))
// Verify that the ACK does not shrink the window.
pkt = c.GetPacket()
@@ -2129,12 +2137,12 @@ func TestNoWindowShrinking(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(seqNum)),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck),
),
)
newWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale
- newLastAcceptableSeq := seqNum.Add(seqnum.Size(newWnd))
+ newLastAcceptableSeq := iss.Add(seqnum.Size(newWnd))
if newLastAcceptableSeq.LessThan(initialLastAcceptableSeq) {
t.Fatalf("receive window shrunk unexpectedly got: %d, want >= %d", newLastAcceptableSeq, initialLastAcceptableSeq)
}
@@ -2145,17 +2153,17 @@ func TestNoWindowShrinking(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqNum,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
- seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2))
+ iss = iss.Add(seqnum.Size(rcvBufSize / 2))
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(seqNum)),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck),
checker.TCPWindow(0),
),
@@ -2173,7 +2181,7 @@ func TestNoWindowShrinking(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(seqNum)),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck),
checker.TCPWindow(uint16(rcvBufSize/2)>>c.RcvdWindowScale),
),
@@ -2184,7 +2192,7 @@ func TestSimpleSend(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
var r bytes.Reader
@@ -2195,12 +2203,13 @@ func TestSimpleSend(t *testing.T) {
// Check that data is received.
b := c.GetPacket()
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, b,
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2214,7 +2223,7 @@ func TestSimpleSend(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1 + seqnum.Size(len(data))),
RcvWnd: 30000,
})
@@ -2224,7 +2233,7 @@ func TestZeroWindowSend(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789 /* iss */, 0 /* rcvWnd */, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 0 /* rcvWnd */, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
var r bytes.Reader
@@ -2235,12 +2244,13 @@ func TestZeroWindowSend(t *testing.T) {
// Check if we got a zero-window probe.
b := c.GetPacket()
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, b,
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2250,7 +2260,7 @@ func TestZeroWindowSend(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -2262,7 +2272,7 @@ func TestZeroWindowSend(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2276,7 +2286,7 @@ func TestZeroWindowSend(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1 + seqnum.Size(len(data))),
RcvWnd: 30000,
})
@@ -2289,7 +2299,7 @@ func TestScaledWindowConnect(t *testing.T) {
defer c.Cleanup()
// Set the window size greater than the maximum non-scaled window.
- c.CreateConnectedWithRawOptions(789, 30000, 65535*3, []byte{
+ c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, 65535*3, []byte{
header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
})
@@ -2303,12 +2313,13 @@ func TestScaledWindowConnect(t *testing.T) {
// Check that data is received, and that advertised window is 0x5fff,
// that is, that it is scaled.
b := c.GetPacket()
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, b,
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPWindow(0x5fff),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
@@ -2322,7 +2333,7 @@ func TestNonScaledWindowConnect(t *testing.T) {
defer c.Cleanup()
// Set the window size greater than the maximum non-scaled window.
- c.CreateConnected(789, 30000, 65535*3)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, 65535*3)
data := []byte{1, 2, 3}
var r bytes.Reader
@@ -2334,12 +2345,13 @@ func TestNonScaledWindowConnect(t *testing.T) {
// Check that data is received, and that advertised window is 0xffff,
// that is, that it's not scaled.
b := c.GetPacket()
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, b,
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPWindow(0xffff),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
@@ -2407,12 +2419,13 @@ func TestScaledWindowAccept(t *testing.T) {
// Check that data is received, and that advertised window is 0x5fff,
// that is, that it is scaled.
b := c.GetPacket()
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, b,
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPWindow(0x5fff),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
@@ -2480,12 +2493,13 @@ func TestNonScaledWindowAccept(t *testing.T) {
// Check that data is received, and that advertised window is 0xffff,
// that is, that it's not scaled.
b := c.GetPacket()
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, b,
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPWindow(0xffff),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
@@ -2502,7 +2516,7 @@ func TestZeroScaledWindowReceive(t *testing.T) {
// Set the buffer size such that a window scale of 5 will be used.
const bufSz = 65535 * 10
const ws = uint32(5)
- c.CreateConnectedWithRawOptions(789, 30000, bufSz, []byte{
+ c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, bufSz, []byte{
header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
})
@@ -2510,13 +2524,14 @@ func TestZeroScaledWindowReceive(t *testing.T) {
remain := 0
sent := 0
data := make([]byte, 50000)
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
// Keep writing till the window drops below len(data).
for {
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790 + sent),
+ SeqNum: iss.Add(seqnum.Size(sent)),
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -2527,7 +2542,7 @@ func TestZeroScaledWindowReceive(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPAckNum(uint32(iss)+uint32(sent)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -2545,7 +2560,7 @@ func TestZeroScaledWindowReceive(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790 + sent),
+ SeqNum: iss.Add(seqnum.Size(sent)),
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -2556,7 +2571,7 @@ func TestZeroScaledWindowReceive(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPAckNum(uint32(iss)+uint32(sent)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -2594,7 +2609,7 @@ func TestZeroScaledWindowReceive(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPAckNum(uint32(iss)+uint32(sent)),
checker.TCPWindowGreaterThanEq(uint16(defaultMTU>>ws)),
checker.TCPFlags(header.TCPFlagAck),
),
@@ -2632,7 +2647,7 @@ func TestSegmentMerging(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
// Send tcp.InitialCwnd number of segments to fill up
// InitialWindow but don't ACK. That should prevent
@@ -2657,6 +2672,7 @@ func TestSegmentMerging(t *testing.T) {
}
// Check that we get tcp.InitialCwnd packets.
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
for i := 0; i < tcp.InitialCwnd; i++ {
b := c.GetPacket()
checker.IPv4(t, b,
@@ -2664,7 +2680,7 @@ func TestSegmentMerging(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2675,7 +2691,7 @@ func TestSegmentMerging(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1 + 10), // 10 for the 10 bytes of payload.
RcvWnd: 30000,
})
@@ -2687,7 +2703,7 @@ func TestSegmentMerging(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+11),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2701,7 +2717,7 @@ func TestSegmentMerging(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(11 + seqnum.Size(len(allData))),
RcvWnd: 30000,
})
@@ -2713,7 +2729,7 @@ func TestDelay(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
c.EP.SocketOptions().SetDelayOption(true)
@@ -2728,6 +2744,7 @@ func TestDelay(t *testing.T) {
}
seq := c.IRS.Add(1)
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
for _, want := range [][]byte{allData[:1], allData[1:]} {
// Check that data is received.
b := c.GetPacket()
@@ -2736,7 +2753,7 @@ func TestDelay(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(seq)),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2751,7 +2768,7 @@ func TestDelay(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seq,
RcvWnd: 30000,
})
@@ -2762,7 +2779,7 @@ func TestUndelay(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
c.EP.SocketOptions().SetDelayOption(true)
@@ -2776,7 +2793,7 @@ func TestUndelay(t *testing.T) {
}
seq := c.IRS.Add(1)
-
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
// Check that data is received.
first := c.GetPacket()
checker.IPv4(t, first,
@@ -2784,7 +2801,7 @@ func TestUndelay(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(seq)),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2807,7 +2824,7 @@ func TestUndelay(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(seq)),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2823,7 +2840,7 @@ func TestUndelay(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seq,
RcvWnd: 30000,
})
@@ -2845,7 +2862,7 @@ func TestMSSNotDelayed(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
+ c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
@@ -2861,7 +2878,7 @@ func TestMSSNotDelayed(t *testing.T) {
}
seq := c.IRS.Add(1)
-
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
for i, data := range allData {
// Check that data is received.
packet := c.GetPacket()
@@ -2870,7 +2887,7 @@ func TestMSSNotDelayed(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(seq)),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2887,7 +2904,7 @@ func TestMSSNotDelayed(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seq,
RcvWnd: 30000,
})
@@ -2912,6 +2929,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
// Check that data is received in chunks.
bytesReceived := 0
numPackets := 0
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
for bytesReceived != dataLen {
b := c.GetPacket()
numPackets++
@@ -2921,7 +2939,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1+uint32(bytesReceived)),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2945,7 +2963,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)),
RcvWnd: 30000,
TCPOpts: options,
@@ -2961,7 +2979,7 @@ func TestSendGreaterThanMTU(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
testBrokenUpWrite(t, c, maxPayload)
}
@@ -3001,7 +3019,7 @@ func TestActiveSendMSSLessThanMTU(t *testing.T) {
c := context.New(t, 65535)
defer c.Cleanup()
- c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
+ c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
testBrokenUpWrite(t, c, maxPayload)
@@ -3210,7 +3228,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
)
// Send SYN-ACK.
- iss := seqnum.Value(789)
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: tcpHdr.DestinationPort(),
DstPort: tcpHdr.SourcePort(),
@@ -3272,14 +3290,15 @@ func TestReceiveOnResetConnection(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
// Send RST segment.
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagRst,
- SeqNum: 790,
+ SeqNum: iss,
RcvWnd: 30000,
})
@@ -3328,14 +3347,15 @@ func TestSendOnResetConnection(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
// Send RST segment.
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagRst,
- SeqNum: 790,
+ SeqNum: iss,
RcvWnd: 30000,
})
@@ -3363,7 +3383,7 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
- c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
@@ -3426,7 +3446,7 @@ func TestMaxRTO(t *testing.T) {
t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
- c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
var r bytes.Reader
r.Reset(make([]byte, 1))
@@ -3469,7 +3489,7 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
// Disabling PMTU discovery causes all packets sent from this socket to
// have DF=0. This needs to be done because the IPv4 ID uniqueness
@@ -3518,19 +3538,20 @@ func TestFinImmediately(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
t.Fatalf("Shutdown failed: %s", err)
}
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3540,7 +3561,7 @@ func TestFinImmediately(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(2),
RcvWnd: 30000,
})
@@ -3551,7 +3572,7 @@ func TestFinImmediately(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+2),
- checker.TCPAckNum(791),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3561,19 +3582,20 @@ func TestFinRetransmit(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
t.Fatalf("Shutdown failed: %s", err)
}
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3584,7 +3606,7 @@ func TestFinRetransmit(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3594,7 +3616,7 @@ func TestFinRetransmit(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(2),
RcvWnd: 30000,
})
@@ -3605,7 +3627,7 @@ func TestFinRetransmit(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+2),
- checker.TCPAckNum(791),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3615,7 +3637,7 @@ func TestFinWithNoPendingData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
// Write something out, and have it acknowledged.
view := make([]byte, 10)
@@ -3626,12 +3648,13 @@ func TestFinWithNoPendingData(t *testing.T) {
}
next := uint32(c.IRS) + 1
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -3641,7 +3664,7 @@ func TestFinWithNoPendingData(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seqnum.Value(next),
RcvWnd: 30000,
})
@@ -3656,7 +3679,7 @@ func TestFinWithNoPendingData(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3667,7 +3690,7 @@ func TestFinWithNoPendingData(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seqnum.Value(next),
RcvWnd: 30000,
})
@@ -3678,7 +3701,7 @@ func TestFinWithNoPendingData(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(791),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3688,7 +3711,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
// Write enough segments to fill the congestion window before ACK'ing
// any of them.
@@ -3702,13 +3725,14 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
}
next := uint32(c.IRS) + 1
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
for i := tcp.InitialCwnd; i > 0; i-- {
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -3727,7 +3751,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -3737,7 +3761,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seqnum.Value(next),
RcvWnd: 30000,
})
@@ -3747,7 +3771,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3758,7 +3782,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seqnum.Value(next),
RcvWnd: 30000,
})
@@ -3768,7 +3792,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(791),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3778,7 +3802,7 @@ func TestFinWithPendingData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
// Write something out, and acknowledge it to get cwnd to 2.
view := make([]byte, 10)
@@ -3789,12 +3813,13 @@ func TestFinWithPendingData(t *testing.T) {
}
next := uint32(c.IRS) + 1
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -3804,7 +3829,7 @@ func TestFinWithPendingData(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seqnum.Value(next),
RcvWnd: 30000,
})
@@ -3820,7 +3845,7 @@ func TestFinWithPendingData(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -3836,7 +3861,7 @@ func TestFinWithPendingData(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3847,7 +3872,7 @@ func TestFinWithPendingData(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seqnum.Value(next),
RcvWnd: 30000,
})
@@ -3857,7 +3882,7 @@ func TestFinWithPendingData(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(791),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3867,7 +3892,7 @@ func TestFinWithPartialAck(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
// Write something out, and acknowledge it to get cwnd to 2. Also send
// FIN from the test side.
@@ -3879,12 +3904,13 @@ func TestFinWithPartialAck(t *testing.T) {
}
next := uint32(c.IRS) + 1
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -3894,7 +3920,7 @@ func TestFinWithPartialAck(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seqnum.Value(next),
RcvWnd: 30000,
})
@@ -3905,7 +3931,7 @@ func TestFinWithPartialAck(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(791),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -3921,7 +3947,7 @@ func TestFinWithPartialAck(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(791),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -3937,7 +3963,7 @@ func TestFinWithPartialAck(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(791),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3948,7 +3974,7 @@ func TestFinWithPartialAck(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 791,
+ SeqNum: iss.Add(1),
AckNum: seqnum.Value(next - 1),
RcvWnd: 30000,
})
@@ -3961,7 +3987,7 @@ func TestFinWithPartialAck(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 791,
+ SeqNum: iss.Add(1),
AckNum: seqnum.Value(next),
RcvWnd: 30000,
})
@@ -4002,17 +4028,18 @@ func scaledSendWindow(t *testing.T, scale uint8) {
defer c.Cleanup()
maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize
- c.CreateConnectedWithRawOptions(789, 0, -1 /* epRcvBuf */, []byte{
+ c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 0, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
header.TCPOptionWS, 3, scale, header.TCPOptionNOP,
})
// Open up the window with a scaled value.
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 1,
})
@@ -4031,7 +4058,7 @@ func scaledSendWindow(t *testing.T, scale uint8) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -4041,7 +4068,7 @@ func scaledSendWindow(t *testing.T, scale uint8) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagRst,
- SeqNum: 790,
+ SeqNum: iss,
})
}
@@ -4054,15 +4081,16 @@ func TestScaledSendWindow(t *testing.T) {
func TestReceivedValidSegmentCountIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.ValidSegmentsReceived.Value() + 1
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790),
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -4083,14 +4111,15 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) {
func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.InvalidSegmentsReceived.Value() + 1
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
vv := c.BuildSegment(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790),
+ SeqNum: seqnum.Value(iss),
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -4110,14 +4139,15 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
func TestReceivedIncorrectChecksumIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.ChecksumErrors.Value() + 1
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790),
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -4144,23 +4174,24 @@ func TestReceivedSegmentQueuing(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
// Send 200 segments.
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
data := []byte{1, 2, 3}
for i := 0; i < 200; i++ {
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790 + i*len(data)),
+ SeqNum: iss.Add(seqnum.Size(i * len(data))),
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
}
// Receive ACKs for all segments.
- last := seqnum.Value(790 + 200*len(data))
+ last := iss.Add(seqnum.Size(200 * len(data)))
for {
b := c.GetPacket()
checker.IPv4(t, b,
@@ -4198,7 +4229,7 @@ func TestReadAfterClosedState(t *testing.T) {
t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -4212,12 +4243,13 @@ func TestReadAfterClosedState(t *testing.T) {
t.Fatalf("Shutdown failed: %s", err)
}
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -4232,7 +4264,7 @@ func TestReadAfterClosedState(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(2),
RcvWnd: 30000,
})
@@ -4242,7 +4274,7 @@ func TestReadAfterClosedState(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+2),
- checker.TCPAckNum(uint32(791+len(data))),
+ checker.TCPAckNum(uint32(iss)+uint32(len(data))+1),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -4815,7 +4847,7 @@ func TestPathMTUDiscovery(t *testing.T) {
// Create new connection with MSS of 1460.
const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize
- c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
+ c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
@@ -4833,6 +4865,7 @@ func TestPathMTUDiscovery(t *testing.T) {
receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte {
var ret []byte
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
for i, size := range sizes {
p := c.GetPacket()
if i == which {
@@ -4843,7 +4876,7 @@ func TestPathMTUDiscovery(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(seqNum),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -4893,14 +4926,15 @@ func TestTCPEndpointProbe(t *testing.T) {
invoked <- struct{}{}
})
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -5027,7 +5061,7 @@ func TestEndpointSetCongestionControl(t *testing.T) {
}
if connected {
- c.Connect(789 /* iss */, 32768 /* rcvWnd */, nil)
+ c.Connect(context.TestInitialSequenceNumber, 32768 /* rcvWnd */, nil)
}
if err := c.EP.SetSockOpt(&tc.cc); err != tc.err {
@@ -5067,7 +5101,7 @@ func TestKeepalive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
const keepAliveIdle = 100 * time.Millisecond
const keepAliveInterval = 3 * time.Second
@@ -5087,13 +5121,14 @@ func TestKeepalive(t *testing.T) {
// 5 unacked keepalives are sent. ACK each one, and check that the
// connection stays alive after 5.
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
for i := 0; i < 10; i++ {
b := c.GetPacket()
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)),
- checker.TCPAckNum(uint32(790)),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -5103,7 +5138,7 @@ func TestKeepalive(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: c.IRS,
RcvWnd: 30000,
})
@@ -5128,7 +5163,7 @@ func TestKeepalive(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -5140,7 +5175,7 @@ func TestKeepalive(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh),
),
)
@@ -5153,7 +5188,7 @@ func TestKeepalive(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seqnum.Value(next),
RcvWnd: 30000,
})
@@ -5166,7 +5201,7 @@ func TestKeepalive(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(next-1)),
- checker.TCPAckNum(uint32(790)),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -5184,7 +5219,7 @@ func TestKeepalive(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seqnum.Value(next),
RcvWnd: 30000,
})
@@ -5215,7 +5250,7 @@ func TestKeepalive(t *testing.T) {
func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
t.Helper()
// Send a SYN request.
- irs = seqnum.Value(789)
+ irs = seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: srcPort,
DstPort: context.StackPort,
@@ -5260,7 +5295,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki
func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
t.Helper()
// Send a SYN request.
- irs = seqnum.Value(789)
+ irs = seqnum.Value(context.TestInitialSequenceNumber)
c.SendV6Packet(nil, &context.Headers{
SrcPort: srcPort,
DstPort: context.StackPort,
@@ -5340,7 +5375,7 @@ func TestListenBacklogFull(t *testing.T) {
SrcPort: context.TestPort + uint16(lastPortOffset),
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
- SeqNum: seqnum.Value(789),
+ SeqNum: seqnum.Value(context.TestInitialSequenceNumber),
RcvWnd: 30000,
})
c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
@@ -5491,7 +5526,7 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) {
t.Fatalf("Listen failed: %s", err)
}
- irs := seqnum.Value(789)
+ irs := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacketWithAddrs(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -5591,7 +5626,7 @@ func TestListenNoAcceptNonUnicastV6(t *testing.T) {
t.Fatalf("Listen failed: %s", err)
}
- irs := seqnum.Value(789)
+ irs := seqnum.Value(context.TestInitialSequenceNumber)
c.SendV6PacketWithAddrs(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -5645,7 +5680,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
// Send two SYN's the first one should get a SYN-ACK, the
// second one should not get any response and is dropped as
// the synRcvd count will be equal to backlog.
- irs := seqnum.Value(789)
+ irs := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -5758,7 +5793,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
time.Sleep(50 * time.Millisecond)
// Send a SYN request.
- irs := seqnum.Value(789)
+ irs := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
// pick a different src port for new SYN.
SrcPort: context.TestPort + 1,
@@ -5824,7 +5859,7 @@ func TestSYNRetransmit(t *testing.T) {
// Send the same SYN packet multiple times. We should still get a valid SYN-ACK
// reply.
- irs := seqnum.Value(789)
+ irs := seqnum.Value(context.TestInitialSequenceNumber)
for i := 0; i < 5; i++ {
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
@@ -5867,7 +5902,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
}
// Send a SYN to get a SYN-ACK. This should put the ep into SYN-RCVD state
- irs := seqnum.Value(789)
+ irs := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -6051,7 +6086,7 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
SrcPort: srcPort + 2,
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
- SeqNum: seqnum.Value(789),
+ SeqNum: seqnum.Value(context.TestInitialSequenceNumber),
RcvWnd: 30000,
})
@@ -6501,7 +6536,7 @@ func TestTCPLingerTimeout(t *testing.T) {
c := context.New(t, 1500 /* mtu */)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
testCases := []struct {
name string
@@ -6552,7 +6587,7 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) {
}
// Send a SYN request.
- iss := seqnum.Value(789)
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -6671,7 +6706,7 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) {
}
// Send a SYN request.
- iss := seqnum.Value(789)
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -6778,7 +6813,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
}
// Send a SYN request.
- iss := seqnum.Value(789)
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -6852,7 +6887,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
// Send a SYN request w/ sequence number lower than
// the highest sequence number sent. We just reuse
// the same number.
- iss = seqnum.Value(789)
+ iss = seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -6872,7 +6907,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
// Send a SYN request w/ sequence number higher than
// the highest sequence number sent.
- iss = seqnum.Value(792)
+ iss = iss.Add(3)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -6942,7 +6977,7 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
}
// Send a SYN request.
- iss := seqnum.Value(789)
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -7091,7 +7126,7 @@ func TestTCPCloseWithData(t *testing.T) {
}
// Send a SYN request.
- iss := seqnum.Value(789)
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -7242,7 +7277,7 @@ func TestTCPUserTimeout(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
@@ -7268,12 +7303,13 @@ func TestTCPUserTimeout(t *testing.T) {
}
next := uint32(c.IRS) + 1
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
- checker.TCPAckNum(790),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -7299,7 +7335,7 @@ func TestTCPUserTimeout(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seqnum.Value(next),
RcvWnd: 30000,
})
@@ -7328,7 +7364,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value()
@@ -7362,11 +7398,12 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
// Now receive 1 keepalives, but don't ACK it.
b := c.GetPacket()
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)),
- checker.TCPAckNum(uint32(790)),
+ checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -7383,7 +7420,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: iss,
AckNum: seqnum.Value(c.IRS + 1),
RcvWnd: 30000,
})
@@ -7413,20 +7450,20 @@ func TestIncreaseWindowOnRead(t *testing.T) {
defer c.Cleanup()
const rcvBuf = 65535 * 10
- c.CreateConnected(789, 30000, rcvBuf)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBuf)
// Write chunks of ~30000 bytes. It's important that two
// payloads make it equal or longer than MSS.
remain := rcvBuf * 2
sent := 0
data := make([]byte, defaultMTU/2)
-
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
for remain > len(data) {
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790 + sent),
+ SeqNum: iss.Add(seqnum.Size(sent)),
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -7438,7 +7475,7 @@ func TestIncreaseWindowOnRead(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPAckNum(uint32(iss)+uint32(sent)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -7469,7 +7506,7 @@ func TestIncreaseWindowOnRead(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPAckNum(uint32(iss)+uint32(sent)),
checker.TCPWindow(uint16(0xffff)),
checker.TCPFlags(header.TCPFlagAck),
),
@@ -7483,20 +7520,20 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
defer c.Cleanup()
const rcvBuf = 65535 * 10
- c.CreateConnected(789, 30000, rcvBuf)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBuf)
// Write chunks of ~30000 bytes. It's important that two
// payloads make it equal or longer than MSS.
remain := rcvBuf
sent := 0
data := make([]byte, defaultMTU/2)
-
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
for remain > len(data) {
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790 + sent),
+ SeqNum: iss.Add(seqnum.Size(sent)),
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -7507,7 +7544,7 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPAckNum(uint32(iss)+uint32(sent)),
checker.TCPWindowLessThanEq(0xffff),
checker.TCPFlags(header.TCPFlagAck),
),
@@ -7523,7 +7560,7 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPAckNum(uint32(iss)+uint32(sent)),
checker.TCPWindow(uint16(0xffff)),
checker.TCPFlags(header.TCPFlagAck),
),
@@ -7664,16 +7701,16 @@ func TestResetDuringClose(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- iss := seqnum.Value(789)
- c.CreateConnected(iss, 30000, -1 /* epRecvBuf */)
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRecvBuf */)
// Send some data to make sure there is some unread
// data to trigger a reset on c.Close.
irs := c.IRS
+ iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: iss.Add(1),
+ SeqNum: iss,
AckNum: irs.Add(1),
RcvWnd: 30000,
})
@@ -7683,7 +7720,7 @@ func TestResetDuringClose(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
checker.TCPSeqNum(uint32(irs.Add(1))),
- checker.TCPAckNum(uint32(iss.Add(5)))))
+ checker.TCPAckNum(uint32(iss)+4)))
// Close in a separate goroutine so that we can trigger
// a race with the RST we send below. This should not
@@ -7698,7 +7735,7 @@ func TestResetDuringClose(t *testing.T) {
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
- SeqNum: iss.Add(5),
+ SeqNum: iss.Add(4),
AckNum: c.IRS.Add(5),
RcvWnd: 30000,
Flags: header.TCPFlagRst,
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 5d81dbb94..c8126b51b 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -45,8 +45,8 @@ import (
// represents the remote endpoint.
const (
v4MappedAddrPrefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff"
- stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ stackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ testV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
stackV4MappedAddr = v4MappedAddrPrefix + stackAddr
testV4MappedAddr = v4MappedAddrPrefix + testAddr
multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr
diff --git a/pkg/test/dockerutil/network.go b/pkg/test/dockerutil/network.go
index 047091e75..dbe17fa5e 100644
--- a/pkg/test/dockerutil/network.go
+++ b/pkg/test/dockerutil/network.go
@@ -102,11 +102,8 @@ func (n *Network) Inspect(ctx context.Context) (types.NetworkResource, error) {
return n.client.NetworkInspect(ctx, n.id, types.NetworkInspectOptions{Verbose: true})
}
-// Cleanup cleans up the docker network and all the containers attached to it.
+// Cleanup cleans up the docker network.
func (n *Network) Cleanup(ctx context.Context) error {
- for _, c := range n.containers {
- c.CleanUp(ctx)
- }
n.containers = nil
return n.client.NetworkRemove(ctx, n.id)