diff options
Diffstat (limited to 'pkg')
93 files changed, 3079 insertions, 1089 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index 59b0e138a..114b516e2 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -44,6 +44,8 @@ go_library( "poll.go", "prctl.go", "ptrace.go", + "ptrace_amd64.go", + "ptrace_arm64.go", "rseq.go", "rusage.go", "sched.go", diff --git a/pkg/abi/linux/ptrace_amd64.go b/pkg/abi/linux/ptrace_amd64.go new file mode 100644 index 000000000..ed3881e27 --- /dev/null +++ b/pkg/abi/linux/ptrace_amd64.go @@ -0,0 +1,52 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package linux + +// PtraceRegs is the set of CPU registers exposed by ptrace. Source: +// syscall.PtraceRegs. +// +// +marshal +// +stateify savable +type PtraceRegs struct { + R15 uint64 + R14 uint64 + R13 uint64 + R12 uint64 + Rbp uint64 + Rbx uint64 + R11 uint64 + R10 uint64 + R9 uint64 + R8 uint64 + Rax uint64 + Rcx uint64 + Rdx uint64 + Rsi uint64 + Rdi uint64 + Orig_rax uint64 + Rip uint64 + Cs uint64 + Eflags uint64 + Rsp uint64 + Ss uint64 + Fs_base uint64 + Gs_base uint64 + Ds uint64 + Es uint64 + Fs uint64 + Gs uint64 +} diff --git a/pkg/sentry/arch/arch_state_aarch64.go b/pkg/abi/linux/ptrace_arm64.go index 0136a85ad..6147738b3 100644 --- a/pkg/sentry/arch/arch_state_aarch64.go +++ b/pkg/abi/linux/ptrace_arm64.go @@ -1,4 +1,4 @@ -// Copyright 2020 The gVisor Authors. +// Copyright 2019 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,25 +14,16 @@ // +build arm64 -package arch +package linux -import ( - "syscall" -) - -type syscallPtraceRegs struct { +// PtraceRegs is the set of CPU registers exposed by ptrace. Source: +// syscall.PtraceRegs. +// +// +marshal +// +stateify savable +type PtraceRegs struct { Regs [31]uint64 Sp uint64 Pc uint64 Pstate uint64 } - -// saveRegs is invoked by stateify. -func (s *State) saveRegs() syscallPtraceRegs { - return syscallPtraceRegs(s.Regs) -} - -// loadRegs is invoked by stateify. -func (s *State) loadRegs(r syscallPtraceRegs) { - s.Regs = syscall.PtraceRegs(r) -} diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD index e27f21e5e..901e0f320 100644 --- a/pkg/sentry/arch/BUILD +++ b/pkg/sentry/arch/BUILD @@ -11,7 +11,6 @@ go_library( "arch_amd64.go", "arch_amd64.s", "arch_arm64.go", - "arch_state_aarch64.go", "arch_state_x86.go", "arch_x86.go", "arch_x86_impl.go", @@ -26,11 +25,11 @@ go_library( "syscalls_amd64.go", "syscalls_arm64.go", ], + marshal = True, visibility = ["//:sandbox"], deps = [ ":registers_go_proto", "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/cpuid", "//pkg/log", @@ -38,6 +37,7 @@ go_library( "//pkg/sync", "//pkg/syserror", "//pkg/usermem", + "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go index c29e1b841..529980267 100644 --- a/pkg/sentry/arch/arch_aarch64.go +++ b/pkg/sentry/arch/arch_aarch64.go @@ -17,18 +17,20 @@ package arch import ( + "encoding/binary" "fmt" "io" - "syscall" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/log" rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) +// Registers represents the CPU registers for this architecture. +type Registers = linux.PtraceRegs + const ( // SyscallWidth is the width of insturctions. SyscallWidth = 4 @@ -90,7 +92,7 @@ func NewFloatingPointData() *FloatingPointData { // file ensures it's only built on aarch64). type State struct { // The system registers. - Regs syscall.PtraceRegs `state:".(syscallPtraceRegs)"` + Regs Registers // Our floating point state. aarch64FPState `state:"wait"` @@ -226,25 +228,27 @@ func (s *State) RegisterMap() (map[string]uintptr, error) { // PtraceGetRegs implements Context.PtraceGetRegs. func (s *State) PtraceGetRegs(dst io.Writer) (int, error) { - return dst.Write(binary.Marshal(nil, usermem.ByteOrder, s.ptraceGetRegs())) + regs := s.ptraceGetRegs() + n, err := regs.WriteTo(dst) + return int(n), err } -func (s *State) ptraceGetRegs() syscall.PtraceRegs { +func (s *State) ptraceGetRegs() Registers { return s.Regs } -var ptraceRegsSize = int(binary.Size(syscall.PtraceRegs{})) +var registersSize = (*Registers)(nil).SizeBytes() // PtraceSetRegs implements Context.PtraceSetRegs. func (s *State) PtraceSetRegs(src io.Reader) (int, error) { - var regs syscall.PtraceRegs - buf := make([]byte, ptraceRegsSize) + var regs Registers + buf := make([]byte, registersSize) if _, err := io.ReadFull(src, buf); err != nil { return 0, err } - binary.Unmarshal(buf, usermem.ByteOrder, ®s) + regs.UnmarshalUnsafe(buf) s.Regs = regs - return ptraceRegsSize, nil + return registersSize, nil } // PtraceGetFPRegs implements Context.PtraceGetFPRegs. diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go index 85d6acc0f..3b3a0a272 100644 --- a/pkg/sentry/arch/arch_amd64.go +++ b/pkg/sentry/arch/arch_amd64.go @@ -22,7 +22,6 @@ import ( "math/rand" "syscall" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/usermem" @@ -301,8 +300,10 @@ func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) { // PTRACE_PEEKUSER and PTRACE_POKEUSER are only effective on regs and // u_debugreg, returning 0 or silently no-oping for other fields // respectively. - if addr < uintptr(ptraceRegsSize) { - buf := binary.Marshal(nil, usermem.ByteOrder, c.ptraceGetRegs()) + if addr < uintptr(registersSize) { + regs := c.ptraceGetRegs() + buf := make([]byte, regs.SizeBytes()) + regs.MarshalUnsafe(buf) return c.Native(uintptr(usermem.ByteOrder.Uint64(buf[addr:]))), nil } // Note: x86 debug registers are missing. @@ -314,8 +315,10 @@ func (c *context64) PtracePokeUser(addr, data uintptr) error { if addr&7 != 0 || addr >= userStructSize { return syscall.EIO } - if addr < uintptr(ptraceRegsSize) { - buf := binary.Marshal(nil, usermem.ByteOrder, c.ptraceGetRegs()) + if addr < uintptr(registersSize) { + regs := c.ptraceGetRegs() + buf := make([]byte, regs.SizeBytes()) + regs.MarshalUnsafe(buf) usermem.ByteOrder.PutUint64(buf[addr:], uint64(data)) _, err := c.PtraceSetRegs(bytes.NewBuffer(buf)) return err diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go index db99c5acb..ada7ac7b8 100644 --- a/pkg/sentry/arch/arch_arm64.go +++ b/pkg/sentry/arch/arch_arm64.go @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +// +build arm64 + package arch import ( diff --git a/pkg/sentry/arch/arch_state_x86.go b/pkg/sentry/arch/arch_state_x86.go index aa31169e0..19ce99d25 100644 --- a/pkg/sentry/arch/arch_state_x86.go +++ b/pkg/sentry/arch/arch_state_x86.go @@ -18,7 +18,6 @@ package arch import ( "fmt" - "syscall" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/usermem" @@ -90,44 +89,3 @@ func (s *State) afterLoadFPState() { // Copy to the new, aligned location. copy(s.x86FPState, old) } - -// +stateify savable -type syscallPtraceRegs struct { - R15 uint64 - R14 uint64 - R13 uint64 - R12 uint64 - Rbp uint64 - Rbx uint64 - R11 uint64 - R10 uint64 - R9 uint64 - R8 uint64 - Rax uint64 - Rcx uint64 - Rdx uint64 - Rsi uint64 - Rdi uint64 - Orig_rax uint64 - Rip uint64 - Cs uint64 - Eflags uint64 - Rsp uint64 - Ss uint64 - Fs_base uint64 - Gs_base uint64 - Ds uint64 - Es uint64 - Fs uint64 - Gs uint64 -} - -// saveRegs is invoked by stateify. -func (s *State) saveRegs() syscallPtraceRegs { - return syscallPtraceRegs(s.Regs) -} - -// loadRegs is invoked by stateify. -func (s *State) loadRegs(r syscallPtraceRegs) { - s.Regs = syscall.PtraceRegs(r) -} diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go index 7fc4c0473..dc458b37f 100644 --- a/pkg/sentry/arch/arch_x86.go +++ b/pkg/sentry/arch/arch_x86.go @@ -21,7 +21,7 @@ import ( "io" "syscall" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/log" rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" @@ -30,6 +30,9 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) +// Registers represents the CPU registers for this architecture. +type Registers = linux.PtraceRegs + // System-related constants for x86. const ( // SyscallWidth is the width of syscall, sysenter, and int 80 insturctions. @@ -267,10 +270,12 @@ func (s *State) RegisterMap() (map[string]uintptr, error) { // PtraceGetRegs implements Context.PtraceGetRegs. func (s *State) PtraceGetRegs(dst io.Writer) (int, error) { - return dst.Write(binary.Marshal(nil, usermem.ByteOrder, s.ptraceGetRegs())) + regs := s.ptraceGetRegs() + n, err := regs.WriteTo(dst) + return int(n), err } -func (s *State) ptraceGetRegs() syscall.PtraceRegs { +func (s *State) ptraceGetRegs() Registers { regs := s.Regs // These may not be initialized. if regs.Cs == 0 || regs.Ss == 0 || regs.Eflags == 0 { @@ -306,16 +311,16 @@ func (s *State) ptraceGetRegs() syscall.PtraceRegs { return regs } -var ptraceRegsSize = int(binary.Size(syscall.PtraceRegs{})) +var registersSize = (*Registers)(nil).SizeBytes() // PtraceSetRegs implements Context.PtraceSetRegs. func (s *State) PtraceSetRegs(src io.Reader) (int, error) { - var regs syscall.PtraceRegs - buf := make([]byte, ptraceRegsSize) + var regs Registers + buf := make([]byte, registersSize) if _, err := io.ReadFull(src, buf); err != nil { return 0, err } - binary.Unmarshal(buf, usermem.ByteOrder, ®s) + regs.UnmarshalUnsafe(buf) // Truncate segment registers to 16 bits. regs.Cs = uint64(uint16(regs.Cs)) regs.Ds = uint64(uint16(regs.Ds)) @@ -369,7 +374,7 @@ func (s *State) PtraceSetRegs(src io.Reader) (int, error) { } regs.Eflags = (s.Regs.Eflags &^ eflagsPtraceMutable) | (regs.Eflags & eflagsPtraceMutable) s.Regs = regs - return ptraceRegsSize, nil + return registersSize, nil } // isUserSegmentSelector returns true if the given segment selector specifies a @@ -538,7 +543,7 @@ const ( func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error) { switch regset { case _NT_PRSTATUS: - if maxlen < ptraceRegsSize { + if maxlen < registersSize { return 0, syserror.EFAULT } return s.PtraceGetRegs(dst) @@ -558,7 +563,7 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error) { switch regset { case _NT_PRSTATUS: - if maxlen < ptraceRegsSize { + if maxlen < registersSize { return 0, syserror.EFAULT } return s.PtraceSetRegs(src) diff --git a/pkg/sentry/arch/arch_x86_impl.go b/pkg/sentry/arch/arch_x86_impl.go index 3edf40764..0c73fcbfb 100644 --- a/pkg/sentry/arch/arch_x86_impl.go +++ b/pkg/sentry/arch/arch_x86_impl.go @@ -17,8 +17,6 @@ package arch import ( - "syscall" - "gvisor.dev/gvisor/pkg/cpuid" ) @@ -28,7 +26,7 @@ import ( // +stateify savable type State struct { // The system registers. - Regs syscall.PtraceRegs `state:".(syscallPtraceRegs)"` + Regs Registers // Our floating point state. x86FPState `state:"wait"` diff --git a/pkg/sentry/arch/signal.go b/pkg/sentry/arch/signal.go index 8b03d0187..c9fb55d00 100644 --- a/pkg/sentry/arch/signal.go +++ b/pkg/sentry/arch/signal.go @@ -22,6 +22,7 @@ import ( // SignalAct represents the action that should be taken when a signal is // delivered, and is equivalent to struct sigaction. // +// +marshal // +stateify savable type SignalAct struct { Handler uint64 @@ -43,6 +44,7 @@ func (s *SignalAct) DeserializeTo(other *SignalAct) { // SignalStack represents information about a user stack, and is equivalent to // stack_t. // +// +marshal // +stateify savable type SignalStack struct { Addr uint64 @@ -64,6 +66,7 @@ func (s *SignalStack) DeserializeTo(other *SignalStack) { // SignalInfo represents information about a signal being delivered, and is // equivalent to struct siginfo in linux kernel(linux/include/uapi/asm-generic/siginfo.h). // +// +marshal // +stateify savable type SignalInfo struct { Signo int32 // Signal number diff --git a/pkg/sentry/arch/signal_act.go b/pkg/sentry/arch/signal_act.go index f9ca2e74e..32173aa20 100644 --- a/pkg/sentry/arch/signal_act.go +++ b/pkg/sentry/arch/signal_act.go @@ -14,6 +14,8 @@ package arch +import "gvisor.dev/gvisor/tools/go_marshal/marshal" + // Special values for SignalAct.Handler. const ( // SignalActDefault is SIG_DFL and specifies that the default behavior for @@ -71,6 +73,8 @@ func (s SignalAct) HasRestorer() bool { // NativeSignalAct is a type that is equivalent to struct sigaction in the // guest architecture. type NativeSignalAct interface { + marshal.Marshallable + // SerializeFrom copies the data in the host SignalAct s into this object. SerializeFrom(s *SignalAct) diff --git a/pkg/sentry/arch/signal_stack.go b/pkg/sentry/arch/signal_stack.go index e58f055c7..0fa738a1d 100644 --- a/pkg/sentry/arch/signal_stack.go +++ b/pkg/sentry/arch/signal_stack.go @@ -18,6 +18,7 @@ package arch import ( "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" ) const ( @@ -55,6 +56,8 @@ func (s *SignalStack) Contains(sp usermem.Addr) bool { // NativeSignalStack is a type that is equivalent to stack_t in the guest // architecture. type NativeSignalStack interface { + marshal.Marshallable + // SerializeFrom copies the data in the host SignalStack s into this // object. SerializeFrom(s *SignalStack) diff --git a/pkg/sentry/fs/gofer/socket.go b/pkg/sentry/fs/gofer/socket.go index 10ba2f5f0..40f2c1cad 100644 --- a/pkg/sentry/fs/gofer/socket.go +++ b/pkg/sentry/fs/gofer/socket.go @@ -47,6 +47,8 @@ func (i *inodeOperations) BoundEndpoint(inode *fs.Inode, path string) transport. return &endpoint{inode, i.fileState.file.file, path} } +// LINT.IfChange + // endpoint is a Gofer-backed transport.BoundEndpoint. // // An endpoint's lifetime is the time between when InodeOperations.BoundEndpoint() @@ -146,3 +148,5 @@ func (e *endpoint) Release() { func (e *endpoint) Passcred() bool { return false } + +// LINT.ThenChange(../../fsimpl/gofer/socket.go) diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 06fc2d80a..b6e94583e 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -37,6 +37,8 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// LINT.IfChange + // maxSendBufferSize is the maximum host send buffer size allowed for endpoint. // // N.B. 8MB is the default maximum on Linux (2 * sysctl_wmem_max). @@ -388,3 +390,5 @@ func (c *ConnectedEndpoint) Release() { // CloseUnread implements transport.ConnectedEndpoint.CloseUnread. func (c *ConnectedEndpoint) CloseUnread() {} + +// LINT.ThenChange(../../fsimpl/host/socket.go) diff --git a/pkg/sentry/fs/host/socket_iovec.go b/pkg/sentry/fs/host/socket_iovec.go index af6955675..5c18dbd5e 100644 --- a/pkg/sentry/fs/host/socket_iovec.go +++ b/pkg/sentry/fs/host/socket_iovec.go @@ -21,6 +21,8 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) +// LINT.IfChange + // maxIovs is the maximum number of iovecs to pass to the host. var maxIovs = linux.UIO_MAXIOV @@ -111,3 +113,5 @@ func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovec return total, iovecs, nil, err } + +// LINT.ThenChange(../../fsimpl/host/socket_iovec.go) diff --git a/pkg/sentry/fs/host/socket_unsafe.go b/pkg/sentry/fs/host/socket_unsafe.go index f3bbed7ea..5d4f312cf 100644 --- a/pkg/sentry/fs/host/socket_unsafe.go +++ b/pkg/sentry/fs/host/socket_unsafe.go @@ -19,6 +19,8 @@ import ( "unsafe" ) +// LINT.IfChange + // fdReadVec receives from fd to bufs. // // If the total length of bufs is > maxlen, fdReadVec will do a partial read @@ -99,3 +101,5 @@ func fdWriteVec(fd int, bufs [][]byte, maxlen int64, truncate bool) (int64, int6 return int64(n), length, err } + +// LINT.ThenChange(../../fsimpl/host/socket_unsafe.go) diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go index 2c22a04af..77b644275 100644 --- a/pkg/sentry/fsimpl/ext/filesystem.go +++ b/pkg/sentry/fsimpl/ext/filesystem.go @@ -485,11 +485,14 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error } // BoundEndpointAt implements FilesystemImpl.BoundEndpointAt. -func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath) (transport.BoundEndpoint, error) { - _, _, err := fs.walk(rp, false) +func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) { + _, inode, err := fs.walk(rp, false) if err != nil { return nil, err } + if err := inode.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil { + return nil, err + } // TODO(b/134676337): Support sockets. return nil, syserror.ECONNREFUSED diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD index b9c4beee4..5ce82b793 100644 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ b/pkg/sentry/fsimpl/gofer/BUILD @@ -38,6 +38,7 @@ go_library( "p9file.go", "pagemath.go", "regular_file.go", + "socket.go", "special_file.go", "symlink.go", "time.go", @@ -52,18 +53,25 @@ go_library( "//pkg/p9", "//pkg/safemem", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fsimpl/host", "//pkg/sentry/hostfd", + "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sentry/kernel/pipe", "//pkg/sentry/kernel/time", "//pkg/sentry/memmap", "//pkg/sentry/pgalloc", "//pkg/sentry/platform", + "//pkg/sentry/socket/control", + "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/sentry/vfs", + "//pkg/syserr", "//pkg/syserror", "//pkg/unet", "//pkg/usermem", + "//pkg/waiter", ], ) diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go index 55f9ed911..b98218753 100644 --- a/pkg/sentry/fsimpl/gofer/directory.go +++ b/pkg/sentry/fsimpl/gofer/directory.go @@ -15,6 +15,7 @@ package gofer import ( + "fmt" "sync" "sync/atomic" @@ -22,6 +23,8 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" + "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -59,28 +62,52 @@ func (d *dentry) cacheNegativeLookupLocked(name string) { d.children[name] = nil } -// createSyntheticDirectory creates a synthetic directory with the given name +type createSyntheticOpts struct { + name string + mode linux.FileMode + kuid auth.KUID + kgid auth.KGID + + // The endpoint for a synthetic socket. endpoint should be nil if the file + // being created is not a socket. + endpoint transport.BoundEndpoint + + // pipe should be nil if the file being created is not a pipe. + pipe *pipe.VFSPipe +} + +// createSyntheticChildLocked creates a synthetic file with the given name // in d. // // Preconditions: d.dirMu must be locked. d.isDir(). d does not already contain // a child with the given name. -func (d *dentry) createSyntheticDirectoryLocked(name string, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) { +func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) { d2 := &dentry{ refs: 1, // held by d fs: d.fs, - mode: uint32(mode) | linux.S_IFDIR, - uid: uint32(kuid), - gid: uint32(kgid), + mode: uint32(opts.mode), + uid: uint32(opts.kuid), + gid: uint32(opts.kgid), blockSize: usermem.PageSize, // arbitrary handle: handle{ fd: -1, }, nlink: uint32(2), } + switch opts.mode.FileType() { + case linux.S_IFDIR: + // Nothing else needs to be done. + case linux.S_IFSOCK: + d2.endpoint = opts.endpoint + case linux.S_IFIFO: + d2.pipe = opts.pipe + default: + panic(fmt.Sprintf("failed to create synthetic file of unrecognized type: %v", opts.mode.FileType())) + } d2.pf.dentry = d2 d2.vfsd.Init(d2) - d.cacheNewChildLocked(d2, name) + d.cacheNewChildLocked(d2, opts.name) d.syntheticChildren++ } diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 98ccb42fd..4a8411371 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -22,9 +22,11 @@ import ( "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // Sync implements vfs.FilesystemImpl.Sync. @@ -652,7 +654,12 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v return err } ctx.Infof("Failed to create remote directory %q: %v; falling back to synthetic directory", name, err) - parent.createSyntheticDirectoryLocked(name, opts.Mode, creds.EffectiveKUID, creds.EffectiveKGID) + parent.createSyntheticChildLocked(&createSyntheticOpts{ + name: name, + mode: linux.S_IFDIR | opts.Mode, + kuid: creds.EffectiveKUID, + kgid: creds.EffectiveKGID, + }) } if fs.opts.interop != InteropModeShared { parent.incLinks() @@ -663,7 +670,12 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v // Can't create non-synthetic files in synthetic directories. return syserror.EPERM } - parent.createSyntheticDirectoryLocked(name, opts.Mode, creds.EffectiveKUID, creds.EffectiveKGID) + parent.createSyntheticChildLocked(&createSyntheticOpts{ + name: name, + mode: linux.S_IFDIR | opts.Mode, + kuid: creds.EffectiveKUID, + kgid: creds.EffectiveKGID, + }) parent.incLinks() return nil }) @@ -674,6 +686,28 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string) error { creds := rp.Credentials() _, err := parent.file.mknod(ctx, name, (p9.FileMode)(opts.Mode), opts.DevMajor, opts.DevMinor, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)) + if err == syserror.EPERM { + switch opts.Mode.FileType() { + case linux.S_IFSOCK: + parent.createSyntheticChildLocked(&createSyntheticOpts{ + name: name, + mode: opts.Mode, + kuid: creds.EffectiveKUID, + kgid: creds.EffectiveKGID, + endpoint: opts.Endpoint, + }) + return nil + case linux.S_IFIFO: + parent.createSyntheticChildLocked(&createSyntheticOpts{ + name: name, + mode: opts.Mode, + kuid: creds.EffectiveKUID, + kgid: creds.EffectiveKGID, + pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize), + }) + return nil + } + } return err }, nil) } @@ -756,20 +790,21 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf return nil, err } mnt := rp.Mount() - filetype := d.fileType() - switch { - case filetype == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD: - if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, opts.Flags&linux.O_TRUNC != 0); err != nil { - return nil, err - } - fd := ®ularFileFD{} - if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{ - AllowDirectIO: true, - }); err != nil { - return nil, err + switch d.fileType() { + case linux.S_IFREG: + if !d.fs.opts.regularFilesUseSpecialFileFD { + if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, opts.Flags&linux.O_TRUNC != 0); err != nil { + return nil, err + } + fd := ®ularFileFD{} + if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{ + AllowDirectIO: true, + }); err != nil { + return nil, err + } + return &fd.vfsfd, nil } - return &fd.vfsfd, nil - case filetype == linux.S_IFDIR: + case linux.S_IFDIR: // Can't open directories with O_CREAT. if opts.Flags&linux.O_CREAT != 0 { return nil, syserror.EISDIR @@ -791,26 +826,40 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf return nil, err } return &fd.vfsfd, nil - case filetype == linux.S_IFLNK: + case linux.S_IFLNK: // Can't open symlinks without O_PATH (which is unimplemented). return nil, syserror.ELOOP - default: - if opts.Flags&linux.O_DIRECT != 0 { - return nil, syserror.EINVAL - } - h, err := openHandle(ctx, d.file, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, opts.Flags&linux.O_TRUNC != 0) - if err != nil { - return nil, err - } - fd := &specialFileFD{ - handle: h, + case linux.S_IFSOCK: + if d.isSynthetic() { + return nil, syserror.ENXIO } - if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil { - h.close(ctx) - return nil, err + case linux.S_IFIFO: + if d.isSynthetic() { + return d.pipe.Open(ctx, mnt, &d.vfsd, opts.Flags) } - return &fd.vfsfd, nil } + return d.openSpecialFileLocked(ctx, mnt, opts) +} + +func (d *dentry) openSpecialFileLocked(ctx context.Context, mnt *vfs.Mount, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { + ats := vfs.AccessTypesForOpenFlags(opts) + // Treat as a special file. This is done for non-synthetic pipes as well as + // regular files when d.fs.opts.regularFilesUseSpecialFileFD is true. + if opts.Flags&linux.O_DIRECT != 0 { + return nil, syserror.EINVAL + } + h, err := openHandle(ctx, d.file, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, opts.Flags&linux.O_TRUNC != 0) + if err != nil { + return nil, err + } + fd := &specialFileFD{ + handle: h, + } + if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil { + h.close(ctx) + return nil, err + } + return &fd.vfsfd, nil } // Preconditions: d.fs.renameMu must be locked. d.dirMu must be locked. @@ -1196,15 +1245,28 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error } // BoundEndpointAt implements FilesystemImpl.BoundEndpointAt. -func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath) (transport.BoundEndpoint, error) { +func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) { var ds *[]*dentry fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckCaching(&ds) - _, err := fs.resolveLocked(ctx, rp, &ds) + d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return nil, err } - // TODO(gvisor.dev/issue/1476): Implement BoundEndpointAt. + if err := d.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil { + return nil, err + } + if d.isSocket() { + if !d.isSynthetic() { + d.IncRef() + return &endpoint{ + dentry: d, + file: d.file.file, + path: opts.Addr, + }, nil + } + return d.endpoint, nil + } return nil, syserror.ECONNREFUSED } diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 8b4e91d17..3235816c2 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -46,9 +46,11 @@ import ( "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" + "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/unet" @@ -585,6 +587,14 @@ type dentry struct { // and target are protected by dataMu. haveTarget bool target string + + // If this dentry represents a synthetic socket file, endpoint is the + // transport endpoint bound to this file. + endpoint transport.BoundEndpoint + + // If this dentry represents a synthetic named pipe, pipe is the pipe + // endpoint bound to this file. + pipe *pipe.VFSPipe } // dentryAttrMask returns a p9.AttrMask enabling all attributes used by the @@ -791,10 +801,21 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin setLocalAtime = stat.Mask&linux.STATX_ATIME != 0 setLocalMtime = stat.Mask&linux.STATX_MTIME != 0 stat.Mask &^= linux.STATX_ATIME | linux.STATX_MTIME - if !setLocalMtime && (stat.Mask&linux.STATX_SIZE != 0) { - // Truncate updates mtime. - setLocalMtime = true - stat.Mtime.Nsec = linux.UTIME_NOW + + // Prepare for truncate. + if stat.Mask&linux.STATX_SIZE != 0 { + switch d.mode & linux.S_IFMT { + case linux.S_IFREG: + if !setLocalMtime { + // Truncate updates mtime. + setLocalMtime = true + stat.Mtime.Nsec = linux.UTIME_NOW + } + case linux.S_IFDIR: + return syserror.EISDIR + default: + return syserror.EINVAL + } } } d.metadataMu.Lock() @@ -1049,6 +1070,7 @@ func (d *dentry) destroyLocked() { d.handle.close(ctx) } d.handleMu.Unlock() + if !d.file.isNil() { d.file.close(ctx) d.file = p9file{} @@ -1134,7 +1156,7 @@ func (d *dentry) removexattr(ctx context.Context, creds *auth.Credentials, name return d.file.removeXattr(ctx, name) } -// Preconditions: !d.file.isNil(). d.isRegularFile() || d.isDirectory(). +// Preconditions: !d.isSynthetic(). d.isRegularFile() || d.isDirectory(). func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool) error { // O_TRUNC unconditionally requires us to obtain a new handle (opened with // O_TRUNC). diff --git a/pkg/sentry/fsimpl/gofer/socket.go b/pkg/sentry/fsimpl/gofer/socket.go new file mode 100644 index 000000000..73835df91 --- /dev/null +++ b/pkg/sentry/fsimpl/gofer/socket.go @@ -0,0 +1,141 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gofer + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/host" + "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/waiter" +) + +func (d *dentry) isSocket() bool { + return d.fileType() == linux.S_IFSOCK +} + +// endpoint is a Gofer-backed transport.BoundEndpoint. +// +// An endpoint's lifetime is the time between when filesystem.BoundEndpointAt() +// is called and either BoundEndpoint.BidirectionalConnect or +// BoundEndpoint.UnidirectionalConnect is called. +type endpoint struct { + // dentry is the filesystem dentry which produced this endpoint. + dentry *dentry + + // file is the p9 file that contains a single unopened fid. + file p9.File + + // path is the sentry path where this endpoint is bound. + path string +} + +func sockTypeToP9(t linux.SockType) (p9.ConnectFlags, bool) { + switch t { + case linux.SOCK_STREAM: + return p9.StreamSocket, true + case linux.SOCK_SEQPACKET: + return p9.SeqpacketSocket, true + case linux.SOCK_DGRAM: + return p9.DgramSocket, true + } + return 0, false +} + +// BidirectionalConnect implements ConnectableEndpoint.BidirectionalConnect. +func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.ConnectingEndpoint, returnConnect func(transport.Receiver, transport.ConnectedEndpoint)) *syserr.Error { + cf, ok := sockTypeToP9(ce.Type()) + if !ok { + return syserr.ErrConnectionRefused + } + + // No lock ordering required as only the ConnectingEndpoint has a mutex. + ce.Lock() + + // Check connecting state. + if ce.Connected() { + ce.Unlock() + return syserr.ErrAlreadyConnected + } + if ce.Listening() { + ce.Unlock() + return syserr.ErrInvalidEndpointState + } + + c, err := e.newConnectedEndpoint(ctx, cf, ce.WaiterQueue()) + if err != nil { + ce.Unlock() + return err + } + + returnConnect(c, c) + ce.Unlock() + c.Init() + + return nil +} + +// UnidirectionalConnect implements +// transport.BoundEndpoint.UnidirectionalConnect. +func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.ConnectedEndpoint, *syserr.Error) { + c, err := e.newConnectedEndpoint(ctx, p9.DgramSocket, &waiter.Queue{}) + if err != nil { + return nil, err + } + c.Init() + + // We don't need the receiver. + c.CloseRecv() + c.Release() + + return c, nil +} + +func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFlags, queue *waiter.Queue) (*host.SCMConnectedEndpoint, *syserr.Error) { + hostFile, err := e.file.Connect(flags) + if err != nil { + return nil, syserr.ErrConnectionRefused + } + // Dup the fd so that the new endpoint can manage its lifetime. + hostFD, err := syscall.Dup(hostFile.FD()) + if err != nil { + log.Warningf("Could not dup host socket fd %d: %v", hostFile.FD(), err) + return nil, syserr.FromError(err) + } + // After duplicating, we no longer need hostFile. + hostFile.Close() + + c, serr := host.NewSCMEndpoint(ctx, hostFD, queue, e.path) + if serr != nil { + log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.file, flags, serr) + return nil, serr + } + return c, nil +} + +// Release implements transport.BoundEndpoint.Release. +func (e *endpoint) Release() { + e.dentry.DecRef() +} + +// Passcred implements transport.BoundEndpoint.Passcred. +func (e *endpoint) Passcred() bool { + return false +} diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD index 2dcb03a73..e1c56d89b 100644 --- a/pkg/sentry/fsimpl/host/BUILD +++ b/pkg/sentry/fsimpl/host/BUILD @@ -8,6 +8,9 @@ go_library( "control.go", "host.go", "ioctl_unsafe.go", + "socket.go", + "socket_iovec.go", + "socket_unsafe.go", "tty.go", "util.go", "util_unsafe.go", @@ -16,6 +19,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/fdnotifier", "//pkg/log", "//pkg/refs", "//pkg/sentry/arch", @@ -25,12 +29,18 @@ go_library( "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", "//pkg/sentry/socket/control", + "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/unimpl", + "//pkg/sentry/uniqueid", "//pkg/sentry/vfs", "//pkg/sync", + "//pkg/syserr", "//pkg/syserror", + "//pkg/tcpip", + "//pkg/unet", "//pkg/usermem", + "//pkg/waiter", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 1e53b5c1b..05538ea42 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -17,7 +17,6 @@ package host import ( - "errors" "fmt" "math" "syscall" @@ -31,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/hostfd" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" + unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -156,7 +156,7 @@ type inode struct { // Note that these flags may become out of date, since they can be modified // on the host, e.g. with fcntl. -func fileFlagsFromHostFD(fd int) (int, error) { +func fileFlagsFromHostFD(fd int) (uint32, error) { flags, err := unix.FcntlInt(uintptr(fd), syscall.F_GETFL, 0) if err != nil { log.Warningf("Failed to get file flags for donated FD %d: %v", fd, err) @@ -164,7 +164,7 @@ func fileFlagsFromHostFD(fd int) (int, error) { } // TODO(gvisor.dev/issue/1672): implement behavior corresponding to these allowed flags. flags &= syscall.O_ACCMODE | syscall.O_DIRECT | syscall.O_NONBLOCK | syscall.O_DSYNC | syscall.O_SYNC | syscall.O_APPEND - return flags, nil + return uint32(flags), nil } // CheckPermissions implements kernfs.Inode. @@ -361,6 +361,10 @@ func (i *inode) Destroy() { // Open implements kernfs.Inode. func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + // Once created, we cannot re-open a socket fd through /proc/[pid]/fd/. + if i.Mode().FileType() == linux.S_IFSOCK { + return nil, syserror.ENXIO + } return i.open(ctx, vfsd, rp.Mount()) } @@ -370,42 +374,45 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount) (*vfs.F return nil, err } fileType := s.Mode & linux.FileTypeMask + flags, err := fileFlagsFromHostFD(i.hostFD) + if err != nil { + return nil, err + } + if fileType == syscall.S_IFSOCK { if i.isTTY { - return nil, errors.New("cannot use host socket as TTY") + log.Warningf("cannot use host socket fd %d as TTY", i.hostFD) + return nil, syserror.ENOTTY } - // TODO(gvisor.dev/issue/1672): support importing sockets. - return nil, errors.New("importing host sockets not supported") + + ep, err := newEndpoint(ctx, i.hostFD) + if err != nil { + return nil, err + } + // Currently, we only allow Unix sockets to be imported. + return unixsocket.NewFileDescription(ep, ep.Type(), flags, mnt, d) } // TODO(gvisor.dev/issue/1672): Whitelist specific file types here, so that // we don't allow importing arbitrary file types without proper support. - var ( - vfsfd *vfs.FileDescription - fdImpl vfs.FileDescriptionImpl - ) if i.isTTY { fd := &ttyFD{ fileDescription: fileDescription{inode: i}, termios: linux.DefaultSlaveTermios, } - vfsfd = &fd.vfsfd - fdImpl = fd - } else { - // For simplicity, set offset to 0. Technically, we should - // only set to 0 on files that are not seekable (sockets, pipes, etc.), - // and use the offset from the host fd otherwise. - fd := &fileDescription{inode: i} - vfsfd = &fd.vfsfd - fdImpl = fd - } - - flags, err := fileFlagsFromHostFD(i.hostFD) - if err != nil { - return nil, err + vfsfd := &fd.vfsfd + if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil { + return nil, err + } + return vfsfd, nil } - if err := vfsfd.Init(fdImpl, uint32(flags), mnt, d, &vfs.FileDescriptionOptions{}); err != nil { + // For simplicity, set offset to 0. Technically, we should + // only set to 0 on files that are not seekable (sockets, pipes, etc.), + // and use the offset from the host fd otherwise. + fd := &fileDescription{inode: i} + vfsfd := &fd.vfsfd + if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil { return nil, err } return vfsfd, nil diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go new file mode 100644 index 000000000..39dd624a0 --- /dev/null +++ b/pkg/sentry/fsimpl/host/socket.go @@ -0,0 +1,397 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package host + +import ( + "fmt" + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sentry/socket/control" + "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sentry/uniqueid" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/unet" + "gvisor.dev/gvisor/pkg/waiter" +) + +// Create a new host-backed endpoint from the given fd. +func newEndpoint(ctx context.Context, hostFD int) (transport.Endpoint, error) { + // Set up an external transport.Endpoint using the host fd. + addr := fmt.Sprintf("hostfd:[%d]", hostFD) + var q waiter.Queue + e, err := NewConnectedEndpoint(ctx, hostFD, &q, addr, true /* saveable */) + if err != nil { + return nil, err.ToError() + } + e.Init() + ep := transport.NewExternal(ctx, e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) + return ep, nil +} + +// maxSendBufferSize is the maximum host send buffer size allowed for endpoint. +// +// N.B. 8MB is the default maximum on Linux (2 * sysctl_wmem_max). +const maxSendBufferSize = 8 << 20 + +// ConnectedEndpoint is an implementation of transport.ConnectedEndpoint and +// transport.Receiver. It is backed by a host fd that was imported at sentry +// startup. This fd is shared with a hostfs inode, which retains ownership of +// it. +// +// ConnectedEndpoint is saveable, since we expect that the host will provide +// the same fd upon restore. +// +// As of this writing, we only allow Unix sockets to be imported. +// +// +stateify savable +type ConnectedEndpoint struct { + // ref keeps track of references to a ConnectedEndpoint. + ref refs.AtomicRefCount + + // mu protects fd below. + mu sync.RWMutex `state:"nosave"` + + // fd is the host fd backing this endpoint. + fd int + + // addr is the address at which this endpoint is bound. + addr string + + queue *waiter.Queue + + // sndbuf is the size of the send buffer. + // + // N.B. When this is smaller than the host size, we present it via + // GetSockOpt and message splitting/rejection in SendMsg, but do not + // prevent lots of small messages from filling the real send buffer + // size on the host. + sndbuf int64 `state:"nosave"` + + // stype is the type of Unix socket. + stype linux.SockType +} + +// init performs initialization required for creating new ConnectedEndpoints and +// for restoring them. +func (c *ConnectedEndpoint) init() *syserr.Error { + family, err := syscall.GetsockoptInt(c.fd, syscall.SOL_SOCKET, syscall.SO_DOMAIN) + if err != nil { + return syserr.FromError(err) + } + + if family != syscall.AF_UNIX { + // We only allow Unix sockets. + return syserr.ErrInvalidEndpointState + } + + stype, err := syscall.GetsockoptInt(c.fd, syscall.SOL_SOCKET, syscall.SO_TYPE) + if err != nil { + return syserr.FromError(err) + } + + if err := syscall.SetNonblock(c.fd, true); err != nil { + return syserr.FromError(err) + } + + sndbuf, err := syscall.GetsockoptInt(c.fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF) + if err != nil { + return syserr.FromError(err) + } + if sndbuf > maxSendBufferSize { + log.Warningf("Socket send buffer too large: %d", sndbuf) + return syserr.ErrInvalidEndpointState + } + + c.stype = linux.SockType(stype) + c.sndbuf = int64(sndbuf) + + return nil +} + +// NewConnectedEndpoint creates a new ConnectedEndpoint backed by a host fd +// imported at sentry startup, +// +// The caller is responsible for calling Init(). Additionaly, Release needs to +// be called twice because ConnectedEndpoint is both a transport.Receiver and +// transport.ConnectedEndpoint. +func NewConnectedEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue, addr string, saveable bool) (*ConnectedEndpoint, *syserr.Error) { + e := ConnectedEndpoint{ + fd: hostFD, + addr: addr, + queue: queue, + } + + if err := e.init(); err != nil { + return nil, err + } + + // AtomicRefCounters start off with a single reference. We need two. + e.ref.IncRef() + e.ref.EnableLeakCheck("host.ConnectedEndpoint") + return &e, nil +} + +// Init will do the initialization required without holding other locks. +func (c *ConnectedEndpoint) Init() { + if err := fdnotifier.AddFD(int32(c.fd), c.queue); err != nil { + panic(err) + } +} + +// Send implements transport.ConnectedEndpoint.Send. +func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) { + c.mu.RLock() + defer c.mu.RUnlock() + + if !controlMessages.Empty() { + return 0, false, syserr.ErrInvalidEndpointState + } + + // Since stream sockets don't preserve message boundaries, we can write + // only as much of the message as fits in the send buffer. + truncate := c.stype == linux.SOCK_STREAM + + n, totalLen, err := fdWriteVec(c.fd, data, c.sndbuf, truncate) + if n < totalLen && err == nil { + // The host only returns a short write if it would otherwise + // block (and only for stream sockets). + err = syserror.EAGAIN + } + if n > 0 && err != syserror.EAGAIN { + // The caller may need to block to send more data, but + // otherwise there isn't anything that can be done about an + // error with a partial write. + err = nil + } + + // There is no need for the callee to call SendNotify because fdWriteVec + // uses the host's sendmsg(2) and the host kernel's queue. + return n, false, syserr.FromError(err) +} + +// SendNotify implements transport.ConnectedEndpoint.SendNotify. +func (c *ConnectedEndpoint) SendNotify() {} + +// CloseSend implements transport.ConnectedEndpoint.CloseSend. +func (c *ConnectedEndpoint) CloseSend() { + c.mu.Lock() + defer c.mu.Unlock() + + if err := syscall.Shutdown(c.fd, syscall.SHUT_WR); err != nil { + // A well-formed UDS shutdown can't fail. See + // net/unix/af_unix.c:unix_shutdown. + panic(fmt.Sprintf("failed write shutdown on host socket %+v: %v", c, err)) + } +} + +// CloseNotify implements transport.ConnectedEndpoint.CloseNotify. +func (c *ConnectedEndpoint) CloseNotify() {} + +// Writable implements transport.ConnectedEndpoint.Writable. +func (c *ConnectedEndpoint) Writable() bool { + c.mu.RLock() + defer c.mu.RUnlock() + + return fdnotifier.NonBlockingPoll(int32(c.fd), waiter.EventOut)&waiter.EventOut != 0 +} + +// Passcred implements transport.ConnectedEndpoint.Passcred. +func (c *ConnectedEndpoint) Passcred() bool { + // We don't support credential passing for host sockets. + return false +} + +// GetLocalAddress implements transport.ConnectedEndpoint.GetLocalAddress. +func (c *ConnectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + return tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, nil +} + +// EventUpdate implements transport.ConnectedEndpoint.EventUpdate. +func (c *ConnectedEndpoint) EventUpdate() { + c.mu.RLock() + defer c.mu.RUnlock() + if c.fd != -1 { + fdnotifier.UpdateFD(int32(c.fd)) + } +} + +// Recv implements transport.Receiver.Recv. +func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights int, peek bool) (int64, int64, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { + c.mu.RLock() + defer c.mu.RUnlock() + + var cm unet.ControlMessage + if numRights > 0 { + cm.EnableFDs(int(numRights)) + } + + // N.B. Unix sockets don't have a receive buffer, the send buffer + // serves both purposes. + rl, ml, cl, cTrunc, err := fdReadVec(c.fd, data, []byte(cm), peek, c.sndbuf) + if rl > 0 && err != nil { + // We got some data, so all we need to do on error is return + // the data that we got. Short reads are fine, no need to + // block. + err = nil + } + if err != nil { + return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.FromError(err) + } + + // There is no need for the callee to call RecvNotify because fdReadVec uses + // the host's recvmsg(2) and the host kernel's queue. + + // Trim the control data if we received less than the full amount. + if cl < uint64(len(cm)) { + cm = cm[:cl] + } + + // Avoid extra allocations in the case where there isn't any control data. + if len(cm) == 0 { + return rl, ml, transport.ControlMessages{}, cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, false, nil + } + + fds, err := cm.ExtractFDs() + if err != nil { + return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.FromError(err) + } + + if len(fds) == 0 { + return rl, ml, transport.ControlMessages{}, cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, false, nil + } + return rl, ml, control.NewVFS2(nil, nil, newSCMRights(fds)), cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, false, nil +} + +// RecvNotify implements transport.Receiver.RecvNotify. +func (c *ConnectedEndpoint) RecvNotify() {} + +// CloseRecv implements transport.Receiver.CloseRecv. +func (c *ConnectedEndpoint) CloseRecv() { + c.mu.Lock() + defer c.mu.Unlock() + + if err := syscall.Shutdown(c.fd, syscall.SHUT_RD); err != nil { + // A well-formed UDS shutdown can't fail. See + // net/unix/af_unix.c:unix_shutdown. + panic(fmt.Sprintf("failed read shutdown on host socket %+v: %v", c, err)) + } +} + +// Readable implements transport.Receiver.Readable. +func (c *ConnectedEndpoint) Readable() bool { + c.mu.RLock() + defer c.mu.RUnlock() + + return fdnotifier.NonBlockingPoll(int32(c.fd), waiter.EventIn)&waiter.EventIn != 0 +} + +// SendQueuedSize implements transport.Receiver.SendQueuedSize. +func (c *ConnectedEndpoint) SendQueuedSize() int64 { + // TODO(gvisor.dev/issue/273): SendQueuedSize isn't supported for host + // sockets because we don't allow the sentry to call ioctl(2). + return -1 +} + +// RecvQueuedSize implements transport.Receiver.RecvQueuedSize. +func (c *ConnectedEndpoint) RecvQueuedSize() int64 { + // TODO(gvisor.dev/issue/273): RecvQueuedSize isn't supported for host + // sockets because we don't allow the sentry to call ioctl(2). + return -1 +} + +// SendMaxQueueSize implements transport.Receiver.SendMaxQueueSize. +func (c *ConnectedEndpoint) SendMaxQueueSize() int64 { + return int64(c.sndbuf) +} + +// RecvMaxQueueSize implements transport.Receiver.RecvMaxQueueSize. +func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 { + // N.B. Unix sockets don't use the receive buffer. We'll claim it is + // the same size as the send buffer. + return int64(c.sndbuf) +} + +func (c *ConnectedEndpoint) destroyLocked() { + fdnotifier.RemoveFD(int32(c.fd)) + c.fd = -1 +} + +// Release implements transport.ConnectedEndpoint.Release and +// transport.Receiver.Release. +func (c *ConnectedEndpoint) Release() { + c.ref.DecRefWithDestructor(func() { + c.mu.Lock() + c.destroyLocked() + c.mu.Unlock() + }) +} + +// CloseUnread implements transport.ConnectedEndpoint.CloseUnread. +func (c *ConnectedEndpoint) CloseUnread() {} + +// SCMConnectedEndpoint represents an endpoint backed by a host fd that was +// passed through a gofer Unix socket. It is almost the same as +// ConnectedEndpoint, with the following differences: +// - SCMConnectedEndpoint is not saveable, because the host cannot guarantee +// the same descriptor number across S/R. +// - SCMConnectedEndpoint holds ownership of its fd and is responsible for +// closing it. +type SCMConnectedEndpoint struct { + ConnectedEndpoint +} + +// Release implements transport.ConnectedEndpoint.Release and +// transport.Receiver.Release. +func (e *SCMConnectedEndpoint) Release() { + e.ref.DecRefWithDestructor(func() { + e.mu.Lock() + if err := syscall.Close(e.fd); err != nil { + log.Warningf("Failed to close host fd %d: %v", err) + } + e.destroyLocked() + e.mu.Unlock() + }) +} + +// NewSCMEndpoint creates a new SCMConnectedEndpoint backed by a host fd that +// was passed through a Unix socket. +// +// The caller is responsible for calling Init(). Additionaly, Release needs to +// be called twice because ConnectedEndpoint is both a transport.Receiver and +// transport.ConnectedEndpoint. +func NewSCMEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue, addr string) (*SCMConnectedEndpoint, *syserr.Error) { + e := SCMConnectedEndpoint{ConnectedEndpoint{ + fd: hostFD, + addr: addr, + queue: queue, + }} + + if err := e.init(); err != nil { + return nil, err + } + + // AtomicRefCounters start off with a single reference. We need two. + e.ref.IncRef() + e.ref.EnableLeakCheck("host.SCMConnectedEndpoint") + return &e, nil +} diff --git a/pkg/sentry/fsimpl/host/socket_iovec.go b/pkg/sentry/fsimpl/host/socket_iovec.go new file mode 100644 index 000000000..584c247d2 --- /dev/null +++ b/pkg/sentry/fsimpl/host/socket_iovec.go @@ -0,0 +1,113 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package host + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/syserror" +) + +// maxIovs is the maximum number of iovecs to pass to the host. +var maxIovs = linux.UIO_MAXIOV + +// copyToMulti copies as many bytes from src to dst as possible. +func copyToMulti(dst [][]byte, src []byte) { + for _, d := range dst { + done := copy(d, src) + src = src[done:] + if len(src) == 0 { + break + } + } +} + +// copyFromMulti copies as many bytes from src to dst as possible. +func copyFromMulti(dst []byte, src [][]byte) { + for _, s := range src { + done := copy(dst, s) + dst = dst[done:] + if len(dst) == 0 { + break + } + } +} + +// buildIovec builds an iovec slice from the given []byte slice. +// +// If truncate, truncate bufs > maxlen. Otherwise, immediately return an error. +// +// If length < the total length of bufs, err indicates why, even when returning +// a truncated iovec. +// +// If intermediate != nil, iovecs references intermediate rather than bufs and +// the caller must copy to/from bufs as necessary. +func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovecs []syscall.Iovec, intermediate []byte, err error) { + var iovsRequired int + for _, b := range bufs { + length += int64(len(b)) + if len(b) > 0 { + iovsRequired++ + } + } + + stopLen := length + if length > maxlen { + if truncate { + stopLen = maxlen + err = syserror.EAGAIN + } else { + return 0, nil, nil, syserror.EMSGSIZE + } + } + + if iovsRequired > maxIovs { + // The kernel will reject our call if we pass this many iovs. + // Use a single intermediate buffer instead. + b := make([]byte, stopLen) + + return stopLen, []syscall.Iovec{{ + Base: &b[0], + Len: uint64(stopLen), + }}, b, err + } + + var total int64 + iovecs = make([]syscall.Iovec, 0, iovsRequired) + for i := range bufs { + l := len(bufs[i]) + if l == 0 { + continue + } + + stop := int64(l) + if total+stop > stopLen { + stop = stopLen - total + } + + iovecs = append(iovecs, syscall.Iovec{ + Base: &bufs[i][0], + Len: uint64(stop), + }) + + total += stop + if total >= stopLen { + break + } + } + + return total, iovecs, nil, err +} diff --git a/pkg/sentry/fsimpl/host/socket_unsafe.go b/pkg/sentry/fsimpl/host/socket_unsafe.go new file mode 100644 index 000000000..35ded24bc --- /dev/null +++ b/pkg/sentry/fsimpl/host/socket_unsafe.go @@ -0,0 +1,101 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package host + +import ( + "syscall" + "unsafe" +) + +// fdReadVec receives from fd to bufs. +// +// If the total length of bufs is > maxlen, fdReadVec will do a partial read +// and err will indicate why the message was truncated. +func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int64) (readLen int64, msgLen int64, controlLen uint64, controlTrunc bool, err error) { + flags := uintptr(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC) + if peek { + flags |= syscall.MSG_PEEK + } + + // Always truncate the receive buffer. All socket types will truncate + // received messages. + length, iovecs, intermediate, err := buildIovec(bufs, maxlen, true) + if err != nil && len(iovecs) == 0 { + // No partial write to do, return error immediately. + return 0, 0, 0, false, err + } + + var msg syscall.Msghdr + if len(control) != 0 { + msg.Control = &control[0] + msg.Controllen = uint64(len(control)) + } + + if len(iovecs) != 0 { + msg.Iov = &iovecs[0] + msg.Iovlen = uint64(len(iovecs)) + } + + rawN, _, e := syscall.RawSyscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), flags) + if e != 0 { + // N.B. prioritize the syscall error over the buildIovec error. + return 0, 0, 0, false, e + } + n := int64(rawN) + + // Copy data back to bufs. + if intermediate != nil { + copyToMulti(bufs, intermediate) + } + + controlTrunc = msg.Flags&syscall.MSG_CTRUNC == syscall.MSG_CTRUNC + + if n > length { + return length, n, msg.Controllen, controlTrunc, err + } + + return n, n, msg.Controllen, controlTrunc, err +} + +// fdWriteVec sends from bufs to fd. +// +// If the total length of bufs is > maxlen && truncate, fdWriteVec will do a +// partial write and err will indicate why the message was truncated. +func fdWriteVec(fd int, bufs [][]byte, maxlen int64, truncate bool) (int64, int64, error) { + length, iovecs, intermediate, err := buildIovec(bufs, maxlen, truncate) + if err != nil && len(iovecs) == 0 { + // No partial write to do, return error immediately. + return 0, length, err + } + + // Copy data to intermediate buf. + if intermediate != nil { + copyFromMulti(intermediate, bufs) + } + + var msg syscall.Msghdr + if len(iovecs) > 0 { + msg.Iov = &iovecs[0] + msg.Iovlen = uint64(len(iovecs)) + } + + n, _, e := syscall.RawSyscall(syscall.SYS_SENDMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), syscall.MSG_DONTWAIT|syscall.MSG_NOSIGNAL) + if e != 0 { + // N.B. prioritize the syscall error over the buildIovec error. + return 0, length, e + } + + return int64(n), length, err +} diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index 9e8d80414..4a12ae245 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -766,14 +766,17 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error } // BoundEndpointAt implements FilesystemImpl.BoundEndpointAt. -func (fs *Filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath) (transport.BoundEndpoint, error) { +func (fs *Filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) { fs.mu.RLock() - _, _, err := fs.walkExistingLocked(ctx, rp) + _, inode, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() fs.processDeferredDecRefs() if err != nil { return nil, err } + if err := inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite); err != nil { + return nil, err + } return nil, syserror.ECONNREFUSED } diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go index 5ce50625b..271134af8 100644 --- a/pkg/sentry/fsimpl/sockfs/sockfs.go +++ b/pkg/sentry/fsimpl/sockfs/sockfs.go @@ -73,26 +73,14 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentr return nil, syserror.ENXIO } -// InitSocket initializes a socket FileDescription, with a corresponding -// Dentry in mnt. -// -// fd should be the FileDescription associated with socketImpl, i.e. its first -// field. mnt should be the global socket mount, Kernel.socketMount. -func InitSocket(socketImpl vfs.FileDescriptionImpl, fd *vfs.FileDescription, mnt *vfs.Mount, creds *auth.Credentials) error { - fsimpl := mnt.Filesystem().Impl() - fs := fsimpl.(*kernfs.Filesystem) - +// NewDentry constructs and returns a sockfs dentry. +func NewDentry(creds *auth.Credentials, ino uint64) *vfs.Dentry { // File mode matches net/socket.c:sock_alloc. filemode := linux.FileMode(linux.S_IFSOCK | 0600) i := &inode{} - i.Init(creds, fs.NextIno(), filemode) + i.Init(creds, ino, filemode) d := &kernfs.Dentry{} d.Init(i) - - opts := &vfs.FileDescriptionOptions{UseDentryMetadata: true} - if err := fd.Init(socketImpl, linux.O_RDWR, mnt, d.VFSDentry(), opts); err != nil { - return err - } - return nil + return d.VFSDentry() } diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index 5b62f9ebb..36ffcb592 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -697,13 +697,16 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error } // BoundEndpointAt implements FilesystemImpl.BoundEndpointAt. -func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath) (transport.BoundEndpoint, error) { +func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) { fs.mu.RLock() defer fs.mu.RUnlock() d, err := resolveLocked(rp) if err != nil { return nil, err } + if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil { + return nil, err + } switch impl := d.inode.impl.(type) { case *socketFile: return impl.ep, nil diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go index 7d25e98f7..79766cafe 100644 --- a/pkg/sentry/kernel/task_signals.go +++ b/pkg/sentry/kernel/task_signals.go @@ -716,7 +716,7 @@ func (tg *ThreadGroup) SetSignalAct(sig linux.Signal, actptr *arch.SignalAct) (a func (t *Task) CopyOutSignalAct(addr usermem.Addr, s *arch.SignalAct) error { n := t.Arch().NewSignalAct() n.SerializeFrom(s) - _, err := t.CopyOut(addr, n) + _, err := n.CopyOut(t, addr) return err } @@ -725,7 +725,7 @@ func (t *Task) CopyOutSignalAct(addr usermem.Addr, s *arch.SignalAct) error { func (t *Task) CopyInSignalAct(addr usermem.Addr) (arch.SignalAct, error) { n := t.Arch().NewSignalAct() var s arch.SignalAct - if _, err := t.CopyIn(addr, n); err != nil { + if _, err := n.CopyIn(t, addr); err != nil { return s, err } n.DeserializeTo(&s) @@ -737,7 +737,7 @@ func (t *Task) CopyInSignalAct(addr usermem.Addr) (arch.SignalAct, error) { func (t *Task) CopyOutSignalStack(addr usermem.Addr, s *arch.SignalStack) error { n := t.Arch().NewSignalStack() n.SerializeFrom(s) - _, err := t.CopyOut(addr, n) + _, err := n.CopyOut(t, addr) return err } @@ -746,7 +746,7 @@ func (t *Task) CopyOutSignalStack(addr usermem.Addr, s *arch.SignalStack) error func (t *Task) CopyInSignalStack(addr usermem.Addr) (arch.SignalStack, error) { n := t.Arch().NewSignalStack() var s arch.SignalStack - if _, err := t.CopyIn(addr, n); err != nil { + if _, err := n.CopyIn(t, addr); err != nil { return s, err } n.DeserializeTo(&s) diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index a5035bb7f..8485fb4b6 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -104,6 +104,9 @@ func (ts *TaskSet) NewTask(cfg *TaskConfig) (*Task, error) { cfg.TaskContext.release() cfg.FSContext.DecRef() cfg.FDTable.DecRef() + if cfg.MountNamespaceVFS2 != nil { + cfg.MountNamespaceVFS2.DecRef() + } return nil, err } return t, nil diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go index 716198712..29d457a7e 100644 --- a/pkg/sentry/platform/kvm/kvm_arm64.go +++ b/pkg/sentry/platform/kvm/kvm_arm64.go @@ -17,8 +17,7 @@ package kvm import ( - "syscall" - + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" ) @@ -37,7 +36,7 @@ type userFpsimdState struct { } type userRegs struct { - Regs syscall.PtraceRegs + Regs arch.Registers sp_el1 uint64 elr_el1 uint64 spsr [KVM_NR_SPSR]uint64 diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go index c42752d50..6c8f4fa28 100644 --- a/pkg/sentry/platform/kvm/kvm_test.go +++ b/pkg/sentry/platform/kvm/kvm_test.go @@ -117,10 +117,10 @@ func TestKernelFloatingPoint(t *testing.T) { }) } -func applicationTest(t testHarness, useHostMappings bool, target func(), fn func(*vCPU, *syscall.PtraceRegs, *pagetables.PageTables) bool) { +func applicationTest(t testHarness, useHostMappings bool, target func(), fn func(*vCPU, *arch.Registers, *pagetables.PageTables) bool) { // Initialize registers & page tables. var ( - regs syscall.PtraceRegs + regs arch.Registers pt *pagetables.PageTables ) testutil.SetTestTarget(®s, target) @@ -154,7 +154,7 @@ func applicationTest(t testHarness, useHostMappings bool, target func(), fn func } func TestApplicationSyscall(t *testing.T) { - applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, @@ -168,7 +168,7 @@ func TestApplicationSyscall(t *testing.T) { } return false }) - applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, @@ -184,7 +184,7 @@ func TestApplicationSyscall(t *testing.T) { } func TestApplicationFault(t *testing.T) { - applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { testutil.SetTouchTarget(regs, nil) // Cause fault. var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ @@ -199,7 +199,7 @@ func TestApplicationFault(t *testing.T) { } return false }) - applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { testutil.SetTouchTarget(regs, nil) // Cause fault. var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ @@ -216,7 +216,7 @@ func TestApplicationFault(t *testing.T) { } func TestRegistersSyscall(t *testing.T) { - applicationTest(t, true, testutil.TwiddleRegsSyscall, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + applicationTest(t, true, testutil.TwiddleRegsSyscall, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { testutil.SetTestRegs(regs) // Fill values for all registers. for { var si arch.SignalInfo @@ -239,7 +239,7 @@ func TestRegistersSyscall(t *testing.T) { } func TestRegistersFault(t *testing.T) { - applicationTest(t, true, testutil.TwiddleRegsFault, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + applicationTest(t, true, testutil.TwiddleRegsFault, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { testutil.SetTestRegs(regs) // Fill values for all registers. for { var si arch.SignalInfo @@ -263,7 +263,7 @@ func TestRegistersFault(t *testing.T) { } func TestSegments(t *testing.T) { - applicationTest(t, true, testutil.TwiddleSegments, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + applicationTest(t, true, testutil.TwiddleSegments, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { testutil.SetTestSegments(regs) for { var si arch.SignalInfo @@ -287,7 +287,7 @@ func TestSegments(t *testing.T) { } func TestBounce(t *testing.T) { - applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { go func() { time.Sleep(time.Millisecond) c.BounceToKernel() @@ -302,7 +302,7 @@ func TestBounce(t *testing.T) { } return false }) - applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { go func() { time.Sleep(time.Millisecond) c.BounceToKernel() @@ -321,7 +321,7 @@ func TestBounce(t *testing.T) { } func TestBounceStress(t *testing.T) { - applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { randomSleep := func() { // O(hundreds of microseconds) is appropriate to ensure // different overlaps and different schedules. @@ -357,7 +357,7 @@ func TestBounceStress(t *testing.T) { func TestInvalidate(t *testing.T) { var data uintptr // Used below. - applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { testutil.SetTouchTarget(regs, &data) // Read legitimate value. for { var si arch.SignalInfo @@ -398,7 +398,7 @@ func IsFault(err error, si *arch.SignalInfo) bool { } func TestEmptyAddressSpace(t *testing.T) { - applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, @@ -412,7 +412,7 @@ func TestEmptyAddressSpace(t *testing.T) { } return false }) - applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, @@ -471,7 +471,7 @@ func BenchmarkApplicationSyscall(b *testing.B) { i int // Iteration includes machine.Get() / machine.Put(). a int // Count for ErrContextInterrupt. ) - applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, @@ -493,7 +493,7 @@ func BenchmarkApplicationSyscall(b *testing.B) { func BenchmarkKernelSyscall(b *testing.B) { // Note that the target passed here is irrelevant, we never execute SwitchToUser. - applicationTest(b, true, testutil.Getpid, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + applicationTest(b, true, testutil.Getpid, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { // iteration does not include machine.Get() / machine.Put(). for i := 0; i < b.N; i++ { testutil.Getpid() @@ -508,7 +508,7 @@ func BenchmarkWorldSwitchToUserRoundtrip(b *testing.B) { i int a int ) - applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { + applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { var si arch.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, diff --git a/pkg/sentry/platform/kvm/testutil/BUILD b/pkg/sentry/platform/kvm/testutil/BUILD index f7605df8a..f7feb8683 100644 --- a/pkg/sentry/platform/kvm/testutil/BUILD +++ b/pkg/sentry/platform/kvm/testutil/BUILD @@ -13,4 +13,5 @@ go_library( "testutil_arm64.s", ], visibility = ["//pkg/sentry/platform/kvm:__pkg__"], + deps = ["//pkg/sentry/arch"], ) diff --git a/pkg/sentry/platform/kvm/testutil/testutil_amd64.go b/pkg/sentry/platform/kvm/testutil/testutil_amd64.go index 4c108abbf..8048eedec 100644 --- a/pkg/sentry/platform/kvm/testutil/testutil_amd64.go +++ b/pkg/sentry/platform/kvm/testutil/testutil_amd64.go @@ -18,19 +18,20 @@ package testutil import ( "reflect" - "syscall" + + "gvisor.dev/gvisor/pkg/sentry/arch" ) // TwiddleSegments reads segments into known registers. func TwiddleSegments() // SetTestTarget sets the rip appropriately. -func SetTestTarget(regs *syscall.PtraceRegs, fn func()) { +func SetTestTarget(regs *arch.Registers, fn func()) { regs.Rip = uint64(reflect.ValueOf(fn).Pointer()) } // SetTouchTarget sets rax appropriately. -func SetTouchTarget(regs *syscall.PtraceRegs, target *uintptr) { +func SetTouchTarget(regs *arch.Registers, target *uintptr) { if target != nil { regs.Rax = uint64(reflect.ValueOf(target).Pointer()) } else { @@ -39,12 +40,12 @@ func SetTouchTarget(regs *syscall.PtraceRegs, target *uintptr) { } // RewindSyscall rewinds a syscall RIP. -func RewindSyscall(regs *syscall.PtraceRegs) { +func RewindSyscall(regs *arch.Registers) { regs.Rip -= 2 } // SetTestRegs initializes registers to known values. -func SetTestRegs(regs *syscall.PtraceRegs) { +func SetTestRegs(regs *arch.Registers) { regs.R15 = 0x15 regs.R14 = 0x14 regs.R13 = 0x13 @@ -64,7 +65,7 @@ func SetTestRegs(regs *syscall.PtraceRegs) { } // CheckTestRegs checks that registers were twiddled per TwiddleRegs. -func CheckTestRegs(regs *syscall.PtraceRegs, full bool) (err error) { +func CheckTestRegs(regs *arch.Registers, full bool) (err error) { if need := ^uint64(0x15); regs.R15 != need { err = addRegisterMismatch(err, "R15", regs.R15, need) } @@ -121,13 +122,13 @@ var fsData uint64 = 0x55 var gsData uint64 = 0x85 // SetTestSegments initializes segments to known values. -func SetTestSegments(regs *syscall.PtraceRegs) { +func SetTestSegments(regs *arch.Registers) { regs.Fs_base = uint64(reflect.ValueOf(&fsData).Pointer()) regs.Gs_base = uint64(reflect.ValueOf(&gsData).Pointer()) } // CheckTestSegments checks that registers were twiddled per TwiddleSegments. -func CheckTestSegments(regs *syscall.PtraceRegs) (err error) { +func CheckTestSegments(regs *arch.Registers) (err error) { if regs.Rax != fsData { err = addRegisterMismatch(err, "Rax", regs.Rax, fsData) } diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go index 40b2e4acc..ca902c8c1 100644 --- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go +++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go @@ -19,16 +19,17 @@ package testutil import ( "fmt" "reflect" - "syscall" + + "gvisor.dev/gvisor/pkg/sentry/arch" ) // SetTestTarget sets the rip appropriately. -func SetTestTarget(regs *syscall.PtraceRegs, fn func()) { +func SetTestTarget(regs *arch.Registers, fn func()) { regs.Pc = uint64(reflect.ValueOf(fn).Pointer()) } // SetTouchTarget sets rax appropriately. -func SetTouchTarget(regs *syscall.PtraceRegs, target *uintptr) { +func SetTouchTarget(regs *arch.Registers, target *uintptr) { if target != nil { regs.Regs[8] = uint64(reflect.ValueOf(target).Pointer()) } else { @@ -37,19 +38,19 @@ func SetTouchTarget(regs *syscall.PtraceRegs, target *uintptr) { } // RewindSyscall rewinds a syscall RIP. -func RewindSyscall(regs *syscall.PtraceRegs) { +func RewindSyscall(regs *arch.Registers) { regs.Pc -= 4 } // SetTestRegs initializes registers to known values. -func SetTestRegs(regs *syscall.PtraceRegs) { +func SetTestRegs(regs *arch.Registers) { for i := 0; i <= 30; i++ { regs.Regs[i] = uint64(i) + 1 } } // CheckTestRegs checks that registers were twiddled per TwiddleRegs. -func CheckTestRegs(regs *syscall.PtraceRegs, full bool) (err error) { +func CheckTestRegs(regs *arch.Registers, full bool) (err error) { for i := 0; i <= 30; i++ { if need := ^uint64(i + 1); regs.Regs[i] != need { err = addRegisterMismatch(err, fmt.Sprintf("R%d", i), regs.Regs[i], need) diff --git a/pkg/sentry/platform/ptrace/ptrace_amd64.go b/pkg/sentry/platform/ptrace/ptrace_amd64.go index 24fc5dc62..3b9a870a5 100644 --- a/pkg/sentry/platform/ptrace/ptrace_amd64.go +++ b/pkg/sentry/platform/ptrace/ptrace_amd64.go @@ -15,9 +15,8 @@ package ptrace import ( - "syscall" - "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" ) // fpRegSet returns the GETREGSET/SETREGSET register set type to be used. @@ -28,12 +27,12 @@ func fpRegSet(useXsave bool) uintptr { return linux.NT_PRFPREG } -func stackPointer(r *syscall.PtraceRegs) uintptr { +func stackPointer(r *arch.Registers) uintptr { return uintptr(r.Rsp) } // x86 use the fs_base register to store the TLS pointer which can be -// get/set in "func (t *thread) get/setRegs(regs *syscall.PtraceRegs)". +// get/set in "func (t *thread) get/setRegs(regs *arch.Registers)". // So both of the get/setTLS() operations are noop here. // getTLS gets the thread local storage register. diff --git a/pkg/sentry/platform/ptrace/ptrace_arm64.go b/pkg/sentry/platform/ptrace/ptrace_arm64.go index 4db28c534..5c869926a 100644 --- a/pkg/sentry/platform/ptrace/ptrace_arm64.go +++ b/pkg/sentry/platform/ptrace/ptrace_arm64.go @@ -15,9 +15,8 @@ package ptrace import ( - "syscall" - "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" ) // fpRegSet returns the GETREGSET/SETREGSET register set type to be used. @@ -25,6 +24,6 @@ func fpRegSet(_ bool) uintptr { return linux.NT_PRFPREG } -func stackPointer(r *syscall.PtraceRegs) uintptr { +func stackPointer(r *arch.Registers) uintptr { return uintptr(r.Sp) } diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go index 6c0ed7b3e..8b72d24e8 100644 --- a/pkg/sentry/platform/ptrace/ptrace_unsafe.go +++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go @@ -24,7 +24,7 @@ import ( ) // getRegs gets the general purpose register set. -func (t *thread) getRegs(regs *syscall.PtraceRegs) error { +func (t *thread) getRegs(regs *arch.Registers) error { iovec := syscall.Iovec{ Base: (*byte)(unsafe.Pointer(regs)), Len: uint64(unsafe.Sizeof(*regs)), @@ -43,7 +43,7 @@ func (t *thread) getRegs(regs *syscall.PtraceRegs) error { } // setRegs sets the general purpose register set. -func (t *thread) setRegs(regs *syscall.PtraceRegs) error { +func (t *thread) setRegs(regs *arch.Registers) error { iovec := syscall.Iovec{ Base: (*byte)(unsafe.Pointer(regs)), Len: uint64(unsafe.Sizeof(*regs)), diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 773ddb1ed..2389423b0 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -63,7 +63,7 @@ type thread struct { // initRegs are the initial registers for the first thread. // // These are used for the register set for system calls. - initRegs syscall.PtraceRegs + initRegs arch.Registers } // threadPool is a collection of threads. @@ -317,7 +317,7 @@ const ( ) func (t *thread) dumpAndPanic(message string) { - var regs syscall.PtraceRegs + var regs arch.Registers message += "\n" if err := t.getRegs(®s); err == nil { message += dumpRegs(®s) @@ -423,7 +423,7 @@ func (t *thread) init() { // This is _not_ for use by application system calls, rather it is for use when // a system call must be injected into the remote context (e.g. mmap, munmap). // Note that clones are handled separately. -func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) { +func (t *thread) syscall(regs *arch.Registers) (uintptr, error) { // Set registers. if err := t.setRegs(regs); err != nil { panic(fmt.Sprintf("ptrace set regs failed: %v", err)) @@ -461,7 +461,7 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) { // syscallIgnoreInterrupt ignores interrupts on the system call thread and // restarts the syscall if the kernel indicates that should happen. func (t *thread) syscallIgnoreInterrupt( - initRegs *syscall.PtraceRegs, + initRegs *arch.Registers, sysno uintptr, args ...arch.SyscallArgument) (uintptr, error) { for { diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go index cd74945e7..84b699f0d 100644 --- a/pkg/sentry/platform/ptrace/subprocess_amd64.go +++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go @@ -41,7 +41,7 @@ const ( // resetSysemuRegs sets up emulation registers. // // This should be called prior to calling sysemu. -func (t *thread) resetSysemuRegs(regs *syscall.PtraceRegs) { +func (t *thread) resetSysemuRegs(regs *arch.Registers) { regs.Cs = t.initRegs.Cs regs.Ss = t.initRegs.Ss regs.Ds = t.initRegs.Ds @@ -53,7 +53,7 @@ func (t *thread) resetSysemuRegs(regs *syscall.PtraceRegs) { // createSyscallRegs sets up syscall registers. // // This should be called to generate registers for a system call. -func createSyscallRegs(initRegs *syscall.PtraceRegs, sysno uintptr, args ...arch.SyscallArgument) syscall.PtraceRegs { +func createSyscallRegs(initRegs *arch.Registers, sysno uintptr, args ...arch.SyscallArgument) arch.Registers { // Copy initial registers. regs := *initRegs @@ -82,18 +82,18 @@ func createSyscallRegs(initRegs *syscall.PtraceRegs, sysno uintptr, args ...arch } // isSingleStepping determines if the registers indicate single-stepping. -func isSingleStepping(regs *syscall.PtraceRegs) bool { +func isSingleStepping(regs *arch.Registers) bool { return (regs.Eflags & arch.X86TrapFlag) != 0 } // updateSyscallRegs updates registers after finishing sysemu. -func updateSyscallRegs(regs *syscall.PtraceRegs) { +func updateSyscallRegs(regs *arch.Registers) { // Ptrace puts -ENOSYS in rax on syscall-enter-stop. regs.Rax = regs.Orig_rax } // syscallReturnValue extracts a sensible return from registers. -func syscallReturnValue(regs *syscall.PtraceRegs) (uintptr, error) { +func syscallReturnValue(regs *arch.Registers) (uintptr, error) { rval := int64(regs.Rax) if rval < 0 { return 0, syscall.Errno(-rval) @@ -101,7 +101,7 @@ func syscallReturnValue(regs *syscall.PtraceRegs) (uintptr, error) { return uintptr(rval), nil } -func dumpRegs(regs *syscall.PtraceRegs) string { +func dumpRegs(regs *arch.Registers) string { var m strings.Builder fmt.Fprintf(&m, "Registers:\n") @@ -143,7 +143,7 @@ func (t *thread) adjustInitRegsRip() { } // Pass the expected PPID to the child via R15 when creating stub process. -func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) { +func initChildProcessPPID(initregs *arch.Registers, ppid int32) { initregs.R15 = uint64(ppid) // Rbx has to be set to 1 when creating stub process. initregs.Rbx = 1 @@ -156,7 +156,7 @@ func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) { // // Note that this should only be called after verifying that the signalInfo has // been generated by the kernel. -func patchSignalInfo(regs *syscall.PtraceRegs, signalInfo *arch.SignalInfo) { +func patchSignalInfo(regs *arch.Registers, signalInfo *arch.SignalInfo) { if linux.Signal(signalInfo.Signo) == linux.SIGSYS { signalInfo.Signo = int32(linux.SIGSEGV) diff --git a/pkg/sentry/platform/ptrace/subprocess_arm64.go b/pkg/sentry/platform/ptrace/subprocess_arm64.go index 7f5c393f0..bd618fae8 100644 --- a/pkg/sentry/platform/ptrace/subprocess_arm64.go +++ b/pkg/sentry/platform/ptrace/subprocess_arm64.go @@ -41,13 +41,13 @@ const ( // resetSysemuRegs sets up emulation registers. // // This should be called prior to calling sysemu. -func (t *thread) resetSysemuRegs(regs *syscall.PtraceRegs) { +func (t *thread) resetSysemuRegs(regs *arch.Registers) { } // createSyscallRegs sets up syscall registers. // // This should be called to generate registers for a system call. -func createSyscallRegs(initRegs *syscall.PtraceRegs, sysno uintptr, args ...arch.SyscallArgument) syscall.PtraceRegs { +func createSyscallRegs(initRegs *arch.Registers, sysno uintptr, args ...arch.SyscallArgument) arch.Registers { // Copy initial registers (Pc, Sp, etc.). regs := *initRegs @@ -78,7 +78,7 @@ func createSyscallRegs(initRegs *syscall.PtraceRegs, sysno uintptr, args ...arch } // isSingleStepping determines if the registers indicate single-stepping. -func isSingleStepping(regs *syscall.PtraceRegs) bool { +func isSingleStepping(regs *arch.Registers) bool { // Refer to the ARM SDM D2.12.3: software step state machine // return (regs.Pstate.SS == 1) && (MDSCR_EL1.SS == 1). // @@ -89,13 +89,13 @@ func isSingleStepping(regs *syscall.PtraceRegs) bool { } // updateSyscallRegs updates registers after finishing sysemu. -func updateSyscallRegs(regs *syscall.PtraceRegs) { +func updateSyscallRegs(regs *arch.Registers) { // No special work is necessary. return } // syscallReturnValue extracts a sensible return from registers. -func syscallReturnValue(regs *syscall.PtraceRegs) (uintptr, error) { +func syscallReturnValue(regs *arch.Registers) (uintptr, error) { rval := int64(regs.Regs[0]) if rval < 0 { return 0, syscall.Errno(-rval) @@ -103,7 +103,7 @@ func syscallReturnValue(regs *syscall.PtraceRegs) (uintptr, error) { return uintptr(rval), nil } -func dumpRegs(regs *syscall.PtraceRegs) string { +func dumpRegs(regs *arch.Registers) string { var m strings.Builder fmt.Fprintf(&m, "Registers:\n") @@ -125,7 +125,7 @@ func (t *thread) adjustInitRegsRip() { } // Pass the expected PPID to the child via X7 when creating stub process -func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) { +func initChildProcessPPID(initregs *arch.Registers, ppid int32) { initregs.Regs[7] = uint64(ppid) // R9 has to be set to 1 when creating stub process. initregs.Regs[9] = 1 @@ -138,7 +138,7 @@ func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) { // // Note that this should only be called after verifying that the signalInfo has // been generated by the kernel. -func patchSignalInfo(regs *syscall.PtraceRegs, signalInfo *arch.SignalInfo) { +func patchSignalInfo(regs *arch.Registers, signalInfo *arch.SignalInfo) { if linux.Signal(signalInfo.Signo) == linux.SIGSYS { signalInfo.Signo = int32(linux.SIGSEGV) diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD index b69520030..679b287c3 100644 --- a/pkg/sentry/platform/ring0/BUILD +++ b/pkg/sentry/platform/ring0/BUILD @@ -79,6 +79,7 @@ go_library( deps = [ "//pkg/cpuid", "//pkg/safecopy", + "//pkg/sentry/arch", "//pkg/sentry/platform/ring0/pagetables", "//pkg/usermem", ], diff --git a/pkg/sentry/platform/ring0/defs.go b/pkg/sentry/platform/ring0/defs.go index 86fd5ed58..e6daf24df 100644 --- a/pkg/sentry/platform/ring0/defs.go +++ b/pkg/sentry/platform/ring0/defs.go @@ -15,8 +15,7 @@ package ring0 import ( - "syscall" - + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" ) @@ -72,7 +71,7 @@ type CPU struct { // registers is a set of registers; these may be used on kernel system // calls and exceptions via the Registers function. - registers syscall.PtraceRegs + registers arch.Registers // hooks are kernel hooks. hooks Hooks @@ -83,14 +82,14 @@ type CPU struct { // This is explicitly safe to call during KernelException and KernelSyscall. // //go:nosplit -func (c *CPU) Registers() *syscall.PtraceRegs { +func (c *CPU) Registers() *arch.Registers { return &c.registers } // SwitchOpts are passed to the Switch function. type SwitchOpts struct { // Registers are the user register state. - Registers *syscall.PtraceRegs + Registers *arch.Registers // FloatingPointState is a byte pointer where floating point state is // saved and restored. diff --git a/pkg/sentry/platform/ring0/entry_amd64.go b/pkg/sentry/platform/ring0/entry_amd64.go index a5ce67885..7fa43c2f5 100644 --- a/pkg/sentry/platform/ring0/entry_amd64.go +++ b/pkg/sentry/platform/ring0/entry_amd64.go @@ -17,7 +17,7 @@ package ring0 import ( - "syscall" + "gvisor.dev/gvisor/pkg/sentry/arch" ) // This is an assembly function. @@ -41,7 +41,7 @@ func swapgs() // The return code is the vector that interrupted execution. // // See stubs.go for a note regarding the frame size of this function. -func sysret(*CPU, *syscall.PtraceRegs) Vector +func sysret(*CPU, *arch.Registers) Vector // "iret is the cadillac of CPL switching." // @@ -50,7 +50,7 @@ func sysret(*CPU, *syscall.PtraceRegs) Vector // iret is nearly identical to sysret, except an iret is used to fully restore // all user state. This must be called in cases where all registers need to be // restored. -func iret(*CPU, *syscall.PtraceRegs) Vector +func iret(*CPU, *arch.Registers) Vector // exception is the generic exception entry. // diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD index 4cae10459..549f3d228 100644 --- a/pkg/sentry/platform/ring0/gen_offsets/BUILD +++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD @@ -27,6 +27,7 @@ go_binary( visibility = ["//pkg/sentry/platform/ring0:__pkg__"], deps = [ "//pkg/cpuid", + "//pkg/sentry/arch", "//pkg/sentry/platform/ring0/pagetables", "//pkg/usermem", ], diff --git a/pkg/sentry/platform/ring0/offsets_amd64.go b/pkg/sentry/platform/ring0/offsets_amd64.go index 85cc3fdad..b8ab120a0 100644 --- a/pkg/sentry/platform/ring0/offsets_amd64.go +++ b/pkg/sentry/platform/ring0/offsets_amd64.go @@ -20,7 +20,8 @@ import ( "fmt" "io" "reflect" - "syscall" + + "gvisor.dev/gvisor/pkg/sentry/arch" ) // Emit prints architecture-specific offsets. @@ -64,7 +65,7 @@ func Emit(w io.Writer) { fmt.Fprintf(w, "#define SyscallInt80 0x%02x\n", SyscallInt80) fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall) - p := &syscall.PtraceRegs{} + p := &arch.Registers{} fmt.Fprintf(w, "\n// Ptrace registers.\n") fmt.Fprintf(w, "#define PTRACE_R15 0x%02x\n", reflect.ValueOf(&p.R15).Pointer()-reflect.ValueOf(p).Pointer()) fmt.Fprintf(w, "#define PTRACE_R14 0x%02x\n", reflect.ValueOf(&p.R14).Pointer()-reflect.ValueOf(p).Pointer()) diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go index 057fb5c69..f3de962f0 100644 --- a/pkg/sentry/platform/ring0/offsets_arm64.go +++ b/pkg/sentry/platform/ring0/offsets_arm64.go @@ -20,7 +20,8 @@ import ( "fmt" "io" "reflect" - "syscall" + + "gvisor.dev/gvisor/pkg/sentry/arch" ) // Emit prints architecture-specific offsets. @@ -87,7 +88,7 @@ func Emit(w io.Writer) { fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall) fmt.Fprintf(w, "#define VirtualizationException 0x%02x\n", VirtualizationException) - p := &syscall.PtraceRegs{} + p := &arch.Registers{} fmt.Fprintf(w, "\n// Ptrace registers.\n") fmt.Fprintf(w, "#define PTRACE_R0 0x%02x\n", reflect.ValueOf(&p.Regs[0]).Pointer()-reflect.ValueOf(p).Pointer()) fmt.Fprintf(w, "#define PTRACE_R1 0x%02x\n", reflect.ValueOf(&p.Regs[1]).Pointer()-reflect.ValueOf(p).Pointer()) diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 55c0f04f3..ff1cfd8f6 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -121,13 +121,12 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa tcpHeader = header.TCP(pkt.TransportHeader) } else { // The TCP header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) - if !ok { + if len(pkt.Data.First()) < header.TCPMinimumSize { // There's no valid TCP header here, so we hotdrop the // packet. return false, true } - tcpHeader = header.TCP(hdr) + tcpHeader = header.TCP(pkt.Data.First()) } // Check whether the source and destination ports are within the diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 04d03d494..3359418c1 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -120,13 +120,12 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa udpHeader = header.UDP(pkt.TransportHeader) } else { // The UDP header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok { + if len(pkt.Data.First()) < header.UDPMinimumSize { // There's no valid UDP header here, so we hotdrop the // packet. return false, true } - udpHeader = header.UDP(hdr) + udpHeader = header.UDP(pkt.Data.First()) } // Check whether the source and destination ports are within the diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index de2cc4bdf..941a91097 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -21,6 +21,7 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/time", diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 7c64f30fa..5b29e9d7f 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -103,7 +103,7 @@ func (s *socketOpsCommon) DecRef() { } // Release implemements fs.FileOperations.Release. -func (s *SocketOperations) Release() { +func (s *socketOpsCommon) Release() { // Release only decrements a reference on s because s may be referenced in // the abstract socket namespace. s.DecRef() @@ -323,7 +323,10 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { // Create the socket. // - // TODO(gvisor.dev/issue/2324): Correctly set file permissions. + // Note that the file permissions here are not set correctly (see + // gvisor.dev/issue/2324). There is no convenient way to get permissions + // on the socket referred to by s, so we will leave this discrepancy + // unresolved until VFS2 replaces this code. childDir, err := d.Bind(t, t.FSContext().RootDirectory(), name, bep, fs.FilePermissions{User: fs.PermMask{Read: true}}) if err != nil { return syserr.ErrPortInUse @@ -373,7 +376,7 @@ func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint, Path: p, FollowFinalSymlink: true, } - ep, e := t.Kernel().VFS().BoundEndpointAt(t, t.Credentials(), &pop) + ep, e := t.Kernel().VFS().BoundEndpointAt(t, t.Credentials(), &pop, &vfs.BoundEndpointOptions{path}) root.DecRef() if relPath { start.DecRef() diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index 3e54d49c4..23db93f33 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket/control" @@ -42,30 +43,44 @@ type SocketVFS2 struct { socketOpsCommon } -// NewVFS2File creates and returns a new vfs.FileDescription for a unix socket. -func NewVFS2File(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) { - sock := NewFDImpl(ep, stype) - vfsfd := &sock.vfsfd - if err := sockfs.InitSocket(sock, vfsfd, t.Kernel().SocketMount(), t.Credentials()); err != nil { +// NewSockfsFile creates a new socket file in the global sockfs mount and +// returns a corresponding file description. +func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) { + mnt := t.Kernel().SocketMount() + fs := mnt.Filesystem().Impl().(*kernfs.Filesystem) + d := sockfs.NewDentry(t.Credentials(), fs.NextIno()) + + fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d) + if err != nil { return nil, syserr.FromError(err) } - return vfsfd, nil + return fd, nil } -// NewFDImpl creates and returns a new SocketVFS2. -func NewFDImpl(ep transport.Endpoint, stype linux.SockType) *SocketVFS2 { +// NewFileDescription creates and returns a socket file description +// corresponding to the given mount and dentry. +func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint32, mnt *vfs.Mount, d *vfs.Dentry) (*vfs.FileDescription, error) { // You can create AF_UNIX, SOCK_RAW sockets. They're the same as // SOCK_DGRAM and don't require CAP_NET_RAW. if stype == linux.SOCK_RAW { stype = linux.SOCK_DGRAM } - return &SocketVFS2{ + sock := &SocketVFS2{ socketOpsCommon: socketOpsCommon{ ep: ep, stype: stype, }, } + vfsfd := &sock.vfsfd + if err := vfsfd.Init(sock, flags, mnt, d, &vfs.FileDescriptionOptions{ + DenyPRead: true, + DenyPWrite: true, + UseDentryMetadata: true, + }); err != nil { + return nil, err + } + return vfsfd, nil } // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by @@ -112,8 +127,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block } } - // We expect this to be a FileDescription here. - ns, err := NewVFS2File(t, ep, s.stype) + ns, err := NewSockfsFile(t, ep, s.stype) if err != nil { return 0, nil, 0, err } @@ -183,11 +197,13 @@ func (s *SocketVFS2) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { Start: start, Path: path, } - err := t.Kernel().VFS().MknodAt(t, t.Credentials(), &pop, &vfs.MknodOptions{ - // TODO(gvisor.dev/issue/2324): The file permissions should be taken - // from s and t.FSContext().Umask() (see net/unix/af_unix.c:unix_bind), - // but VFS1 just always uses 0400. Resolve this inconsistency. - Mode: linux.S_IFSOCK | 0400, + stat, err := s.vfsfd.Stat(t, vfs.StatOptions{Mask: linux.STATX_MODE}) + if err != nil { + return syserr.FromError(err) + } + err = t.Kernel().VFS().MknodAt(t, t.Credentials(), &pop, &vfs.MknodOptions{ + // File permissions correspond to net/unix/af_unix.c:unix_bind. + Mode: linux.FileMode(linux.S_IFSOCK | uint(stat.Mode)&^t.FSContext().Umask()), Endpoint: bep, }) if err == syserror.EEXIST { @@ -259,13 +275,6 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs }) } -// Release implements vfs.FileDescriptionImpl. -func (s *SocketVFS2) Release() { - // Release only decrements a reference on s because s may be referenced in - // the abstract socket namespace. - s.DecRef() -} - // Readiness implements waiter.Waitable.Readiness. func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask { return s.socketOpsCommon.Readiness(mask) @@ -307,7 +316,7 @@ func (*providerVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int) return nil, syserr.ErrInvalidArgument } - f, err := NewVFS2File(t, ep, stype) + f, err := NewSockfsFile(t, ep, stype) if err != nil { ep.Close() return nil, err @@ -331,13 +340,13 @@ func (*providerVFS2) Pair(t *kernel.Task, stype linux.SockType, protocol int) (* // Create the endpoints and sockets. ep1, ep2 := transport.NewPair(t, stype, t.Kernel()) - s1, err := NewVFS2File(t, ep1, stype) + s1, err := NewSockfsFile(t, ep1, stype) if err != nil { ep1.Close() ep2.Close() return nil, nil, err } - s2, err := NewVFS2File(t, ep2, stype) + s2, err := NewSockfsFile(t, ep2, stype) if err != nil { s1.DecRef() ep2.Close() diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go index 7e1747a0c..582d37e03 100644 --- a/pkg/sentry/syscalls/linux/sys_signal.go +++ b/pkg/sentry/syscalls/linux/sys_signal.go @@ -355,7 +355,7 @@ func Pause(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall func RtSigpending(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { addr := args[0].Pointer() pending := t.PendingSignals() - _, err := t.CopyOut(addr, pending) + _, err := pending.CopyOut(t, addr) return 0, nil, err } @@ -392,7 +392,7 @@ func RtSigtimedwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne if siginfo != 0 { si.FixSignalCodeForUser() - if _, err := t.CopyOut(siginfo, si); err != nil { + if _, err := si.CopyOut(t, siginfo); err != nil { return 0, nil, err } } @@ -411,7 +411,7 @@ func RtSigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne // same way), and that the code is in the allowed set. This same logic // appears below in RtSigtgqueueinfo and should be kept in sync. var info arch.SignalInfo - if _, err := t.CopyIn(infoAddr, &info); err != nil { + if _, err := info.CopyIn(t, infoAddr); err != nil { return 0, nil, err } info.Signo = int32(sig) @@ -455,7 +455,7 @@ func RtTgsigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *ker // Copy in the info. See RtSigqueueinfo above. var info arch.SignalInfo - if _, err := t.CopyIn(infoAddr, &info); err != nil { + if _, err := info.CopyIn(t, infoAddr); err != nil { return 0, nil, err } info.Signo = int32(sig) @@ -485,7 +485,7 @@ func RtSigsuspend(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. // Copy in the signal mask. var mask linux.SignalSet - if _, err := t.CopyIn(sigset, &mask); err != nil { + if _, err := mask.CopyIn(t, sigset); err != nil { return 0, nil, err } mask &^= kernel.UnblockableSignals diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go index a64d86122..981bd8caa 100644 --- a/pkg/sentry/vfs/anonfs.go +++ b/pkg/sentry/vfs/anonfs.go @@ -237,10 +237,13 @@ func (fs *anonFilesystem) UnlinkAt(ctx context.Context, rp *ResolvingPath) error } // BoundEndpointAt implements FilesystemImpl.BoundEndpointAt. -func (fs *anonFilesystem) BoundEndpointAt(ctx context.Context, rp *ResolvingPath) (transport.BoundEndpoint, error) { +func (fs *anonFilesystem) BoundEndpointAt(ctx context.Context, rp *ResolvingPath, opts BoundEndpointOptions) (transport.BoundEndpoint, error) { if !rp.Final() { return nil, syserror.ENOTDIR } + if err := GenericCheckPermissions(rp.Credentials(), MayWrite, anonFileMode, anonFileUID, anonFileGID); err != nil { + return nil, err + } return nil, syserror.ECONNREFUSED } diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go index 20e5bb072..1edd584c9 100644 --- a/pkg/sentry/vfs/filesystem.go +++ b/pkg/sentry/vfs/filesystem.go @@ -494,8 +494,14 @@ type FilesystemImpl interface { // BoundEndpointAt returns the Unix socket endpoint bound at the path rp. // - // - If a non-socket file exists at rp, then BoundEndpointAt returns ECONNREFUSED. - BoundEndpointAt(ctx context.Context, rp *ResolvingPath) (transport.BoundEndpoint, error) + // Errors: + // + // - If the file does not have write permissions, then BoundEndpointAt + // returns EACCES. + // + // - If a non-socket file exists at rp, then BoundEndpointAt returns + // ECONNREFUSED. + BoundEndpointAt(ctx context.Context, rp *ResolvingPath, opts BoundEndpointOptions) (transport.BoundEndpoint, error) // PrependPath prepends a path from vd to vd.Mount().Root() to b. // diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go index 022bac127..53d364c5c 100644 --- a/pkg/sentry/vfs/options.go +++ b/pkg/sentry/vfs/options.go @@ -151,6 +151,24 @@ type SetStatOptions struct { Stat linux.Statx } +// BoundEndpointOptions contains options to VirtualFilesystem.BoundEndpointAt() +// and FilesystemImpl.BoundEndpointAt(). +type BoundEndpointOptions struct { + // Addr is the path of the file whose socket endpoint is being retrieved. + // It is generally irrelevant: most endpoints are stored at a dentry that + // was created through a bind syscall, so the path can be stored on creation. + // However, if the endpoint was created in FilesystemImpl.BoundEndpointAt(), + // then we may not know what the original bind address was. + // + // For example, if connect(2) is called with address "foo" which corresponds + // a remote named socket in goferfs, we need to generate an endpoint wrapping + // that file. In this case, we can use Addr to set the endpoint address to + // "foo". Note that Addr is only a best-effort attempt--we still do not know + // the exact address that was used on the remote fs to bind the socket (it + // may have been "foo", "./foo", etc.). + Addr string +} + // GetxattrOptions contains options to VirtualFilesystem.GetxattrAt(), // FilesystemImpl.GetxattrAt(), FileDescription.Getxattr(), and // FileDescriptionImpl.Getxattr(). diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 9015f2cc1..8d7f8f8af 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -351,33 +351,6 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia } } -// BoundEndpointAt gets the bound endpoint at the given path, if one exists. -func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (transport.BoundEndpoint, error) { - if !pop.Path.Begin.Ok() { - if pop.Path.Absolute { - return nil, syserror.ECONNREFUSED - } - return nil, syserror.ENOENT - } - rp := vfs.getResolvingPath(creds, pop) - for { - bep, err := rp.mount.fs.impl.BoundEndpointAt(ctx, rp) - if err == nil { - vfs.putResolvingPath(rp) - return bep, nil - } - if checkInvariants { - if rp.canHandleError(err) && rp.Done() { - panic(fmt.Sprintf("%T.BoundEndpointAt() consumed all path components and returned %v", rp.mount.fs.impl, err)) - } - } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) - return nil, err - } - } -} - // OpenAt returns a FileDescription providing access to the file at the given // path. A reference is taken on the returned FileDescription. func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) { @@ -675,6 +648,33 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti } } +// BoundEndpointAt gets the bound endpoint at the given path, if one exists. +func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *BoundEndpointOptions) (transport.BoundEndpoint, error) { + if !pop.Path.Begin.Ok() { + if pop.Path.Absolute { + return nil, syserror.ECONNREFUSED + } + return nil, syserror.ENOENT + } + rp := vfs.getResolvingPath(creds, pop) + for { + bep, err := rp.mount.fs.impl.BoundEndpointAt(ctx, rp, *opts) + if err == nil { + vfs.putResolvingPath(rp) + return bep, nil + } + if checkInvariants { + if rp.canHandleError(err) && rp.Done() { + panic(fmt.Sprintf("%T.BoundEndpointAt() consumed all path components and returned %v", rp.mount.fs.impl, err)) + } + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return nil, err + } + } +} + // ListxattrAt returns all extended attribute names for the file at the given // path. func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, size uint64) ([]string, error) { diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go index fcc46420f..101497ed6 100644 --- a/pkg/sentry/watchdog/watchdog.go +++ b/pkg/sentry/watchdog/watchdog.go @@ -255,7 +255,7 @@ func (w *Watchdog) runTurn() { case <-done: case <-time.After(w.TaskTimeout): // Report if the watchdog is not making progress. - // No one is wathching the watchdog watcher though. + // No one is watching the watchdog watcher though. w.reportStuckWatchdog() <-done } @@ -317,10 +317,8 @@ func (w *Watchdog) report(offenders map[*kernel.Task]*offender, newTaskFound boo buf.WriteString("Search for '(*Task).run(0x..., 0x<tid>)' in the stack dump to find the offending goroutine") - // Dump stack only if a new task is detected or if it sometime has - // passed since the last time a stack dump was generated. - showStack := newTaskFound || time.Since(w.lastStackDump) >= stackDumpSameTaskPeriod - w.doAction(w.TaskTimeoutAction, showStack, &buf) + // Force stack dump only if a new task is detected. + w.doAction(w.TaskTimeoutAction, newTaskFound, &buf) } func (w *Watchdog) reportStuckWatchdog() { @@ -329,12 +327,15 @@ func (w *Watchdog) reportStuckWatchdog() { w.doAction(w.TaskTimeoutAction, false, &buf) } -// doAction will take the given action. If the action is LogWarning and -// showStack is false, then the stack printing will be skipped. -func (w *Watchdog) doAction(action Action, showStack bool, msg *bytes.Buffer) { +// doAction will take the given action. If the action is LogWarning, the stack +// is not always dumpped to the log to prevent log flooding. "forceStack" +// guarantees that the stack will be dumped regarless. +func (w *Watchdog) doAction(action Action, forceStack bool, msg *bytes.Buffer) { switch action { case LogWarning: - if !showStack { + // Dump stack only if forced or sometime has passed since the last time a + // stack dump was generated. + if !forceStack && time.Since(w.lastStackDump) < stackDumpSameTaskPeriod { msg.WriteString("\n...[stack dump skipped]...") log.Warningf(msg.String()) return @@ -359,6 +360,7 @@ func (w *Watchdog) doAction(action Action, showStack bool, msg *bytes.Buffer) { case <-time.After(1 * time.Second): } panic(fmt.Sprintf("%s\nStack for running G's are skipped while panicking.", msg.String())) + default: panic(fmt.Sprintf("Unknown watchdog action %v", action)) diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index f01217c91..8ec5d5d5c 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -77,8 +77,7 @@ func NewVectorisedView(size int, views []View) VectorisedView { return VectorisedView{views: views, size: size} } -// TrimFront removes the first "count" bytes of the vectorised view. It panics -// if count > vv.Size(). +// TrimFront removes the first "count" bytes of the vectorised view. func (vv *VectorisedView) TrimFront(count int) { for count > 0 && len(vv.views) > 0 { if count < len(vv.views[0]) { @@ -87,7 +86,7 @@ func (vv *VectorisedView) TrimFront(count int) { return } count -= len(vv.views[0]) - vv.removeFirst() + vv.RemoveFirst() } } @@ -105,7 +104,7 @@ func (vv *VectorisedView) Read(v View) (copied int, err error) { count -= len(vv.views[0]) copy(v[copied:], vv.views[0]) copied += len(vv.views[0]) - vv.removeFirst() + vv.RemoveFirst() } if copied == 0 { return 0, io.EOF @@ -127,7 +126,7 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int count -= len(vv.views[0]) dstVV.AppendView(vv.views[0]) copied += len(vv.views[0]) - vv.removeFirst() + vv.RemoveFirst() } return copied } @@ -163,37 +162,22 @@ func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } -// PullUp returns the first "count" bytes of the vectorised view. If those -// bytes aren't already contiguous inside the vectorised view, PullUp will -// reallocate as needed to make them contiguous. PullUp fails and returns false -// when count > vv.Size(). -func (vv *VectorisedView) PullUp(count int) (View, bool) { +// First returns the first view of the vectorised view. +func (vv *VectorisedView) First() View { if len(vv.views) == 0 { - return nil, count == 0 - } - if count <= len(vv.views[0]) { - return vv.views[0][:count], true - } - if count > vv.size { - return nil, false + return nil } + return vv.views[0] +} - newFirst := NewView(count) - i := 0 - for offset := 0; offset < count; i++ { - copy(newFirst[offset:], vv.views[i]) - if count-offset < len(vv.views[i]) { - vv.views[i].TrimFront(count - offset) - break - } - offset += len(vv.views[i]) - vv.views[i] = nil +// RemoveFirst removes the first view of the vectorised view. +func (vv *VectorisedView) RemoveFirst() { + if len(vv.views) == 0 { + return } - // We're guaranteed that i > 0, since count is too large for the first - // view. - vv.views[i-1] = newFirst - vv.views = vv.views[i-1:] - return newFirst, true + vv.size -= len(vv.views[0]) + vv.views[0] = nil + vv.views = vv.views[1:] } // Size returns the size in bytes of the entire content stored in the vectorised view. @@ -241,10 +225,3 @@ func (vv *VectorisedView) Readers() []bytes.Reader { } return readers } - -// removeFirst panics when len(vv.views) < 1. -func (vv *VectorisedView) removeFirst() { - vv.size -= len(vv.views[0]) - vv.views[0] = nil - vv.views = vv.views[1:] -} diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index c56795c7b..106e1994c 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -16,7 +16,6 @@ package buffer import ( - "bytes" "reflect" "testing" ) @@ -371,115 +370,3 @@ func TestVVRead(t *testing.T) { }) } } - -var pullUpTestCases = []struct { - comment string - in VectorisedView - count int - want []byte - result VectorisedView - ok bool -}{ - { - comment: "simple case", - in: vv(2, "12"), - count: 1, - want: []byte("1"), - result: vv(2, "12"), - ok: true, - }, - { - comment: "entire View", - in: vv(2, "1", "2"), - count: 1, - want: []byte("1"), - result: vv(2, "1", "2"), - ok: true, - }, - { - comment: "spanning across two Views", - in: vv(3, "1", "23"), - count: 2, - want: []byte("12"), - result: vv(3, "12", "3"), - ok: true, - }, - { - comment: "spanning across all Views", - in: vv(5, "1", "23", "45"), - count: 5, - want: []byte("12345"), - result: vv(5, "12345"), - ok: true, - }, - { - comment: "count = 0", - in: vv(1, "1"), - count: 0, - want: []byte{}, - result: vv(1, "1"), - ok: true, - }, - { - comment: "count = size", - in: vv(1, "1"), - count: 1, - want: []byte("1"), - result: vv(1, "1"), - ok: true, - }, - { - comment: "count too large", - in: vv(3, "1", "23"), - count: 4, - want: nil, - result: vv(3, "1", "23"), - ok: false, - }, - { - comment: "empty vv", - in: vv(0, ""), - count: 1, - want: nil, - result: vv(0, ""), - ok: false, - }, - { - comment: "empty vv, count = 0", - in: vv(0, ""), - count: 0, - want: nil, - result: vv(0, ""), - ok: true, - }, - { - comment: "empty views", - in: vv(3, "", "1", "", "23"), - count: 2, - want: []byte("12"), - result: vv(3, "12", "3"), - ok: true, - }, -} - -func TestPullUp(t *testing.T) { - for _, c := range pullUpTestCases { - got, ok := c.in.PullUp(c.count) - - // Is the return value right? - if ok != c.ok { - t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t", - c.comment, c.count, c.in, ok, c.ok) - } - if bytes.Compare(got, View(c.want)) != 0 { - t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v", - c.comment, c.count, c.in, got, c.want) - } - - // Is the underlying structure right? - if !reflect.DeepEqual(c.in, c.result) { - t.Errorf("Test %q failed when calling PullUp(%d). Got vv with structure %v. Wanted %v", - c.comment, c.count, c.in, c.result) - } - } -} diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index 0cac6c0a5..7908c5744 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -71,6 +71,7 @@ const ( // Values for ICMP code as defined in RFC 792. const ( + ICMPv4TTLExceeded = 0 ICMPv4PortUnreachable = 3 ICMPv4FragmentationNeeded = 4 ) diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index ba80b64a8..4f367fe4c 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -17,6 +17,7 @@ package header import ( "crypto/sha256" "encoding/binary" + "fmt" "strings" "gvisor.dev/gvisor/pkg/tcpip" @@ -445,3 +446,54 @@ func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) { return GlobalScope, nil } } + +// InitialTempIID generates the initial temporary IID history value to generate +// temporary SLAAC addresses with. +// +// Panics if initialTempIIDHistory is not at least IIDSize bytes. +func InitialTempIID(initialTempIIDHistory []byte, seed []byte, nicID tcpip.NICID) { + h := sha256.New() + // h.Write never returns an error. + h.Write(seed) + var nicIDBuf [4]byte + binary.BigEndian.PutUint32(nicIDBuf[:], uint32(nicID)) + h.Write(nicIDBuf[:]) + + var sumBuf [sha256.Size]byte + sum := h.Sum(sumBuf[:0]) + + if n := copy(initialTempIIDHistory, sum[sha256.Size-IIDSize:]); n != IIDSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IIDSize)) + } +} + +// GenerateTempIPv6SLAACAddr generates a temporary SLAAC IPv6 address for an +// associated stable/permanent SLAAC address. +// +// GenerateTempIPv6SLAACAddr will update the temporary IID history value to be +// used when generating a new temporary IID. +// +// Panics if tempIIDHistory is not at least IIDSize bytes. +func GenerateTempIPv6SLAACAddr(tempIIDHistory []byte, stableAddr tcpip.Address) tcpip.AddressWithPrefix { + addrBytes := []byte(stableAddr) + h := sha256.New() + h.Write(tempIIDHistory) + h.Write(addrBytes[IIDOffsetInIPv6Address:]) + var sumBuf [sha256.Size]byte + sum := h.Sum(sumBuf[:0]) + + // The rightmost 64 bits of sum are saved for the next iteration. + if n := copy(tempIIDHistory, sum[sha256.Size-IIDSize:]); n != IIDSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IIDSize)) + } + + // The leftmost 64 bits of sum is used as the IID. + if n := copy(addrBytes[IIDOffsetInIPv6Address:], sum); n != IIDSize { + panic(fmt.Sprintf("copied %d IID bytes, expected %d bytes", n, IIDSize)) + } + + return tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes), + PrefixLen: IIDOffsetInIPv6Address * 8, + } +} diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 073c84ef9..1e2255bfa 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -98,13 +98,13 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - // There should be an ethernet header at the beginning of vv. - hdr, ok := vv.PullUp(header.EthernetMinimumSize) - if !ok { - // Reject the packet if it's shorter than an ethernet header. + // Reject the packet if it's shorter than an ethernet header. + if vv.Size() < header.EthernetMinimumSize { return tcpip.ErrBadAddress } - linkHeader := header.Ethernet(hdr) + + // There should be an ethernet header at the beginning of vv. + linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize]) vv.TrimFront(len(linkHeader)) e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{ Data: vv, diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index 9cc08d0e2..14b527bc2 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -18,10 +18,3 @@ go_library( "@org_golang_x_sys//unix:go_default_library", ], ) - -go_test( - name = "rawfile_test", - size = "small", - srcs = ["rawfile_test.go"], - library = ":rawfile", -) diff --git a/pkg/tcpip/link/rawfile/rawfile_test.go b/pkg/tcpip/link/rawfile/rawfile_test.go deleted file mode 100644 index 8f14ba761..000000000 --- a/pkg/tcpip/link/rawfile/rawfile_test.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -package rawfile - -import ( - "syscall" - "testing" -) - -func TestNonBlockingWrite3ZeroLength(t *testing.T) { - fd, err := syscall.Open("/dev/null", syscall.O_WRONLY, 0) - if err != nil { - t.Fatalf("failed to open /dev/null: %v", err) - } - defer syscall.Close(fd) - - if err := NonBlockingWrite3(fd, []byte{}, []byte{0}, nil); err != nil { - t.Fatalf("failed to write: %v", err) - } -} - -func TestNonBlockingWrite3Nil(t *testing.T) { - fd, err := syscall.Open("/dev/null", syscall.O_WRONLY, 0) - if err != nil { - t.Fatalf("failed to open /dev/null: %v", err) - } - defer syscall.Close(fd) - - if err := NonBlockingWrite3(fd, nil, []byte{0}, nil); err != nil { - t.Fatalf("failed to write: %v", err) - } -} diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index 92efd0bf8..44e25d475 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -76,13 +76,9 @@ func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error { // We have two buffers. Build the iovec that represents them and issue // a writev syscall. - var base *byte - if len(b1) > 0 { - base = &b1[0] - } iovec := [3]syscall.Iovec{ { - Base: base, + Base: &b1[0], Len: uint64(len(b1)), }, { diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 33f640b85..27ea3f531 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -674,7 +674,7 @@ func TestSimpleReceive(t *testing.T) { // Wait for packet to be received, then check it. c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") c.mu.Lock() - rcvd := []byte(c.packets[0].vv.ToView()) + rcvd := []byte(c.packets[0].vv.First()) c.packets = c.packets[:0] c.mu.Unlock() diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 0799c8f4d..be2537a82 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -171,7 +171,11 @@ func (e *endpoint) GSOMaxSize() uint32 { func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { - logPacket(prefix, protocol, pkt, gso) + first := pkt.Header.View() + if len(first) == 0 { + first = pkt.Data.First() + } + logPacket(prefix, protocol, first, gso) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { totalLength := pkt.Header.UsedLength() + pkt.Data.Size() @@ -234,7 +238,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { // Wait implements stack.LinkEndpoint.Wait. func (e *endpoint) Wait() { e.lower.Wait() } -func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { +func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -243,49 +247,28 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P size := uint16(0) var fragmentOffset uint16 var moreFragments bool - - // Create a clone of pkt, including any headers if present. Avoid allocating - // backing memory for the clone. - views := [8]buffer.View{} - vv := buffer.NewVectorisedView(0, views[:0]) - vv.AppendView(pkt.Header.View()) - vv.Append(pkt.Data) - switch protocol { case header.IPv4ProtocolNumber: - hdr, ok := vv.PullUp(header.IPv4MinimumSize) - if !ok { - return - } - ipv4 := header.IPv4(hdr) + ipv4 := header.IPv4(b) fragmentOffset = ipv4.FragmentOffset() moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments src = ipv4.SourceAddress() dst = ipv4.DestinationAddress() transProto = ipv4.Protocol() size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) - vv.TrimFront(int(ipv4.HeaderLength())) + b = b[ipv4.HeaderLength():] id = int(ipv4.ID()) case header.IPv6ProtocolNumber: - hdr, ok := vv.PullUp(header.IPv6MinimumSize) - if !ok { - return - } - ipv6 := header.IPv6(hdr) + ipv6 := header.IPv6(b) src = ipv6.SourceAddress() dst = ipv6.DestinationAddress() transProto = ipv6.NextHeader() size = ipv6.PayloadLength() - vv.TrimFront(header.IPv6MinimumSize) + b = b[header.IPv6MinimumSize:] case header.ARPProtocolNumber: - hdr, ok := vv.PullUp(header.ARPSize) - if !ok { - return - } - vv.TrimFront(header.ARPSize) - arp := header.ARP(hdr) + arp := header.ARP(b) log.Infof( "%s arp %v (%v) -> %v (%v) valid:%v", prefix, @@ -301,7 +284,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P // We aren't guaranteed to have a transport header - it's possible for // writes via raw endpoints to contain only network headers. - if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && vv.Size() < minSize { + if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && len(b) < minSize { log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto) return } @@ -314,11 +297,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P switch tcpip.TransportProtocolNumber(transProto) { case header.ICMPv4ProtocolNumber: transName = "icmp" - hdr, ok := vv.PullUp(header.ICMPv4MinimumSize) - if !ok { - break - } - icmp := header.ICMPv4(hdr) + icmp := header.ICMPv4(b) icmpType := "unknown" if fragmentOffset == 0 { switch icmp.Type() { @@ -351,11 +330,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.ICMPv6ProtocolNumber: transName = "icmp" - hdr, ok := vv.PullUp(header.ICMPv6MinimumSize) - if !ok { - break - } - icmp := header.ICMPv6(hdr) + icmp := header.ICMPv6(b) icmpType := "unknown" switch icmp.Type() { case header.ICMPv6DstUnreachable: @@ -386,11 +361,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.UDPProtocolNumber: transName = "udp" - hdr, ok := vv.PullUp(header.UDPMinimumSize) - if !ok { - break - } - udp := header.UDP(hdr) + udp := header.UDP(b) if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize { srcPort = udp.SourcePort() dstPort = udp.DestinationPort() @@ -400,11 +371,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.TCPProtocolNumber: transName = "tcp" - hdr, ok := vv.PullUp(header.TCPMinimumSize) - if !ok { - break - } - tcp := header.TCP(hdr) + tcp := header.TCP(b) if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize { offset := int(tcp.DataOffset()) if offset < header.TCPMinimumSize { diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index cf73a939e..7acbfa0a8 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -93,10 +93,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf } func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - v, ok := pkt.Data.PullUp(header.ARPSize) - if !ok { - return - } + v := pkt.Data.First() h := header.ARP(v) if !h.IsValid() { return diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 4cbefe5ab..c4bf1ba5c 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -25,11 +25,7 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return - } - hdr := header.IPv4(h) + h := header.IPv4(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that the IP // header plus 8 bytes of the transport header be included. So it's @@ -38,12 +34,12 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match the endpoint's address. - if hdr.SourceAddress() != e.id.LocalAddress { + if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress { return } - hlen := int(hdr.HeaderLength()) - if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 { + hlen := int(h.HeaderLength()) + if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -52,15 +48,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip the ip header, then deliver control message. pkt.Data.TrimFront(hlen) - p := hdr.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + p := h.TransportProtocol() + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived - v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) - if !ok { + v := pkt.Data.First() + if len(v) < header.ICMPv4MinimumSize { received.Invalid.Increment() return } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 17202cc7a..104aafbed 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -328,11 +328,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return tcpip.ErrInvalidOptionValue - } - ip := header.IPv4(h) + ip := header.IPv4(pkt.Data.First()) if !ip.IsValid(pkt.Data.Size()) { return tcpip.ErrInvalidOptionValue } @@ -382,11 +378,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - r.Stats().IP.MalformedPacketsReceived.Increment() - return - } + headerView := pkt.Data.First() h := header.IPv4(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index bdf3a0d25..b68983d10 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -28,11 +28,7 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) - if !ok { - return - } - hdr := header.IPv6(h) + h := header.IPv6(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that up to // 1280 bytes of the original packet be included. So it's likely that it @@ -40,21 +36,17 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv6 header or if the // original source address doesn't match the endpoint's address. - if hdr.SourceAddress() != e.id.LocalAddress { + if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress { return } // Skip the IP header, then handle the fragmentation header if there // is one. pkt.Data.TrimFront(header.IPv6MinimumSize) - p := hdr.TransportProtocol() + p := h.TransportProtocol() if p == header.IPv6FragmentHeader { - f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize) - if !ok { - return - } - fragHdr := header.IPv6Fragment(f) - if !fragHdr.IsValid() || fragHdr.FragmentOffset() != 0 { + f := header.IPv6Fragment(pkt.Data.First()) + if !f.IsValid() || f.FragmentOffset() != 0 { // We can't handle fragments that aren't at offset 0 // because they don't have the transport headers. return @@ -63,19 +55,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip fragmentation header and find out the actual protocol // number. pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) - p = fragHdr.TransportProtocol() + p = f.TransportProtocol() } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived - v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) - if !ok { + v := pkt.Data.First() + if len(v) < header.ICMPv6MinimumSize { received.Invalid.Increment() return } @@ -84,9 +76,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // Validate ICMPv6 checksum before processing the packet. // + // Only the first view in vv is accounted for by h. To account for the + // rest of vv, a shallow copy is made and the first view is removed. // This copy is used as extra payload during the checksum calculation. payload := pkt.Data.Clone(nil) - payload.TrimFront(len(h)) + payload.RemoveFirst() if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { received.Invalid.Increment() return @@ -107,40 +101,34 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P switch h.Type() { case header.ICMPv6PacketTooBig: received.PacketTooBig.Increment() - hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) - if !ok { + if len(v) < header.ICMPv6PacketTooBigMinimumSize { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := header.ICMPv6(hdr).MTU() + mtu := h.MTU() e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() - hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize) - if !ok { + if len(v) < header.ICMPv6DstUnreachableMinimumSize { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) - switch header.ICMPv6(hdr).Code() { + switch h.Code() { case header.ICMPv6PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { + if len(v) < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { received.Invalid.Increment() return } - // The remainder of payload must be only the neighbor solicitation, so - // payload.ToView() always returns the solicitation. Per RFC 6980 section 5, - // NDP messages cannot be fragmented. Also note that in the common case NDP - // datagrams are very small and ToView() will not incur allocations. - ns := header.NDPNeighborSolicit(payload.ToView()) + ns := header.NDPNeighborSolicit(h.NDPPayload()) it, err := ns.Options().Iter(true) if err != nil { // If we have a malformed NDP NS option, drop the packet. @@ -298,16 +286,12 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6NeighborAdvert: received.NeighborAdvert.Increment() - if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() { + if len(v) < header.ICMPv6NeighborAdvertSize || !isNDPValid() { received.Invalid.Increment() return } - // The remainder of payload must be only the neighbor advertisement, so - // payload.ToView() always returns the advertisement. Per RFC 6980 section - // 5, NDP messages cannot be fragmented. Also note that in the common case - // NDP datagrams are very small and ToView() will not incur allocations. - na := header.NDPNeighborAdvert(payload.ToView()) + na := header.NDPNeighborAdvert(h.NDPPayload()) it, err := na.Options().Iter(true) if err != nil { // If we have a malformed NDP NA option, drop the packet. @@ -379,15 +363,14 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoRequest: received.EchoRequest.Increment() - icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize) - if !ok { + if len(v) < header.ICMPv6EchoMinimumSize { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - copy(packet, icmpHdr) + copy(packet, h) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ @@ -401,7 +384,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoReply: received.EchoReply.Increment() - if pkt.Data.Size() < header.ICMPv6EchoMinimumSize { + if len(v) < header.ICMPv6EchoMinimumSize { received.Invalid.Increment() return } @@ -423,9 +406,8 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6RouterAdvert: received.RouterAdvert.Increment() - // Is the NDP payload of sufficient size to hold a Router - // Advertisement? - if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() { + p := h.NDPPayload() + if len(p) < header.NDPRAMinimumSize || !isNDPValid() { received.Invalid.Increment() return } @@ -443,11 +425,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P return } - // The remainder of payload must be only the router advertisement, so - // payload.ToView() always returns the advertisement. Per RFC 6980 section - // 5, NDP messages cannot be fragmented. Also note that in the common case - // NDP datagrams are very small and ToView() will not incur allocations. - ra := header.NDPRouterAdvert(payload.ToView()) + ra := header.NDPRouterAdvert(p) opts := ra.Options() // Are options valid as per the wire format? diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index d412ff688..bd099a7f8 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -166,8 +166,7 @@ func TestICMPCounts(t *testing.T) { }, { typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize, - }, + size: header.ICMPv6NeighborSolicitMinimumSize}, { typ: header.ICMPv6NeighborAdvert, size: header.ICMPv6NeighborAdvertMinimumSize, diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 486725131..331b0817b 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -171,11 +171,7 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffe // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize) - if !ok { - r.Stats().IP.MalformedPacketsReceived.Increment() - return - } + headerView := pkt.Data.First() h := header.IPv6(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index c7c663498..e9c652042 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -70,10 +70,7 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) { // Consume the network header. - b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) - if !ok { - return - } + b := pkt.Data.First() pkt.Data.TrimFront(fwdTestNetHeaderLen) // Dispatch the packet to the transport protocol. @@ -476,7 +473,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -520,7 +517,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -567,7 +564,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -622,7 +619,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] < 8 { t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 6b91159d4..6c0a4b24d 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -212,11 +212,6 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. -// // NOTE: unlike the Check API the returned map contains packets that should be // dropped. func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) { @@ -231,9 +226,7 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*Pa return drop } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. +// Precondition: pkt.NetworkHeader is set. func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. @@ -278,21 +271,14 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx return chainDrop } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. +// Precondition: pk.NetworkHeader is set. func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // If pkt.NetworkHeader hasn't been set yet, it will be contained in - // pkt.Data. + // pkt.Data.First(). if pkt.NetworkHeader == nil { - var ok bool - pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - // Precondition has been violated. - panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize)) - } + pkt.NetworkHeader = pkt.Data.First() } // Check whether the packet matches the IP header filter. diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 8be61f4b1..7b4543caf 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -96,12 +96,9 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { newPkt := pkt.Clone() // Set network header. - headerView, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return RuleDrop, 0 - } + headerView := newPkt.Data.First() netHeader := header.IPv4(headerView) - newPkt.NetworkHeader = headerView + newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize] hlen := int(netHeader.HeaderLength()) tlen := int(netHeader.TotalLength()) @@ -120,14 +117,10 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { udpHeader = header.UDP(newPkt.TransportHeader) } else { - if pkt.Data.Size() < header.UDPMinimumSize { - return RuleDrop, 0 - } - hdr, ok := newPkt.Data.PullUp(header.UDPMinimumSize) - if !ok { + if len(pkt.Data.First()) < header.UDPMinimumSize { return RuleDrop, 0 } - udpHeader = header.UDP(hdr) + udpHeader = header.UDP(newPkt.Data.First()) } udpHeader.SetDestinationPort(rt.MinPort) case header.TCPProtocolNumber: @@ -135,14 +128,10 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { tcpHeader = header.TCP(newPkt.TransportHeader) } else { - if pkt.Data.Size() < header.TCPMinimumSize { + if len(pkt.Data.First()) < header.TCPMinimumSize { return RuleDrop, 0 } - hdr, ok := newPkt.Data.PullUp(header.TCPMinimumSize) - if !ok { - return RuleDrop, 0 - } - tcpHeader = header.TCP(hdr) + tcpHeader = header.TCP(newPkt.TransportHeader) } // TODO(gvisor.dev/issue/170): Need to recompute checksum // and implement nat connection tracking to support TCP. diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index c11d62f97..fbc0e3f55 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -119,6 +119,32 @@ const ( // identifier (IID) is 64 bits and an IPv6 address is 128 bits, so // 128 - 64 = 64. validPrefixLenForAutoGen = 64 + + // defaultAutoGenTempGlobalAddresses is the default configuration for whether + // or not to generate temporary SLAAC addresses. + defaultAutoGenTempGlobalAddresses = true + + // defaultMaxTempAddrValidLifetime is the default maximum valid lifetime + // for temporary SLAAC addresses generated as part of RFC 4941. + // + // Default = 7 days (from RFC 4941 section 5). + defaultMaxTempAddrValidLifetime = 7 * 24 * time.Hour + + // defaultMaxTempAddrPreferredLifetime is the default preferred lifetime + // for temporary SLAAC addresses generated as part of RFC 4941. + // + // Default = 1 day (from RFC 4941 section 5). + defaultMaxTempAddrPreferredLifetime = 24 * time.Hour + + // defaultRegenAdvanceDuration is the default duration before the deprecation + // of a temporary address when a new address will be generated. + // + // Default = 5s (from RFC 4941 section 5). + defaultRegenAdvanceDuration = 5 * time.Second + + // minRegenAdvanceDuration is the minimum duration before the deprecation + // of a temporary address when a new address will be generated. + minRegenAdvanceDuration = time.Duration(0) ) var ( @@ -131,6 +157,37 @@ var ( // // Min = 2hrs. MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour + + // MaxDesyncFactor is the upper bound for the preferred lifetime's desync + // factor for temporary SLAAC addresses. + // + // This is exported as a variable (instead of a constant) so tests + // can update it to a smaller value. + // + // Must be greater than 0. + // + // Max = 10m (from RFC 4941 section 5). + MaxDesyncFactor = 10 * time.Minute + + // MinMaxTempAddrPreferredLifetime is the minimum value allowed for the + // maximum preferred lifetime for temporary SLAAC addresses. + // + // This is exported as a variable (instead of a constant) so tests + // can update it to a smaller value. + // + // This value guarantees that a temporary address will be preferred for at + // least 1hr if the SLAAC prefix is valid for at least that time. + MinMaxTempAddrPreferredLifetime = defaultRegenAdvanceDuration + MaxDesyncFactor + time.Hour + + // MinMaxTempAddrValidLifetime is the minimum value allowed for the + // maximum valid lifetime for temporary SLAAC addresses. + // + // This is exported as a variable (instead of a constant) so tests + // can update it to a smaller value. + // + // This value guarantees that a temporary address will be valid for at least + // 2hrs if the SLAAC prefix is valid for at least that time. + MinMaxTempAddrValidLifetime = 2 * time.Hour ) // DHCPv6ConfigurationFromNDPRA is a configuration available via DHCPv6 that an @@ -324,35 +381,49 @@ type NDPConfigurations struct { // alternative addresses (e.g. IIDs based on the modified EUI64 of a NIC's // MAC address), then no attempt will be made to resolve the conflict. AutoGenAddressConflictRetries uint8 + + // AutoGenTempGlobalAddresses determines whether or not temporary SLAAC + // addresses will be generated for a NIC as part of SLAAC privacy extensions, + // RFC 4941. + // + // Ignored if AutoGenGlobalAddresses is false. + AutoGenTempGlobalAddresses bool + + // MaxTempAddrValidLifetime is the maximum valid lifetime for temporary + // SLAAC addresses. + MaxTempAddrValidLifetime time.Duration + + // MaxTempAddrPreferredLifetime is the maximum preferred lifetime for + // temporary SLAAC addresses. + MaxTempAddrPreferredLifetime time.Duration + + // RegenAdvanceDuration is the duration before the deprecation of a temporary + // address when a new address will be generated. + RegenAdvanceDuration time.Duration } // DefaultNDPConfigurations returns an NDPConfigurations populated with // default values. func DefaultNDPConfigurations() NDPConfigurations { return NDPConfigurations{ - DupAddrDetectTransmits: defaultDupAddrDetectTransmits, - RetransmitTimer: defaultRetransmitTimer, - MaxRtrSolicitations: defaultMaxRtrSolicitations, - RtrSolicitationInterval: defaultRtrSolicitationInterval, - MaxRtrSolicitationDelay: defaultMaxRtrSolicitationDelay, - HandleRAs: defaultHandleRAs, - DiscoverDefaultRouters: defaultDiscoverDefaultRouters, - DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes, - AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses, + DupAddrDetectTransmits: defaultDupAddrDetectTransmits, + RetransmitTimer: defaultRetransmitTimer, + MaxRtrSolicitations: defaultMaxRtrSolicitations, + RtrSolicitationInterval: defaultRtrSolicitationInterval, + MaxRtrSolicitationDelay: defaultMaxRtrSolicitationDelay, + HandleRAs: defaultHandleRAs, + DiscoverDefaultRouters: defaultDiscoverDefaultRouters, + DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes, + AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses, + AutoGenTempGlobalAddresses: defaultAutoGenTempGlobalAddresses, + MaxTempAddrValidLifetime: defaultMaxTempAddrValidLifetime, + MaxTempAddrPreferredLifetime: defaultMaxTempAddrPreferredLifetime, + RegenAdvanceDuration: defaultRegenAdvanceDuration, } } // validate modifies an NDPConfigurations with valid values. If invalid values // are present in c, the corresponding default values will be used instead. -// -// If RetransmitTimer is less than minimumRetransmitTimer, then a value of -// defaultRetransmitTimer will be used. -// -// If RtrSolicitationInterval is less than minimumRtrSolicitationInterval, then -// a value of defaultRtrSolicitationInterval will be used. -// -// If MaxRtrSolicitationDelay is less than minimumMaxRtrSolicitationDelay, then -// a value of defaultMaxRtrSolicitationDelay will be used. func (c *NDPConfigurations) validate() { if c.RetransmitTimer < minimumRetransmitTimer { c.RetransmitTimer = defaultRetransmitTimer @@ -365,6 +436,18 @@ func (c *NDPConfigurations) validate() { if c.MaxRtrSolicitationDelay < minimumMaxRtrSolicitationDelay { c.MaxRtrSolicitationDelay = defaultMaxRtrSolicitationDelay } + + if c.MaxTempAddrValidLifetime < MinMaxTempAddrValidLifetime { + c.MaxTempAddrValidLifetime = MinMaxTempAddrValidLifetime + } + + if c.MaxTempAddrPreferredLifetime < MinMaxTempAddrPreferredLifetime || c.MaxTempAddrPreferredLifetime > c.MaxTempAddrValidLifetime { + c.MaxTempAddrPreferredLifetime = MinMaxTempAddrPreferredLifetime + } + + if c.RegenAdvanceDuration < minRegenAdvanceDuration { + c.RegenAdvanceDuration = minRegenAdvanceDuration + } } // ndpState is the per-interface NDP state. @@ -394,6 +477,14 @@ type ndpState struct { // The last learned DHCPv6 configuration from an NDP RA. dhcpv6Configuration DHCPv6ConfigurationFromNDPRA + + // temporaryIIDHistory is the history value used to generate a new temporary + // IID. + temporaryIIDHistory [header.IIDSize]byte + + // temporaryAddressDesyncFactor is the preferred lifetime's desync factor for + // temporary SLAAC addresses. + temporaryAddressDesyncFactor time.Duration } // dadState holds the Duplicate Address Detection timer and channel to signal @@ -414,7 +505,7 @@ type dadState struct { type defaultRouterState struct { // Timer to invalidate the default router. // - // May not be nil. + // Must not be nil. invalidationTimer *tcpip.CancellableTimer } @@ -424,20 +515,48 @@ type defaultRouterState struct { type onLinkPrefixState struct { // Timer to invalidate the on-link prefix. // - // May not be nil. + // Must not be nil. + invalidationTimer *tcpip.CancellableTimer +} + +// tempSLAACAddrState holds state associated with a temporary SLAAC address. +type tempSLAACAddrState struct { + // Timer to deprecate the temporary SLAAC address. + // + // Must not be nil. + deprecationTimer *tcpip.CancellableTimer + + // Timer to invalidate the temporary SLAAC address. + // + // Must not be nil. invalidationTimer *tcpip.CancellableTimer + + // Timer to regenerate the temporary SLAAC address. + // + // Must not be nil. + regenTimer *tcpip.CancellableTimer + + createdAt time.Time + + // The address's endpoint. + // + // Must not be nil. + ref *referencedNetworkEndpoint + + // Has a new temporary SLAAC address already been regenerated? + regenerated bool } // slaacPrefixState holds state associated with a SLAAC prefix. type slaacPrefixState struct { // Timer to deprecate the prefix. // - // May not be nil. + // Must not be nil. deprecationTimer *tcpip.CancellableTimer // Timer to invalidate the prefix. // - // May not be nil. + // Must not be nil. invalidationTimer *tcpip.CancellableTimer // Nonzero only when the address is not valid forever. @@ -446,19 +565,27 @@ type slaacPrefixState struct { // Nonzero only when the address is not preferred forever. preferredUntil time.Time - // The prefix's permanent address endpoint. + // The endpoint for the stable address generated for a SLAAC prefix. // // May only be nil when a SLAAC address is being (re-)generated. Otherwise, - // must not be nil as all SLAAC prefixes must have a SLAAC address. - ref *referencedNetworkEndpoint + // must not be nil as all SLAAC prefixes must have a stable address. + stableAddrRef *referencedNetworkEndpoint + + // The temporary (short-lived) addresses generated for the SLAAC prefix. + tempAddrs map[tcpip.Address]tempSLAACAddrState - // The number of times a permanent address has been generated for the prefix. + // The next two fields are used by both stable and temporary addresses + // generated for a SLAAC prefix. This is safe as only 1 address will be + // in the generation and DAD process at any time. That is, no two addresses + // will be generated at the same time for a given SLAAC prefix. + + // The number of times an address has been generated. // // Addresses may be regenerated in reseponse to a DAD conflicts. generationAttempts uint8 - // The maximum number of times to attempt regeneration of a permanent SLAAC - // address in response to DAD conflicts. + // The maximum number of times to attempt regeneration of a SLAAC address + // in response to DAD conflicts. maxGenerationAttempts uint8 } @@ -536,10 +663,10 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref } ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() if done { // If we reach this point, it means that DAD was stopped after we released // the NIC's read lock and before we obtained the write lock. - ndp.nic.mu.Unlock() return } @@ -551,8 +678,6 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // schedule the next DAD timer. remaining-- timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer) - - ndp.nic.mu.Unlock() return } @@ -560,15 +685,18 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // the last NDP NS. Either way, clean up addr's DAD state and let the // integrator know DAD has completed. delete(ndp.dad, addr) - ndp.nic.mu.Unlock() - - if err != nil { - log.Printf("ndpdad: error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, ndp.nic.ID(), err) - } if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, dadDone, err) } + + // If DAD resolved for a stable SLAAC address, attempt generation of a + // temporary SLAAC address. + if dadDone && ref.configType == slaac { + // Reset the generation attempts counter as we are starting the generation + // of a new address for the SLAAC prefix. + ndp.regenerateTempSLAACAddr(ref.addrWithPrefix().Subnet(), true /* resetGenAttempts */) + } }) ndp.dad[addr] = dadState{ @@ -953,9 +1081,10 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform prefix := pi.Subnet() // Check if we already maintain SLAAC state for prefix. - if _, ok := ndp.slaacPrefixes[prefix]; ok { + if state, ok := ndp.slaacPrefixes[prefix]; ok { // As per RFC 4862 section 5.5.3.e, refresh prefix's SLAAC lifetimes. - ndp.refreshSLAACPrefixLifetimes(prefix, pl, vl) + ndp.refreshSLAACPrefixLifetimes(prefix, &state, pl, vl) + ndp.slaacPrefixes[prefix] = state return } @@ -996,7 +1125,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix)) } - ndp.deprecateSLAACAddress(state.ref) + ndp.deprecateSLAACAddress(state.stableAddrRef) }), invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { state, ok := ndp.slaacPrefixes[prefix] @@ -1006,6 +1135,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { ndp.invalidateSLAACPrefix(prefix, state) }), + tempAddrs: make(map[tcpip.Address]tempSLAACAddrState), maxGenerationAttempts: ndp.configs.AutoGenAddressConflictRetries + 1, } @@ -1035,9 +1165,49 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { state.validUntil = now.Add(vl) } + // If the address is assigned (DAD resolved), generate a temporary address. + if state.stableAddrRef.getKind() == permanent { + // Reset the generation attempts counter as we are starting the generation + // of a new address for the SLAAC prefix. + ndp.generateTempSLAACAddr(prefix, &state, true /* resetGenAttempts */) + } + ndp.slaacPrefixes[prefix] = state } +// addSLAACAddr adds a SLAAC address to the NIC. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) addSLAACAddr(addr tcpip.AddressWithPrefix, configType networkEndpointConfigType, deprecated bool) *referencedNetworkEndpoint { + // If the nic already has this address, do nothing further. + if ndp.nic.hasPermanentAddrLocked(addr.Address) { + return nil + } + + // Inform the integrator that we have a new SLAAC address. + ndpDisp := ndp.nic.stack.ndpDisp + if ndpDisp == nil { + return nil + } + + if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addr) { + // Informed by the integrator not to add the address. + return nil + } + + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr, + } + + ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, configType, deprecated) + if err != nil { + panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", protocolAddr, err)) + } + + return ref +} + // generateSLAACAddr generates a SLAAC address for prefix. // // Returns true if an address was successfully generated. @@ -1046,7 +1216,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { // // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixState) bool { - if r := state.ref; r != nil { + if r := state.stableAddrRef; r != nil { panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, r.addrWithPrefix())) } @@ -1085,39 +1255,18 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt return false } - generatedAddr := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(addrBytes), - PrefixLen: validPrefixLenForAutoGen, - }, - } - - // If the nic already has this address, do nothing further. - if ndp.nic.hasPermanentAddrLocked(generatedAddr.AddressWithPrefix.Address) { - return false - } - - // Inform the integrator that we have a new SLAAC address. - ndpDisp := ndp.nic.stack.ndpDisp - if ndpDisp == nil { - return false + generatedAddr := tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes), + PrefixLen: validPrefixLenForAutoGen, } - if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), generatedAddr.AddressWithPrefix) { - // Informed by the integrator not to add the address. - return false + if ref := ndp.addSLAACAddr(generatedAddr, slaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); ref != nil { + state.stableAddrRef = ref + state.generationAttempts++ + return true } - deprecated := time.Since(state.preferredUntil) >= 0 - ref, err := ndp.nic.addAddressLocked(generatedAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated) - if err != nil { - panic(fmt.Sprintf("ndp: error when adding address %+v: %s", generatedAddr, err)) - } - - state.generationAttempts++ - state.ref = ref - return true + return false } // regenerateSLAACAddr regenerates an address for a SLAAC prefix. @@ -1143,24 +1292,167 @@ func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) { ndp.invalidateSLAACPrefix(prefix, state) } -// refreshSLAACPrefixLifetimes refreshes the lifetimes of a SLAAC prefix. +// generateTempSLAACAddr generates a new temporary SLAAC address. // -// pl is the new preferred lifetime. vl is the new valid lifetime. +// If resetGenAttempts is true, the prefix's generation counter will be reset. +// +// Returns true if a new address was generated. +func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *slaacPrefixState, resetGenAttempts bool) bool { + // Are we configured to auto-generate new temporary global addresses for the + // prefix? + if !ndp.configs.AutoGenTempGlobalAddresses || prefix == header.IPv6LinkLocalPrefix.Subnet() { + return false + } + + if resetGenAttempts { + prefixState.generationAttempts = 0 + prefixState.maxGenerationAttempts = ndp.configs.AutoGenAddressConflictRetries + 1 + } + + // If we have already reached the maximum address generation attempts for the + // prefix, do not generate another address. + if prefixState.generationAttempts == prefixState.maxGenerationAttempts { + return false + } + + stableAddr := prefixState.stableAddrRef.ep.ID().LocalAddress + now := time.Now() + + // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary + // address is the lower of the valid lifetime of the stable address or the + // maximum temporary address valid lifetime. + vl := ndp.configs.MaxTempAddrValidLifetime + if prefixState.validUntil != (time.Time{}) { + if prefixVL := prefixState.validUntil.Sub(now); vl > prefixVL { + vl = prefixVL + } + } + + if vl <= 0 { + // Cannot create an address without a valid lifetime. + return false + } + + // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary + // address is the lower of the preferred lifetime of the stable address or the + // maximum temporary address preferred lifetime - the temporary address desync + // factor. + pl := ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor + if prefixState.preferredUntil != (time.Time{}) { + if prefixPL := prefixState.preferredUntil.Sub(now); pl > prefixPL { + // Respect the preferred lifetime of the prefix, as per RFC 4941 section + // 3.3 step 4. + pl = prefixPL + } + } + + // As per RFC 4941 section 3.3 step 5, a temporary address is created only if + // the calculated preferred lifetime is greater than the advance regeneration + // duration. In particular, we MUST NOT create a temporary address with a zero + // Preferred Lifetime. + if pl <= ndp.configs.RegenAdvanceDuration { + return false + } + + generatedAddr := header.GenerateTempIPv6SLAACAddr(ndp.temporaryIIDHistory[:], stableAddr) + + // As per RFC RFC 4941 section 3.3 step 5, we MUST NOT create a temporary + // address with a zero preferred lifetime. The checks above ensure this + // so we know the address is not deprecated. + ref := ndp.addSLAACAddr(generatedAddr, slaacTemp, false /* deprecated */) + if ref == nil { + return false + } + + state := tempSLAACAddrState{ + deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + prefixState, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr)) + } + + tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address] + if !ok { + panic(fmt.Sprintf("ndp: must have a tempAddr entry to deprecate temporary address %s", generatedAddr)) + } + + ndp.deprecateSLAACAddress(tempAddrState.ref) + }), + invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + prefixState, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr)) + } + + tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address] + if !ok { + panic(fmt.Sprintf("ndp: must have a tempAddr entry to invalidate temporary address %s", generatedAddr)) + } + + ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState) + }), + regenTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + prefixState, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr)) + } + + tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address] + if !ok { + panic(fmt.Sprintf("ndp: must have a tempAddr entry to regenerate temporary address after %s", generatedAddr)) + } + + // If an address has already been regenerated for this address, don't + // regenerate another address. + if tempAddrState.regenerated { + return + } + + // Reset the generation attempts counter as we are starting the generation + // of a new address for the SLAAC prefix. + tempAddrState.regenerated = ndp.generateTempSLAACAddr(prefix, &prefixState, true /* resetGenAttempts */) + prefixState.tempAddrs[generatedAddr.Address] = tempAddrState + ndp.slaacPrefixes[prefix] = prefixState + }), + createdAt: now, + ref: ref, + } + + state.deprecationTimer.Reset(pl) + state.invalidationTimer.Reset(vl) + state.regenTimer.Reset(pl - ndp.configs.RegenAdvanceDuration) + + prefixState.generationAttempts++ + prefixState.tempAddrs[generatedAddr.Address] = state + + return true +} + +// regenerateTempSLAACAddr regenerates a temporary address for a SLAAC prefix. // // The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, pl, vl time.Duration) { - prefixState, ok := ndp.slaacPrefixes[prefix] +func (ndp *ndpState) regenerateTempSLAACAddr(prefix tcpip.Subnet, resetGenAttempts bool) { + state, ok := ndp.slaacPrefixes[prefix] if !ok { - panic(fmt.Sprintf("ndp: SLAAC prefix state not found to refresh lifetimes for %s", prefix)) + panic(fmt.Sprintf("ndp: SLAAC prefix state not found to regenerate temporary address for %s", prefix)) } - defer func() { ndp.slaacPrefixes[prefix] = prefixState }() + ndp.generateTempSLAACAddr(prefix, &state, resetGenAttempts) + ndp.slaacPrefixes[prefix] = state +} + +// refreshSLAACPrefixLifetimes refreshes the lifetimes of a SLAAC prefix. +// +// pl is the new preferred lifetime. vl is the new valid lifetime. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixState *slaacPrefixState, pl, vl time.Duration) { // If the preferred lifetime is zero, then the prefix should be deprecated. deprecated := pl == 0 if deprecated { - ndp.deprecateSLAACAddress(prefixState.ref) + ndp.deprecateSLAACAddress(prefixState.stableAddrRef) } else { - prefixState.ref.deprecated = false + prefixState.stableAddrRef.deprecated = false } // If prefix was preferred for some finite lifetime before, stop the @@ -1190,36 +1482,118 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, pl, vl tim // // 3) Otherwise, reset the valid lifetime of the prefix to 2 hours. - // Handle the infinite valid lifetime separately as we do not keep a timer in - // this case. if vl >= header.NDPInfiniteLifetime { + // Handle the infinite valid lifetime separately as we do not keep a timer + // in this case. prefixState.invalidationTimer.StopLocked() prefixState.validUntil = time.Time{} - return - } + } else { + var effectiveVl time.Duration + var rl time.Duration + + // If the prefix was originally set to be valid forever, assume the + // remaining time to be the maximum possible value. + if prefixState.validUntil == (time.Time{}) { + rl = header.NDPInfiniteLifetime + } else { + rl = time.Until(prefixState.validUntil) + } - var effectiveVl time.Duration - var rl time.Duration + if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { + effectiveVl = vl + } else if rl > MinPrefixInformationValidLifetimeForUpdate { + effectiveVl = MinPrefixInformationValidLifetimeForUpdate + } - // If the prefix was originally set to be valid forever, assume the remaining - // time to be the maximum possible value. - if prefixState.validUntil == (time.Time{}) { - rl = header.NDPInfiniteLifetime - } else { - rl = time.Until(prefixState.validUntil) + if effectiveVl != 0 { + prefixState.invalidationTimer.StopLocked() + prefixState.invalidationTimer.Reset(effectiveVl) + prefixState.validUntil = now.Add(effectiveVl) + } } - if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { - effectiveVl = vl - } else if rl <= MinPrefixInformationValidLifetimeForUpdate { + // If DAD is not yet complete on the stable address, there is no need to do + // work with temporary addresses. + if prefixState.stableAddrRef.getKind() != permanent { return - } else { - effectiveVl = MinPrefixInformationValidLifetimeForUpdate } - prefixState.invalidationTimer.StopLocked() - prefixState.invalidationTimer.Reset(effectiveVl) - prefixState.validUntil = now.Add(effectiveVl) + // Note, we do not need to update the entries in the temporary address map + // after updating the timers because the timers are held as pointers. + var regenForAddr tcpip.Address + allAddressesRegenerated := true + for tempAddr, tempAddrState := range prefixState.tempAddrs { + // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary + // address is the lower of the valid lifetime of the stable address or the + // maximum temporary address valid lifetime. Note, the valid lifetime of a + // temporary address is relative to the address's creation time. + validUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrValidLifetime) + if prefixState.validUntil != (time.Time{}) && validUntil.Sub(prefixState.validUntil) > 0 { + validUntil = prefixState.validUntil + } + + // If the address is no longer valid, invalidate it immediately. Otherwise, + // reset the invalidation timer. + newValidLifetime := validUntil.Sub(now) + if newValidLifetime <= 0 { + ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, tempAddr, tempAddrState) + continue + } + tempAddrState.invalidationTimer.StopLocked() + tempAddrState.invalidationTimer.Reset(newValidLifetime) + + // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary + // address is the lower of the preferred lifetime of the stable address or + // the maximum temporary address preferred lifetime - the temporary address + // desync factor. Note, the preferred lifetime of a temporary address is + // relative to the address's creation time. + preferredUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor) + if prefixState.preferredUntil != (time.Time{}) && preferredUntil.Sub(prefixState.preferredUntil) > 0 { + preferredUntil = prefixState.preferredUntil + } + + // If the address is no longer preferred, deprecate it immediately. + // Otherwise, reset the deprecation timer. + newPreferredLifetime := preferredUntil.Sub(now) + tempAddrState.deprecationTimer.StopLocked() + if newPreferredLifetime <= 0 { + ndp.deprecateSLAACAddress(tempAddrState.ref) + } else { + tempAddrState.ref.deprecated = false + tempAddrState.deprecationTimer.Reset(newPreferredLifetime) + } + + tempAddrState.regenTimer.StopLocked() + if tempAddrState.regenerated { + } else { + allAddressesRegenerated = false + + if newPreferredLifetime <= ndp.configs.RegenAdvanceDuration { + // The new preferred lifetime is less than the advance regeneration + // duration so regenerate an address for this temporary address + // immediately after we finish iterating over the temporary addresses. + regenForAddr = tempAddr + } else { + tempAddrState.regenTimer.Reset(newPreferredLifetime - ndp.configs.RegenAdvanceDuration) + } + } + } + + // Generate a new temporary address if all of the existing temporary addresses + // have been regenerated, or we need to immediately regenerate an address + // due to an update in preferred lifetime. + // + // If each temporay address has already been regenerated, no new temporary + // address will be generated. To ensure continuation of temporary SLAAC + // addresses, we manually try to regenerate an address here. + if len(regenForAddr) != 0 || allAddressesRegenerated { + // Reset the generation attempts counter as we are starting the generation + // of a new address for the SLAAC prefix. + if state, ok := prefixState.tempAddrs[regenForAddr]; ndp.generateTempSLAACAddr(prefix, prefixState, true /* resetGenAttempts */) && ok { + state.regenerated = true + prefixState.tempAddrs[regenForAddr] = state + } + } } // deprecateSLAACAddress marks ref as deprecated and notifies the stack's NDP @@ -1243,11 +1617,11 @@ func (ndp *ndpState) deprecateSLAACAddress(ref *referencedNetworkEndpoint) { // // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefixState) { - if r := state.ref; r != nil { + if r := state.stableAddrRef; r != nil { // Since we are already invalidating the prefix, do not invalidate the // prefix when removing the address. - if err := ndp.nic.removePermanentIPv6EndpointLocked(r, false /* allowSLAACPrefixInvalidation */); err != nil { - panic(fmt.Sprintf("ndp: removePermanentIPv6EndpointLocked(%s, false): %s", r.addrWithPrefix(), err)) + if err := ndp.nic.removePermanentIPv6EndpointLocked(r, false /* allowSLAACInvalidation */); err != nil { + panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", r.addrWithPrefix(), err)) } } @@ -1265,14 +1639,14 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr prefix := addr.Subnet() state, ok := ndp.slaacPrefixes[prefix] - if !ok || state.ref == nil || addr.Address != state.ref.ep.ID().LocalAddress { + if !ok || state.stableAddrRef == nil || addr.Address != state.stableAddrRef.ep.ID().LocalAddress { return } if !invalidatePrefix { // If the prefix is not being invalidated, disassociate the address from the // prefix and do nothing further. - state.ref = nil + state.stableAddrRef = nil ndp.slaacPrefixes[prefix] = state return } @@ -1286,11 +1660,68 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr // // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaacPrefixState) { + // Invalidate all temporary addresses. + for tempAddr, tempAddrState := range state.tempAddrs { + ndp.invalidateTempSLAACAddr(state.tempAddrs, tempAddr, tempAddrState) + } + + state.stableAddrRef = nil state.deprecationTimer.StopLocked() state.invalidationTimer.StopLocked() delete(ndp.slaacPrefixes, prefix) } +// invalidateTempSLAACAddr invalidates a temporary SLAAC address. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) { + // Since we are already invalidating the address, do not invalidate the + // address when removing the address. + if err := ndp.nic.removePermanentIPv6EndpointLocked(tempAddrState.ref, false /* allowSLAACInvalidation */); err != nil { + panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.ref.addrWithPrefix(), err)) + } + + ndp.cleanupTempSLAACAddrResources(tempAddrs, tempAddr, tempAddrState) +} + +// cleanupTempSLAACAddrResourcesAndNotify cleans up an invalidated temporary +// SLAAC address's resources from ndp. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) { + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr) + } + + if !invalidateAddr { + return + } + + prefix := addr.Subnet() + state, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry to clean up temp addr %s resources", addr)) + } + + tempAddrState, ok := state.tempAddrs[addr.Address] + if !ok { + panic(fmt.Sprintf("ndp: must have a tempAddr entry to clean up temp addr %s resources", addr)) + } + + ndp.cleanupTempSLAACAddrResources(state.tempAddrs, addr.Address, tempAddrState) +} + +// cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's +// timers and entry. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) { + tempAddrState.deprecationTimer.StopLocked() + tempAddrState.invalidationTimer.StopLocked() + tempAddrState.regenTimer.StopLocked() + delete(tempAddrs, tempAddr) +} + // cleanupState cleans up ndp's state. // // If hostOnly is true, then only host-specific state will be cleaned up. @@ -1450,3 +1881,13 @@ func (ndp *ndpState) stopSolicitingRouters() { ndp.rtrSolicitTimer.Stop() ndp.rtrSolicitTimer = nil } + +// initializeTempAddrState initializes state related to temporary SLAAC +// addresses. +func (ndp *ndpState) initializeTempAddrState() { + header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.nic.stack.tempIIDSeed, ndp.nic.ID()) + + if MaxDesyncFactor != 0 { + ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor))) + } +} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 6dd460984..421df674f 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1801,6 +1801,726 @@ func TestAutoGenAddr(t *testing.T) { } } +func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []tcpip.AddressWithPrefix) string { + ret := "" + for _, c := range containList { + if !containsV6Addr(addrs, c) { + ret += fmt.Sprintf("should have %s in the list of addresses\n", c) + } + } + for _, c := range notContainList { + if containsV6Addr(addrs, c) { + ret += fmt.Sprintf("should not have %s in the list of addresses\n", c) + } + } + return ret +} + +// TestAutoGenTempAddr tests that temporary SLAAC addresses are generated when +// configured to do so as part of IPv6 Privacy Extensions. +func TestAutoGenTempAddr(t *testing.T) { + const ( + nicID = 1 + newMinVL = 5 + newMinVLDuration = newMinVL * time.Second + ) + + savedMinPrefixInformationValidLifetimeForUpdate := stack.MinPrefixInformationValidLifetimeForUpdate + savedMaxDesync := stack.MaxDesyncFactor + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate + stack.MaxDesyncFactor = savedMaxDesync + }() + stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + stack.MaxDesyncFactor = time.Nanosecond + + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + + tests := []struct { + name string + dupAddrTransmits uint8 + retransmitTimer time.Duration + }{ + { + name: "DAD disabled", + }, + { + name: "DAD enabled", + dupAddrTransmits: 1, + retransmitTimer: time.Second, + }, + } + + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the + // parallel tests complete. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run("group", func(t *testing.T) { + for i, test := range tests { + i := i + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + seed := []byte{uint8(i)} + var tempIIDHistory [header.IIDSize]byte + header.InitialTempIID(tempIIDHistory[:], seed, nicID) + newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix { + return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr) + } + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 2), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: test.dupAddrTransmits, + RetransmitTimer: test.retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + TempIIDSeed: seed, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } + + expectDADEventAsync := func(addr tcpip.Address) { + t.Helper() + + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for DAD event") + } + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e) + default: + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, newAddr) + expectDADEventAsync(addr1.Address) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly got an auto gen addr event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero valid & preferred lifetimes. + tempAddr1 := newTempAddr(addr1.Address) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + expectAutoGenAddrEvent(tempAddr1, newAddr) + expectDADEventAsync(tempAddr1.Address) + if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Receive an RA with prefix2 in an NDP Prefix Information option (PI) + // with preferred lifetime > valid lifetime + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Receive an RA with prefix2 in a PI w/ non-zero valid and preferred + // lifetimes. + tempAddr2 := newTempAddr(addr2.Address) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + expectAutoGenAddrEvent(addr2, newAddr) + expectDADEventAsync(addr2.Address) + expectAutoGenAddrEventAsync(tempAddr2, newAddr) + expectDADEventAsync(tempAddr2.Address) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Deprecate prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Refresh lifetimes for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Reduce valid lifetime and deprecate addresses of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Wait for addrs of prefix1 to be invalidated. They should be + // invalidated at the same time. + select { + case e := <-ndpDisp.autoGenAddrC: + var nextAddr tcpip.AddressWithPrefix + if e.addr == addr1 { + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + nextAddr = tempAddr1 + } else { + if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + nextAddr = addr1 + } + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + case <-time.After(newMinVLDuration + defaultTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { + t.Fatal(mismatch) + } + + // Receive an RA with prefix2 in a PI w/ 0 lifetimes. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0)) + expectAutoGenAddrEvent(addr2, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr2, deprecatedAddr) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("got unexpected auto gen addr event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { + t.Fatal(mismatch) + } + }) + } + }) +} + +// TestNoAutoGenTempAddrForLinkLocal test that temporary SLAAC addresses are not +// generated for auto generated link-local addresses. +func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { + const nicID = 1 + + savedMaxDesyncFactor := stack.MaxDesyncFactor + defer func() { + stack.MaxDesyncFactor = savedMaxDesyncFactor + }() + stack.MaxDesyncFactor = time.Nanosecond + + tests := []struct { + name string + dupAddrTransmits uint8 + retransmitTimer time.Duration + }{ + { + name: "DAD disabled", + }, + { + name: "DAD enabled", + dupAddrTransmits: 1, + retransmitTimer: time.Second, + }, + } + + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the + // parallel tests complete. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run("group", func(t *testing.T) { + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + AutoGenIPv6LinkLocal: true, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + // The stable link-local address should auto-generate and resolve DAD. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, llAddr1, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for DAD event") + } + + // No new addresses should be generated. + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("got unxpected auto gen addr event = %+v", e) + case <-time.After(defaultAsyncEventTimeout): + } + }) + } + }) +} + +// TestNoAutoGenTempAddrWithoutStableAddr tests that a temporary SLAAC address +// will not be generated until after DAD completes, even if a new Router +// Advertisement is received to refresh lifetimes. +func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { + const ( + nicID = 1 + dadTransmits = 1 + retransmitTimer = 2 * time.Second + ) + + savedMaxDesyncFactor := stack.MaxDesyncFactor + defer func() { + stack.MaxDesyncFactor = savedMaxDesyncFactor + }() + stack.MaxDesyncFactor = 0 + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + var tempIIDHistory [header.IIDSize]byte + header.InitialTempIID(tempIIDHistory[:], nil, nicID) + tempAddr := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + // Receive an RA to trigger SLAAC for prefix. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + + // DAD on the stable address for prefix has not yet completed. Receiving a new + // RA that would refresh lifetimes should not generate a temporary SLAAC + // address for the prefix. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpected auto gen addr event = %+v", e) + default: + } + + // Wait for DAD to complete for the stable address then expect the temporary + // address to be generated. + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for DAD event") + } + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, tempAddr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } +} + +// TestAutoGenTempAddrRegen tests that temporary SLAAC addresses are +// regenerated. +func TestAutoGenTempAddrRegen(t *testing.T) { + const ( + nicID = 1 + regenAfter = 2 * time.Second + newMinVL = 10 + newMinVLDuration = newMinVL * time.Second + ) + + savedMaxDesyncFactor := stack.MaxDesyncFactor + savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime + savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime + defer func() { + stack.MaxDesyncFactor = savedMaxDesyncFactor + stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime + stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime + }() + stack.MaxDesyncFactor = 0 + stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration + stack.MinMaxTempAddrValidLifetime = newMinVLDuration + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + var tempIIDHistory [header.IIDSize]byte + header.InitialTempIID(tempIIDHistory[:], nil, nicID) + tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + ndpConfigs := stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + RegenAdvanceDuration: newMinVLDuration - regenAfter, + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero valid & preferred lifetimes. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + expectAutoGenAddrEvent(addr, newAddr) + expectAutoGenAddrEvent(tempAddr1, newAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Wait for regeneration + expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncEventTimeout) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Wait for regeneration + expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncEventTimeout) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2, tempAddr3}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Stop generating temporary addresses + ndpConfigs.AutoGenTempGlobalAddresses = false + if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { + t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + } + + // Wait for all the temporary addresses to get invalidated. + tempAddrs := []tcpip.AddressWithPrefix{tempAddr1, tempAddr2, tempAddr3} + invalidateAfter := newMinVLDuration - 2*regenAfter + for _, addr := range tempAddrs { + // Wait for a deprecation then invalidation event, or just an invalidation + // event. We need to cover both cases but cannot deterministically hit both + // cases because the deprecation and invalidation timers could fire in any + // order. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff == "" { + // If we get a deprecation event first, we should get an invalidation + // event almost immediately after. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } else if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff == "" { + // If we get an invalidation event first, we shouldn't get a deprecation + // event after. + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly got an auto-generated event = %+v", e) + case <-time.After(defaultTimeout): + } + } else { + t.Fatalf("got unexpected auto-generated event = %+v", e) + } + case <-time.After(invalidateAfter + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + + invalidateAfter = regenAfter + } + if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs); mismatch != "" { + t.Fatal(mismatch) + } +} + +// TestAutoGenTempAddrRegenTimerUpdates tests that a temporary address's +// regeneration timer gets updated when refreshing the address's lifetimes. +func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { + const ( + nicID = 1 + regenAfter = 2 * time.Second + newMinVL = 10 + newMinVLDuration = newMinVL * time.Second + ) + + savedMaxDesyncFactor := stack.MaxDesyncFactor + savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime + savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime + defer func() { + stack.MaxDesyncFactor = savedMaxDesyncFactor + stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime + stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime + }() + stack.MaxDesyncFactor = 0 + stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration + stack.MinMaxTempAddrValidLifetime = newMinVLDuration + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + var tempIIDHistory [header.IIDSize]byte + header.InitialTempIID(tempIIDHistory[:], nil, nicID) + tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + ndpConfigs := stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + RegenAdvanceDuration: newMinVLDuration - regenAfter, + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero valid & preferred lifetimes. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + expectAutoGenAddrEvent(addr, newAddr) + expectAutoGenAddrEvent(tempAddr1, newAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Deprecate the prefix. + // + // A new temporary address should be generated after the regeneration + // time has passed since the prefix is deprecated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 0)) + expectAutoGenAddrEvent(addr, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpected auto gen addr event = %+v", e) + case <-time.After(regenAfter + defaultAsyncEventTimeout): + } + + // Prefer the prefix again. + // + // A new temporary address should immediately be generated since the + // regeneration time has already passed since the last address was generated + // - this regeneration does not depend on a timer. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + expectAutoGenAddrEvent(tempAddr2, newAddr) + + // Increase the maximum lifetimes for temporary addresses to large values + // then refresh the lifetimes of the prefix. + // + // A new address should not be generated after the regeneration time that was + // expected for the previous check. This is because the preferred lifetime for + // the temporary addresses has increased, so it will take more time to + // regenerate a new temporary address. Note, new addresses are only + // regenerated after the preferred lifetime - the regenerate advance duration + // as paased. + ndpConfigs.MaxTempAddrValidLifetime = 100 * time.Second + ndpConfigs.MaxTempAddrPreferredLifetime = 100 * time.Second + if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { + t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + } + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpected auto gen addr event = %+v", e) + case <-time.After(regenAfter + defaultAsyncEventTimeout): + } + + // Set the maximum lifetimes for temporary addresses such that on the next + // RA, the regeneration timer gets reset. + // + // The maximum lifetime is the sum of the minimum lifetimes for temporary + // addresses + the time that has already passed since the last address was + // generated so that the regeneration timer is needed to generate the next + // address. + newLifetimes := newMinVLDuration + regenAfter + defaultAsyncEventTimeout + ndpConfigs.MaxTempAddrValidLifetime = newLifetimes + ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes + if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { + t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + } + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncEventTimeout) +} + // stackAndNdpDispatcherWithDefaultRoute returns an ndpDispatcher, // channel.Endpoint and stack.Stack. // @@ -2196,7 +2916,6 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { } else { t.Fatalf("got unexpected auto-generated event") } - case <-time.After(newMinVLDuration + defaultAsyncEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } @@ -2808,9 +3527,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { } } -// TestAutoGenAddrWithOpaqueIIDDADRetries tests the regeneration of an -// auto-generated IPv6 address in response to a DAD conflict. -func TestAutoGenAddrWithOpaqueIIDDADRetries(t *testing.T) { +func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { const nicID = 1 const nicName = "nic" const dadTransmits = 1 @@ -2818,6 +3535,13 @@ func TestAutoGenAddrWithOpaqueIIDDADRetries(t *testing.T) { const maxMaxRetries = 3 const lifetimeSeconds = 10 + // Needed for the temporary address sub test. + savedMaxDesync := stack.MaxDesyncFactor + defer func() { + stack.MaxDesyncFactor = savedMaxDesync + }() + stack.MaxDesyncFactor = time.Nanosecond + var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte secretKey := secretKeyBuf[:] n, err := rand.Read(secretKey) @@ -2830,185 +3554,234 @@ func TestAutoGenAddrWithOpaqueIIDDADRetries(t *testing.T) { prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) - for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ { - for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ { - addrTypes := []struct { - name string - ndpConfigs stack.NDPConfigurations - autoGenLinkLocal bool - subnet tcpip.Subnet - triggerSLAACFn func(e *channel.Endpoint) - }{ - { - name: "Global address", - ndpConfigs: stack.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenAddressConflictRetries: maxRetries, - }, - subnet: subnet, - triggerSLAACFn: func(e *channel.Endpoint) { - // Receive an RA with prefix1 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) + addrForSubnet := func(subnet tcpip.Subnet, dadCounter uint8) tcpip.AddressWithPrefix { + addrBytes := []byte(subnet.ID()) + return tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, dadCounter, secretKey)), + PrefixLen: 64, + } + } - }, - }, - { - name: "LinkLocal address", - ndpConfigs: stack.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - AutoGenAddressConflictRetries: maxRetries, - }, - autoGenLinkLocal: true, - subnet: header.IPv6LinkLocalPrefix.Subnet(), - triggerSLAACFn: func(e *channel.Endpoint) {}, - }, + expectAutoGenAddrEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("expected addr auto gen event") + } + } - for _, addrType := range addrTypes { - maxRetries := maxRetries - numFailures := numFailures - addrType := addrType + expectAutoGenAddrEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() - t.Run(fmt.Sprintf("%s with %d max retries and %d failures", addrType.name, maxRetries, numFailures), func(t *testing.T) { - t.Parallel() + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, - NDPConfigs: addrType.ndpConfigs, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName - }, - SecretKey: secretKey, - }, - }) - opts := stack.NICOptions{Name: nicName} - if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) - } + expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool) { + t.Helper() - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DAD event") + } + } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } + expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool) { + t.Helper() - addrType.triggerSLAACFn(e) + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for DAD event") + } + } - // Simulate DAD conflicts so the address is regenerated. - for i := uint8(0); i < numFailures; i++ { - addrBytes := []byte(addrType.subnet.ID()) - addr := tcpip.AddressWithPrefix{ - Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], addrType.subnet, nicName, i, secretKey)), - PrefixLen: 64, - } - expectAutoGenAddrEvent(addr, newAddr) + stableAddrForTempAddrTest := addrForSubnet(subnet, 0) - // Should not have any addresses assigned to the NIC. - mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) + addrTypes := []struct { + name string + ndpConfigs stack.NDPConfigurations + autoGenLinkLocal bool + prepareFn func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix + addrGenFn func(dadCounter uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix + }{ + { + name: "Global address", + ndpConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + prepareFn: func(_ *testing.T, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix { + // Receive an RA with prefix1 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) + return nil + + }, + addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix { + return addrForSubnet(subnet, dadCounter) + }, + }, + { + name: "LinkLocal address", + ndpConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + }, + autoGenLinkLocal: true, + prepareFn: func(*testing.T, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix { + return nil + }, + addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix { + return addrForSubnet(header.IPv6LinkLocalPrefix.Subnet(), dadCounter) + }, + }, + { + name: "Temporary address", + ndpConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + prepareFn: func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix { + header.InitialTempIID(tempIIDHistory, nil, nicID) + + // Generate a stable SLAAC address so temporary addresses will be + // generated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + expectAutoGenAddrEvent(t, ndpDisp, stableAddrForTempAddrTest, newAddr) + expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, true) + + // The stable address will be assigned throughout the test. + return []tcpip.AddressWithPrefix{stableAddrForTempAddrTest} + }, + addrGenFn: func(_ uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix { + return header.GenerateTempIPv6SLAACAddr(tempIIDHistory, stableAddrForTempAddrTest.Address) + }, + }, + } + + for _, addrType := range addrTypes { + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the parallel + // tests complete and limit the number of parallel tests running at the same + // time to reduce flakes. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run(addrType.name, func(t *testing.T) { + for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ { + for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ { + maxRetries := maxRetries + numFailures := numFailures + addrType := addrType + + t.Run(fmt.Sprintf("%d max retries and %d failures", maxRetries, numFailures), func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } - if want := (tcpip.AddressWithPrefix{}); mainAddr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, want) + e := channel.New(0, 1280, linkAddr1) + ndpConfigs := addrType.ndpConfigs + ndpConfigs.AutoGenAddressConflictRetries = maxRetries + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, + }, + }) + opts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) } - // Simulate a DAD conflict. - if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { - t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) - } - expectAutoGenAddrEvent(addr, invalidatedAddr) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + var tempIIDHistory [header.IIDSize]byte + stableAddrs := addrType.prepareFn(t, &ndpDisp, e, tempIIDHistory[:]) + + // Simulate DAD conflicts so the address is regenerated. + for i := uint8(0); i < numFailures; i++ { + addr := addrType.addrGenFn(i, tempIIDHistory[:]) + expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr) + + // Should not have any new addresses assigned to the NIC. + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" { + t.Fatal(mismatch) } - default: - t.Fatal("expected DAD event") - } - // Attempting to add the address manually should not fail if the - // address's state was cleaned up when DAD failed. - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err) - } - if err := s.RemoveAddress(nicID, addr.Address); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) - } - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + // Simulate a DAD conflict. + if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { + t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) } - default: - t.Fatal("expected DAD event") - } - } + expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr) + expectDADEvent(t, &ndpDisp, addr.Address, false) - // Should not have any addresses assigned to the NIC. - mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) - } - if want := (tcpip.AddressWithPrefix{}); mainAddr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, want) - } + // Attempting to add the address manually should not fail if the + // address's state was cleaned up when DAD failed. + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err) + } + if err := s.RemoveAddress(nicID, addr.Address); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) + } + expectDADEvent(t, &ndpDisp, addr.Address, false) + } - // If we had less failures than generation attempts, we should have an - // address after DAD resolves. - if maxRetries+1 > numFailures { - addrBytes := []byte(addrType.subnet.ID()) - addr := tcpip.AddressWithPrefix{ - Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], addrType.subnet, nicName, numFailures, secretKey)), - PrefixLen: 64, + // Should not have any new addresses assigned to the NIC. + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" { + t.Fatal(mismatch) } - expectAutoGenAddrEvent(addr, newAddr) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + // If we had less failures than generation attempts, we should have + // an address after DAD resolves. + if maxRetries+1 > numFailures { + addr := addrType.addrGenFn(numFailures, tempIIDHistory[:]) + expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr) + expectDADEventAsync(t, &ndpDisp, addr.Address, true) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, append(stableAddrs, addr), nil); mismatch != "" { + t.Fatal(mismatch) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): - t.Fatal("timed out waiting for DAD event") } - mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) - } - if mainAddr != addr { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, addr) + // Should not attempt address generation again. + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) + case <-time.After(defaultAsyncEventTimeout): } - } - - // Should not attempt address regeneration again. - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) - case <-time.After(defaultAsyncEventTimeout): - } - }) + }) + } } - } + }) } } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 0c2b1f36a..25188b4fb 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -131,6 +131,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState), } + nic.mu.ndp.initializeTempAddrState() // Register supported packet endpoint protocols. for _, netProto := range header.Ethertypes { @@ -1014,14 +1015,14 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { switch r.protocol { case header.IPv6ProtocolNumber: - return n.removePermanentIPv6EndpointLocked(r, true /* allowSLAAPrefixInvalidation */) + return n.removePermanentIPv6EndpointLocked(r, true /* allowSLAACInvalidation */) default: r.expireLocked() return nil } } -func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, allowSLAACPrefixInvalidation bool) *tcpip.Error { +func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, allowSLAACInvalidation bool) *tcpip.Error { addr := r.addrWithPrefix() isIPv6Unicast := header.IsV6UnicastAddress(addr.Address) @@ -1031,8 +1032,11 @@ func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, al // If we are removing an address generated via SLAAC, cleanup // its SLAAC resources and notify the integrator. - if r.configType == slaac { - n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACPrefixInvalidation) + switch r.configType { + case slaac: + n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation) + case slaacTemp: + n.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation) } } @@ -1203,12 +1207,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link n.stack.stats.IP.PacketsReceived.Increment() } - netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize()) - if !ok { + if len(pkt.Data.First()) < netProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - src, dst := netProto.ParseAddresses(netHeader) + + src, dst := netProto.ParseAddresses(pkt.Data.First()) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1289,8 +1293,22 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 { - pkt.Header = buffer.NewPrependable(linkHeaderLen) + + firstData := pkt.Data.First() + pkt.Data.RemoveFirst() + + if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 { + pkt.Header = buffer.NewPrependableFromView(firstData) + } else { + firstDataLen := len(firstData) + + // pkt.Header should have enough capacity to hold n.linkEP's headers. + pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen) + + // TODO(b/151227689): avoid copying the packet when forwarding + if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen { + panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen)) + } } if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { @@ -1318,13 +1336,12 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) - if !ok { + if len(pkt.Data.First()) < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(transHeader) + srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return @@ -1362,12 +1379,11 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - transHeader, ok := pkt.Data.PullUp(8) - if !ok { + if len(pkt.Data.First()) < 8 { return } - srcPort, dstPort, err := transProto.ParsePorts(transHeader) + srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) if err != nil { return } @@ -1436,12 +1452,19 @@ func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { // If the address is a SLAAC address, do not invalidate its SLAAC prefix as a // new address will be generated for it. - if err := n.removePermanentIPv6EndpointLocked(ref, false /* allowSLAACPrefixInvalidation */); err != nil { + if err := n.removePermanentIPv6EndpointLocked(ref, false /* allowSLAACInvalidation */); err != nil { return err } - if ref.configType == slaac { - n.mu.ndp.regenerateSLAACAddr(ref.addrWithPrefix().Subnet()) + prefix := ref.addrWithPrefix().Subnet() + + switch ref.configType { + case slaac: + n.mu.ndp.regenerateSLAACAddr(prefix) + case slaacTemp: + // Do not reset the generation attempts counter for the prefix as the + // temporary address is being regenerated in response to a DAD conflict. + n.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */) } return nil @@ -1540,9 +1563,14 @@ const ( // multicast group). static networkEndpointConfigType = iota - // A slaac configured endpoint is an IPv6 endpoint that was - // added by SLAAC as per RFC 4862 section 5.5.3. + // A SLAAC configured endpoint is an IPv6 endpoint that was added by + // SLAAC as per RFC 4862 section 5.5.3. slaac + + // A temporary SLAAC configured endpoint is an IPv6 endpoint that was added by + // SLAAC as per RFC 4941. Temporary SLAAC addresses are short-lived and are + // not expected to be valid (or preferred) forever; hence the term temporary. + slaacTemp ) type referencedNetworkEndpoint struct { diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 7d36f8e84..dc125f25e 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -37,13 +37,7 @@ type PacketBuffer struct { Data buffer.VectorisedView // Header holds the headers of outbound packets. As a packet is passed - // down the stack, each layer adds to Header. Note that forwarded - // packets don't populate Headers on their way out -- their headers and - // payload are never parsed out and remain in Data. - // - // TODO(gvisor.dev/issue/170): Forwarded packets don't currently - // populate Header, but should. This will be doable once early parsing - // (https://github.com/google/gvisor/pull/1995) is supported. + // down the stack, each layer adds to Header. Header buffer.Prependable // These fields are used by both inbound and outbound packets. They diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 41398a1b6..4a2dc3dc6 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -464,6 +464,10 @@ type Stack struct { // (IIDs) as outlined by RFC 7217. opaqueIIDOpts OpaqueInterfaceIdentifierOptions + // tempIIDSeed is used to seed the initial temporary interface identifier + // history value used to generate IIDs for temporary SLAAC addresses. + tempIIDSeed []byte + // forwarder holds the packets that wait for their link-address resolutions // to complete, and forwards them when each resolution is done. forwarder *forwardQueue @@ -541,6 +545,21 @@ type Options struct { // // RandSource must be thread-safe. RandSource mathrand.Source + + // TempIIDSeed is used to seed the initial temporary interface identifier + // history value used to generate IIDs for temporary SLAAC addresses. + // + // Temporary SLAAC adresses are short-lived addresses which are unpredictable + // and random from the perspective of other nodes on the network. It is + // recommended that the seed be a random byte buffer of at least + // header.IIDSize bytes to make sure that temporary SLAAC addresses are + // sufficiently random. It should follow minimum randomness requirements for + // security as outlined by RFC 4086. + // + // Note: using a nil value, the same seed across netstack program runs, or a + // seed that is too small would reduce randomness and increase predictability, + // defeating the purpose of temporary SLAAC addresses. + TempIIDSeed []byte } // TransportEndpointInfo holds useful information about a transport endpoint @@ -664,6 +683,7 @@ func New(opts Options) *Stack { uniqueIDGenerator: opts.UniqueID, ndpDisp: opts.NDPDisp, opaqueIIDOpts: opts.OpaqueIIDOpts, + tempIIDSeed: opts.TempIIDSeed, forwarder: newForwardQueue(), randomGenerator: mathrand.New(randSrc), } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index d45d2cc1f..c7634ceb1 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -95,18 +95,16 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffe f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Consume the network header. - b, ok := pkt.Data.PullUp(fakeNetHeaderLen) - if !ok { - return - } + b := pkt.Data.First() pkt.Data.TrimFront(fakeNetHeaderLen) // Handle control packets. if b[2] == uint8(fakeControlProtocol) { - nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) - if !ok { + nb := pkt.Data.First() + if len(nb) < fakeNetHeaderLen { return } + pkt.Data.TrimFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) return diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index a611e44ab..3084e6593 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -642,11 +642,10 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - hdrs := p.Pkt.Data.ToView() - if dst := hdrs[0]; dst != 3 { + if dst := p.Pkt.Header.View()[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := hdrs[1]; src != 1 { + if src := p.Pkt.Header.View()[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index b1d820372..feef8dca0 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -747,15 +747,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) - if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply { + h := header.ICMPv4(pkt.Data.First()) + if h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) - if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply { + h := header.ICMPv6(pkt.Data.First()) + if h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 7712ce652..40461fd31 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -144,11 +144,7 @@ func (s *segment) logicalLen() seqnum.Size { // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. func (s *segment) parse() bool { - h, ok := s.data.PullUp(header.TCPMinimumSize) - if !ok { - return false - } - hdr := header.TCP(h) + h := header.TCP(s.data.First()) // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: @@ -160,16 +156,12 @@ func (s *segment) parse() bool { // N.B. The segment has already been validated as having at least the // minimum TCP size before reaching here, so it's safe to read the // fields. - offset := int(hdr.DataOffset()) - if offset < header.TCPMinimumSize { - return false - } - hdrWithOpts, ok := s.data.PullUp(offset) - if !ok { + offset := int(h.DataOffset()) + if offset < header.TCPMinimumSize || offset > len(h) { return false } - s.options = []byte(hdrWithOpts[header.TCPMinimumSize:]) + s.options = []byte(h[header.TCPMinimumSize:offset]) s.parsedOptions = header.ParseTCPOptions(s.options) // Query the link capabilities to decide if checksum validation is @@ -181,19 +173,18 @@ func (s *segment) parse() bool { s.data.TrimFront(offset) } if verifyChecksum { - hdr = header.TCP(hdrWithOpts) - s.csum = hdr.Checksum() + s.csum = h.Checksum() xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size())) - xsum = hdr.CalculateChecksum(xsum) + xsum = h.CalculateChecksum(xsum) s.data.TrimFront(offset) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff } - s.sequenceNumber = seqnum.Value(hdr.SequenceNumber()) - s.ackNumber = seqnum.Value(hdr.AckNumber()) - s.flags = hdr.Flags() - s.window = seqnum.Size(hdr.WindowSize()) + s.sequenceNumber = seqnum.Value(h.SequenceNumber()) + s.ackNumber = seqnum.Value(h.AckNumber()) + s.flags = h.Flags() + s.window = seqnum.Size(h.WindowSize()) return true } diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 1dd63dd61..ace79b7b2 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -28,6 +28,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" + "gvisor.dev/gvisor/pkg/test/testutil" ) // createConnectedWithSACKPermittedOption creates and connects c.ep with the @@ -439,21 +440,28 @@ func TestSACKRecovery(t *testing.T) { // Receive the retransmitted packet. c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize) - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - t.Errorf("got %s.Value() = %v, want = %v", s.name, got, want) + metricPollFn := func() error { + tcpStats := c.Stack().Stats().TCP + stats := []struct { + stat *tcpip.StatCounter + name string + want uint64 + }{ + {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, + {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, + {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, + {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0}, + } + for _, s := range stats { + if got, want := s.stat.Value(), s.want; got != want { + return fmt.Errorf("got %s.Value() = %v, want = %v", s.name, got, want) + } } + return nil + } + + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) } // Now send 7 mode duplicate ACKs. In SACK TCP dupAcks do not cause @@ -517,22 +525,28 @@ func TestSACKRecovery(t *testing.T) { bytesRead += maxPayload } - // In SACK recovery only the first segment is fast retransmitted when - // entering recovery. - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) - } + metricPollFn = func() error { + // In SACK recovery only the first segment is fast retransmitted when + // entering recovery. + if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) + } - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want { - t.Errorf("got EP stats SendErrors.FastRetransmit = %v, want = %v", got, want) - } + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want { + return fmt.Errorf("got EP stats SendErrors.FastRetransmit = %v, want = %v", got, want) + } - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want { - t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) - } + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want { + return fmt.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) + } - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want { - t.Errorf("got EP stats Stats.SendErrors.Retransmits = %v, want = %v", got, want) + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want { + return fmt.Errorf("got EP stats Stats.SendErrors.Retransmits = %v, want = %v", got, want) + } + return nil + } + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) } c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond) diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 286c66cf5..7e574859b 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -35,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" + "gvisor.dev/gvisor/pkg/test/testutil" "gvisor.dev/gvisor/pkg/waiter" ) @@ -209,8 +210,15 @@ func TestTCPResetsSentIncrement(t *testing.T) { c.SendPacket(nil, ackHeaders) c.GetPacket() - if got := stats.TCP.ResetsSent.Value(); got != want { - t.Errorf("got stats.TCP.ResetsSent.Value() = %v, want = %v", got, want) + + metricPollFn := func() error { + if got := stats.TCP.ResetsSent.Value(); got != want { + return fmt.Errorf("got stats.TCP.ResetsSent.Value() = %v, want = %v", got, want) + } + return nil + } + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) } } @@ -3548,7 +3556,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.ToView()[header.IPv4MinimumSize:] + tcpbuf := vv.First()[header.IPv4MinimumSize:] tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 c.SendSegment(vv) @@ -3575,7 +3583,7 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.ToView()[header.IPv4MinimumSize:] + tcpbuf := vv.First()[header.IPv4MinimumSize:] // Overwrite a byte in the payload which should cause checksum // verification to fail. tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4 diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 756ab913a..edb54f0be 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1250,8 +1250,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { // Get the header then trim it from the view. - hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() { + hdr := header.UDP(pkt.Data.First()) + if int(hdr.Length()) > pkt.Data.Size() { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -1286,7 +1286,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, - Port: header.UDP(hdr).SourcePort(), + Port: hdr.SourcePort(), }, } packet.data = pkt.Data diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 52af6de22..6e31a9bac 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -68,13 +68,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // that don't match any existing endpoint. func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { // Get the header then trim it from the view. - h, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok { - // Malformed packet. - r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() - return true - } - if int(header.UDP(h).Length()) > pkt.Data.Size() { + hdr := header.UDP(pkt.Data.First()) + if int(hdr.Length()) > pkt.Data.Size() { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true |