summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/BUILD2
-rw-r--r--pkg/abi/linux/ptrace_amd64.go52
-rw-r--r--pkg/abi/linux/ptrace_arm64.go (renamed from pkg/sentry/arch/arch_state_aarch64.go)25
-rw-r--r--pkg/sentry/arch/BUILD4
-rw-r--r--pkg/sentry/arch/arch_aarch64.go26
-rw-r--r--pkg/sentry/arch/arch_amd64.go13
-rw-r--r--pkg/sentry/arch/arch_arm64.go2
-rw-r--r--pkg/sentry/arch/arch_state_x86.go42
-rw-r--r--pkg/sentry/arch/arch_x86.go25
-rw-r--r--pkg/sentry/arch/arch_x86_impl.go4
-rw-r--r--pkg/sentry/arch/signal.go3
-rw-r--r--pkg/sentry/arch/signal_act.go4
-rw-r--r--pkg/sentry/arch/signal_stack.go3
-rw-r--r--pkg/sentry/fs/gofer/socket.go4
-rw-r--r--pkg/sentry/fs/host/socket.go4
-rw-r--r--pkg/sentry/fs/host/socket_iovec.go4
-rw-r--r--pkg/sentry/fs/host/socket_unsafe.go4
-rw-r--r--pkg/sentry/fsimpl/ext/filesystem.go7
-rw-r--r--pkg/sentry/fsimpl/gofer/BUILD8
-rw-r--r--pkg/sentry/fsimpl/gofer/directory.go39
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go128
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go32
-rw-r--r--pkg/sentry/fsimpl/gofer/socket.go141
-rw-r--r--pkg/sentry/fsimpl/host/BUILD10
-rw-r--r--pkg/sentry/fsimpl/host/host.go57
-rw-r--r--pkg/sentry/fsimpl/host/socket.go397
-rw-r--r--pkg/sentry/fsimpl/host/socket_iovec.go113
-rw-r--r--pkg/sentry/fsimpl/host/socket_unsafe.go101
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go7
-rw-r--r--pkg/sentry/fsimpl/sockfs/sockfs.go20
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go5
-rw-r--r--pkg/sentry/kernel/task_signals.go8
-rw-r--r--pkg/sentry/kernel/task_start.go3
-rw-r--r--pkg/sentry/platform/kvm/kvm_arm64.go5
-rw-r--r--pkg/sentry/platform/kvm/kvm_test.go36
-rw-r--r--pkg/sentry/platform/kvm/testutil/BUILD1
-rw-r--r--pkg/sentry/platform/kvm/testutil/testutil_amd64.go17
-rw-r--r--pkg/sentry/platform/kvm/testutil/testutil_arm64.go13
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_amd64.go7
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_arm64.go5
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_unsafe.go4
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go8
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_amd64.go16
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_arm64.go16
-rw-r--r--pkg/sentry/platform/ring0/BUILD1
-rw-r--r--pkg/sentry/platform/ring0/defs.go9
-rw-r--r--pkg/sentry/platform/ring0/entry_amd64.go6
-rw-r--r--pkg/sentry/platform/ring0/gen_offsets/BUILD1
-rw-r--r--pkg/sentry/platform/ring0/offsets_amd64.go5
-rw-r--r--pkg/sentry/platform/ring0/offsets_arm64.go5
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go5
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go5
-rw-r--r--pkg/sentry/socket/unix/BUILD1
-rw-r--r--pkg/sentry/socket/unix/unix.go9
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go61
-rw-r--r--pkg/sentry/syscalls/linux/sys_signal.go10
-rw-r--r--pkg/sentry/vfs/anonfs.go5
-rw-r--r--pkg/sentry/vfs/filesystem.go10
-rw-r--r--pkg/sentry/vfs/options.go18
-rw-r--r--pkg/sentry/vfs/vfs.go54
-rw-r--r--pkg/sentry/watchdog/watchdog.go20
-rw-r--r--pkg/tcpip/buffer/view.go55
-rw-r--r--pkg/tcpip/buffer/view_test.go113
-rw-r--r--pkg/tcpip/header/icmpv4.go1
-rw-r--r--pkg/tcpip/header/ipv6.go52
-rw-r--r--pkg/tcpip/link/loopback/loopback.go10
-rw-r--r--pkg/tcpip/link/rawfile/BUILD9
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_test.go46
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go6
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go2
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go65
-rw-r--r--pkg/tcpip/network/arp/arp.go5
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go20
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go12
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go74
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go3
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go6
-rw-r--r--pkg/tcpip/stack/forwarder_test.go13
-rw-r--r--pkg/tcpip/stack/iptables.go22
-rw-r--r--pkg/tcpip/stack/iptables_targets.go23
-rw-r--r--pkg/tcpip/stack/ndp.go645
-rw-r--r--pkg/tcpip/stack/ndp_test.go1089
-rw-r--r--pkg/tcpip/stack/nic.go68
-rw-r--r--pkg/tcpip/stack/packet_buffer.go8
-rw-r--r--pkg/tcpip/stack/stack.go20
-rw-r--r--pkg/tcpip/stack/stack_test.go10
-rw-r--r--pkg/tcpip/stack/transport_test.go5
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go8
-rw-r--r--pkg/tcpip/transport/tcp/segment.go29
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go68
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go16
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go6
-rw-r--r--pkg/tcpip/transport/udp/protocol.go9
93 files changed, 3079 insertions, 1089 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index 59b0e138a..114b516e2 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -44,6 +44,8 @@ go_library(
"poll.go",
"prctl.go",
"ptrace.go",
+ "ptrace_amd64.go",
+ "ptrace_arm64.go",
"rseq.go",
"rusage.go",
"sched.go",
diff --git a/pkg/abi/linux/ptrace_amd64.go b/pkg/abi/linux/ptrace_amd64.go
new file mode 100644
index 000000000..ed3881e27
--- /dev/null
+++ b/pkg/abi/linux/ptrace_amd64.go
@@ -0,0 +1,52 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package linux
+
+// PtraceRegs is the set of CPU registers exposed by ptrace. Source:
+// syscall.PtraceRegs.
+//
+// +marshal
+// +stateify savable
+type PtraceRegs struct {
+ R15 uint64
+ R14 uint64
+ R13 uint64
+ R12 uint64
+ Rbp uint64
+ Rbx uint64
+ R11 uint64
+ R10 uint64
+ R9 uint64
+ R8 uint64
+ Rax uint64
+ Rcx uint64
+ Rdx uint64
+ Rsi uint64
+ Rdi uint64
+ Orig_rax uint64
+ Rip uint64
+ Cs uint64
+ Eflags uint64
+ Rsp uint64
+ Ss uint64
+ Fs_base uint64
+ Gs_base uint64
+ Ds uint64
+ Es uint64
+ Fs uint64
+ Gs uint64
+}
diff --git a/pkg/sentry/arch/arch_state_aarch64.go b/pkg/abi/linux/ptrace_arm64.go
index 0136a85ad..6147738b3 100644
--- a/pkg/sentry/arch/arch_state_aarch64.go
+++ b/pkg/abi/linux/ptrace_arm64.go
@@ -1,4 +1,4 @@
-// Copyright 2020 The gVisor Authors.
+// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -14,25 +14,16 @@
// +build arm64
-package arch
+package linux
-import (
- "syscall"
-)
-
-type syscallPtraceRegs struct {
+// PtraceRegs is the set of CPU registers exposed by ptrace. Source:
+// syscall.PtraceRegs.
+//
+// +marshal
+// +stateify savable
+type PtraceRegs struct {
Regs [31]uint64
Sp uint64
Pc uint64
Pstate uint64
}
-
-// saveRegs is invoked by stateify.
-func (s *State) saveRegs() syscallPtraceRegs {
- return syscallPtraceRegs(s.Regs)
-}
-
-// loadRegs is invoked by stateify.
-func (s *State) loadRegs(r syscallPtraceRegs) {
- s.Regs = syscall.PtraceRegs(r)
-}
diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD
index e27f21e5e..901e0f320 100644
--- a/pkg/sentry/arch/BUILD
+++ b/pkg/sentry/arch/BUILD
@@ -11,7 +11,6 @@ go_library(
"arch_amd64.go",
"arch_amd64.s",
"arch_arm64.go",
- "arch_state_aarch64.go",
"arch_state_x86.go",
"arch_x86.go",
"arch_x86_impl.go",
@@ -26,11 +25,11 @@ go_library(
"syscalls_amd64.go",
"syscalls_arm64.go",
],
+ marshal = True,
visibility = ["//:sandbox"],
deps = [
":registers_go_proto",
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/context",
"//pkg/cpuid",
"//pkg/log",
@@ -38,6 +37,7 @@ go_library(
"//pkg/sync",
"//pkg/syserror",
"//pkg/usermem",
+ "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go
index c29e1b841..529980267 100644
--- a/pkg/sentry/arch/arch_aarch64.go
+++ b/pkg/sentry/arch/arch_aarch64.go
@@ -17,18 +17,20 @@
package arch
import (
+ "encoding/binary"
"fmt"
"io"
- "syscall"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/log"
rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
+// Registers represents the CPU registers for this architecture.
+type Registers = linux.PtraceRegs
+
const (
// SyscallWidth is the width of insturctions.
SyscallWidth = 4
@@ -90,7 +92,7 @@ func NewFloatingPointData() *FloatingPointData {
// file ensures it's only built on aarch64).
type State struct {
// The system registers.
- Regs syscall.PtraceRegs `state:".(syscallPtraceRegs)"`
+ Regs Registers
// Our floating point state.
aarch64FPState `state:"wait"`
@@ -226,25 +228,27 @@ func (s *State) RegisterMap() (map[string]uintptr, error) {
// PtraceGetRegs implements Context.PtraceGetRegs.
func (s *State) PtraceGetRegs(dst io.Writer) (int, error) {
- return dst.Write(binary.Marshal(nil, usermem.ByteOrder, s.ptraceGetRegs()))
+ regs := s.ptraceGetRegs()
+ n, err := regs.WriteTo(dst)
+ return int(n), err
}
-func (s *State) ptraceGetRegs() syscall.PtraceRegs {
+func (s *State) ptraceGetRegs() Registers {
return s.Regs
}
-var ptraceRegsSize = int(binary.Size(syscall.PtraceRegs{}))
+var registersSize = (*Registers)(nil).SizeBytes()
// PtraceSetRegs implements Context.PtraceSetRegs.
func (s *State) PtraceSetRegs(src io.Reader) (int, error) {
- var regs syscall.PtraceRegs
- buf := make([]byte, ptraceRegsSize)
+ var regs Registers
+ buf := make([]byte, registersSize)
if _, err := io.ReadFull(src, buf); err != nil {
return 0, err
}
- binary.Unmarshal(buf, usermem.ByteOrder, &regs)
+ regs.UnmarshalUnsafe(buf)
s.Regs = regs
- return ptraceRegsSize, nil
+ return registersSize, nil
}
// PtraceGetFPRegs implements Context.PtraceGetFPRegs.
diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go
index 85d6acc0f..3b3a0a272 100644
--- a/pkg/sentry/arch/arch_amd64.go
+++ b/pkg/sentry/arch/arch_amd64.go
@@ -22,7 +22,6 @@ import (
"math/rand"
"syscall"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/usermem"
@@ -301,8 +300,10 @@ func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) {
// PTRACE_PEEKUSER and PTRACE_POKEUSER are only effective on regs and
// u_debugreg, returning 0 or silently no-oping for other fields
// respectively.
- if addr < uintptr(ptraceRegsSize) {
- buf := binary.Marshal(nil, usermem.ByteOrder, c.ptraceGetRegs())
+ if addr < uintptr(registersSize) {
+ regs := c.ptraceGetRegs()
+ buf := make([]byte, regs.SizeBytes())
+ regs.MarshalUnsafe(buf)
return c.Native(uintptr(usermem.ByteOrder.Uint64(buf[addr:]))), nil
}
// Note: x86 debug registers are missing.
@@ -314,8 +315,10 @@ func (c *context64) PtracePokeUser(addr, data uintptr) error {
if addr&7 != 0 || addr >= userStructSize {
return syscall.EIO
}
- if addr < uintptr(ptraceRegsSize) {
- buf := binary.Marshal(nil, usermem.ByteOrder, c.ptraceGetRegs())
+ if addr < uintptr(registersSize) {
+ regs := c.ptraceGetRegs()
+ buf := make([]byte, regs.SizeBytes())
+ regs.MarshalUnsafe(buf)
usermem.ByteOrder.PutUint64(buf[addr:], uint64(data))
_, err := c.PtraceSetRegs(bytes.NewBuffer(buf))
return err
diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go
index db99c5acb..ada7ac7b8 100644
--- a/pkg/sentry/arch/arch_arm64.go
+++ b/pkg/sentry/arch/arch_arm64.go
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build arm64
+
package arch
import (
diff --git a/pkg/sentry/arch/arch_state_x86.go b/pkg/sentry/arch/arch_state_x86.go
index aa31169e0..19ce99d25 100644
--- a/pkg/sentry/arch/arch_state_x86.go
+++ b/pkg/sentry/arch/arch_state_x86.go
@@ -18,7 +18,6 @@ package arch
import (
"fmt"
- "syscall"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/usermem"
@@ -90,44 +89,3 @@ func (s *State) afterLoadFPState() {
// Copy to the new, aligned location.
copy(s.x86FPState, old)
}
-
-// +stateify savable
-type syscallPtraceRegs struct {
- R15 uint64
- R14 uint64
- R13 uint64
- R12 uint64
- Rbp uint64
- Rbx uint64
- R11 uint64
- R10 uint64
- R9 uint64
- R8 uint64
- Rax uint64
- Rcx uint64
- Rdx uint64
- Rsi uint64
- Rdi uint64
- Orig_rax uint64
- Rip uint64
- Cs uint64
- Eflags uint64
- Rsp uint64
- Ss uint64
- Fs_base uint64
- Gs_base uint64
- Ds uint64
- Es uint64
- Fs uint64
- Gs uint64
-}
-
-// saveRegs is invoked by stateify.
-func (s *State) saveRegs() syscallPtraceRegs {
- return syscallPtraceRegs(s.Regs)
-}
-
-// loadRegs is invoked by stateify.
-func (s *State) loadRegs(r syscallPtraceRegs) {
- s.Regs = syscall.PtraceRegs(r)
-}
diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go
index 7fc4c0473..dc458b37f 100644
--- a/pkg/sentry/arch/arch_x86.go
+++ b/pkg/sentry/arch/arch_x86.go
@@ -21,7 +21,7 @@ import (
"io"
"syscall"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/log"
rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto"
@@ -30,6 +30,9 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// Registers represents the CPU registers for this architecture.
+type Registers = linux.PtraceRegs
+
// System-related constants for x86.
const (
// SyscallWidth is the width of syscall, sysenter, and int 80 insturctions.
@@ -267,10 +270,12 @@ func (s *State) RegisterMap() (map[string]uintptr, error) {
// PtraceGetRegs implements Context.PtraceGetRegs.
func (s *State) PtraceGetRegs(dst io.Writer) (int, error) {
- return dst.Write(binary.Marshal(nil, usermem.ByteOrder, s.ptraceGetRegs()))
+ regs := s.ptraceGetRegs()
+ n, err := regs.WriteTo(dst)
+ return int(n), err
}
-func (s *State) ptraceGetRegs() syscall.PtraceRegs {
+func (s *State) ptraceGetRegs() Registers {
regs := s.Regs
// These may not be initialized.
if regs.Cs == 0 || regs.Ss == 0 || regs.Eflags == 0 {
@@ -306,16 +311,16 @@ func (s *State) ptraceGetRegs() syscall.PtraceRegs {
return regs
}
-var ptraceRegsSize = int(binary.Size(syscall.PtraceRegs{}))
+var registersSize = (*Registers)(nil).SizeBytes()
// PtraceSetRegs implements Context.PtraceSetRegs.
func (s *State) PtraceSetRegs(src io.Reader) (int, error) {
- var regs syscall.PtraceRegs
- buf := make([]byte, ptraceRegsSize)
+ var regs Registers
+ buf := make([]byte, registersSize)
if _, err := io.ReadFull(src, buf); err != nil {
return 0, err
}
- binary.Unmarshal(buf, usermem.ByteOrder, &regs)
+ regs.UnmarshalUnsafe(buf)
// Truncate segment registers to 16 bits.
regs.Cs = uint64(uint16(regs.Cs))
regs.Ds = uint64(uint16(regs.Ds))
@@ -369,7 +374,7 @@ func (s *State) PtraceSetRegs(src io.Reader) (int, error) {
}
regs.Eflags = (s.Regs.Eflags &^ eflagsPtraceMutable) | (regs.Eflags & eflagsPtraceMutable)
s.Regs = regs
- return ptraceRegsSize, nil
+ return registersSize, nil
}
// isUserSegmentSelector returns true if the given segment selector specifies a
@@ -538,7 +543,7 @@ const (
func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error) {
switch regset {
case _NT_PRSTATUS:
- if maxlen < ptraceRegsSize {
+ if maxlen < registersSize {
return 0, syserror.EFAULT
}
return s.PtraceGetRegs(dst)
@@ -558,7 +563,7 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int,
func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error) {
switch regset {
case _NT_PRSTATUS:
- if maxlen < ptraceRegsSize {
+ if maxlen < registersSize {
return 0, syserror.EFAULT
}
return s.PtraceSetRegs(src)
diff --git a/pkg/sentry/arch/arch_x86_impl.go b/pkg/sentry/arch/arch_x86_impl.go
index 3edf40764..0c73fcbfb 100644
--- a/pkg/sentry/arch/arch_x86_impl.go
+++ b/pkg/sentry/arch/arch_x86_impl.go
@@ -17,8 +17,6 @@
package arch
import (
- "syscall"
-
"gvisor.dev/gvisor/pkg/cpuid"
)
@@ -28,7 +26,7 @@ import (
// +stateify savable
type State struct {
// The system registers.
- Regs syscall.PtraceRegs `state:".(syscallPtraceRegs)"`
+ Regs Registers
// Our floating point state.
x86FPState `state:"wait"`
diff --git a/pkg/sentry/arch/signal.go b/pkg/sentry/arch/signal.go
index 8b03d0187..c9fb55d00 100644
--- a/pkg/sentry/arch/signal.go
+++ b/pkg/sentry/arch/signal.go
@@ -22,6 +22,7 @@ import (
// SignalAct represents the action that should be taken when a signal is
// delivered, and is equivalent to struct sigaction.
//
+// +marshal
// +stateify savable
type SignalAct struct {
Handler uint64
@@ -43,6 +44,7 @@ func (s *SignalAct) DeserializeTo(other *SignalAct) {
// SignalStack represents information about a user stack, and is equivalent to
// stack_t.
//
+// +marshal
// +stateify savable
type SignalStack struct {
Addr uint64
@@ -64,6 +66,7 @@ func (s *SignalStack) DeserializeTo(other *SignalStack) {
// SignalInfo represents information about a signal being delivered, and is
// equivalent to struct siginfo in linux kernel(linux/include/uapi/asm-generic/siginfo.h).
//
+// +marshal
// +stateify savable
type SignalInfo struct {
Signo int32 // Signal number
diff --git a/pkg/sentry/arch/signal_act.go b/pkg/sentry/arch/signal_act.go
index f9ca2e74e..32173aa20 100644
--- a/pkg/sentry/arch/signal_act.go
+++ b/pkg/sentry/arch/signal_act.go
@@ -14,6 +14,8 @@
package arch
+import "gvisor.dev/gvisor/tools/go_marshal/marshal"
+
// Special values for SignalAct.Handler.
const (
// SignalActDefault is SIG_DFL and specifies that the default behavior for
@@ -71,6 +73,8 @@ func (s SignalAct) HasRestorer() bool {
// NativeSignalAct is a type that is equivalent to struct sigaction in the
// guest architecture.
type NativeSignalAct interface {
+ marshal.Marshallable
+
// SerializeFrom copies the data in the host SignalAct s into this object.
SerializeFrom(s *SignalAct)
diff --git a/pkg/sentry/arch/signal_stack.go b/pkg/sentry/arch/signal_stack.go
index e58f055c7..0fa738a1d 100644
--- a/pkg/sentry/arch/signal_stack.go
+++ b/pkg/sentry/arch/signal_stack.go
@@ -18,6 +18,7 @@ package arch
import (
"gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
const (
@@ -55,6 +56,8 @@ func (s *SignalStack) Contains(sp usermem.Addr) bool {
// NativeSignalStack is a type that is equivalent to stack_t in the guest
// architecture.
type NativeSignalStack interface {
+ marshal.Marshallable
+
// SerializeFrom copies the data in the host SignalStack s into this
// object.
SerializeFrom(s *SignalStack)
diff --git a/pkg/sentry/fs/gofer/socket.go b/pkg/sentry/fs/gofer/socket.go
index 10ba2f5f0..40f2c1cad 100644
--- a/pkg/sentry/fs/gofer/socket.go
+++ b/pkg/sentry/fs/gofer/socket.go
@@ -47,6 +47,8 @@ func (i *inodeOperations) BoundEndpoint(inode *fs.Inode, path string) transport.
return &endpoint{inode, i.fileState.file.file, path}
}
+// LINT.IfChange
+
// endpoint is a Gofer-backed transport.BoundEndpoint.
//
// An endpoint's lifetime is the time between when InodeOperations.BoundEndpoint()
@@ -146,3 +148,5 @@ func (e *endpoint) Release() {
func (e *endpoint) Passcred() bool {
return false
}
+
+// LINT.ThenChange(../../fsimpl/gofer/socket.go)
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
index 06fc2d80a..b6e94583e 100644
--- a/pkg/sentry/fs/host/socket.go
+++ b/pkg/sentry/fs/host/socket.go
@@ -37,6 +37,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// maxSendBufferSize is the maximum host send buffer size allowed for endpoint.
//
// N.B. 8MB is the default maximum on Linux (2 * sysctl_wmem_max).
@@ -388,3 +390,5 @@ func (c *ConnectedEndpoint) Release() {
// CloseUnread implements transport.ConnectedEndpoint.CloseUnread.
func (c *ConnectedEndpoint) CloseUnread() {}
+
+// LINT.ThenChange(../../fsimpl/host/socket.go)
diff --git a/pkg/sentry/fs/host/socket_iovec.go b/pkg/sentry/fs/host/socket_iovec.go
index af6955675..5c18dbd5e 100644
--- a/pkg/sentry/fs/host/socket_iovec.go
+++ b/pkg/sentry/fs/host/socket_iovec.go
@@ -21,6 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+// LINT.IfChange
+
// maxIovs is the maximum number of iovecs to pass to the host.
var maxIovs = linux.UIO_MAXIOV
@@ -111,3 +113,5 @@ func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovec
return total, iovecs, nil, err
}
+
+// LINT.ThenChange(../../fsimpl/host/socket_iovec.go)
diff --git a/pkg/sentry/fs/host/socket_unsafe.go b/pkg/sentry/fs/host/socket_unsafe.go
index f3bbed7ea..5d4f312cf 100644
--- a/pkg/sentry/fs/host/socket_unsafe.go
+++ b/pkg/sentry/fs/host/socket_unsafe.go
@@ -19,6 +19,8 @@ import (
"unsafe"
)
+// LINT.IfChange
+
// fdReadVec receives from fd to bufs.
//
// If the total length of bufs is > maxlen, fdReadVec will do a partial read
@@ -99,3 +101,5 @@ func fdWriteVec(fd int, bufs [][]byte, maxlen int64, truncate bool) (int64, int6
return int64(n), length, err
}
+
+// LINT.ThenChange(../../fsimpl/host/socket_unsafe.go)
diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go
index 2c22a04af..77b644275 100644
--- a/pkg/sentry/fsimpl/ext/filesystem.go
+++ b/pkg/sentry/fsimpl/ext/filesystem.go
@@ -485,11 +485,14 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
}
// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
-func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath) (transport.BoundEndpoint, error) {
- _, _, err := fs.walk(rp, false)
+func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
+ _, inode, err := fs.walk(rp, false)
if err != nil {
return nil, err
}
+ if err := inode.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
+ return nil, err
+ }
// TODO(b/134676337): Support sockets.
return nil, syserror.ECONNREFUSED
diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD
index b9c4beee4..5ce82b793 100644
--- a/pkg/sentry/fsimpl/gofer/BUILD
+++ b/pkg/sentry/fsimpl/gofer/BUILD
@@ -38,6 +38,7 @@ go_library(
"p9file.go",
"pagemath.go",
"regular_file.go",
+ "socket.go",
"special_file.go",
"symlink.go",
"time.go",
@@ -52,18 +53,25 @@ go_library(
"//pkg/p9",
"//pkg/safemem",
"//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fsimpl/host",
"//pkg/sentry/hostfd",
+ "//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/pipe",
"//pkg/sentry/kernel/time",
"//pkg/sentry/memmap",
"//pkg/sentry/pgalloc",
"//pkg/sentry/platform",
+ "//pkg/sentry/socket/control",
+ "//pkg/sentry/socket/unix",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usage",
"//pkg/sentry/vfs",
+ "//pkg/syserr",
"//pkg/syserror",
"//pkg/unet",
"//pkg/usermem",
+ "//pkg/waiter",
],
)
diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go
index 55f9ed911..b98218753 100644
--- a/pkg/sentry/fsimpl/gofer/directory.go
+++ b/pkg/sentry/fsimpl/gofer/directory.go
@@ -15,6 +15,7 @@
package gofer
import (
+ "fmt"
"sync"
"sync/atomic"
@@ -22,6 +23,8 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -59,28 +62,52 @@ func (d *dentry) cacheNegativeLookupLocked(name string) {
d.children[name] = nil
}
-// createSyntheticDirectory creates a synthetic directory with the given name
+type createSyntheticOpts struct {
+ name string
+ mode linux.FileMode
+ kuid auth.KUID
+ kgid auth.KGID
+
+ // The endpoint for a synthetic socket. endpoint should be nil if the file
+ // being created is not a socket.
+ endpoint transport.BoundEndpoint
+
+ // pipe should be nil if the file being created is not a pipe.
+ pipe *pipe.VFSPipe
+}
+
+// createSyntheticChildLocked creates a synthetic file with the given name
// in d.
//
// Preconditions: d.dirMu must be locked. d.isDir(). d does not already contain
// a child with the given name.
-func (d *dentry) createSyntheticDirectoryLocked(name string, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) {
+func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) {
d2 := &dentry{
refs: 1, // held by d
fs: d.fs,
- mode: uint32(mode) | linux.S_IFDIR,
- uid: uint32(kuid),
- gid: uint32(kgid),
+ mode: uint32(opts.mode),
+ uid: uint32(opts.kuid),
+ gid: uint32(opts.kgid),
blockSize: usermem.PageSize, // arbitrary
handle: handle{
fd: -1,
},
nlink: uint32(2),
}
+ switch opts.mode.FileType() {
+ case linux.S_IFDIR:
+ // Nothing else needs to be done.
+ case linux.S_IFSOCK:
+ d2.endpoint = opts.endpoint
+ case linux.S_IFIFO:
+ d2.pipe = opts.pipe
+ default:
+ panic(fmt.Sprintf("failed to create synthetic file of unrecognized type: %v", opts.mode.FileType()))
+ }
d2.pf.dentry = d2
d2.vfsd.Init(d2)
- d.cacheNewChildLocked(d2, name)
+ d.cacheNewChildLocked(d2, opts.name)
d.syntheticChildren++
}
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 98ccb42fd..4a8411371 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -22,9 +22,11 @@ import (
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// Sync implements vfs.FilesystemImpl.Sync.
@@ -652,7 +654,12 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
return err
}
ctx.Infof("Failed to create remote directory %q: %v; falling back to synthetic directory", name, err)
- parent.createSyntheticDirectoryLocked(name, opts.Mode, creds.EffectiveKUID, creds.EffectiveKGID)
+ parent.createSyntheticChildLocked(&createSyntheticOpts{
+ name: name,
+ mode: linux.S_IFDIR | opts.Mode,
+ kuid: creds.EffectiveKUID,
+ kgid: creds.EffectiveKGID,
+ })
}
if fs.opts.interop != InteropModeShared {
parent.incLinks()
@@ -663,7 +670,12 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
// Can't create non-synthetic files in synthetic directories.
return syserror.EPERM
}
- parent.createSyntheticDirectoryLocked(name, opts.Mode, creds.EffectiveKUID, creds.EffectiveKGID)
+ parent.createSyntheticChildLocked(&createSyntheticOpts{
+ name: name,
+ mode: linux.S_IFDIR | opts.Mode,
+ kuid: creds.EffectiveKUID,
+ kgid: creds.EffectiveKGID,
+ })
parent.incLinks()
return nil
})
@@ -674,6 +686,28 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string) error {
creds := rp.Credentials()
_, err := parent.file.mknod(ctx, name, (p9.FileMode)(opts.Mode), opts.DevMajor, opts.DevMinor, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
+ if err == syserror.EPERM {
+ switch opts.Mode.FileType() {
+ case linux.S_IFSOCK:
+ parent.createSyntheticChildLocked(&createSyntheticOpts{
+ name: name,
+ mode: opts.Mode,
+ kuid: creds.EffectiveKUID,
+ kgid: creds.EffectiveKGID,
+ endpoint: opts.Endpoint,
+ })
+ return nil
+ case linux.S_IFIFO:
+ parent.createSyntheticChildLocked(&createSyntheticOpts{
+ name: name,
+ mode: opts.Mode,
+ kuid: creds.EffectiveKUID,
+ kgid: creds.EffectiveKGID,
+ pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize),
+ })
+ return nil
+ }
+ }
return err
}, nil)
}
@@ -756,20 +790,21 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
return nil, err
}
mnt := rp.Mount()
- filetype := d.fileType()
- switch {
- case filetype == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD:
- if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, opts.Flags&linux.O_TRUNC != 0); err != nil {
- return nil, err
- }
- fd := &regularFileFD{}
- if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{
- AllowDirectIO: true,
- }); err != nil {
- return nil, err
+ switch d.fileType() {
+ case linux.S_IFREG:
+ if !d.fs.opts.regularFilesUseSpecialFileFD {
+ if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, opts.Flags&linux.O_TRUNC != 0); err != nil {
+ return nil, err
+ }
+ fd := &regularFileFD{}
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{
+ AllowDirectIO: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
}
- return &fd.vfsfd, nil
- case filetype == linux.S_IFDIR:
+ case linux.S_IFDIR:
// Can't open directories with O_CREAT.
if opts.Flags&linux.O_CREAT != 0 {
return nil, syserror.EISDIR
@@ -791,26 +826,40 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
return nil, err
}
return &fd.vfsfd, nil
- case filetype == linux.S_IFLNK:
+ case linux.S_IFLNK:
// Can't open symlinks without O_PATH (which is unimplemented).
return nil, syserror.ELOOP
- default:
- if opts.Flags&linux.O_DIRECT != 0 {
- return nil, syserror.EINVAL
- }
- h, err := openHandle(ctx, d.file, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, opts.Flags&linux.O_TRUNC != 0)
- if err != nil {
- return nil, err
- }
- fd := &specialFileFD{
- handle: h,
+ case linux.S_IFSOCK:
+ if d.isSynthetic() {
+ return nil, syserror.ENXIO
}
- if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
- h.close(ctx)
- return nil, err
+ case linux.S_IFIFO:
+ if d.isSynthetic() {
+ return d.pipe.Open(ctx, mnt, &d.vfsd, opts.Flags)
}
- return &fd.vfsfd, nil
}
+ return d.openSpecialFileLocked(ctx, mnt, opts)
+}
+
+func (d *dentry) openSpecialFileLocked(ctx context.Context, mnt *vfs.Mount, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
+ ats := vfs.AccessTypesForOpenFlags(opts)
+ // Treat as a special file. This is done for non-synthetic pipes as well as
+ // regular files when d.fs.opts.regularFilesUseSpecialFileFD is true.
+ if opts.Flags&linux.O_DIRECT != 0 {
+ return nil, syserror.EINVAL
+ }
+ h, err := openHandle(ctx, d.file, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, opts.Flags&linux.O_TRUNC != 0)
+ if err != nil {
+ return nil, err
+ }
+ fd := &specialFileFD{
+ handle: h,
+ }
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ h.close(ctx)
+ return nil, err
+ }
+ return &fd.vfsfd, nil
}
// Preconditions: d.fs.renameMu must be locked. d.dirMu must be locked.
@@ -1196,15 +1245,28 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
}
// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
-func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath) (transport.BoundEndpoint, error) {
+func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
var ds *[]*dentry
fs.renameMu.RLock()
defer fs.renameMuRUnlockAndCheckCaching(&ds)
- _, err := fs.resolveLocked(ctx, rp, &ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return nil, err
}
- // TODO(gvisor.dev/issue/1476): Implement BoundEndpointAt.
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
+ return nil, err
+ }
+ if d.isSocket() {
+ if !d.isSynthetic() {
+ d.IncRef()
+ return &endpoint{
+ dentry: d,
+ file: d.file.file,
+ path: opts.Addr,
+ }, nil
+ }
+ return d.endpoint, nil
+ }
return nil, syserror.ECONNREFUSED
}
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 8b4e91d17..3235816c2 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -46,9 +46,11 @@ import (
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/unet"
@@ -585,6 +587,14 @@ type dentry struct {
// and target are protected by dataMu.
haveTarget bool
target string
+
+ // If this dentry represents a synthetic socket file, endpoint is the
+ // transport endpoint bound to this file.
+ endpoint transport.BoundEndpoint
+
+ // If this dentry represents a synthetic named pipe, pipe is the pipe
+ // endpoint bound to this file.
+ pipe *pipe.VFSPipe
}
// dentryAttrMask returns a p9.AttrMask enabling all attributes used by the
@@ -791,10 +801,21 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin
setLocalAtime = stat.Mask&linux.STATX_ATIME != 0
setLocalMtime = stat.Mask&linux.STATX_MTIME != 0
stat.Mask &^= linux.STATX_ATIME | linux.STATX_MTIME
- if !setLocalMtime && (stat.Mask&linux.STATX_SIZE != 0) {
- // Truncate updates mtime.
- setLocalMtime = true
- stat.Mtime.Nsec = linux.UTIME_NOW
+
+ // Prepare for truncate.
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ switch d.mode & linux.S_IFMT {
+ case linux.S_IFREG:
+ if !setLocalMtime {
+ // Truncate updates mtime.
+ setLocalMtime = true
+ stat.Mtime.Nsec = linux.UTIME_NOW
+ }
+ case linux.S_IFDIR:
+ return syserror.EISDIR
+ default:
+ return syserror.EINVAL
+ }
}
}
d.metadataMu.Lock()
@@ -1049,6 +1070,7 @@ func (d *dentry) destroyLocked() {
d.handle.close(ctx)
}
d.handleMu.Unlock()
+
if !d.file.isNil() {
d.file.close(ctx)
d.file = p9file{}
@@ -1134,7 +1156,7 @@ func (d *dentry) removexattr(ctx context.Context, creds *auth.Credentials, name
return d.file.removeXattr(ctx, name)
}
-// Preconditions: !d.file.isNil(). d.isRegularFile() || d.isDirectory().
+// Preconditions: !d.isSynthetic(). d.isRegularFile() || d.isDirectory().
func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool) error {
// O_TRUNC unconditionally requires us to obtain a new handle (opened with
// O_TRUNC).
diff --git a/pkg/sentry/fsimpl/gofer/socket.go b/pkg/sentry/fsimpl/gofer/socket.go
new file mode 100644
index 000000000..73835df91
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/socket.go
@@ -0,0 +1,141 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/host"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+func (d *dentry) isSocket() bool {
+ return d.fileType() == linux.S_IFSOCK
+}
+
+// endpoint is a Gofer-backed transport.BoundEndpoint.
+//
+// An endpoint's lifetime is the time between when filesystem.BoundEndpointAt()
+// is called and either BoundEndpoint.BidirectionalConnect or
+// BoundEndpoint.UnidirectionalConnect is called.
+type endpoint struct {
+ // dentry is the filesystem dentry which produced this endpoint.
+ dentry *dentry
+
+ // file is the p9 file that contains a single unopened fid.
+ file p9.File
+
+ // path is the sentry path where this endpoint is bound.
+ path string
+}
+
+func sockTypeToP9(t linux.SockType) (p9.ConnectFlags, bool) {
+ switch t {
+ case linux.SOCK_STREAM:
+ return p9.StreamSocket, true
+ case linux.SOCK_SEQPACKET:
+ return p9.SeqpacketSocket, true
+ case linux.SOCK_DGRAM:
+ return p9.DgramSocket, true
+ }
+ return 0, false
+}
+
+// BidirectionalConnect implements ConnectableEndpoint.BidirectionalConnect.
+func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.ConnectingEndpoint, returnConnect func(transport.Receiver, transport.ConnectedEndpoint)) *syserr.Error {
+ cf, ok := sockTypeToP9(ce.Type())
+ if !ok {
+ return syserr.ErrConnectionRefused
+ }
+
+ // No lock ordering required as only the ConnectingEndpoint has a mutex.
+ ce.Lock()
+
+ // Check connecting state.
+ if ce.Connected() {
+ ce.Unlock()
+ return syserr.ErrAlreadyConnected
+ }
+ if ce.Listening() {
+ ce.Unlock()
+ return syserr.ErrInvalidEndpointState
+ }
+
+ c, err := e.newConnectedEndpoint(ctx, cf, ce.WaiterQueue())
+ if err != nil {
+ ce.Unlock()
+ return err
+ }
+
+ returnConnect(c, c)
+ ce.Unlock()
+ c.Init()
+
+ return nil
+}
+
+// UnidirectionalConnect implements
+// transport.BoundEndpoint.UnidirectionalConnect.
+func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.ConnectedEndpoint, *syserr.Error) {
+ c, err := e.newConnectedEndpoint(ctx, p9.DgramSocket, &waiter.Queue{})
+ if err != nil {
+ return nil, err
+ }
+ c.Init()
+
+ // We don't need the receiver.
+ c.CloseRecv()
+ c.Release()
+
+ return c, nil
+}
+
+func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFlags, queue *waiter.Queue) (*host.SCMConnectedEndpoint, *syserr.Error) {
+ hostFile, err := e.file.Connect(flags)
+ if err != nil {
+ return nil, syserr.ErrConnectionRefused
+ }
+ // Dup the fd so that the new endpoint can manage its lifetime.
+ hostFD, err := syscall.Dup(hostFile.FD())
+ if err != nil {
+ log.Warningf("Could not dup host socket fd %d: %v", hostFile.FD(), err)
+ return nil, syserr.FromError(err)
+ }
+ // After duplicating, we no longer need hostFile.
+ hostFile.Close()
+
+ c, serr := host.NewSCMEndpoint(ctx, hostFD, queue, e.path)
+ if serr != nil {
+ log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.file, flags, serr)
+ return nil, serr
+ }
+ return c, nil
+}
+
+// Release implements transport.BoundEndpoint.Release.
+func (e *endpoint) Release() {
+ e.dentry.DecRef()
+}
+
+// Passcred implements transport.BoundEndpoint.Passcred.
+func (e *endpoint) Passcred() bool {
+ return false
+}
diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD
index 2dcb03a73..e1c56d89b 100644
--- a/pkg/sentry/fsimpl/host/BUILD
+++ b/pkg/sentry/fsimpl/host/BUILD
@@ -8,6 +8,9 @@ go_library(
"control.go",
"host.go",
"ioctl_unsafe.go",
+ "socket.go",
+ "socket_iovec.go",
+ "socket_unsafe.go",
"tty.go",
"util.go",
"util_unsafe.go",
@@ -16,6 +19,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/fdnotifier",
"//pkg/log",
"//pkg/refs",
"//pkg/sentry/arch",
@@ -25,12 +29,18 @@ go_library(
"//pkg/sentry/kernel/auth",
"//pkg/sentry/memmap",
"//pkg/sentry/socket/control",
+ "//pkg/sentry/socket/unix",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/unimpl",
+ "//pkg/sentry/uniqueid",
"//pkg/sentry/vfs",
"//pkg/sync",
+ "//pkg/syserr",
"//pkg/syserror",
+ "//pkg/tcpip",
+ "//pkg/unet",
"//pkg/usermem",
+ "//pkg/waiter",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index 1e53b5c1b..05538ea42 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -17,7 +17,6 @@
package host
import (
- "errors"
"fmt"
"math"
"syscall"
@@ -31,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/hostfd"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
+ unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -156,7 +156,7 @@ type inode struct {
// Note that these flags may become out of date, since they can be modified
// on the host, e.g. with fcntl.
-func fileFlagsFromHostFD(fd int) (int, error) {
+func fileFlagsFromHostFD(fd int) (uint32, error) {
flags, err := unix.FcntlInt(uintptr(fd), syscall.F_GETFL, 0)
if err != nil {
log.Warningf("Failed to get file flags for donated FD %d: %v", fd, err)
@@ -164,7 +164,7 @@ func fileFlagsFromHostFD(fd int) (int, error) {
}
// TODO(gvisor.dev/issue/1672): implement behavior corresponding to these allowed flags.
flags &= syscall.O_ACCMODE | syscall.O_DIRECT | syscall.O_NONBLOCK | syscall.O_DSYNC | syscall.O_SYNC | syscall.O_APPEND
- return flags, nil
+ return uint32(flags), nil
}
// CheckPermissions implements kernfs.Inode.
@@ -361,6 +361,10 @@ func (i *inode) Destroy() {
// Open implements kernfs.Inode.
func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ // Once created, we cannot re-open a socket fd through /proc/[pid]/fd/.
+ if i.Mode().FileType() == linux.S_IFSOCK {
+ return nil, syserror.ENXIO
+ }
return i.open(ctx, vfsd, rp.Mount())
}
@@ -370,42 +374,45 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount) (*vfs.F
return nil, err
}
fileType := s.Mode & linux.FileTypeMask
+ flags, err := fileFlagsFromHostFD(i.hostFD)
+ if err != nil {
+ return nil, err
+ }
+
if fileType == syscall.S_IFSOCK {
if i.isTTY {
- return nil, errors.New("cannot use host socket as TTY")
+ log.Warningf("cannot use host socket fd %d as TTY", i.hostFD)
+ return nil, syserror.ENOTTY
}
- // TODO(gvisor.dev/issue/1672): support importing sockets.
- return nil, errors.New("importing host sockets not supported")
+
+ ep, err := newEndpoint(ctx, i.hostFD)
+ if err != nil {
+ return nil, err
+ }
+ // Currently, we only allow Unix sockets to be imported.
+ return unixsocket.NewFileDescription(ep, ep.Type(), flags, mnt, d)
}
// TODO(gvisor.dev/issue/1672): Whitelist specific file types here, so that
// we don't allow importing arbitrary file types without proper support.
- var (
- vfsfd *vfs.FileDescription
- fdImpl vfs.FileDescriptionImpl
- )
if i.isTTY {
fd := &ttyFD{
fileDescription: fileDescription{inode: i},
termios: linux.DefaultSlaveTermios,
}
- vfsfd = &fd.vfsfd
- fdImpl = fd
- } else {
- // For simplicity, set offset to 0. Technically, we should
- // only set to 0 on files that are not seekable (sockets, pipes, etc.),
- // and use the offset from the host fd otherwise.
- fd := &fileDescription{inode: i}
- vfsfd = &fd.vfsfd
- fdImpl = fd
- }
-
- flags, err := fileFlagsFromHostFD(i.hostFD)
- if err != nil {
- return nil, err
+ vfsfd := &fd.vfsfd
+ if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
+ return vfsfd, nil
}
- if err := vfsfd.Init(fdImpl, uint32(flags), mnt, d, &vfs.FileDescriptionOptions{}); err != nil {
+ // For simplicity, set offset to 0. Technically, we should
+ // only set to 0 on files that are not seekable (sockets, pipes, etc.),
+ // and use the offset from the host fd otherwise.
+ fd := &fileDescription{inode: i}
+ vfsfd := &fd.vfsfd
+ if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
return vfsfd, nil
diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go
new file mode 100644
index 000000000..39dd624a0
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/socket.go
@@ -0,0 +1,397 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/uniqueid"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/unet"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// Create a new host-backed endpoint from the given fd.
+func newEndpoint(ctx context.Context, hostFD int) (transport.Endpoint, error) {
+ // Set up an external transport.Endpoint using the host fd.
+ addr := fmt.Sprintf("hostfd:[%d]", hostFD)
+ var q waiter.Queue
+ e, err := NewConnectedEndpoint(ctx, hostFD, &q, addr, true /* saveable */)
+ if err != nil {
+ return nil, err.ToError()
+ }
+ e.Init()
+ ep := transport.NewExternal(ctx, e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e)
+ return ep, nil
+}
+
+// maxSendBufferSize is the maximum host send buffer size allowed for endpoint.
+//
+// N.B. 8MB is the default maximum on Linux (2 * sysctl_wmem_max).
+const maxSendBufferSize = 8 << 20
+
+// ConnectedEndpoint is an implementation of transport.ConnectedEndpoint and
+// transport.Receiver. It is backed by a host fd that was imported at sentry
+// startup. This fd is shared with a hostfs inode, which retains ownership of
+// it.
+//
+// ConnectedEndpoint is saveable, since we expect that the host will provide
+// the same fd upon restore.
+//
+// As of this writing, we only allow Unix sockets to be imported.
+//
+// +stateify savable
+type ConnectedEndpoint struct {
+ // ref keeps track of references to a ConnectedEndpoint.
+ ref refs.AtomicRefCount
+
+ // mu protects fd below.
+ mu sync.RWMutex `state:"nosave"`
+
+ // fd is the host fd backing this endpoint.
+ fd int
+
+ // addr is the address at which this endpoint is bound.
+ addr string
+
+ queue *waiter.Queue
+
+ // sndbuf is the size of the send buffer.
+ //
+ // N.B. When this is smaller than the host size, we present it via
+ // GetSockOpt and message splitting/rejection in SendMsg, but do not
+ // prevent lots of small messages from filling the real send buffer
+ // size on the host.
+ sndbuf int64 `state:"nosave"`
+
+ // stype is the type of Unix socket.
+ stype linux.SockType
+}
+
+// init performs initialization required for creating new ConnectedEndpoints and
+// for restoring them.
+func (c *ConnectedEndpoint) init() *syserr.Error {
+ family, err := syscall.GetsockoptInt(c.fd, syscall.SOL_SOCKET, syscall.SO_DOMAIN)
+ if err != nil {
+ return syserr.FromError(err)
+ }
+
+ if family != syscall.AF_UNIX {
+ // We only allow Unix sockets.
+ return syserr.ErrInvalidEndpointState
+ }
+
+ stype, err := syscall.GetsockoptInt(c.fd, syscall.SOL_SOCKET, syscall.SO_TYPE)
+ if err != nil {
+ return syserr.FromError(err)
+ }
+
+ if err := syscall.SetNonblock(c.fd, true); err != nil {
+ return syserr.FromError(err)
+ }
+
+ sndbuf, err := syscall.GetsockoptInt(c.fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF)
+ if err != nil {
+ return syserr.FromError(err)
+ }
+ if sndbuf > maxSendBufferSize {
+ log.Warningf("Socket send buffer too large: %d", sndbuf)
+ return syserr.ErrInvalidEndpointState
+ }
+
+ c.stype = linux.SockType(stype)
+ c.sndbuf = int64(sndbuf)
+
+ return nil
+}
+
+// NewConnectedEndpoint creates a new ConnectedEndpoint backed by a host fd
+// imported at sentry startup,
+//
+// The caller is responsible for calling Init(). Additionaly, Release needs to
+// be called twice because ConnectedEndpoint is both a transport.Receiver and
+// transport.ConnectedEndpoint.
+func NewConnectedEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue, addr string, saveable bool) (*ConnectedEndpoint, *syserr.Error) {
+ e := ConnectedEndpoint{
+ fd: hostFD,
+ addr: addr,
+ queue: queue,
+ }
+
+ if err := e.init(); err != nil {
+ return nil, err
+ }
+
+ // AtomicRefCounters start off with a single reference. We need two.
+ e.ref.IncRef()
+ e.ref.EnableLeakCheck("host.ConnectedEndpoint")
+ return &e, nil
+}
+
+// Init will do the initialization required without holding other locks.
+func (c *ConnectedEndpoint) Init() {
+ if err := fdnotifier.AddFD(int32(c.fd), c.queue); err != nil {
+ panic(err)
+ }
+}
+
+// Send implements transport.ConnectedEndpoint.Send.
+func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ if !controlMessages.Empty() {
+ return 0, false, syserr.ErrInvalidEndpointState
+ }
+
+ // Since stream sockets don't preserve message boundaries, we can write
+ // only as much of the message as fits in the send buffer.
+ truncate := c.stype == linux.SOCK_STREAM
+
+ n, totalLen, err := fdWriteVec(c.fd, data, c.sndbuf, truncate)
+ if n < totalLen && err == nil {
+ // The host only returns a short write if it would otherwise
+ // block (and only for stream sockets).
+ err = syserror.EAGAIN
+ }
+ if n > 0 && err != syserror.EAGAIN {
+ // The caller may need to block to send more data, but
+ // otherwise there isn't anything that can be done about an
+ // error with a partial write.
+ err = nil
+ }
+
+ // There is no need for the callee to call SendNotify because fdWriteVec
+ // uses the host's sendmsg(2) and the host kernel's queue.
+ return n, false, syserr.FromError(err)
+}
+
+// SendNotify implements transport.ConnectedEndpoint.SendNotify.
+func (c *ConnectedEndpoint) SendNotify() {}
+
+// CloseSend implements transport.ConnectedEndpoint.CloseSend.
+func (c *ConnectedEndpoint) CloseSend() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if err := syscall.Shutdown(c.fd, syscall.SHUT_WR); err != nil {
+ // A well-formed UDS shutdown can't fail. See
+ // net/unix/af_unix.c:unix_shutdown.
+ panic(fmt.Sprintf("failed write shutdown on host socket %+v: %v", c, err))
+ }
+}
+
+// CloseNotify implements transport.ConnectedEndpoint.CloseNotify.
+func (c *ConnectedEndpoint) CloseNotify() {}
+
+// Writable implements transport.ConnectedEndpoint.Writable.
+func (c *ConnectedEndpoint) Writable() bool {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ return fdnotifier.NonBlockingPoll(int32(c.fd), waiter.EventOut)&waiter.EventOut != 0
+}
+
+// Passcred implements transport.ConnectedEndpoint.Passcred.
+func (c *ConnectedEndpoint) Passcred() bool {
+ // We don't support credential passing for host sockets.
+ return false
+}
+
+// GetLocalAddress implements transport.ConnectedEndpoint.GetLocalAddress.
+func (c *ConnectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, nil
+}
+
+// EventUpdate implements transport.ConnectedEndpoint.EventUpdate.
+func (c *ConnectedEndpoint) EventUpdate() {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+ if c.fd != -1 {
+ fdnotifier.UpdateFD(int32(c.fd))
+ }
+}
+
+// Recv implements transport.Receiver.Recv.
+func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights int, peek bool) (int64, int64, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ var cm unet.ControlMessage
+ if numRights > 0 {
+ cm.EnableFDs(int(numRights))
+ }
+
+ // N.B. Unix sockets don't have a receive buffer, the send buffer
+ // serves both purposes.
+ rl, ml, cl, cTrunc, err := fdReadVec(c.fd, data, []byte(cm), peek, c.sndbuf)
+ if rl > 0 && err != nil {
+ // We got some data, so all we need to do on error is return
+ // the data that we got. Short reads are fine, no need to
+ // block.
+ err = nil
+ }
+ if err != nil {
+ return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.FromError(err)
+ }
+
+ // There is no need for the callee to call RecvNotify because fdReadVec uses
+ // the host's recvmsg(2) and the host kernel's queue.
+
+ // Trim the control data if we received less than the full amount.
+ if cl < uint64(len(cm)) {
+ cm = cm[:cl]
+ }
+
+ // Avoid extra allocations in the case where there isn't any control data.
+ if len(cm) == 0 {
+ return rl, ml, transport.ControlMessages{}, cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, false, nil
+ }
+
+ fds, err := cm.ExtractFDs()
+ if err != nil {
+ return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.FromError(err)
+ }
+
+ if len(fds) == 0 {
+ return rl, ml, transport.ControlMessages{}, cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, false, nil
+ }
+ return rl, ml, control.NewVFS2(nil, nil, newSCMRights(fds)), cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, false, nil
+}
+
+// RecvNotify implements transport.Receiver.RecvNotify.
+func (c *ConnectedEndpoint) RecvNotify() {}
+
+// CloseRecv implements transport.Receiver.CloseRecv.
+func (c *ConnectedEndpoint) CloseRecv() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if err := syscall.Shutdown(c.fd, syscall.SHUT_RD); err != nil {
+ // A well-formed UDS shutdown can't fail. See
+ // net/unix/af_unix.c:unix_shutdown.
+ panic(fmt.Sprintf("failed read shutdown on host socket %+v: %v", c, err))
+ }
+}
+
+// Readable implements transport.Receiver.Readable.
+func (c *ConnectedEndpoint) Readable() bool {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ return fdnotifier.NonBlockingPoll(int32(c.fd), waiter.EventIn)&waiter.EventIn != 0
+}
+
+// SendQueuedSize implements transport.Receiver.SendQueuedSize.
+func (c *ConnectedEndpoint) SendQueuedSize() int64 {
+ // TODO(gvisor.dev/issue/273): SendQueuedSize isn't supported for host
+ // sockets because we don't allow the sentry to call ioctl(2).
+ return -1
+}
+
+// RecvQueuedSize implements transport.Receiver.RecvQueuedSize.
+func (c *ConnectedEndpoint) RecvQueuedSize() int64 {
+ // TODO(gvisor.dev/issue/273): RecvQueuedSize isn't supported for host
+ // sockets because we don't allow the sentry to call ioctl(2).
+ return -1
+}
+
+// SendMaxQueueSize implements transport.Receiver.SendMaxQueueSize.
+func (c *ConnectedEndpoint) SendMaxQueueSize() int64 {
+ return int64(c.sndbuf)
+}
+
+// RecvMaxQueueSize implements transport.Receiver.RecvMaxQueueSize.
+func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 {
+ // N.B. Unix sockets don't use the receive buffer. We'll claim it is
+ // the same size as the send buffer.
+ return int64(c.sndbuf)
+}
+
+func (c *ConnectedEndpoint) destroyLocked() {
+ fdnotifier.RemoveFD(int32(c.fd))
+ c.fd = -1
+}
+
+// Release implements transport.ConnectedEndpoint.Release and
+// transport.Receiver.Release.
+func (c *ConnectedEndpoint) Release() {
+ c.ref.DecRefWithDestructor(func() {
+ c.mu.Lock()
+ c.destroyLocked()
+ c.mu.Unlock()
+ })
+}
+
+// CloseUnread implements transport.ConnectedEndpoint.CloseUnread.
+func (c *ConnectedEndpoint) CloseUnread() {}
+
+// SCMConnectedEndpoint represents an endpoint backed by a host fd that was
+// passed through a gofer Unix socket. It is almost the same as
+// ConnectedEndpoint, with the following differences:
+// - SCMConnectedEndpoint is not saveable, because the host cannot guarantee
+// the same descriptor number across S/R.
+// - SCMConnectedEndpoint holds ownership of its fd and is responsible for
+// closing it.
+type SCMConnectedEndpoint struct {
+ ConnectedEndpoint
+}
+
+// Release implements transport.ConnectedEndpoint.Release and
+// transport.Receiver.Release.
+func (e *SCMConnectedEndpoint) Release() {
+ e.ref.DecRefWithDestructor(func() {
+ e.mu.Lock()
+ if err := syscall.Close(e.fd); err != nil {
+ log.Warningf("Failed to close host fd %d: %v", err)
+ }
+ e.destroyLocked()
+ e.mu.Unlock()
+ })
+}
+
+// NewSCMEndpoint creates a new SCMConnectedEndpoint backed by a host fd that
+// was passed through a Unix socket.
+//
+// The caller is responsible for calling Init(). Additionaly, Release needs to
+// be called twice because ConnectedEndpoint is both a transport.Receiver and
+// transport.ConnectedEndpoint.
+func NewSCMEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue, addr string) (*SCMConnectedEndpoint, *syserr.Error) {
+ e := SCMConnectedEndpoint{ConnectedEndpoint{
+ fd: hostFD,
+ addr: addr,
+ queue: queue,
+ }}
+
+ if err := e.init(); err != nil {
+ return nil, err
+ }
+
+ // AtomicRefCounters start off with a single reference. We need two.
+ e.ref.IncRef()
+ e.ref.EnableLeakCheck("host.SCMConnectedEndpoint")
+ return &e, nil
+}
diff --git a/pkg/sentry/fsimpl/host/socket_iovec.go b/pkg/sentry/fsimpl/host/socket_iovec.go
new file mode 100644
index 000000000..584c247d2
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/socket_iovec.go
@@ -0,0 +1,113 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// maxIovs is the maximum number of iovecs to pass to the host.
+var maxIovs = linux.UIO_MAXIOV
+
+// copyToMulti copies as many bytes from src to dst as possible.
+func copyToMulti(dst [][]byte, src []byte) {
+ for _, d := range dst {
+ done := copy(d, src)
+ src = src[done:]
+ if len(src) == 0 {
+ break
+ }
+ }
+}
+
+// copyFromMulti copies as many bytes from src to dst as possible.
+func copyFromMulti(dst []byte, src [][]byte) {
+ for _, s := range src {
+ done := copy(dst, s)
+ dst = dst[done:]
+ if len(dst) == 0 {
+ break
+ }
+ }
+}
+
+// buildIovec builds an iovec slice from the given []byte slice.
+//
+// If truncate, truncate bufs > maxlen. Otherwise, immediately return an error.
+//
+// If length < the total length of bufs, err indicates why, even when returning
+// a truncated iovec.
+//
+// If intermediate != nil, iovecs references intermediate rather than bufs and
+// the caller must copy to/from bufs as necessary.
+func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovecs []syscall.Iovec, intermediate []byte, err error) {
+ var iovsRequired int
+ for _, b := range bufs {
+ length += int64(len(b))
+ if len(b) > 0 {
+ iovsRequired++
+ }
+ }
+
+ stopLen := length
+ if length > maxlen {
+ if truncate {
+ stopLen = maxlen
+ err = syserror.EAGAIN
+ } else {
+ return 0, nil, nil, syserror.EMSGSIZE
+ }
+ }
+
+ if iovsRequired > maxIovs {
+ // The kernel will reject our call if we pass this many iovs.
+ // Use a single intermediate buffer instead.
+ b := make([]byte, stopLen)
+
+ return stopLen, []syscall.Iovec{{
+ Base: &b[0],
+ Len: uint64(stopLen),
+ }}, b, err
+ }
+
+ var total int64
+ iovecs = make([]syscall.Iovec, 0, iovsRequired)
+ for i := range bufs {
+ l := len(bufs[i])
+ if l == 0 {
+ continue
+ }
+
+ stop := int64(l)
+ if total+stop > stopLen {
+ stop = stopLen - total
+ }
+
+ iovecs = append(iovecs, syscall.Iovec{
+ Base: &bufs[i][0],
+ Len: uint64(stop),
+ })
+
+ total += stop
+ if total >= stopLen {
+ break
+ }
+ }
+
+ return total, iovecs, nil, err
+}
diff --git a/pkg/sentry/fsimpl/host/socket_unsafe.go b/pkg/sentry/fsimpl/host/socket_unsafe.go
new file mode 100644
index 000000000..35ded24bc
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/socket_unsafe.go
@@ -0,0 +1,101 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+ "unsafe"
+)
+
+// fdReadVec receives from fd to bufs.
+//
+// If the total length of bufs is > maxlen, fdReadVec will do a partial read
+// and err will indicate why the message was truncated.
+func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int64) (readLen int64, msgLen int64, controlLen uint64, controlTrunc bool, err error) {
+ flags := uintptr(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC)
+ if peek {
+ flags |= syscall.MSG_PEEK
+ }
+
+ // Always truncate the receive buffer. All socket types will truncate
+ // received messages.
+ length, iovecs, intermediate, err := buildIovec(bufs, maxlen, true)
+ if err != nil && len(iovecs) == 0 {
+ // No partial write to do, return error immediately.
+ return 0, 0, 0, false, err
+ }
+
+ var msg syscall.Msghdr
+ if len(control) != 0 {
+ msg.Control = &control[0]
+ msg.Controllen = uint64(len(control))
+ }
+
+ if len(iovecs) != 0 {
+ msg.Iov = &iovecs[0]
+ msg.Iovlen = uint64(len(iovecs))
+ }
+
+ rawN, _, e := syscall.RawSyscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), flags)
+ if e != 0 {
+ // N.B. prioritize the syscall error over the buildIovec error.
+ return 0, 0, 0, false, e
+ }
+ n := int64(rawN)
+
+ // Copy data back to bufs.
+ if intermediate != nil {
+ copyToMulti(bufs, intermediate)
+ }
+
+ controlTrunc = msg.Flags&syscall.MSG_CTRUNC == syscall.MSG_CTRUNC
+
+ if n > length {
+ return length, n, msg.Controllen, controlTrunc, err
+ }
+
+ return n, n, msg.Controllen, controlTrunc, err
+}
+
+// fdWriteVec sends from bufs to fd.
+//
+// If the total length of bufs is > maxlen && truncate, fdWriteVec will do a
+// partial write and err will indicate why the message was truncated.
+func fdWriteVec(fd int, bufs [][]byte, maxlen int64, truncate bool) (int64, int64, error) {
+ length, iovecs, intermediate, err := buildIovec(bufs, maxlen, truncate)
+ if err != nil && len(iovecs) == 0 {
+ // No partial write to do, return error immediately.
+ return 0, length, err
+ }
+
+ // Copy data to intermediate buf.
+ if intermediate != nil {
+ copyFromMulti(intermediate, bufs)
+ }
+
+ var msg syscall.Msghdr
+ if len(iovecs) > 0 {
+ msg.Iov = &iovecs[0]
+ msg.Iovlen = uint64(len(iovecs))
+ }
+
+ n, _, e := syscall.RawSyscall(syscall.SYS_SENDMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), syscall.MSG_DONTWAIT|syscall.MSG_NOSIGNAL)
+ if e != 0 {
+ // N.B. prioritize the syscall error over the buildIovec error.
+ return 0, length, e
+ }
+
+ return int64(n), length, err
+}
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index 9e8d80414..4a12ae245 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -766,14 +766,17 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
}
// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
-func (fs *Filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath) (transport.BoundEndpoint, error) {
+func (fs *Filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
fs.mu.RLock()
- _, _, err := fs.walkExistingLocked(ctx, rp)
+ _, inode, err := fs.walkExistingLocked(ctx, rp)
fs.mu.RUnlock()
fs.processDeferredDecRefs()
if err != nil {
return nil, err
}
+ if err := inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite); err != nil {
+ return nil, err
+ }
return nil, syserror.ECONNREFUSED
}
diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go
index 5ce50625b..271134af8 100644
--- a/pkg/sentry/fsimpl/sockfs/sockfs.go
+++ b/pkg/sentry/fsimpl/sockfs/sockfs.go
@@ -73,26 +73,14 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentr
return nil, syserror.ENXIO
}
-// InitSocket initializes a socket FileDescription, with a corresponding
-// Dentry in mnt.
-//
-// fd should be the FileDescription associated with socketImpl, i.e. its first
-// field. mnt should be the global socket mount, Kernel.socketMount.
-func InitSocket(socketImpl vfs.FileDescriptionImpl, fd *vfs.FileDescription, mnt *vfs.Mount, creds *auth.Credentials) error {
- fsimpl := mnt.Filesystem().Impl()
- fs := fsimpl.(*kernfs.Filesystem)
-
+// NewDentry constructs and returns a sockfs dentry.
+func NewDentry(creds *auth.Credentials, ino uint64) *vfs.Dentry {
// File mode matches net/socket.c:sock_alloc.
filemode := linux.FileMode(linux.S_IFSOCK | 0600)
i := &inode{}
- i.Init(creds, fs.NextIno(), filemode)
+ i.Init(creds, ino, filemode)
d := &kernfs.Dentry{}
d.Init(i)
-
- opts := &vfs.FileDescriptionOptions{UseDentryMetadata: true}
- if err := fd.Init(socketImpl, linux.O_RDWR, mnt, d.VFSDentry(), opts); err != nil {
- return err
- }
- return nil
+ return d.VFSDentry()
}
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index 5b62f9ebb..36ffcb592 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -697,13 +697,16 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
}
// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
-func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath) (transport.BoundEndpoint, error) {
+func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
fs.mu.RLock()
defer fs.mu.RUnlock()
d, err := resolveLocked(rp)
if err != nil {
return nil, err
}
+ if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
+ return nil, err
+ }
switch impl := d.inode.impl.(type) {
case *socketFile:
return impl.ep, nil
diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go
index 7d25e98f7..79766cafe 100644
--- a/pkg/sentry/kernel/task_signals.go
+++ b/pkg/sentry/kernel/task_signals.go
@@ -716,7 +716,7 @@ func (tg *ThreadGroup) SetSignalAct(sig linux.Signal, actptr *arch.SignalAct) (a
func (t *Task) CopyOutSignalAct(addr usermem.Addr, s *arch.SignalAct) error {
n := t.Arch().NewSignalAct()
n.SerializeFrom(s)
- _, err := t.CopyOut(addr, n)
+ _, err := n.CopyOut(t, addr)
return err
}
@@ -725,7 +725,7 @@ func (t *Task) CopyOutSignalAct(addr usermem.Addr, s *arch.SignalAct) error {
func (t *Task) CopyInSignalAct(addr usermem.Addr) (arch.SignalAct, error) {
n := t.Arch().NewSignalAct()
var s arch.SignalAct
- if _, err := t.CopyIn(addr, n); err != nil {
+ if _, err := n.CopyIn(t, addr); err != nil {
return s, err
}
n.DeserializeTo(&s)
@@ -737,7 +737,7 @@ func (t *Task) CopyInSignalAct(addr usermem.Addr) (arch.SignalAct, error) {
func (t *Task) CopyOutSignalStack(addr usermem.Addr, s *arch.SignalStack) error {
n := t.Arch().NewSignalStack()
n.SerializeFrom(s)
- _, err := t.CopyOut(addr, n)
+ _, err := n.CopyOut(t, addr)
return err
}
@@ -746,7 +746,7 @@ func (t *Task) CopyOutSignalStack(addr usermem.Addr, s *arch.SignalStack) error
func (t *Task) CopyInSignalStack(addr usermem.Addr) (arch.SignalStack, error) {
n := t.Arch().NewSignalStack()
var s arch.SignalStack
- if _, err := t.CopyIn(addr, n); err != nil {
+ if _, err := n.CopyIn(t, addr); err != nil {
return s, err
}
n.DeserializeTo(&s)
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
index a5035bb7f..8485fb4b6 100644
--- a/pkg/sentry/kernel/task_start.go
+++ b/pkg/sentry/kernel/task_start.go
@@ -104,6 +104,9 @@ func (ts *TaskSet) NewTask(cfg *TaskConfig) (*Task, error) {
cfg.TaskContext.release()
cfg.FSContext.DecRef()
cfg.FDTable.DecRef()
+ if cfg.MountNamespaceVFS2 != nil {
+ cfg.MountNamespaceVFS2.DecRef()
+ }
return nil, err
}
return t, nil
diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go
index 716198712..29d457a7e 100644
--- a/pkg/sentry/platform/kvm/kvm_arm64.go
+++ b/pkg/sentry/platform/kvm/kvm_arm64.go
@@ -17,8 +17,7 @@
package kvm
import (
- "syscall"
-
+ "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
)
@@ -37,7 +36,7 @@ type userFpsimdState struct {
}
type userRegs struct {
- Regs syscall.PtraceRegs
+ Regs arch.Registers
sp_el1 uint64
elr_el1 uint64
spsr [KVM_NR_SPSR]uint64
diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go
index c42752d50..6c8f4fa28 100644
--- a/pkg/sentry/platform/kvm/kvm_test.go
+++ b/pkg/sentry/platform/kvm/kvm_test.go
@@ -117,10 +117,10 @@ func TestKernelFloatingPoint(t *testing.T) {
})
}
-func applicationTest(t testHarness, useHostMappings bool, target func(), fn func(*vCPU, *syscall.PtraceRegs, *pagetables.PageTables) bool) {
+func applicationTest(t testHarness, useHostMappings bool, target func(), fn func(*vCPU, *arch.Registers, *pagetables.PageTables) bool) {
// Initialize registers & page tables.
var (
- regs syscall.PtraceRegs
+ regs arch.Registers
pt *pagetables.PageTables
)
testutil.SetTestTarget(&regs, target)
@@ -154,7 +154,7 @@ func applicationTest(t testHarness, useHostMappings bool, target func(), fn func
}
func TestApplicationSyscall(t *testing.T) {
- applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
@@ -168,7 +168,7 @@ func TestApplicationSyscall(t *testing.T) {
}
return false
})
- applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
@@ -184,7 +184,7 @@ func TestApplicationSyscall(t *testing.T) {
}
func TestApplicationFault(t *testing.T) {
- applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
testutil.SetTouchTarget(regs, nil) // Cause fault.
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
@@ -199,7 +199,7 @@ func TestApplicationFault(t *testing.T) {
}
return false
})
- applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
testutil.SetTouchTarget(regs, nil) // Cause fault.
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
@@ -216,7 +216,7 @@ func TestApplicationFault(t *testing.T) {
}
func TestRegistersSyscall(t *testing.T) {
- applicationTest(t, true, testutil.TwiddleRegsSyscall, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(t, true, testutil.TwiddleRegsSyscall, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
testutil.SetTestRegs(regs) // Fill values for all registers.
for {
var si arch.SignalInfo
@@ -239,7 +239,7 @@ func TestRegistersSyscall(t *testing.T) {
}
func TestRegistersFault(t *testing.T) {
- applicationTest(t, true, testutil.TwiddleRegsFault, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(t, true, testutil.TwiddleRegsFault, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
testutil.SetTestRegs(regs) // Fill values for all registers.
for {
var si arch.SignalInfo
@@ -263,7 +263,7 @@ func TestRegistersFault(t *testing.T) {
}
func TestSegments(t *testing.T) {
- applicationTest(t, true, testutil.TwiddleSegments, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(t, true, testutil.TwiddleSegments, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
testutil.SetTestSegments(regs)
for {
var si arch.SignalInfo
@@ -287,7 +287,7 @@ func TestSegments(t *testing.T) {
}
func TestBounce(t *testing.T) {
- applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
go func() {
time.Sleep(time.Millisecond)
c.BounceToKernel()
@@ -302,7 +302,7 @@ func TestBounce(t *testing.T) {
}
return false
})
- applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
go func() {
time.Sleep(time.Millisecond)
c.BounceToKernel()
@@ -321,7 +321,7 @@ func TestBounce(t *testing.T) {
}
func TestBounceStress(t *testing.T) {
- applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
randomSleep := func() {
// O(hundreds of microseconds) is appropriate to ensure
// different overlaps and different schedules.
@@ -357,7 +357,7 @@ func TestBounceStress(t *testing.T) {
func TestInvalidate(t *testing.T) {
var data uintptr // Used below.
- applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
testutil.SetTouchTarget(regs, &data) // Read legitimate value.
for {
var si arch.SignalInfo
@@ -398,7 +398,7 @@ func IsFault(err error, si *arch.SignalInfo) bool {
}
func TestEmptyAddressSpace(t *testing.T) {
- applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
@@ -412,7 +412,7 @@ func TestEmptyAddressSpace(t *testing.T) {
}
return false
})
- applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
@@ -471,7 +471,7 @@ func BenchmarkApplicationSyscall(b *testing.B) {
i int // Iteration includes machine.Get() / machine.Put().
a int // Count for ErrContextInterrupt.
)
- applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
@@ -493,7 +493,7 @@ func BenchmarkApplicationSyscall(b *testing.B) {
func BenchmarkKernelSyscall(b *testing.B) {
// Note that the target passed here is irrelevant, we never execute SwitchToUser.
- applicationTest(b, true, testutil.Getpid, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(b, true, testutil.Getpid, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
// iteration does not include machine.Get() / machine.Put().
for i := 0; i < b.N; i++ {
testutil.Getpid()
@@ -508,7 +508,7 @@ func BenchmarkWorldSwitchToUserRoundtrip(b *testing.B) {
i int
a int
)
- applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
diff --git a/pkg/sentry/platform/kvm/testutil/BUILD b/pkg/sentry/platform/kvm/testutil/BUILD
index f7605df8a..f7feb8683 100644
--- a/pkg/sentry/platform/kvm/testutil/BUILD
+++ b/pkg/sentry/platform/kvm/testutil/BUILD
@@ -13,4 +13,5 @@ go_library(
"testutil_arm64.s",
],
visibility = ["//pkg/sentry/platform/kvm:__pkg__"],
+ deps = ["//pkg/sentry/arch"],
)
diff --git a/pkg/sentry/platform/kvm/testutil/testutil_amd64.go b/pkg/sentry/platform/kvm/testutil/testutil_amd64.go
index 4c108abbf..8048eedec 100644
--- a/pkg/sentry/platform/kvm/testutil/testutil_amd64.go
+++ b/pkg/sentry/platform/kvm/testutil/testutil_amd64.go
@@ -18,19 +18,20 @@ package testutil
import (
"reflect"
- "syscall"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
)
// TwiddleSegments reads segments into known registers.
func TwiddleSegments()
// SetTestTarget sets the rip appropriately.
-func SetTestTarget(regs *syscall.PtraceRegs, fn func()) {
+func SetTestTarget(regs *arch.Registers, fn func()) {
regs.Rip = uint64(reflect.ValueOf(fn).Pointer())
}
// SetTouchTarget sets rax appropriately.
-func SetTouchTarget(regs *syscall.PtraceRegs, target *uintptr) {
+func SetTouchTarget(regs *arch.Registers, target *uintptr) {
if target != nil {
regs.Rax = uint64(reflect.ValueOf(target).Pointer())
} else {
@@ -39,12 +40,12 @@ func SetTouchTarget(regs *syscall.PtraceRegs, target *uintptr) {
}
// RewindSyscall rewinds a syscall RIP.
-func RewindSyscall(regs *syscall.PtraceRegs) {
+func RewindSyscall(regs *arch.Registers) {
regs.Rip -= 2
}
// SetTestRegs initializes registers to known values.
-func SetTestRegs(regs *syscall.PtraceRegs) {
+func SetTestRegs(regs *arch.Registers) {
regs.R15 = 0x15
regs.R14 = 0x14
regs.R13 = 0x13
@@ -64,7 +65,7 @@ func SetTestRegs(regs *syscall.PtraceRegs) {
}
// CheckTestRegs checks that registers were twiddled per TwiddleRegs.
-func CheckTestRegs(regs *syscall.PtraceRegs, full bool) (err error) {
+func CheckTestRegs(regs *arch.Registers, full bool) (err error) {
if need := ^uint64(0x15); regs.R15 != need {
err = addRegisterMismatch(err, "R15", regs.R15, need)
}
@@ -121,13 +122,13 @@ var fsData uint64 = 0x55
var gsData uint64 = 0x85
// SetTestSegments initializes segments to known values.
-func SetTestSegments(regs *syscall.PtraceRegs) {
+func SetTestSegments(regs *arch.Registers) {
regs.Fs_base = uint64(reflect.ValueOf(&fsData).Pointer())
regs.Gs_base = uint64(reflect.ValueOf(&gsData).Pointer())
}
// CheckTestSegments checks that registers were twiddled per TwiddleSegments.
-func CheckTestSegments(regs *syscall.PtraceRegs) (err error) {
+func CheckTestSegments(regs *arch.Registers) (err error) {
if regs.Rax != fsData {
err = addRegisterMismatch(err, "Rax", regs.Rax, fsData)
}
diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go
index 40b2e4acc..ca902c8c1 100644
--- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go
+++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go
@@ -19,16 +19,17 @@ package testutil
import (
"fmt"
"reflect"
- "syscall"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
)
// SetTestTarget sets the rip appropriately.
-func SetTestTarget(regs *syscall.PtraceRegs, fn func()) {
+func SetTestTarget(regs *arch.Registers, fn func()) {
regs.Pc = uint64(reflect.ValueOf(fn).Pointer())
}
// SetTouchTarget sets rax appropriately.
-func SetTouchTarget(regs *syscall.PtraceRegs, target *uintptr) {
+func SetTouchTarget(regs *arch.Registers, target *uintptr) {
if target != nil {
regs.Regs[8] = uint64(reflect.ValueOf(target).Pointer())
} else {
@@ -37,19 +38,19 @@ func SetTouchTarget(regs *syscall.PtraceRegs, target *uintptr) {
}
// RewindSyscall rewinds a syscall RIP.
-func RewindSyscall(regs *syscall.PtraceRegs) {
+func RewindSyscall(regs *arch.Registers) {
regs.Pc -= 4
}
// SetTestRegs initializes registers to known values.
-func SetTestRegs(regs *syscall.PtraceRegs) {
+func SetTestRegs(regs *arch.Registers) {
for i := 0; i <= 30; i++ {
regs.Regs[i] = uint64(i) + 1
}
}
// CheckTestRegs checks that registers were twiddled per TwiddleRegs.
-func CheckTestRegs(regs *syscall.PtraceRegs, full bool) (err error) {
+func CheckTestRegs(regs *arch.Registers, full bool) (err error) {
for i := 0; i <= 30; i++ {
if need := ^uint64(i + 1); regs.Regs[i] != need {
err = addRegisterMismatch(err, fmt.Sprintf("R%d", i), regs.Regs[i], need)
diff --git a/pkg/sentry/platform/ptrace/ptrace_amd64.go b/pkg/sentry/platform/ptrace/ptrace_amd64.go
index 24fc5dc62..3b9a870a5 100644
--- a/pkg/sentry/platform/ptrace/ptrace_amd64.go
+++ b/pkg/sentry/platform/ptrace/ptrace_amd64.go
@@ -15,9 +15,8 @@
package ptrace
import (
- "syscall"
-
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
)
// fpRegSet returns the GETREGSET/SETREGSET register set type to be used.
@@ -28,12 +27,12 @@ func fpRegSet(useXsave bool) uintptr {
return linux.NT_PRFPREG
}
-func stackPointer(r *syscall.PtraceRegs) uintptr {
+func stackPointer(r *arch.Registers) uintptr {
return uintptr(r.Rsp)
}
// x86 use the fs_base register to store the TLS pointer which can be
-// get/set in "func (t *thread) get/setRegs(regs *syscall.PtraceRegs)".
+// get/set in "func (t *thread) get/setRegs(regs *arch.Registers)".
// So both of the get/setTLS() operations are noop here.
// getTLS gets the thread local storage register.
diff --git a/pkg/sentry/platform/ptrace/ptrace_arm64.go b/pkg/sentry/platform/ptrace/ptrace_arm64.go
index 4db28c534..5c869926a 100644
--- a/pkg/sentry/platform/ptrace/ptrace_arm64.go
+++ b/pkg/sentry/platform/ptrace/ptrace_arm64.go
@@ -15,9 +15,8 @@
package ptrace
import (
- "syscall"
-
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
)
// fpRegSet returns the GETREGSET/SETREGSET register set type to be used.
@@ -25,6 +24,6 @@ func fpRegSet(_ bool) uintptr {
return linux.NT_PRFPREG
}
-func stackPointer(r *syscall.PtraceRegs) uintptr {
+func stackPointer(r *arch.Registers) uintptr {
return uintptr(r.Sp)
}
diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
index 6c0ed7b3e..8b72d24e8 100644
--- a/pkg/sentry/platform/ptrace/ptrace_unsafe.go
+++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
@@ -24,7 +24,7 @@ import (
)
// getRegs gets the general purpose register set.
-func (t *thread) getRegs(regs *syscall.PtraceRegs) error {
+func (t *thread) getRegs(regs *arch.Registers) error {
iovec := syscall.Iovec{
Base: (*byte)(unsafe.Pointer(regs)),
Len: uint64(unsafe.Sizeof(*regs)),
@@ -43,7 +43,7 @@ func (t *thread) getRegs(regs *syscall.PtraceRegs) error {
}
// setRegs sets the general purpose register set.
-func (t *thread) setRegs(regs *syscall.PtraceRegs) error {
+func (t *thread) setRegs(regs *arch.Registers) error {
iovec := syscall.Iovec{
Base: (*byte)(unsafe.Pointer(regs)),
Len: uint64(unsafe.Sizeof(*regs)),
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index 773ddb1ed..2389423b0 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -63,7 +63,7 @@ type thread struct {
// initRegs are the initial registers for the first thread.
//
// These are used for the register set for system calls.
- initRegs syscall.PtraceRegs
+ initRegs arch.Registers
}
// threadPool is a collection of threads.
@@ -317,7 +317,7 @@ const (
)
func (t *thread) dumpAndPanic(message string) {
- var regs syscall.PtraceRegs
+ var regs arch.Registers
message += "\n"
if err := t.getRegs(&regs); err == nil {
message += dumpRegs(&regs)
@@ -423,7 +423,7 @@ func (t *thread) init() {
// This is _not_ for use by application system calls, rather it is for use when
// a system call must be injected into the remote context (e.g. mmap, munmap).
// Note that clones are handled separately.
-func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) {
+func (t *thread) syscall(regs *arch.Registers) (uintptr, error) {
// Set registers.
if err := t.setRegs(regs); err != nil {
panic(fmt.Sprintf("ptrace set regs failed: %v", err))
@@ -461,7 +461,7 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) {
// syscallIgnoreInterrupt ignores interrupts on the system call thread and
// restarts the syscall if the kernel indicates that should happen.
func (t *thread) syscallIgnoreInterrupt(
- initRegs *syscall.PtraceRegs,
+ initRegs *arch.Registers,
sysno uintptr,
args ...arch.SyscallArgument) (uintptr, error) {
for {
diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go
index cd74945e7..84b699f0d 100644
--- a/pkg/sentry/platform/ptrace/subprocess_amd64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go
@@ -41,7 +41,7 @@ const (
// resetSysemuRegs sets up emulation registers.
//
// This should be called prior to calling sysemu.
-func (t *thread) resetSysemuRegs(regs *syscall.PtraceRegs) {
+func (t *thread) resetSysemuRegs(regs *arch.Registers) {
regs.Cs = t.initRegs.Cs
regs.Ss = t.initRegs.Ss
regs.Ds = t.initRegs.Ds
@@ -53,7 +53,7 @@ func (t *thread) resetSysemuRegs(regs *syscall.PtraceRegs) {
// createSyscallRegs sets up syscall registers.
//
// This should be called to generate registers for a system call.
-func createSyscallRegs(initRegs *syscall.PtraceRegs, sysno uintptr, args ...arch.SyscallArgument) syscall.PtraceRegs {
+func createSyscallRegs(initRegs *arch.Registers, sysno uintptr, args ...arch.SyscallArgument) arch.Registers {
// Copy initial registers.
regs := *initRegs
@@ -82,18 +82,18 @@ func createSyscallRegs(initRegs *syscall.PtraceRegs, sysno uintptr, args ...arch
}
// isSingleStepping determines if the registers indicate single-stepping.
-func isSingleStepping(regs *syscall.PtraceRegs) bool {
+func isSingleStepping(regs *arch.Registers) bool {
return (regs.Eflags & arch.X86TrapFlag) != 0
}
// updateSyscallRegs updates registers after finishing sysemu.
-func updateSyscallRegs(regs *syscall.PtraceRegs) {
+func updateSyscallRegs(regs *arch.Registers) {
// Ptrace puts -ENOSYS in rax on syscall-enter-stop.
regs.Rax = regs.Orig_rax
}
// syscallReturnValue extracts a sensible return from registers.
-func syscallReturnValue(regs *syscall.PtraceRegs) (uintptr, error) {
+func syscallReturnValue(regs *arch.Registers) (uintptr, error) {
rval := int64(regs.Rax)
if rval < 0 {
return 0, syscall.Errno(-rval)
@@ -101,7 +101,7 @@ func syscallReturnValue(regs *syscall.PtraceRegs) (uintptr, error) {
return uintptr(rval), nil
}
-func dumpRegs(regs *syscall.PtraceRegs) string {
+func dumpRegs(regs *arch.Registers) string {
var m strings.Builder
fmt.Fprintf(&m, "Registers:\n")
@@ -143,7 +143,7 @@ func (t *thread) adjustInitRegsRip() {
}
// Pass the expected PPID to the child via R15 when creating stub process.
-func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) {
+func initChildProcessPPID(initregs *arch.Registers, ppid int32) {
initregs.R15 = uint64(ppid)
// Rbx has to be set to 1 when creating stub process.
initregs.Rbx = 1
@@ -156,7 +156,7 @@ func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) {
//
// Note that this should only be called after verifying that the signalInfo has
// been generated by the kernel.
-func patchSignalInfo(regs *syscall.PtraceRegs, signalInfo *arch.SignalInfo) {
+func patchSignalInfo(regs *arch.Registers, signalInfo *arch.SignalInfo) {
if linux.Signal(signalInfo.Signo) == linux.SIGSYS {
signalInfo.Signo = int32(linux.SIGSEGV)
diff --git a/pkg/sentry/platform/ptrace/subprocess_arm64.go b/pkg/sentry/platform/ptrace/subprocess_arm64.go
index 7f5c393f0..bd618fae8 100644
--- a/pkg/sentry/platform/ptrace/subprocess_arm64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_arm64.go
@@ -41,13 +41,13 @@ const (
// resetSysemuRegs sets up emulation registers.
//
// This should be called prior to calling sysemu.
-func (t *thread) resetSysemuRegs(regs *syscall.PtraceRegs) {
+func (t *thread) resetSysemuRegs(regs *arch.Registers) {
}
// createSyscallRegs sets up syscall registers.
//
// This should be called to generate registers for a system call.
-func createSyscallRegs(initRegs *syscall.PtraceRegs, sysno uintptr, args ...arch.SyscallArgument) syscall.PtraceRegs {
+func createSyscallRegs(initRegs *arch.Registers, sysno uintptr, args ...arch.SyscallArgument) arch.Registers {
// Copy initial registers (Pc, Sp, etc.).
regs := *initRegs
@@ -78,7 +78,7 @@ func createSyscallRegs(initRegs *syscall.PtraceRegs, sysno uintptr, args ...arch
}
// isSingleStepping determines if the registers indicate single-stepping.
-func isSingleStepping(regs *syscall.PtraceRegs) bool {
+func isSingleStepping(regs *arch.Registers) bool {
// Refer to the ARM SDM D2.12.3: software step state machine
// return (regs.Pstate.SS == 1) && (MDSCR_EL1.SS == 1).
//
@@ -89,13 +89,13 @@ func isSingleStepping(regs *syscall.PtraceRegs) bool {
}
// updateSyscallRegs updates registers after finishing sysemu.
-func updateSyscallRegs(regs *syscall.PtraceRegs) {
+func updateSyscallRegs(regs *arch.Registers) {
// No special work is necessary.
return
}
// syscallReturnValue extracts a sensible return from registers.
-func syscallReturnValue(regs *syscall.PtraceRegs) (uintptr, error) {
+func syscallReturnValue(regs *arch.Registers) (uintptr, error) {
rval := int64(regs.Regs[0])
if rval < 0 {
return 0, syscall.Errno(-rval)
@@ -103,7 +103,7 @@ func syscallReturnValue(regs *syscall.PtraceRegs) (uintptr, error) {
return uintptr(rval), nil
}
-func dumpRegs(regs *syscall.PtraceRegs) string {
+func dumpRegs(regs *arch.Registers) string {
var m strings.Builder
fmt.Fprintf(&m, "Registers:\n")
@@ -125,7 +125,7 @@ func (t *thread) adjustInitRegsRip() {
}
// Pass the expected PPID to the child via X7 when creating stub process
-func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) {
+func initChildProcessPPID(initregs *arch.Registers, ppid int32) {
initregs.Regs[7] = uint64(ppid)
// R9 has to be set to 1 when creating stub process.
initregs.Regs[9] = 1
@@ -138,7 +138,7 @@ func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) {
//
// Note that this should only be called after verifying that the signalInfo has
// been generated by the kernel.
-func patchSignalInfo(regs *syscall.PtraceRegs, signalInfo *arch.SignalInfo) {
+func patchSignalInfo(regs *arch.Registers, signalInfo *arch.SignalInfo) {
if linux.Signal(signalInfo.Signo) == linux.SIGSYS {
signalInfo.Signo = int32(linux.SIGSEGV)
diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD
index b69520030..679b287c3 100644
--- a/pkg/sentry/platform/ring0/BUILD
+++ b/pkg/sentry/platform/ring0/BUILD
@@ -79,6 +79,7 @@ go_library(
deps = [
"//pkg/cpuid",
"//pkg/safecopy",
+ "//pkg/sentry/arch",
"//pkg/sentry/platform/ring0/pagetables",
"//pkg/usermem",
],
diff --git a/pkg/sentry/platform/ring0/defs.go b/pkg/sentry/platform/ring0/defs.go
index 86fd5ed58..e6daf24df 100644
--- a/pkg/sentry/platform/ring0/defs.go
+++ b/pkg/sentry/platform/ring0/defs.go
@@ -15,8 +15,7 @@
package ring0
import (
- "syscall"
-
+ "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
)
@@ -72,7 +71,7 @@ type CPU struct {
// registers is a set of registers; these may be used on kernel system
// calls and exceptions via the Registers function.
- registers syscall.PtraceRegs
+ registers arch.Registers
// hooks are kernel hooks.
hooks Hooks
@@ -83,14 +82,14 @@ type CPU struct {
// This is explicitly safe to call during KernelException and KernelSyscall.
//
//go:nosplit
-func (c *CPU) Registers() *syscall.PtraceRegs {
+func (c *CPU) Registers() *arch.Registers {
return &c.registers
}
// SwitchOpts are passed to the Switch function.
type SwitchOpts struct {
// Registers are the user register state.
- Registers *syscall.PtraceRegs
+ Registers *arch.Registers
// FloatingPointState is a byte pointer where floating point state is
// saved and restored.
diff --git a/pkg/sentry/platform/ring0/entry_amd64.go b/pkg/sentry/platform/ring0/entry_amd64.go
index a5ce67885..7fa43c2f5 100644
--- a/pkg/sentry/platform/ring0/entry_amd64.go
+++ b/pkg/sentry/platform/ring0/entry_amd64.go
@@ -17,7 +17,7 @@
package ring0
import (
- "syscall"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
)
// This is an assembly function.
@@ -41,7 +41,7 @@ func swapgs()
// The return code is the vector that interrupted execution.
//
// See stubs.go for a note regarding the frame size of this function.
-func sysret(*CPU, *syscall.PtraceRegs) Vector
+func sysret(*CPU, *arch.Registers) Vector
// "iret is the cadillac of CPL switching."
//
@@ -50,7 +50,7 @@ func sysret(*CPU, *syscall.PtraceRegs) Vector
// iret is nearly identical to sysret, except an iret is used to fully restore
// all user state. This must be called in cases where all registers need to be
// restored.
-func iret(*CPU, *syscall.PtraceRegs) Vector
+func iret(*CPU, *arch.Registers) Vector
// exception is the generic exception entry.
//
diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD
index 4cae10459..549f3d228 100644
--- a/pkg/sentry/platform/ring0/gen_offsets/BUILD
+++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD
@@ -27,6 +27,7 @@ go_binary(
visibility = ["//pkg/sentry/platform/ring0:__pkg__"],
deps = [
"//pkg/cpuid",
+ "//pkg/sentry/arch",
"//pkg/sentry/platform/ring0/pagetables",
"//pkg/usermem",
],
diff --git a/pkg/sentry/platform/ring0/offsets_amd64.go b/pkg/sentry/platform/ring0/offsets_amd64.go
index 85cc3fdad..b8ab120a0 100644
--- a/pkg/sentry/platform/ring0/offsets_amd64.go
+++ b/pkg/sentry/platform/ring0/offsets_amd64.go
@@ -20,7 +20,8 @@ import (
"fmt"
"io"
"reflect"
- "syscall"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
)
// Emit prints architecture-specific offsets.
@@ -64,7 +65,7 @@ func Emit(w io.Writer) {
fmt.Fprintf(w, "#define SyscallInt80 0x%02x\n", SyscallInt80)
fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall)
- p := &syscall.PtraceRegs{}
+ p := &arch.Registers{}
fmt.Fprintf(w, "\n// Ptrace registers.\n")
fmt.Fprintf(w, "#define PTRACE_R15 0x%02x\n", reflect.ValueOf(&p.R15).Pointer()-reflect.ValueOf(p).Pointer())
fmt.Fprintf(w, "#define PTRACE_R14 0x%02x\n", reflect.ValueOf(&p.R14).Pointer()-reflect.ValueOf(p).Pointer())
diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go
index 057fb5c69..f3de962f0 100644
--- a/pkg/sentry/platform/ring0/offsets_arm64.go
+++ b/pkg/sentry/platform/ring0/offsets_arm64.go
@@ -20,7 +20,8 @@ import (
"fmt"
"io"
"reflect"
- "syscall"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
)
// Emit prints architecture-specific offsets.
@@ -87,7 +88,7 @@ func Emit(w io.Writer) {
fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall)
fmt.Fprintf(w, "#define VirtualizationException 0x%02x\n", VirtualizationException)
- p := &syscall.PtraceRegs{}
+ p := &arch.Registers{}
fmt.Fprintf(w, "\n// Ptrace registers.\n")
fmt.Fprintf(w, "#define PTRACE_R0 0x%02x\n", reflect.ValueOf(&p.Regs[0]).Pointer()-reflect.ValueOf(p).Pointer())
fmt.Fprintf(w, "#define PTRACE_R1 0x%02x\n", reflect.ValueOf(&p.Regs[1]).Pointer()-reflect.ValueOf(p).Pointer())
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
index 55c0f04f3..ff1cfd8f6 100644
--- a/pkg/sentry/socket/netfilter/tcp_matcher.go
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -121,13 +121,12 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa
tcpHeader = header.TCP(pkt.TransportHeader)
} else {
// The TCP header hasn't been parsed yet. We have to do it here.
- hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize)
- if !ok {
+ if len(pkt.Data.First()) < header.TCPMinimumSize {
// There's no valid TCP header here, so we hotdrop the
// packet.
return false, true
}
- tcpHeader = header.TCP(hdr)
+ tcpHeader = header.TCP(pkt.Data.First())
}
// Check whether the source and destination ports are within the
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
index 04d03d494..3359418c1 100644
--- a/pkg/sentry/socket/netfilter/udp_matcher.go
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -120,13 +120,12 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa
udpHeader = header.UDP(pkt.TransportHeader)
} else {
// The UDP header hasn't been parsed yet. We have to do it here.
- hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize)
- if !ok {
+ if len(pkt.Data.First()) < header.UDPMinimumSize {
// There's no valid UDP header here, so we hotdrop the
// packet.
return false, true
}
- udpHeader = header.UDP(hdr)
+ udpHeader = header.UDP(pkt.Data.First())
}
// Check whether the source and destination ports are within the
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index de2cc4bdf..941a91097 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -21,6 +21,7 @@ go_library(
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fsimpl/kernfs",
"//pkg/sentry/fsimpl/sockfs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/time",
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 7c64f30fa..5b29e9d7f 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -103,7 +103,7 @@ func (s *socketOpsCommon) DecRef() {
}
// Release implemements fs.FileOperations.Release.
-func (s *SocketOperations) Release() {
+func (s *socketOpsCommon) Release() {
// Release only decrements a reference on s because s may be referenced in
// the abstract socket namespace.
s.DecRef()
@@ -323,7 +323,10 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
// Create the socket.
//
- // TODO(gvisor.dev/issue/2324): Correctly set file permissions.
+ // Note that the file permissions here are not set correctly (see
+ // gvisor.dev/issue/2324). There is no convenient way to get permissions
+ // on the socket referred to by s, so we will leave this discrepancy
+ // unresolved until VFS2 replaces this code.
childDir, err := d.Bind(t, t.FSContext().RootDirectory(), name, bep, fs.FilePermissions{User: fs.PermMask{Read: true}})
if err != nil {
return syserr.ErrPortInUse
@@ -373,7 +376,7 @@ func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint,
Path: p,
FollowFinalSymlink: true,
}
- ep, e := t.Kernel().VFS().BoundEndpointAt(t, t.Credentials(), &pop)
+ ep, e := t.Kernel().VFS().BoundEndpointAt(t, t.Credentials(), &pop, &vfs.BoundEndpointOptions{path})
root.DecRef()
if relPath {
start.DecRef()
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index 3e54d49c4..23db93f33 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/socket/control"
@@ -42,30 +43,44 @@ type SocketVFS2 struct {
socketOpsCommon
}
-// NewVFS2File creates and returns a new vfs.FileDescription for a unix socket.
-func NewVFS2File(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) {
- sock := NewFDImpl(ep, stype)
- vfsfd := &sock.vfsfd
- if err := sockfs.InitSocket(sock, vfsfd, t.Kernel().SocketMount(), t.Credentials()); err != nil {
+// NewSockfsFile creates a new socket file in the global sockfs mount and
+// returns a corresponding file description.
+func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) {
+ mnt := t.Kernel().SocketMount()
+ fs := mnt.Filesystem().Impl().(*kernfs.Filesystem)
+ d := sockfs.NewDentry(t.Credentials(), fs.NextIno())
+
+ fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d)
+ if err != nil {
return nil, syserr.FromError(err)
}
- return vfsfd, nil
+ return fd, nil
}
-// NewFDImpl creates and returns a new SocketVFS2.
-func NewFDImpl(ep transport.Endpoint, stype linux.SockType) *SocketVFS2 {
+// NewFileDescription creates and returns a socket file description
+// corresponding to the given mount and dentry.
+func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint32, mnt *vfs.Mount, d *vfs.Dentry) (*vfs.FileDescription, error) {
// You can create AF_UNIX, SOCK_RAW sockets. They're the same as
// SOCK_DGRAM and don't require CAP_NET_RAW.
if stype == linux.SOCK_RAW {
stype = linux.SOCK_DGRAM
}
- return &SocketVFS2{
+ sock := &SocketVFS2{
socketOpsCommon: socketOpsCommon{
ep: ep,
stype: stype,
},
}
+ vfsfd := &sock.vfsfd
+ if err := vfsfd.Init(sock, flags, mnt, d, &vfs.FileDescriptionOptions{
+ DenyPRead: true,
+ DenyPWrite: true,
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return vfsfd, nil
}
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
@@ -112,8 +127,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block
}
}
- // We expect this to be a FileDescription here.
- ns, err := NewVFS2File(t, ep, s.stype)
+ ns, err := NewSockfsFile(t, ep, s.stype)
if err != nil {
return 0, nil, 0, err
}
@@ -183,11 +197,13 @@ func (s *SocketVFS2) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
Start: start,
Path: path,
}
- err := t.Kernel().VFS().MknodAt(t, t.Credentials(), &pop, &vfs.MknodOptions{
- // TODO(gvisor.dev/issue/2324): The file permissions should be taken
- // from s and t.FSContext().Umask() (see net/unix/af_unix.c:unix_bind),
- // but VFS1 just always uses 0400. Resolve this inconsistency.
- Mode: linux.S_IFSOCK | 0400,
+ stat, err := s.vfsfd.Stat(t, vfs.StatOptions{Mask: linux.STATX_MODE})
+ if err != nil {
+ return syserr.FromError(err)
+ }
+ err = t.Kernel().VFS().MknodAt(t, t.Credentials(), &pop, &vfs.MknodOptions{
+ // File permissions correspond to net/unix/af_unix.c:unix_bind.
+ Mode: linux.FileMode(linux.S_IFSOCK | uint(stat.Mode)&^t.FSContext().Umask()),
Endpoint: bep,
})
if err == syserror.EEXIST {
@@ -259,13 +275,6 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs
})
}
-// Release implements vfs.FileDescriptionImpl.
-func (s *SocketVFS2) Release() {
- // Release only decrements a reference on s because s may be referenced in
- // the abstract socket namespace.
- s.DecRef()
-}
-
// Readiness implements waiter.Waitable.Readiness.
func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
return s.socketOpsCommon.Readiness(mask)
@@ -307,7 +316,7 @@ func (*providerVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int)
return nil, syserr.ErrInvalidArgument
}
- f, err := NewVFS2File(t, ep, stype)
+ f, err := NewSockfsFile(t, ep, stype)
if err != nil {
ep.Close()
return nil, err
@@ -331,13 +340,13 @@ func (*providerVFS2) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*
// Create the endpoints and sockets.
ep1, ep2 := transport.NewPair(t, stype, t.Kernel())
- s1, err := NewVFS2File(t, ep1, stype)
+ s1, err := NewSockfsFile(t, ep1, stype)
if err != nil {
ep1.Close()
ep2.Close()
return nil, nil, err
}
- s2, err := NewVFS2File(t, ep2, stype)
+ s2, err := NewSockfsFile(t, ep2, stype)
if err != nil {
s1.DecRef()
ep2.Close()
diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go
index 7e1747a0c..582d37e03 100644
--- a/pkg/sentry/syscalls/linux/sys_signal.go
+++ b/pkg/sentry/syscalls/linux/sys_signal.go
@@ -355,7 +355,7 @@ func Pause(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
func RtSigpending(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
addr := args[0].Pointer()
pending := t.PendingSignals()
- _, err := t.CopyOut(addr, pending)
+ _, err := pending.CopyOut(t, addr)
return 0, nil, err
}
@@ -392,7 +392,7 @@ func RtSigtimedwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
if siginfo != 0 {
si.FixSignalCodeForUser()
- if _, err := t.CopyOut(siginfo, si); err != nil {
+ if _, err := si.CopyOut(t, siginfo); err != nil {
return 0, nil, err
}
}
@@ -411,7 +411,7 @@ func RtSigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
// same way), and that the code is in the allowed set. This same logic
// appears below in RtSigtgqueueinfo and should be kept in sync.
var info arch.SignalInfo
- if _, err := t.CopyIn(infoAddr, &info); err != nil {
+ if _, err := info.CopyIn(t, infoAddr); err != nil {
return 0, nil, err
}
info.Signo = int32(sig)
@@ -455,7 +455,7 @@ func RtTgsigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *ker
// Copy in the info. See RtSigqueueinfo above.
var info arch.SignalInfo
- if _, err := t.CopyIn(infoAddr, &info); err != nil {
+ if _, err := info.CopyIn(t, infoAddr); err != nil {
return 0, nil, err
}
info.Signo = int32(sig)
@@ -485,7 +485,7 @@ func RtSigsuspend(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
// Copy in the signal mask.
var mask linux.SignalSet
- if _, err := t.CopyIn(sigset, &mask); err != nil {
+ if _, err := mask.CopyIn(t, sigset); err != nil {
return 0, nil, err
}
mask &^= kernel.UnblockableSignals
diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go
index a64d86122..981bd8caa 100644
--- a/pkg/sentry/vfs/anonfs.go
+++ b/pkg/sentry/vfs/anonfs.go
@@ -237,10 +237,13 @@ func (fs *anonFilesystem) UnlinkAt(ctx context.Context, rp *ResolvingPath) error
}
// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
-func (fs *anonFilesystem) BoundEndpointAt(ctx context.Context, rp *ResolvingPath) (transport.BoundEndpoint, error) {
+func (fs *anonFilesystem) BoundEndpointAt(ctx context.Context, rp *ResolvingPath, opts BoundEndpointOptions) (transport.BoundEndpoint, error) {
if !rp.Final() {
return nil, syserror.ENOTDIR
}
+ if err := GenericCheckPermissions(rp.Credentials(), MayWrite, anonFileMode, anonFileUID, anonFileGID); err != nil {
+ return nil, err
+ }
return nil, syserror.ECONNREFUSED
}
diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go
index 20e5bb072..1edd584c9 100644
--- a/pkg/sentry/vfs/filesystem.go
+++ b/pkg/sentry/vfs/filesystem.go
@@ -494,8 +494,14 @@ type FilesystemImpl interface {
// BoundEndpointAt returns the Unix socket endpoint bound at the path rp.
//
- // - If a non-socket file exists at rp, then BoundEndpointAt returns ECONNREFUSED.
- BoundEndpointAt(ctx context.Context, rp *ResolvingPath) (transport.BoundEndpoint, error)
+ // Errors:
+ //
+ // - If the file does not have write permissions, then BoundEndpointAt
+ // returns EACCES.
+ //
+ // - If a non-socket file exists at rp, then BoundEndpointAt returns
+ // ECONNREFUSED.
+ BoundEndpointAt(ctx context.Context, rp *ResolvingPath, opts BoundEndpointOptions) (transport.BoundEndpoint, error)
// PrependPath prepends a path from vd to vd.Mount().Root() to b.
//
diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go
index 022bac127..53d364c5c 100644
--- a/pkg/sentry/vfs/options.go
+++ b/pkg/sentry/vfs/options.go
@@ -151,6 +151,24 @@ type SetStatOptions struct {
Stat linux.Statx
}
+// BoundEndpointOptions contains options to VirtualFilesystem.BoundEndpointAt()
+// and FilesystemImpl.BoundEndpointAt().
+type BoundEndpointOptions struct {
+ // Addr is the path of the file whose socket endpoint is being retrieved.
+ // It is generally irrelevant: most endpoints are stored at a dentry that
+ // was created through a bind syscall, so the path can be stored on creation.
+ // However, if the endpoint was created in FilesystemImpl.BoundEndpointAt(),
+ // then we may not know what the original bind address was.
+ //
+ // For example, if connect(2) is called with address "foo" which corresponds
+ // a remote named socket in goferfs, we need to generate an endpoint wrapping
+ // that file. In this case, we can use Addr to set the endpoint address to
+ // "foo". Note that Addr is only a best-effort attempt--we still do not know
+ // the exact address that was used on the remote fs to bind the socket (it
+ // may have been "foo", "./foo", etc.).
+ Addr string
+}
+
// GetxattrOptions contains options to VirtualFilesystem.GetxattrAt(),
// FilesystemImpl.GetxattrAt(), FileDescription.Getxattr(), and
// FileDescriptionImpl.Getxattr().
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 9015f2cc1..8d7f8f8af 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -351,33 +351,6 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia
}
}
-// BoundEndpointAt gets the bound endpoint at the given path, if one exists.
-func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (transport.BoundEndpoint, error) {
- if !pop.Path.Begin.Ok() {
- if pop.Path.Absolute {
- return nil, syserror.ECONNREFUSED
- }
- return nil, syserror.ENOENT
- }
- rp := vfs.getResolvingPath(creds, pop)
- for {
- bep, err := rp.mount.fs.impl.BoundEndpointAt(ctx, rp)
- if err == nil {
- vfs.putResolvingPath(rp)
- return bep, nil
- }
- if checkInvariants {
- if rp.canHandleError(err) && rp.Done() {
- panic(fmt.Sprintf("%T.BoundEndpointAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
- }
- }
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
- return nil, err
- }
- }
-}
-
// OpenAt returns a FileDescription providing access to the file at the given
// path. A reference is taken on the returned FileDescription.
func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) {
@@ -675,6 +648,33 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti
}
}
+// BoundEndpointAt gets the bound endpoint at the given path, if one exists.
+func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *BoundEndpointOptions) (transport.BoundEndpoint, error) {
+ if !pop.Path.Begin.Ok() {
+ if pop.Path.Absolute {
+ return nil, syserror.ECONNREFUSED
+ }
+ return nil, syserror.ENOENT
+ }
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ bep, err := rp.mount.fs.impl.BoundEndpointAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return bep, nil
+ }
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.BoundEndpointAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
+ }
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return nil, err
+ }
+ }
+}
+
// ListxattrAt returns all extended attribute names for the file at the given
// path.
func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, size uint64) ([]string, error) {
diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go
index fcc46420f..101497ed6 100644
--- a/pkg/sentry/watchdog/watchdog.go
+++ b/pkg/sentry/watchdog/watchdog.go
@@ -255,7 +255,7 @@ func (w *Watchdog) runTurn() {
case <-done:
case <-time.After(w.TaskTimeout):
// Report if the watchdog is not making progress.
- // No one is wathching the watchdog watcher though.
+ // No one is watching the watchdog watcher though.
w.reportStuckWatchdog()
<-done
}
@@ -317,10 +317,8 @@ func (w *Watchdog) report(offenders map[*kernel.Task]*offender, newTaskFound boo
buf.WriteString("Search for '(*Task).run(0x..., 0x<tid>)' in the stack dump to find the offending goroutine")
- // Dump stack only if a new task is detected or if it sometime has
- // passed since the last time a stack dump was generated.
- showStack := newTaskFound || time.Since(w.lastStackDump) >= stackDumpSameTaskPeriod
- w.doAction(w.TaskTimeoutAction, showStack, &buf)
+ // Force stack dump only if a new task is detected.
+ w.doAction(w.TaskTimeoutAction, newTaskFound, &buf)
}
func (w *Watchdog) reportStuckWatchdog() {
@@ -329,12 +327,15 @@ func (w *Watchdog) reportStuckWatchdog() {
w.doAction(w.TaskTimeoutAction, false, &buf)
}
-// doAction will take the given action. If the action is LogWarning and
-// showStack is false, then the stack printing will be skipped.
-func (w *Watchdog) doAction(action Action, showStack bool, msg *bytes.Buffer) {
+// doAction will take the given action. If the action is LogWarning, the stack
+// is not always dumpped to the log to prevent log flooding. "forceStack"
+// guarantees that the stack will be dumped regarless.
+func (w *Watchdog) doAction(action Action, forceStack bool, msg *bytes.Buffer) {
switch action {
case LogWarning:
- if !showStack {
+ // Dump stack only if forced or sometime has passed since the last time a
+ // stack dump was generated.
+ if !forceStack && time.Since(w.lastStackDump) < stackDumpSameTaskPeriod {
msg.WriteString("\n...[stack dump skipped]...")
log.Warningf(msg.String())
return
@@ -359,6 +360,7 @@ func (w *Watchdog) doAction(action Action, showStack bool, msg *bytes.Buffer) {
case <-time.After(1 * time.Second):
}
panic(fmt.Sprintf("%s\nStack for running G's are skipped while panicking.", msg.String()))
+
default:
panic(fmt.Sprintf("Unknown watchdog action %v", action))
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
index f01217c91..8ec5d5d5c 100644
--- a/pkg/tcpip/buffer/view.go
+++ b/pkg/tcpip/buffer/view.go
@@ -77,8 +77,7 @@ func NewVectorisedView(size int, views []View) VectorisedView {
return VectorisedView{views: views, size: size}
}
-// TrimFront removes the first "count" bytes of the vectorised view. It panics
-// if count > vv.Size().
+// TrimFront removes the first "count" bytes of the vectorised view.
func (vv *VectorisedView) TrimFront(count int) {
for count > 0 && len(vv.views) > 0 {
if count < len(vv.views[0]) {
@@ -87,7 +86,7 @@ func (vv *VectorisedView) TrimFront(count int) {
return
}
count -= len(vv.views[0])
- vv.removeFirst()
+ vv.RemoveFirst()
}
}
@@ -105,7 +104,7 @@ func (vv *VectorisedView) Read(v View) (copied int, err error) {
count -= len(vv.views[0])
copy(v[copied:], vv.views[0])
copied += len(vv.views[0])
- vv.removeFirst()
+ vv.RemoveFirst()
}
if copied == 0 {
return 0, io.EOF
@@ -127,7 +126,7 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int
count -= len(vv.views[0])
dstVV.AppendView(vv.views[0])
copied += len(vv.views[0])
- vv.removeFirst()
+ vv.RemoveFirst()
}
return copied
}
@@ -163,37 +162,22 @@ func (vv *VectorisedView) Clone(buffer []View) VectorisedView {
return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size}
}
-// PullUp returns the first "count" bytes of the vectorised view. If those
-// bytes aren't already contiguous inside the vectorised view, PullUp will
-// reallocate as needed to make them contiguous. PullUp fails and returns false
-// when count > vv.Size().
-func (vv *VectorisedView) PullUp(count int) (View, bool) {
+// First returns the first view of the vectorised view.
+func (vv *VectorisedView) First() View {
if len(vv.views) == 0 {
- return nil, count == 0
- }
- if count <= len(vv.views[0]) {
- return vv.views[0][:count], true
- }
- if count > vv.size {
- return nil, false
+ return nil
}
+ return vv.views[0]
+}
- newFirst := NewView(count)
- i := 0
- for offset := 0; offset < count; i++ {
- copy(newFirst[offset:], vv.views[i])
- if count-offset < len(vv.views[i]) {
- vv.views[i].TrimFront(count - offset)
- break
- }
- offset += len(vv.views[i])
- vv.views[i] = nil
+// RemoveFirst removes the first view of the vectorised view.
+func (vv *VectorisedView) RemoveFirst() {
+ if len(vv.views) == 0 {
+ return
}
- // We're guaranteed that i > 0, since count is too large for the first
- // view.
- vv.views[i-1] = newFirst
- vv.views = vv.views[i-1:]
- return newFirst, true
+ vv.size -= len(vv.views[0])
+ vv.views[0] = nil
+ vv.views = vv.views[1:]
}
// Size returns the size in bytes of the entire content stored in the vectorised view.
@@ -241,10 +225,3 @@ func (vv *VectorisedView) Readers() []bytes.Reader {
}
return readers
}
-
-// removeFirst panics when len(vv.views) < 1.
-func (vv *VectorisedView) removeFirst() {
- vv.size -= len(vv.views[0])
- vv.views[0] = nil
- vv.views = vv.views[1:]
-}
diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go
index c56795c7b..106e1994c 100644
--- a/pkg/tcpip/buffer/view_test.go
+++ b/pkg/tcpip/buffer/view_test.go
@@ -16,7 +16,6 @@
package buffer
import (
- "bytes"
"reflect"
"testing"
)
@@ -371,115 +370,3 @@ func TestVVRead(t *testing.T) {
})
}
}
-
-var pullUpTestCases = []struct {
- comment string
- in VectorisedView
- count int
- want []byte
- result VectorisedView
- ok bool
-}{
- {
- comment: "simple case",
- in: vv(2, "12"),
- count: 1,
- want: []byte("1"),
- result: vv(2, "12"),
- ok: true,
- },
- {
- comment: "entire View",
- in: vv(2, "1", "2"),
- count: 1,
- want: []byte("1"),
- result: vv(2, "1", "2"),
- ok: true,
- },
- {
- comment: "spanning across two Views",
- in: vv(3, "1", "23"),
- count: 2,
- want: []byte("12"),
- result: vv(3, "12", "3"),
- ok: true,
- },
- {
- comment: "spanning across all Views",
- in: vv(5, "1", "23", "45"),
- count: 5,
- want: []byte("12345"),
- result: vv(5, "12345"),
- ok: true,
- },
- {
- comment: "count = 0",
- in: vv(1, "1"),
- count: 0,
- want: []byte{},
- result: vv(1, "1"),
- ok: true,
- },
- {
- comment: "count = size",
- in: vv(1, "1"),
- count: 1,
- want: []byte("1"),
- result: vv(1, "1"),
- ok: true,
- },
- {
- comment: "count too large",
- in: vv(3, "1", "23"),
- count: 4,
- want: nil,
- result: vv(3, "1", "23"),
- ok: false,
- },
- {
- comment: "empty vv",
- in: vv(0, ""),
- count: 1,
- want: nil,
- result: vv(0, ""),
- ok: false,
- },
- {
- comment: "empty vv, count = 0",
- in: vv(0, ""),
- count: 0,
- want: nil,
- result: vv(0, ""),
- ok: true,
- },
- {
- comment: "empty views",
- in: vv(3, "", "1", "", "23"),
- count: 2,
- want: []byte("12"),
- result: vv(3, "12", "3"),
- ok: true,
- },
-}
-
-func TestPullUp(t *testing.T) {
- for _, c := range pullUpTestCases {
- got, ok := c.in.PullUp(c.count)
-
- // Is the return value right?
- if ok != c.ok {
- t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t",
- c.comment, c.count, c.in, ok, c.ok)
- }
- if bytes.Compare(got, View(c.want)) != 0 {
- t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v",
- c.comment, c.count, c.in, got, c.want)
- }
-
- // Is the underlying structure right?
- if !reflect.DeepEqual(c.in, c.result) {
- t.Errorf("Test %q failed when calling PullUp(%d). Got vv with structure %v. Wanted %v",
- c.comment, c.count, c.in, c.result)
- }
- }
-}
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
index 0cac6c0a5..7908c5744 100644
--- a/pkg/tcpip/header/icmpv4.go
+++ b/pkg/tcpip/header/icmpv4.go
@@ -71,6 +71,7 @@ const (
// Values for ICMP code as defined in RFC 792.
const (
+ ICMPv4TTLExceeded = 0
ICMPv4PortUnreachable = 3
ICMPv4FragmentationNeeded = 4
)
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index ba80b64a8..4f367fe4c 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -17,6 +17,7 @@ package header
import (
"crypto/sha256"
"encoding/binary"
+ "fmt"
"strings"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -445,3 +446,54 @@ func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) {
return GlobalScope, nil
}
}
+
+// InitialTempIID generates the initial temporary IID history value to generate
+// temporary SLAAC addresses with.
+//
+// Panics if initialTempIIDHistory is not at least IIDSize bytes.
+func InitialTempIID(initialTempIIDHistory []byte, seed []byte, nicID tcpip.NICID) {
+ h := sha256.New()
+ // h.Write never returns an error.
+ h.Write(seed)
+ var nicIDBuf [4]byte
+ binary.BigEndian.PutUint32(nicIDBuf[:], uint32(nicID))
+ h.Write(nicIDBuf[:])
+
+ var sumBuf [sha256.Size]byte
+ sum := h.Sum(sumBuf[:0])
+
+ if n := copy(initialTempIIDHistory, sum[sha256.Size-IIDSize:]); n != IIDSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IIDSize))
+ }
+}
+
+// GenerateTempIPv6SLAACAddr generates a temporary SLAAC IPv6 address for an
+// associated stable/permanent SLAAC address.
+//
+// GenerateTempIPv6SLAACAddr will update the temporary IID history value to be
+// used when generating a new temporary IID.
+//
+// Panics if tempIIDHistory is not at least IIDSize bytes.
+func GenerateTempIPv6SLAACAddr(tempIIDHistory []byte, stableAddr tcpip.Address) tcpip.AddressWithPrefix {
+ addrBytes := []byte(stableAddr)
+ h := sha256.New()
+ h.Write(tempIIDHistory)
+ h.Write(addrBytes[IIDOffsetInIPv6Address:])
+ var sumBuf [sha256.Size]byte
+ sum := h.Sum(sumBuf[:0])
+
+ // The rightmost 64 bits of sum are saved for the next iteration.
+ if n := copy(tempIIDHistory, sum[sha256.Size-IIDSize:]); n != IIDSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IIDSize))
+ }
+
+ // The leftmost 64 bits of sum is used as the IID.
+ if n := copy(addrBytes[IIDOffsetInIPv6Address:], sum); n != IIDSize {
+ panic(fmt.Sprintf("copied %d IID bytes, expected %d bytes", n, IIDSize))
+ }
+
+ return tcpip.AddressWithPrefix{
+ Address: tcpip.Address(addrBytes),
+ PrefixLen: IIDOffsetInIPv6Address * 8,
+ }
+}
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index 073c84ef9..1e2255bfa 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -98,13 +98,13 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- // There should be an ethernet header at the beginning of vv.
- hdr, ok := vv.PullUp(header.EthernetMinimumSize)
- if !ok {
- // Reject the packet if it's shorter than an ethernet header.
+ // Reject the packet if it's shorter than an ethernet header.
+ if vv.Size() < header.EthernetMinimumSize {
return tcpip.ErrBadAddress
}
- linkHeader := header.Ethernet(hdr)
+
+ // There should be an ethernet header at the beginning of vv.
+ linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize])
vv.TrimFront(len(linkHeader))
e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{
Data: vv,
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
index 9cc08d0e2..14b527bc2 100644
--- a/pkg/tcpip/link/rawfile/BUILD
+++ b/pkg/tcpip/link/rawfile/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -18,10 +18,3 @@ go_library(
"@org_golang_x_sys//unix:go_default_library",
],
)
-
-go_test(
- name = "rawfile_test",
- size = "small",
- srcs = ["rawfile_test.go"],
- library = ":rawfile",
-)
diff --git a/pkg/tcpip/link/rawfile/rawfile_test.go b/pkg/tcpip/link/rawfile/rawfile_test.go
deleted file mode 100644
index 8f14ba761..000000000
--- a/pkg/tcpip/link/rawfile/rawfile_test.go
+++ /dev/null
@@ -1,46 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// +build linux
-
-package rawfile
-
-import (
- "syscall"
- "testing"
-)
-
-func TestNonBlockingWrite3ZeroLength(t *testing.T) {
- fd, err := syscall.Open("/dev/null", syscall.O_WRONLY, 0)
- if err != nil {
- t.Fatalf("failed to open /dev/null: %v", err)
- }
- defer syscall.Close(fd)
-
- if err := NonBlockingWrite3(fd, []byte{}, []byte{0}, nil); err != nil {
- t.Fatalf("failed to write: %v", err)
- }
-}
-
-func TestNonBlockingWrite3Nil(t *testing.T) {
- fd, err := syscall.Open("/dev/null", syscall.O_WRONLY, 0)
- if err != nil {
- t.Fatalf("failed to open /dev/null: %v", err)
- }
- defer syscall.Close(fd)
-
- if err := NonBlockingWrite3(fd, nil, []byte{0}, nil); err != nil {
- t.Fatalf("failed to write: %v", err)
- }
-}
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index 92efd0bf8..44e25d475 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -76,13 +76,9 @@ func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error {
// We have two buffers. Build the iovec that represents them and issue
// a writev syscall.
- var base *byte
- if len(b1) > 0 {
- base = &b1[0]
- }
iovec := [3]syscall.Iovec{
{
- Base: base,
+ Base: &b1[0],
Len: uint64(len(b1)),
},
{
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 33f640b85..27ea3f531 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -674,7 +674,7 @@ func TestSimpleReceive(t *testing.T) {
// Wait for packet to be received, then check it.
c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet")
c.mu.Lock()
- rcvd := []byte(c.packets[0].vv.ToView())
+ rcvd := []byte(c.packets[0].vv.First())
c.packets = c.packets[:0]
c.mu.Unlock()
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 0799c8f4d..be2537a82 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -171,7 +171,11 @@ func (e *endpoint) GSOMaxSize() uint32 {
func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
writer := e.writer
if writer == nil && atomic.LoadUint32(&LogPackets) == 1 {
- logPacket(prefix, protocol, pkt, gso)
+ first := pkt.Header.View()
+ if len(first) == 0 {
+ first = pkt.Data.First()
+ }
+ logPacket(prefix, protocol, first, gso)
}
if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 {
totalLength := pkt.Header.UsedLength() + pkt.Data.Size()
@@ -234,7 +238,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
// Wait implements stack.LinkEndpoint.Wait.
func (e *endpoint) Wait() { e.lower.Wait() }
-func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) {
+func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) {
// Figure out the network layer info.
var transProto uint8
src := tcpip.Address("unknown")
@@ -243,49 +247,28 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
size := uint16(0)
var fragmentOffset uint16
var moreFragments bool
-
- // Create a clone of pkt, including any headers if present. Avoid allocating
- // backing memory for the clone.
- views := [8]buffer.View{}
- vv := buffer.NewVectorisedView(0, views[:0])
- vv.AppendView(pkt.Header.View())
- vv.Append(pkt.Data)
-
switch protocol {
case header.IPv4ProtocolNumber:
- hdr, ok := vv.PullUp(header.IPv4MinimumSize)
- if !ok {
- return
- }
- ipv4 := header.IPv4(hdr)
+ ipv4 := header.IPv4(b)
fragmentOffset = ipv4.FragmentOffset()
moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments
src = ipv4.SourceAddress()
dst = ipv4.DestinationAddress()
transProto = ipv4.Protocol()
size = ipv4.TotalLength() - uint16(ipv4.HeaderLength())
- vv.TrimFront(int(ipv4.HeaderLength()))
+ b = b[ipv4.HeaderLength():]
id = int(ipv4.ID())
case header.IPv6ProtocolNumber:
- hdr, ok := vv.PullUp(header.IPv6MinimumSize)
- if !ok {
- return
- }
- ipv6 := header.IPv6(hdr)
+ ipv6 := header.IPv6(b)
src = ipv6.SourceAddress()
dst = ipv6.DestinationAddress()
transProto = ipv6.NextHeader()
size = ipv6.PayloadLength()
- vv.TrimFront(header.IPv6MinimumSize)
+ b = b[header.IPv6MinimumSize:]
case header.ARPProtocolNumber:
- hdr, ok := vv.PullUp(header.ARPSize)
- if !ok {
- return
- }
- vv.TrimFront(header.ARPSize)
- arp := header.ARP(hdr)
+ arp := header.ARP(b)
log.Infof(
"%s arp %v (%v) -> %v (%v) valid:%v",
prefix,
@@ -301,7 +284,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
// We aren't guaranteed to have a transport header - it's possible for
// writes via raw endpoints to contain only network headers.
- if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && vv.Size() < minSize {
+ if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && len(b) < minSize {
log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto)
return
}
@@ -314,11 +297,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
switch tcpip.TransportProtocolNumber(transProto) {
case header.ICMPv4ProtocolNumber:
transName = "icmp"
- hdr, ok := vv.PullUp(header.ICMPv4MinimumSize)
- if !ok {
- break
- }
- icmp := header.ICMPv4(hdr)
+ icmp := header.ICMPv4(b)
icmpType := "unknown"
if fragmentOffset == 0 {
switch icmp.Type() {
@@ -351,11 +330,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
case header.ICMPv6ProtocolNumber:
transName = "icmp"
- hdr, ok := vv.PullUp(header.ICMPv6MinimumSize)
- if !ok {
- break
- }
- icmp := header.ICMPv6(hdr)
+ icmp := header.ICMPv6(b)
icmpType := "unknown"
switch icmp.Type() {
case header.ICMPv6DstUnreachable:
@@ -386,11 +361,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
case header.UDPProtocolNumber:
transName = "udp"
- hdr, ok := vv.PullUp(header.UDPMinimumSize)
- if !ok {
- break
- }
- udp := header.UDP(hdr)
+ udp := header.UDP(b)
if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize {
srcPort = udp.SourcePort()
dstPort = udp.DestinationPort()
@@ -400,11 +371,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
case header.TCPProtocolNumber:
transName = "tcp"
- hdr, ok := vv.PullUp(header.TCPMinimumSize)
- if !ok {
- break
- }
- tcp := header.TCP(hdr)
+ tcp := header.TCP(b)
if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize {
offset := int(tcp.DataOffset())
if offset < header.TCPMinimumSize {
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index cf73a939e..7acbfa0a8 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -93,10 +93,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf
}
func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
- v, ok := pkt.Data.PullUp(header.ARPSize)
- if !ok {
- return
- }
+ v := pkt.Data.First()
h := header.ARP(v)
if !h.IsValid() {
return
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 4cbefe5ab..c4bf1ba5c 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -25,11 +25,7 @@ import (
// used to find out which transport endpoint must be notified about the ICMP
// packet.
func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
- h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
- return
- }
- hdr := header.IPv4(h)
+ h := header.IPv4(pkt.Data.First())
// We don't use IsValid() here because ICMP only requires that the IP
// header plus 8 bytes of the transport header be included. So it's
@@ -38,12 +34,12 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.
//
// Drop packet if it doesn't have the basic IPv4 header or if the
// original source address doesn't match the endpoint's address.
- if hdr.SourceAddress() != e.id.LocalAddress {
+ if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress {
return
}
- hlen := int(hdr.HeaderLength())
- if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 {
+ hlen := int(h.HeaderLength())
+ if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 {
// We won't be able to handle this if it doesn't contain the
// full IPv4 header, or if it's a fragment not at offset 0
// (because it won't have the transport header).
@@ -52,15 +48,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.
// Skip the ip header, then deliver control message.
pkt.Data.TrimFront(hlen)
- p := hdr.TransportProtocol()
- e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
+ p := h.TransportProtocol()
+ e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) {
stats := r.Stats()
received := stats.ICMP.V4PacketsReceived
- v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
- if !ok {
+ v := pkt.Data.First()
+ if len(v) < header.ICMPv4MinimumSize {
received.Invalid.Increment()
return
}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 17202cc7a..104aafbed 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -328,11 +328,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
// The packet already has an IP header, but there are a few required
// checks.
- h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
- return tcpip.ErrInvalidOptionValue
- }
- ip := header.IPv4(h)
+ ip := header.IPv4(pkt.Data.First())
if !ip.IsValid(pkt.Data.Size()) {
return tcpip.ErrInvalidOptionValue
}
@@ -382,11 +378,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
- headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
- r.Stats().IP.MalformedPacketsReceived.Increment()
- return
- }
+ headerView := pkt.Data.First()
h := header.IPv4(headerView)
if !h.IsValid(pkt.Data.Size()) {
r.Stats().IP.MalformedPacketsReceived.Increment()
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index bdf3a0d25..b68983d10 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -28,11 +28,7 @@ import (
// used to find out which transport endpoint must be notified about the ICMP
// packet.
func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
- h, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
- if !ok {
- return
- }
- hdr := header.IPv6(h)
+ h := header.IPv6(pkt.Data.First())
// We don't use IsValid() here because ICMP only requires that up to
// 1280 bytes of the original packet be included. So it's likely that it
@@ -40,21 +36,17 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.
//
// Drop packet if it doesn't have the basic IPv6 header or if the
// original source address doesn't match the endpoint's address.
- if hdr.SourceAddress() != e.id.LocalAddress {
+ if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress {
return
}
// Skip the IP header, then handle the fragmentation header if there
// is one.
pkt.Data.TrimFront(header.IPv6MinimumSize)
- p := hdr.TransportProtocol()
+ p := h.TransportProtocol()
if p == header.IPv6FragmentHeader {
- f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize)
- if !ok {
- return
- }
- fragHdr := header.IPv6Fragment(f)
- if !fragHdr.IsValid() || fragHdr.FragmentOffset() != 0 {
+ f := header.IPv6Fragment(pkt.Data.First())
+ if !f.IsValid() || f.FragmentOffset() != 0 {
// We can't handle fragments that aren't at offset 0
// because they don't have the transport headers.
return
@@ -63,19 +55,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.
// Skip fragmentation header and find out the actual protocol
// number.
pkt.Data.TrimFront(header.IPv6FragmentHeaderSize)
- p = fragHdr.TransportProtocol()
+ p = f.TransportProtocol()
}
// Deliver the control packet to the transport endpoint.
- e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
+ e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) {
stats := r.Stats().ICMP
sent := stats.V6PacketsSent
received := stats.V6PacketsReceived
- v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize)
- if !ok {
+ v := pkt.Data.First()
+ if len(v) < header.ICMPv6MinimumSize {
received.Invalid.Increment()
return
}
@@ -84,9 +76,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
// Validate ICMPv6 checksum before processing the packet.
//
+ // Only the first view in vv is accounted for by h. To account for the
+ // rest of vv, a shallow copy is made and the first view is removed.
// This copy is used as extra payload during the checksum calculation.
payload := pkt.Data.Clone(nil)
- payload.TrimFront(len(h))
+ payload.RemoveFirst()
if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want {
received.Invalid.Increment()
return
@@ -107,40 +101,34 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
switch h.Type() {
case header.ICMPv6PacketTooBig:
received.PacketTooBig.Increment()
- hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize)
- if !ok {
+ if len(v) < header.ICMPv6PacketTooBigMinimumSize {
received.Invalid.Increment()
return
}
pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
- mtu := header.ICMPv6(hdr).MTU()
+ mtu := h.MTU()
e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt)
case header.ICMPv6DstUnreachable:
received.DstUnreachable.Increment()
- hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize)
- if !ok {
+ if len(v) < header.ICMPv6DstUnreachableMinimumSize {
received.Invalid.Increment()
return
}
pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
- switch header.ICMPv6(hdr).Code() {
+ switch h.Code() {
case header.ICMPv6PortUnreachable:
e.handleControl(stack.ControlPortUnreachable, 0, pkt)
}
case header.ICMPv6NeighborSolicit:
received.NeighborSolicit.Increment()
- if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() {
+ if len(v) < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() {
received.Invalid.Increment()
return
}
- // The remainder of payload must be only the neighbor solicitation, so
- // payload.ToView() always returns the solicitation. Per RFC 6980 section 5,
- // NDP messages cannot be fragmented. Also note that in the common case NDP
- // datagrams are very small and ToView() will not incur allocations.
- ns := header.NDPNeighborSolicit(payload.ToView())
+ ns := header.NDPNeighborSolicit(h.NDPPayload())
it, err := ns.Options().Iter(true)
if err != nil {
// If we have a malformed NDP NS option, drop the packet.
@@ -298,16 +286,12 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
case header.ICMPv6NeighborAdvert:
received.NeighborAdvert.Increment()
- if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() {
+ if len(v) < header.ICMPv6NeighborAdvertSize || !isNDPValid() {
received.Invalid.Increment()
return
}
- // The remainder of payload must be only the neighbor advertisement, so
- // payload.ToView() always returns the advertisement. Per RFC 6980 section
- // 5, NDP messages cannot be fragmented. Also note that in the common case
- // NDP datagrams are very small and ToView() will not incur allocations.
- na := header.NDPNeighborAdvert(payload.ToView())
+ na := header.NDPNeighborAdvert(h.NDPPayload())
it, err := na.Options().Iter(true)
if err != nil {
// If we have a malformed NDP NA option, drop the packet.
@@ -379,15 +363,14 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
case header.ICMPv6EchoRequest:
received.EchoRequest.Increment()
- icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize)
- if !ok {
+ if len(v) < header.ICMPv6EchoMinimumSize {
received.Invalid.Increment()
return
}
pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize)
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize)
packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
- copy(packet, icmpHdr)
+ copy(packet, h)
packet.SetType(header.ICMPv6EchoReply)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data))
if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
@@ -401,7 +384,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
case header.ICMPv6EchoReply:
received.EchoReply.Increment()
- if pkt.Data.Size() < header.ICMPv6EchoMinimumSize {
+ if len(v) < header.ICMPv6EchoMinimumSize {
received.Invalid.Increment()
return
}
@@ -423,9 +406,8 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
case header.ICMPv6RouterAdvert:
received.RouterAdvert.Increment()
- // Is the NDP payload of sufficient size to hold a Router
- // Advertisement?
- if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() {
+ p := h.NDPPayload()
+ if len(p) < header.NDPRAMinimumSize || !isNDPValid() {
received.Invalid.Increment()
return
}
@@ -443,11 +425,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
return
}
- // The remainder of payload must be only the router advertisement, so
- // payload.ToView() always returns the advertisement. Per RFC 6980 section
- // 5, NDP messages cannot be fragmented. Also note that in the common case
- // NDP datagrams are very small and ToView() will not incur allocations.
- ra := header.NDPRouterAdvert(payload.ToView())
+ ra := header.NDPRouterAdvert(p)
opts := ra.Options()
// Are options valid as per the wire format?
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index d412ff688..bd099a7f8 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -166,8 +166,7 @@ func TestICMPCounts(t *testing.T) {
},
{
typ: header.ICMPv6NeighborSolicit,
- size: header.ICMPv6NeighborSolicitMinimumSize,
- },
+ size: header.ICMPv6NeighborSolicitMinimumSize},
{
typ: header.ICMPv6NeighborAdvert,
size: header.ICMPv6NeighborAdvertMinimumSize,
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 486725131..331b0817b 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -171,11 +171,7 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffe
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
- headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
- if !ok {
- r.Stats().IP.MalformedPacketsReceived.Increment()
- return
- }
+ headerView := pkt.Data.First()
h := header.IPv6(headerView)
if !h.IsValid(pkt.Data.Size()) {
r.Stats().IP.MalformedPacketsReceived.Increment()
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index c7c663498..e9c652042 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -70,10 +70,7 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID {
func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) {
// Consume the network header.
- b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen)
- if !ok {
- return
- }
+ b := pkt.Data.First()
pkt.Data.TrimFront(fwdTestNetHeaderLen)
// Dispatch the packet to the transport protocol.
@@ -476,7 +473,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
t.Fatal("packet not forwarded")
}
- b := p.Pkt.Data.ToView()
+ b := p.Pkt.Header.View()
if b[0] != 3 {
t.Fatalf("got b[0] = %d, want = 3", b[0])
}
@@ -520,7 +517,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
t.Fatal("packet not forwarded")
}
- b := p.Pkt.Data.ToView()
+ b := p.Pkt.Header.View()
if b[0] != 3 {
t.Fatalf("got b[0] = %d, want = 3", b[0])
}
@@ -567,7 +564,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
t.Fatal("packet not forwarded")
}
- b := p.Pkt.Data.ToView()
+ b := p.Pkt.Header.View()
if b[0] != 3 {
t.Fatalf("got b[0] = %d, want = 3", b[0])
}
@@ -622,7 +619,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
// The first 5 packets (address 3 to 7) should not be forwarded
// because their address resolutions are interrupted.
- b := p.Pkt.Data.ToView()
+ b := p.Pkt.Header.View()
if b[0] < 8 {
t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0])
}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 6b91159d4..6c0a4b24d 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -212,11 +212,6 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool {
// CheckPackets runs pkts through the rules for hook and returns a map of packets that
// should not go forward.
//
-// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-//
-// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a
-// precondition.
-//
// NOTE: unlike the Check API the returned map contains packets that should be
// dropped.
func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) {
@@ -231,9 +226,7 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*Pa
return drop
}
-// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a
-// precondition.
+// Precondition: pkt.NetworkHeader is set.
func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
@@ -278,21 +271,14 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx
return chainDrop
}
-// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a
-// precondition.
+// Precondition: pk.NetworkHeader is set.
func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
// If pkt.NetworkHeader hasn't been set yet, it will be contained in
- // pkt.Data.
+ // pkt.Data.First().
if pkt.NetworkHeader == nil {
- var ok bool
- pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
- // Precondition has been violated.
- panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize))
- }
+ pkt.NetworkHeader = pkt.Data.First()
}
// Check whether the packet matches the IP header filter.
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 8be61f4b1..7b4543caf 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -96,12 +96,9 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) {
newPkt := pkt.Clone()
// Set network header.
- headerView, ok := newPkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
- return RuleDrop, 0
- }
+ headerView := newPkt.Data.First()
netHeader := header.IPv4(headerView)
- newPkt.NetworkHeader = headerView
+ newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize]
hlen := int(netHeader.HeaderLength())
tlen := int(netHeader.TotalLength())
@@ -120,14 +117,10 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) {
if newPkt.TransportHeader != nil {
udpHeader = header.UDP(newPkt.TransportHeader)
} else {
- if pkt.Data.Size() < header.UDPMinimumSize {
- return RuleDrop, 0
- }
- hdr, ok := newPkt.Data.PullUp(header.UDPMinimumSize)
- if !ok {
+ if len(pkt.Data.First()) < header.UDPMinimumSize {
return RuleDrop, 0
}
- udpHeader = header.UDP(hdr)
+ udpHeader = header.UDP(newPkt.Data.First())
}
udpHeader.SetDestinationPort(rt.MinPort)
case header.TCPProtocolNumber:
@@ -135,14 +128,10 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) {
if newPkt.TransportHeader != nil {
tcpHeader = header.TCP(newPkt.TransportHeader)
} else {
- if pkt.Data.Size() < header.TCPMinimumSize {
+ if len(pkt.Data.First()) < header.TCPMinimumSize {
return RuleDrop, 0
}
- hdr, ok := newPkt.Data.PullUp(header.TCPMinimumSize)
- if !ok {
- return RuleDrop, 0
- }
- tcpHeader = header.TCP(hdr)
+ tcpHeader = header.TCP(newPkt.TransportHeader)
}
// TODO(gvisor.dev/issue/170): Need to recompute checksum
// and implement nat connection tracking to support TCP.
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index c11d62f97..fbc0e3f55 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -119,6 +119,32 @@ const (
// identifier (IID) is 64 bits and an IPv6 address is 128 bits, so
// 128 - 64 = 64.
validPrefixLenForAutoGen = 64
+
+ // defaultAutoGenTempGlobalAddresses is the default configuration for whether
+ // or not to generate temporary SLAAC addresses.
+ defaultAutoGenTempGlobalAddresses = true
+
+ // defaultMaxTempAddrValidLifetime is the default maximum valid lifetime
+ // for temporary SLAAC addresses generated as part of RFC 4941.
+ //
+ // Default = 7 days (from RFC 4941 section 5).
+ defaultMaxTempAddrValidLifetime = 7 * 24 * time.Hour
+
+ // defaultMaxTempAddrPreferredLifetime is the default preferred lifetime
+ // for temporary SLAAC addresses generated as part of RFC 4941.
+ //
+ // Default = 1 day (from RFC 4941 section 5).
+ defaultMaxTempAddrPreferredLifetime = 24 * time.Hour
+
+ // defaultRegenAdvanceDuration is the default duration before the deprecation
+ // of a temporary address when a new address will be generated.
+ //
+ // Default = 5s (from RFC 4941 section 5).
+ defaultRegenAdvanceDuration = 5 * time.Second
+
+ // minRegenAdvanceDuration is the minimum duration before the deprecation
+ // of a temporary address when a new address will be generated.
+ minRegenAdvanceDuration = time.Duration(0)
)
var (
@@ -131,6 +157,37 @@ var (
//
// Min = 2hrs.
MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour
+
+ // MaxDesyncFactor is the upper bound for the preferred lifetime's desync
+ // factor for temporary SLAAC addresses.
+ //
+ // This is exported as a variable (instead of a constant) so tests
+ // can update it to a smaller value.
+ //
+ // Must be greater than 0.
+ //
+ // Max = 10m (from RFC 4941 section 5).
+ MaxDesyncFactor = 10 * time.Minute
+
+ // MinMaxTempAddrPreferredLifetime is the minimum value allowed for the
+ // maximum preferred lifetime for temporary SLAAC addresses.
+ //
+ // This is exported as a variable (instead of a constant) so tests
+ // can update it to a smaller value.
+ //
+ // This value guarantees that a temporary address will be preferred for at
+ // least 1hr if the SLAAC prefix is valid for at least that time.
+ MinMaxTempAddrPreferredLifetime = defaultRegenAdvanceDuration + MaxDesyncFactor + time.Hour
+
+ // MinMaxTempAddrValidLifetime is the minimum value allowed for the
+ // maximum valid lifetime for temporary SLAAC addresses.
+ //
+ // This is exported as a variable (instead of a constant) so tests
+ // can update it to a smaller value.
+ //
+ // This value guarantees that a temporary address will be valid for at least
+ // 2hrs if the SLAAC prefix is valid for at least that time.
+ MinMaxTempAddrValidLifetime = 2 * time.Hour
)
// DHCPv6ConfigurationFromNDPRA is a configuration available via DHCPv6 that an
@@ -324,35 +381,49 @@ type NDPConfigurations struct {
// alternative addresses (e.g. IIDs based on the modified EUI64 of a NIC's
// MAC address), then no attempt will be made to resolve the conflict.
AutoGenAddressConflictRetries uint8
+
+ // AutoGenTempGlobalAddresses determines whether or not temporary SLAAC
+ // addresses will be generated for a NIC as part of SLAAC privacy extensions,
+ // RFC 4941.
+ //
+ // Ignored if AutoGenGlobalAddresses is false.
+ AutoGenTempGlobalAddresses bool
+
+ // MaxTempAddrValidLifetime is the maximum valid lifetime for temporary
+ // SLAAC addresses.
+ MaxTempAddrValidLifetime time.Duration
+
+ // MaxTempAddrPreferredLifetime is the maximum preferred lifetime for
+ // temporary SLAAC addresses.
+ MaxTempAddrPreferredLifetime time.Duration
+
+ // RegenAdvanceDuration is the duration before the deprecation of a temporary
+ // address when a new address will be generated.
+ RegenAdvanceDuration time.Duration
}
// DefaultNDPConfigurations returns an NDPConfigurations populated with
// default values.
func DefaultNDPConfigurations() NDPConfigurations {
return NDPConfigurations{
- DupAddrDetectTransmits: defaultDupAddrDetectTransmits,
- RetransmitTimer: defaultRetransmitTimer,
- MaxRtrSolicitations: defaultMaxRtrSolicitations,
- RtrSolicitationInterval: defaultRtrSolicitationInterval,
- MaxRtrSolicitationDelay: defaultMaxRtrSolicitationDelay,
- HandleRAs: defaultHandleRAs,
- DiscoverDefaultRouters: defaultDiscoverDefaultRouters,
- DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes,
- AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses,
+ DupAddrDetectTransmits: defaultDupAddrDetectTransmits,
+ RetransmitTimer: defaultRetransmitTimer,
+ MaxRtrSolicitations: defaultMaxRtrSolicitations,
+ RtrSolicitationInterval: defaultRtrSolicitationInterval,
+ MaxRtrSolicitationDelay: defaultMaxRtrSolicitationDelay,
+ HandleRAs: defaultHandleRAs,
+ DiscoverDefaultRouters: defaultDiscoverDefaultRouters,
+ DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes,
+ AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses,
+ AutoGenTempGlobalAddresses: defaultAutoGenTempGlobalAddresses,
+ MaxTempAddrValidLifetime: defaultMaxTempAddrValidLifetime,
+ MaxTempAddrPreferredLifetime: defaultMaxTempAddrPreferredLifetime,
+ RegenAdvanceDuration: defaultRegenAdvanceDuration,
}
}
// validate modifies an NDPConfigurations with valid values. If invalid values
// are present in c, the corresponding default values will be used instead.
-//
-// If RetransmitTimer is less than minimumRetransmitTimer, then a value of
-// defaultRetransmitTimer will be used.
-//
-// If RtrSolicitationInterval is less than minimumRtrSolicitationInterval, then
-// a value of defaultRtrSolicitationInterval will be used.
-//
-// If MaxRtrSolicitationDelay is less than minimumMaxRtrSolicitationDelay, then
-// a value of defaultMaxRtrSolicitationDelay will be used.
func (c *NDPConfigurations) validate() {
if c.RetransmitTimer < minimumRetransmitTimer {
c.RetransmitTimer = defaultRetransmitTimer
@@ -365,6 +436,18 @@ func (c *NDPConfigurations) validate() {
if c.MaxRtrSolicitationDelay < minimumMaxRtrSolicitationDelay {
c.MaxRtrSolicitationDelay = defaultMaxRtrSolicitationDelay
}
+
+ if c.MaxTempAddrValidLifetime < MinMaxTempAddrValidLifetime {
+ c.MaxTempAddrValidLifetime = MinMaxTempAddrValidLifetime
+ }
+
+ if c.MaxTempAddrPreferredLifetime < MinMaxTempAddrPreferredLifetime || c.MaxTempAddrPreferredLifetime > c.MaxTempAddrValidLifetime {
+ c.MaxTempAddrPreferredLifetime = MinMaxTempAddrPreferredLifetime
+ }
+
+ if c.RegenAdvanceDuration < minRegenAdvanceDuration {
+ c.RegenAdvanceDuration = minRegenAdvanceDuration
+ }
}
// ndpState is the per-interface NDP state.
@@ -394,6 +477,14 @@ type ndpState struct {
// The last learned DHCPv6 configuration from an NDP RA.
dhcpv6Configuration DHCPv6ConfigurationFromNDPRA
+
+ // temporaryIIDHistory is the history value used to generate a new temporary
+ // IID.
+ temporaryIIDHistory [header.IIDSize]byte
+
+ // temporaryAddressDesyncFactor is the preferred lifetime's desync factor for
+ // temporary SLAAC addresses.
+ temporaryAddressDesyncFactor time.Duration
}
// dadState holds the Duplicate Address Detection timer and channel to signal
@@ -414,7 +505,7 @@ type dadState struct {
type defaultRouterState struct {
// Timer to invalidate the default router.
//
- // May not be nil.
+ // Must not be nil.
invalidationTimer *tcpip.CancellableTimer
}
@@ -424,20 +515,48 @@ type defaultRouterState struct {
type onLinkPrefixState struct {
// Timer to invalidate the on-link prefix.
//
- // May not be nil.
+ // Must not be nil.
+ invalidationTimer *tcpip.CancellableTimer
+}
+
+// tempSLAACAddrState holds state associated with a temporary SLAAC address.
+type tempSLAACAddrState struct {
+ // Timer to deprecate the temporary SLAAC address.
+ //
+ // Must not be nil.
+ deprecationTimer *tcpip.CancellableTimer
+
+ // Timer to invalidate the temporary SLAAC address.
+ //
+ // Must not be nil.
invalidationTimer *tcpip.CancellableTimer
+
+ // Timer to regenerate the temporary SLAAC address.
+ //
+ // Must not be nil.
+ regenTimer *tcpip.CancellableTimer
+
+ createdAt time.Time
+
+ // The address's endpoint.
+ //
+ // Must not be nil.
+ ref *referencedNetworkEndpoint
+
+ // Has a new temporary SLAAC address already been regenerated?
+ regenerated bool
}
// slaacPrefixState holds state associated with a SLAAC prefix.
type slaacPrefixState struct {
// Timer to deprecate the prefix.
//
- // May not be nil.
+ // Must not be nil.
deprecationTimer *tcpip.CancellableTimer
// Timer to invalidate the prefix.
//
- // May not be nil.
+ // Must not be nil.
invalidationTimer *tcpip.CancellableTimer
// Nonzero only when the address is not valid forever.
@@ -446,19 +565,27 @@ type slaacPrefixState struct {
// Nonzero only when the address is not preferred forever.
preferredUntil time.Time
- // The prefix's permanent address endpoint.
+ // The endpoint for the stable address generated for a SLAAC prefix.
//
// May only be nil when a SLAAC address is being (re-)generated. Otherwise,
- // must not be nil as all SLAAC prefixes must have a SLAAC address.
- ref *referencedNetworkEndpoint
+ // must not be nil as all SLAAC prefixes must have a stable address.
+ stableAddrRef *referencedNetworkEndpoint
+
+ // The temporary (short-lived) addresses generated for the SLAAC prefix.
+ tempAddrs map[tcpip.Address]tempSLAACAddrState
- // The number of times a permanent address has been generated for the prefix.
+ // The next two fields are used by both stable and temporary addresses
+ // generated for a SLAAC prefix. This is safe as only 1 address will be
+ // in the generation and DAD process at any time. That is, no two addresses
+ // will be generated at the same time for a given SLAAC prefix.
+
+ // The number of times an address has been generated.
//
// Addresses may be regenerated in reseponse to a DAD conflicts.
generationAttempts uint8
- // The maximum number of times to attempt regeneration of a permanent SLAAC
- // address in response to DAD conflicts.
+ // The maximum number of times to attempt regeneration of a SLAAC address
+ // in response to DAD conflicts.
maxGenerationAttempts uint8
}
@@ -536,10 +663,10 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
}
ndp.nic.mu.Lock()
+ defer ndp.nic.mu.Unlock()
if done {
// If we reach this point, it means that DAD was stopped after we released
// the NIC's read lock and before we obtained the write lock.
- ndp.nic.mu.Unlock()
return
}
@@ -551,8 +678,6 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
// schedule the next DAD timer.
remaining--
timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer)
-
- ndp.nic.mu.Unlock()
return
}
@@ -560,15 +685,18 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
// the last NDP NS. Either way, clean up addr's DAD state and let the
// integrator know DAD has completed.
delete(ndp.dad, addr)
- ndp.nic.mu.Unlock()
-
- if err != nil {
- log.Printf("ndpdad: error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, ndp.nic.ID(), err)
- }
if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, dadDone, err)
}
+
+ // If DAD resolved for a stable SLAAC address, attempt generation of a
+ // temporary SLAAC address.
+ if dadDone && ref.configType == slaac {
+ // Reset the generation attempts counter as we are starting the generation
+ // of a new address for the SLAAC prefix.
+ ndp.regenerateTempSLAACAddr(ref.addrWithPrefix().Subnet(), true /* resetGenAttempts */)
+ }
})
ndp.dad[addr] = dadState{
@@ -953,9 +1081,10 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform
prefix := pi.Subnet()
// Check if we already maintain SLAAC state for prefix.
- if _, ok := ndp.slaacPrefixes[prefix]; ok {
+ if state, ok := ndp.slaacPrefixes[prefix]; ok {
// As per RFC 4862 section 5.5.3.e, refresh prefix's SLAAC lifetimes.
- ndp.refreshSLAACPrefixLifetimes(prefix, pl, vl)
+ ndp.refreshSLAACPrefixLifetimes(prefix, &state, pl, vl)
+ ndp.slaacPrefixes[prefix] = state
return
}
@@ -996,7 +1125,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix))
}
- ndp.deprecateSLAACAddress(state.ref)
+ ndp.deprecateSLAACAddress(state.stableAddrRef)
}),
invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
@@ -1006,6 +1135,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
ndp.invalidateSLAACPrefix(prefix, state)
}),
+ tempAddrs: make(map[tcpip.Address]tempSLAACAddrState),
maxGenerationAttempts: ndp.configs.AutoGenAddressConflictRetries + 1,
}
@@ -1035,9 +1165,49 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
state.validUntil = now.Add(vl)
}
+ // If the address is assigned (DAD resolved), generate a temporary address.
+ if state.stableAddrRef.getKind() == permanent {
+ // Reset the generation attempts counter as we are starting the generation
+ // of a new address for the SLAAC prefix.
+ ndp.generateTempSLAACAddr(prefix, &state, true /* resetGenAttempts */)
+ }
+
ndp.slaacPrefixes[prefix] = state
}
+// addSLAACAddr adds a SLAAC address to the NIC.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) addSLAACAddr(addr tcpip.AddressWithPrefix, configType networkEndpointConfigType, deprecated bool) *referencedNetworkEndpoint {
+ // If the nic already has this address, do nothing further.
+ if ndp.nic.hasPermanentAddrLocked(addr.Address) {
+ return nil
+ }
+
+ // Inform the integrator that we have a new SLAAC address.
+ ndpDisp := ndp.nic.stack.ndpDisp
+ if ndpDisp == nil {
+ return nil
+ }
+
+ if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addr) {
+ // Informed by the integrator not to add the address.
+ return nil
+ }
+
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr,
+ }
+
+ ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, configType, deprecated)
+ if err != nil {
+ panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", protocolAddr, err))
+ }
+
+ return ref
+}
+
// generateSLAACAddr generates a SLAAC address for prefix.
//
// Returns true if an address was successfully generated.
@@ -1046,7 +1216,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
//
// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixState) bool {
- if r := state.ref; r != nil {
+ if r := state.stableAddrRef; r != nil {
panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, r.addrWithPrefix()))
}
@@ -1085,39 +1255,18 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
return false
}
- generatedAddr := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(addrBytes),
- PrefixLen: validPrefixLenForAutoGen,
- },
- }
-
- // If the nic already has this address, do nothing further.
- if ndp.nic.hasPermanentAddrLocked(generatedAddr.AddressWithPrefix.Address) {
- return false
- }
-
- // Inform the integrator that we have a new SLAAC address.
- ndpDisp := ndp.nic.stack.ndpDisp
- if ndpDisp == nil {
- return false
+ generatedAddr := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(addrBytes),
+ PrefixLen: validPrefixLenForAutoGen,
}
- if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), generatedAddr.AddressWithPrefix) {
- // Informed by the integrator not to add the address.
- return false
+ if ref := ndp.addSLAACAddr(generatedAddr, slaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); ref != nil {
+ state.stableAddrRef = ref
+ state.generationAttempts++
+ return true
}
- deprecated := time.Since(state.preferredUntil) >= 0
- ref, err := ndp.nic.addAddressLocked(generatedAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated)
- if err != nil {
- panic(fmt.Sprintf("ndp: error when adding address %+v: %s", generatedAddr, err))
- }
-
- state.generationAttempts++
- state.ref = ref
- return true
+ return false
}
// regenerateSLAACAddr regenerates an address for a SLAAC prefix.
@@ -1143,24 +1292,167 @@ func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) {
ndp.invalidateSLAACPrefix(prefix, state)
}
-// refreshSLAACPrefixLifetimes refreshes the lifetimes of a SLAAC prefix.
+// generateTempSLAACAddr generates a new temporary SLAAC address.
//
-// pl is the new preferred lifetime. vl is the new valid lifetime.
+// If resetGenAttempts is true, the prefix's generation counter will be reset.
+//
+// Returns true if a new address was generated.
+func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *slaacPrefixState, resetGenAttempts bool) bool {
+ // Are we configured to auto-generate new temporary global addresses for the
+ // prefix?
+ if !ndp.configs.AutoGenTempGlobalAddresses || prefix == header.IPv6LinkLocalPrefix.Subnet() {
+ return false
+ }
+
+ if resetGenAttempts {
+ prefixState.generationAttempts = 0
+ prefixState.maxGenerationAttempts = ndp.configs.AutoGenAddressConflictRetries + 1
+ }
+
+ // If we have already reached the maximum address generation attempts for the
+ // prefix, do not generate another address.
+ if prefixState.generationAttempts == prefixState.maxGenerationAttempts {
+ return false
+ }
+
+ stableAddr := prefixState.stableAddrRef.ep.ID().LocalAddress
+ now := time.Now()
+
+ // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary
+ // address is the lower of the valid lifetime of the stable address or the
+ // maximum temporary address valid lifetime.
+ vl := ndp.configs.MaxTempAddrValidLifetime
+ if prefixState.validUntil != (time.Time{}) {
+ if prefixVL := prefixState.validUntil.Sub(now); vl > prefixVL {
+ vl = prefixVL
+ }
+ }
+
+ if vl <= 0 {
+ // Cannot create an address without a valid lifetime.
+ return false
+ }
+
+ // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary
+ // address is the lower of the preferred lifetime of the stable address or the
+ // maximum temporary address preferred lifetime - the temporary address desync
+ // factor.
+ pl := ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor
+ if prefixState.preferredUntil != (time.Time{}) {
+ if prefixPL := prefixState.preferredUntil.Sub(now); pl > prefixPL {
+ // Respect the preferred lifetime of the prefix, as per RFC 4941 section
+ // 3.3 step 4.
+ pl = prefixPL
+ }
+ }
+
+ // As per RFC 4941 section 3.3 step 5, a temporary address is created only if
+ // the calculated preferred lifetime is greater than the advance regeneration
+ // duration. In particular, we MUST NOT create a temporary address with a zero
+ // Preferred Lifetime.
+ if pl <= ndp.configs.RegenAdvanceDuration {
+ return false
+ }
+
+ generatedAddr := header.GenerateTempIPv6SLAACAddr(ndp.temporaryIIDHistory[:], stableAddr)
+
+ // As per RFC RFC 4941 section 3.3 step 5, we MUST NOT create a temporary
+ // address with a zero preferred lifetime. The checks above ensure this
+ // so we know the address is not deprecated.
+ ref := ndp.addSLAACAddr(generatedAddr, slaacTemp, false /* deprecated */)
+ if ref == nil {
+ return false
+ }
+
+ state := tempSLAACAddrState{
+ deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ prefixState, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr))
+ }
+
+ tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a tempAddr entry to deprecate temporary address %s", generatedAddr))
+ }
+
+ ndp.deprecateSLAACAddress(tempAddrState.ref)
+ }),
+ invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ prefixState, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr))
+ }
+
+ tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a tempAddr entry to invalidate temporary address %s", generatedAddr))
+ }
+
+ ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState)
+ }),
+ regenTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ prefixState, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr))
+ }
+
+ tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a tempAddr entry to regenerate temporary address after %s", generatedAddr))
+ }
+
+ // If an address has already been regenerated for this address, don't
+ // regenerate another address.
+ if tempAddrState.regenerated {
+ return
+ }
+
+ // Reset the generation attempts counter as we are starting the generation
+ // of a new address for the SLAAC prefix.
+ tempAddrState.regenerated = ndp.generateTempSLAACAddr(prefix, &prefixState, true /* resetGenAttempts */)
+ prefixState.tempAddrs[generatedAddr.Address] = tempAddrState
+ ndp.slaacPrefixes[prefix] = prefixState
+ }),
+ createdAt: now,
+ ref: ref,
+ }
+
+ state.deprecationTimer.Reset(pl)
+ state.invalidationTimer.Reset(vl)
+ state.regenTimer.Reset(pl - ndp.configs.RegenAdvanceDuration)
+
+ prefixState.generationAttempts++
+ prefixState.tempAddrs[generatedAddr.Address] = state
+
+ return true
+}
+
+// regenerateTempSLAACAddr regenerates a temporary address for a SLAAC prefix.
//
// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, pl, vl time.Duration) {
- prefixState, ok := ndp.slaacPrefixes[prefix]
+func (ndp *ndpState) regenerateTempSLAACAddr(prefix tcpip.Subnet, resetGenAttempts bool) {
+ state, ok := ndp.slaacPrefixes[prefix]
if !ok {
- panic(fmt.Sprintf("ndp: SLAAC prefix state not found to refresh lifetimes for %s", prefix))
+ panic(fmt.Sprintf("ndp: SLAAC prefix state not found to regenerate temporary address for %s", prefix))
}
- defer func() { ndp.slaacPrefixes[prefix] = prefixState }()
+ ndp.generateTempSLAACAddr(prefix, &state, resetGenAttempts)
+ ndp.slaacPrefixes[prefix] = state
+}
+
+// refreshSLAACPrefixLifetimes refreshes the lifetimes of a SLAAC prefix.
+//
+// pl is the new preferred lifetime. vl is the new valid lifetime.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixState *slaacPrefixState, pl, vl time.Duration) {
// If the preferred lifetime is zero, then the prefix should be deprecated.
deprecated := pl == 0
if deprecated {
- ndp.deprecateSLAACAddress(prefixState.ref)
+ ndp.deprecateSLAACAddress(prefixState.stableAddrRef)
} else {
- prefixState.ref.deprecated = false
+ prefixState.stableAddrRef.deprecated = false
}
// If prefix was preferred for some finite lifetime before, stop the
@@ -1190,36 +1482,118 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, pl, vl tim
//
// 3) Otherwise, reset the valid lifetime of the prefix to 2 hours.
- // Handle the infinite valid lifetime separately as we do not keep a timer in
- // this case.
if vl >= header.NDPInfiniteLifetime {
+ // Handle the infinite valid lifetime separately as we do not keep a timer
+ // in this case.
prefixState.invalidationTimer.StopLocked()
prefixState.validUntil = time.Time{}
- return
- }
+ } else {
+ var effectiveVl time.Duration
+ var rl time.Duration
+
+ // If the prefix was originally set to be valid forever, assume the
+ // remaining time to be the maximum possible value.
+ if prefixState.validUntil == (time.Time{}) {
+ rl = header.NDPInfiniteLifetime
+ } else {
+ rl = time.Until(prefixState.validUntil)
+ }
- var effectiveVl time.Duration
- var rl time.Duration
+ if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl {
+ effectiveVl = vl
+ } else if rl > MinPrefixInformationValidLifetimeForUpdate {
+ effectiveVl = MinPrefixInformationValidLifetimeForUpdate
+ }
- // If the prefix was originally set to be valid forever, assume the remaining
- // time to be the maximum possible value.
- if prefixState.validUntil == (time.Time{}) {
- rl = header.NDPInfiniteLifetime
- } else {
- rl = time.Until(prefixState.validUntil)
+ if effectiveVl != 0 {
+ prefixState.invalidationTimer.StopLocked()
+ prefixState.invalidationTimer.Reset(effectiveVl)
+ prefixState.validUntil = now.Add(effectiveVl)
+ }
}
- if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl {
- effectiveVl = vl
- } else if rl <= MinPrefixInformationValidLifetimeForUpdate {
+ // If DAD is not yet complete on the stable address, there is no need to do
+ // work with temporary addresses.
+ if prefixState.stableAddrRef.getKind() != permanent {
return
- } else {
- effectiveVl = MinPrefixInformationValidLifetimeForUpdate
}
- prefixState.invalidationTimer.StopLocked()
- prefixState.invalidationTimer.Reset(effectiveVl)
- prefixState.validUntil = now.Add(effectiveVl)
+ // Note, we do not need to update the entries in the temporary address map
+ // after updating the timers because the timers are held as pointers.
+ var regenForAddr tcpip.Address
+ allAddressesRegenerated := true
+ for tempAddr, tempAddrState := range prefixState.tempAddrs {
+ // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary
+ // address is the lower of the valid lifetime of the stable address or the
+ // maximum temporary address valid lifetime. Note, the valid lifetime of a
+ // temporary address is relative to the address's creation time.
+ validUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrValidLifetime)
+ if prefixState.validUntil != (time.Time{}) && validUntil.Sub(prefixState.validUntil) > 0 {
+ validUntil = prefixState.validUntil
+ }
+
+ // If the address is no longer valid, invalidate it immediately. Otherwise,
+ // reset the invalidation timer.
+ newValidLifetime := validUntil.Sub(now)
+ if newValidLifetime <= 0 {
+ ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, tempAddr, tempAddrState)
+ continue
+ }
+ tempAddrState.invalidationTimer.StopLocked()
+ tempAddrState.invalidationTimer.Reset(newValidLifetime)
+
+ // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary
+ // address is the lower of the preferred lifetime of the stable address or
+ // the maximum temporary address preferred lifetime - the temporary address
+ // desync factor. Note, the preferred lifetime of a temporary address is
+ // relative to the address's creation time.
+ preferredUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor)
+ if prefixState.preferredUntil != (time.Time{}) && preferredUntil.Sub(prefixState.preferredUntil) > 0 {
+ preferredUntil = prefixState.preferredUntil
+ }
+
+ // If the address is no longer preferred, deprecate it immediately.
+ // Otherwise, reset the deprecation timer.
+ newPreferredLifetime := preferredUntil.Sub(now)
+ tempAddrState.deprecationTimer.StopLocked()
+ if newPreferredLifetime <= 0 {
+ ndp.deprecateSLAACAddress(tempAddrState.ref)
+ } else {
+ tempAddrState.ref.deprecated = false
+ tempAddrState.deprecationTimer.Reset(newPreferredLifetime)
+ }
+
+ tempAddrState.regenTimer.StopLocked()
+ if tempAddrState.regenerated {
+ } else {
+ allAddressesRegenerated = false
+
+ if newPreferredLifetime <= ndp.configs.RegenAdvanceDuration {
+ // The new preferred lifetime is less than the advance regeneration
+ // duration so regenerate an address for this temporary address
+ // immediately after we finish iterating over the temporary addresses.
+ regenForAddr = tempAddr
+ } else {
+ tempAddrState.regenTimer.Reset(newPreferredLifetime - ndp.configs.RegenAdvanceDuration)
+ }
+ }
+ }
+
+ // Generate a new temporary address if all of the existing temporary addresses
+ // have been regenerated, or we need to immediately regenerate an address
+ // due to an update in preferred lifetime.
+ //
+ // If each temporay address has already been regenerated, no new temporary
+ // address will be generated. To ensure continuation of temporary SLAAC
+ // addresses, we manually try to regenerate an address here.
+ if len(regenForAddr) != 0 || allAddressesRegenerated {
+ // Reset the generation attempts counter as we are starting the generation
+ // of a new address for the SLAAC prefix.
+ if state, ok := prefixState.tempAddrs[regenForAddr]; ndp.generateTempSLAACAddr(prefix, prefixState, true /* resetGenAttempts */) && ok {
+ state.regenerated = true
+ prefixState.tempAddrs[regenForAddr] = state
+ }
+ }
}
// deprecateSLAACAddress marks ref as deprecated and notifies the stack's NDP
@@ -1243,11 +1617,11 @@ func (ndp *ndpState) deprecateSLAACAddress(ref *referencedNetworkEndpoint) {
//
// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefixState) {
- if r := state.ref; r != nil {
+ if r := state.stableAddrRef; r != nil {
// Since we are already invalidating the prefix, do not invalidate the
// prefix when removing the address.
- if err := ndp.nic.removePermanentIPv6EndpointLocked(r, false /* allowSLAACPrefixInvalidation */); err != nil {
- panic(fmt.Sprintf("ndp: removePermanentIPv6EndpointLocked(%s, false): %s", r.addrWithPrefix(), err))
+ if err := ndp.nic.removePermanentIPv6EndpointLocked(r, false /* allowSLAACInvalidation */); err != nil {
+ panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", r.addrWithPrefix(), err))
}
}
@@ -1265,14 +1639,14 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr
prefix := addr.Subnet()
state, ok := ndp.slaacPrefixes[prefix]
- if !ok || state.ref == nil || addr.Address != state.ref.ep.ID().LocalAddress {
+ if !ok || state.stableAddrRef == nil || addr.Address != state.stableAddrRef.ep.ID().LocalAddress {
return
}
if !invalidatePrefix {
// If the prefix is not being invalidated, disassociate the address from the
// prefix and do nothing further.
- state.ref = nil
+ state.stableAddrRef = nil
ndp.slaacPrefixes[prefix] = state
return
}
@@ -1286,11 +1660,68 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr
//
// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaacPrefixState) {
+ // Invalidate all temporary addresses.
+ for tempAddr, tempAddrState := range state.tempAddrs {
+ ndp.invalidateTempSLAACAddr(state.tempAddrs, tempAddr, tempAddrState)
+ }
+
+ state.stableAddrRef = nil
state.deprecationTimer.StopLocked()
state.invalidationTimer.StopLocked()
delete(ndp.slaacPrefixes, prefix)
}
+// invalidateTempSLAACAddr invalidates a temporary SLAAC address.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) {
+ // Since we are already invalidating the address, do not invalidate the
+ // address when removing the address.
+ if err := ndp.nic.removePermanentIPv6EndpointLocked(tempAddrState.ref, false /* allowSLAACInvalidation */); err != nil {
+ panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.ref.addrWithPrefix(), err))
+ }
+
+ ndp.cleanupTempSLAACAddrResources(tempAddrs, tempAddr, tempAddrState)
+}
+
+// cleanupTempSLAACAddrResourcesAndNotify cleans up an invalidated temporary
+// SLAAC address's resources from ndp.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) {
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr)
+ }
+
+ if !invalidateAddr {
+ return
+ }
+
+ prefix := addr.Subnet()
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry to clean up temp addr %s resources", addr))
+ }
+
+ tempAddrState, ok := state.tempAddrs[addr.Address]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a tempAddr entry to clean up temp addr %s resources", addr))
+ }
+
+ ndp.cleanupTempSLAACAddrResources(state.tempAddrs, addr.Address, tempAddrState)
+}
+
+// cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's
+// timers and entry.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) {
+ tempAddrState.deprecationTimer.StopLocked()
+ tempAddrState.invalidationTimer.StopLocked()
+ tempAddrState.regenTimer.StopLocked()
+ delete(tempAddrs, tempAddr)
+}
+
// cleanupState cleans up ndp's state.
//
// If hostOnly is true, then only host-specific state will be cleaned up.
@@ -1450,3 +1881,13 @@ func (ndp *ndpState) stopSolicitingRouters() {
ndp.rtrSolicitTimer.Stop()
ndp.rtrSolicitTimer = nil
}
+
+// initializeTempAddrState initializes state related to temporary SLAAC
+// addresses.
+func (ndp *ndpState) initializeTempAddrState() {
+ header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.nic.stack.tempIIDSeed, ndp.nic.ID())
+
+ if MaxDesyncFactor != 0 {
+ ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor)))
+ }
+}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 6dd460984..421df674f 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -1801,6 +1801,726 @@ func TestAutoGenAddr(t *testing.T) {
}
}
+func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []tcpip.AddressWithPrefix) string {
+ ret := ""
+ for _, c := range containList {
+ if !containsV6Addr(addrs, c) {
+ ret += fmt.Sprintf("should have %s in the list of addresses\n", c)
+ }
+ }
+ for _, c := range notContainList {
+ if containsV6Addr(addrs, c) {
+ ret += fmt.Sprintf("should not have %s in the list of addresses\n", c)
+ }
+ }
+ return ret
+}
+
+// TestAutoGenTempAddr tests that temporary SLAAC addresses are generated when
+// configured to do so as part of IPv6 Privacy Extensions.
+func TestAutoGenTempAddr(t *testing.T) {
+ const (
+ nicID = 1
+ newMinVL = 5
+ newMinVLDuration = newMinVL * time.Second
+ )
+
+ savedMinPrefixInformationValidLifetimeForUpdate := stack.MinPrefixInformationValidLifetimeForUpdate
+ savedMaxDesync := stack.MaxDesyncFactor
+ defer func() {
+ stack.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate
+ stack.MaxDesyncFactor = savedMaxDesync
+ }()
+ stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
+ stack.MaxDesyncFactor = time.Nanosecond
+
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+
+ tests := []struct {
+ name string
+ dupAddrTransmits uint8
+ retransmitTimer time.Duration
+ }{
+ {
+ name: "DAD disabled",
+ },
+ {
+ name: "DAD enabled",
+ dupAddrTransmits: 1,
+ retransmitTimer: time.Second,
+ },
+ }
+
+ // This Run will not return until the parallel tests finish.
+ //
+ // We need this because we need to do some teardown work after the
+ // parallel tests complete.
+ //
+ // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
+ // more details.
+ t.Run("group", func(t *testing.T) {
+ for i, test := range tests {
+ i := i
+ test := test
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ seed := []byte{uint8(i)}
+ var tempIIDHistory [header.IIDSize]byte
+ header.InitialTempIID(tempIIDHistory[:], seed, nicID)
+ newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix {
+ return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr)
+ }
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 2),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: test.dupAddrTransmits,
+ RetransmitTimer: test.retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ TempIIDSeed: seed,
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ }
+
+ expectDADEventAsync := func(addr tcpip.Address) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncEventTimeout):
+ t.Fatal("timed out waiting for DAD event")
+ }
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with zero valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e)
+ default:
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ expectDADEventAsync(addr1.Address)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly got an auto gen addr event = %+v", e)
+ default:
+ }
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero valid & preferred lifetimes.
+ tempAddr1 := newTempAddr(addr1.Address)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ expectAutoGenAddrEvent(tempAddr1, newAddr)
+ expectDADEventAsync(tempAddr1.Address)
+ if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Receive an RA with prefix2 in an NDP Prefix Information option (PI)
+ // with preferred lifetime > valid lifetime
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e)
+ default:
+ }
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Receive an RA with prefix2 in a PI w/ non-zero valid and preferred
+ // lifetimes.
+ tempAddr2 := newTempAddr(addr2.Address)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ expectDADEventAsync(addr2.Address)
+ expectAutoGenAddrEventAsync(tempAddr2, newAddr)
+ expectDADEventAsync(tempAddr2.Address)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Deprecate prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, deprecatedAddr)
+ expectAutoGenAddrEvent(tempAddr1, deprecatedAddr)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Refresh lifetimes for prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Reduce valid lifetime and deprecate addresses of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
+ expectAutoGenAddrEvent(addr1, deprecatedAddr)
+ expectAutoGenAddrEvent(tempAddr1, deprecatedAddr)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Wait for addrs of prefix1 to be invalidated. They should be
+ // invalidated at the same time.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ var nextAddr tcpip.AddressWithPrefix
+ if e.addr == addr1 {
+ if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ nextAddr = tempAddr1
+ } else {
+ if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ nextAddr = addr1
+ }
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ case <-time.After(newMinVLDuration + defaultTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Receive an RA with prefix2 in a PI w/ 0 lifetimes.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0))
+ expectAutoGenAddrEvent(addr2, deprecatedAddr)
+ expectAutoGenAddrEvent(tempAddr2, deprecatedAddr)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Errorf("got unexpected auto gen addr event = %+v", e)
+ default:
+ }
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+ })
+ }
+ })
+}
+
+// TestNoAutoGenTempAddrForLinkLocal test that temporary SLAAC addresses are not
+// generated for auto generated link-local addresses.
+func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
+ const nicID = 1
+
+ savedMaxDesyncFactor := stack.MaxDesyncFactor
+ defer func() {
+ stack.MaxDesyncFactor = savedMaxDesyncFactor
+ }()
+ stack.MaxDesyncFactor = time.Nanosecond
+
+ tests := []struct {
+ name string
+ dupAddrTransmits uint8
+ retransmitTimer time.Duration
+ }{
+ {
+ name: "DAD disabled",
+ },
+ {
+ name: "DAD enabled",
+ dupAddrTransmits: 1,
+ retransmitTimer: time.Second,
+ },
+ }
+
+ // This Run will not return until the parallel tests finish.
+ //
+ // We need this because we need to do some teardown work after the
+ // parallel tests complete.
+ //
+ // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
+ // more details.
+ t.Run("group", func(t *testing.T) {
+ for _, test := range tests {
+ test := test
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ AutoGenTempGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ AutoGenIPv6LinkLocal: true,
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ // The stable link-local address should auto-generate and resolve DAD.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, llAddr1, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncEventTimeout):
+ t.Fatal("timed out waiting for DAD event")
+ }
+
+ // No new addresses should be generated.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Errorf("got unxpected auto gen addr event = %+v", e)
+ case <-time.After(defaultAsyncEventTimeout):
+ }
+ })
+ }
+ })
+}
+
+// TestNoAutoGenTempAddrWithoutStableAddr tests that a temporary SLAAC address
+// will not be generated until after DAD completes, even if a new Router
+// Advertisement is received to refresh lifetimes.
+func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
+ const (
+ nicID = 1
+ dadTransmits = 1
+ retransmitTimer = 2 * time.Second
+ )
+
+ savedMaxDesyncFactor := stack.MaxDesyncFactor
+ defer func() {
+ stack.MaxDesyncFactor = savedMaxDesyncFactor
+ }()
+ stack.MaxDesyncFactor = 0
+
+ prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
+ var tempIIDHistory [header.IIDSize]byte
+ header.InitialTempIID(tempIIDHistory[:], nil, nicID)
+ tempAddr := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ // Receive an RA to trigger SLAAC for prefix.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+
+ // DAD on the stable address for prefix has not yet completed. Receiving a new
+ // RA that would refresh lifetimes should not generate a temporary SLAAC
+ // address for the prefix.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpected auto gen addr event = %+v", e)
+ default:
+ }
+
+ // Wait for DAD to complete for the stable address then expect the temporary
+ // address to be generated.
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ t.Fatal("timed out waiting for DAD event")
+ }
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, tempAddr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+}
+
+// TestAutoGenTempAddrRegen tests that temporary SLAAC addresses are
+// regenerated.
+func TestAutoGenTempAddrRegen(t *testing.T) {
+ const (
+ nicID = 1
+ regenAfter = 2 * time.Second
+ newMinVL = 10
+ newMinVLDuration = newMinVL * time.Second
+ )
+
+ savedMaxDesyncFactor := stack.MaxDesyncFactor
+ savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime
+ savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime
+ defer func() {
+ stack.MaxDesyncFactor = savedMaxDesyncFactor
+ stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime
+ stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime
+ }()
+ stack.MaxDesyncFactor = 0
+ stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration
+ stack.MinMaxTempAddrValidLifetime = newMinVLDuration
+
+ prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
+ var tempIIDHistory [header.IIDSize]byte
+ header.InitialTempIID(tempIIDHistory[:], nil, nicID)
+ tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+ tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+ tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ ndpConfigs := stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ RegenAdvanceDuration: newMinVLDuration - regenAfter,
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(timeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero valid & preferred lifetimes.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr, newAddr)
+ expectAutoGenAddrEvent(tempAddr1, newAddr)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Wait for regeneration
+ expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncEventTimeout)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Wait for regeneration
+ expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncEventTimeout)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2, tempAddr3}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Stop generating temporary addresses
+ ndpConfigs.AutoGenTempGlobalAddresses = false
+ if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
+ t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
+ }
+
+ // Wait for all the temporary addresses to get invalidated.
+ tempAddrs := []tcpip.AddressWithPrefix{tempAddr1, tempAddr2, tempAddr3}
+ invalidateAfter := newMinVLDuration - 2*regenAfter
+ for _, addr := range tempAddrs {
+ // Wait for a deprecation then invalidation event, or just an invalidation
+ // event. We need to cover both cases but cannot deterministically hit both
+ // cases because the deprecation and invalidation timers could fire in any
+ // order.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff == "" {
+ // If we get a deprecation event first, we should get an invalidation
+ // event almost immediately after.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ } else if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff == "" {
+ // If we get an invalidation event first, we shouldn't get a deprecation
+ // event after.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly got an auto-generated event = %+v", e)
+ case <-time.After(defaultTimeout):
+ }
+ } else {
+ t.Fatalf("got unexpected auto-generated event = %+v", e)
+ }
+ case <-time.After(invalidateAfter + defaultAsyncEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+
+ invalidateAfter = regenAfter
+ }
+ if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+}
+
+// TestAutoGenTempAddrRegenTimerUpdates tests that a temporary address's
+// regeneration timer gets updated when refreshing the address's lifetimes.
+func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
+ const (
+ nicID = 1
+ regenAfter = 2 * time.Second
+ newMinVL = 10
+ newMinVLDuration = newMinVL * time.Second
+ )
+
+ savedMaxDesyncFactor := stack.MaxDesyncFactor
+ savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime
+ savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime
+ defer func() {
+ stack.MaxDesyncFactor = savedMaxDesyncFactor
+ stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime
+ stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime
+ }()
+ stack.MaxDesyncFactor = 0
+ stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration
+ stack.MinMaxTempAddrValidLifetime = newMinVLDuration
+
+ prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
+ var tempIIDHistory [header.IIDSize]byte
+ header.InitialTempIID(tempIIDHistory[:], nil, nicID)
+ tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+ tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+ tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ ndpConfigs := stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ RegenAdvanceDuration: newMinVLDuration - regenAfter,
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(timeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero valid & preferred lifetimes.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr, newAddr)
+ expectAutoGenAddrEvent(tempAddr1, newAddr)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Deprecate the prefix.
+ //
+ // A new temporary address should be generated after the regeneration
+ // time has passed since the prefix is deprecated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr, deprecatedAddr)
+ expectAutoGenAddrEvent(tempAddr1, deprecatedAddr)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpected auto gen addr event = %+v", e)
+ case <-time.After(regenAfter + defaultAsyncEventTimeout):
+ }
+
+ // Prefer the prefix again.
+ //
+ // A new temporary address should immediately be generated since the
+ // regeneration time has already passed since the last address was generated
+ // - this regeneration does not depend on a timer.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ expectAutoGenAddrEvent(tempAddr2, newAddr)
+
+ // Increase the maximum lifetimes for temporary addresses to large values
+ // then refresh the lifetimes of the prefix.
+ //
+ // A new address should not be generated after the regeneration time that was
+ // expected for the previous check. This is because the preferred lifetime for
+ // the temporary addresses has increased, so it will take more time to
+ // regenerate a new temporary address. Note, new addresses are only
+ // regenerated after the preferred lifetime - the regenerate advance duration
+ // as paased.
+ ndpConfigs.MaxTempAddrValidLifetime = 100 * time.Second
+ ndpConfigs.MaxTempAddrPreferredLifetime = 100 * time.Second
+ if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
+ t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
+ }
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpected auto gen addr event = %+v", e)
+ case <-time.After(regenAfter + defaultAsyncEventTimeout):
+ }
+
+ // Set the maximum lifetimes for temporary addresses such that on the next
+ // RA, the regeneration timer gets reset.
+ //
+ // The maximum lifetime is the sum of the minimum lifetimes for temporary
+ // addresses + the time that has already passed since the last address was
+ // generated so that the regeneration timer is needed to generate the next
+ // address.
+ newLifetimes := newMinVLDuration + regenAfter + defaultAsyncEventTimeout
+ ndpConfigs.MaxTempAddrValidLifetime = newLifetimes
+ ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes
+ if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
+ t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
+ }
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncEventTimeout)
+}
+
// stackAndNdpDispatcherWithDefaultRoute returns an ndpDispatcher,
// channel.Endpoint and stack.Stack.
//
@@ -2196,7 +2916,6 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
} else {
t.Fatalf("got unexpected auto-generated event")
}
-
case <-time.After(newMinVLDuration + defaultAsyncEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
@@ -2808,9 +3527,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
}
}
-// TestAutoGenAddrWithOpaqueIIDDADRetries tests the regeneration of an
-// auto-generated IPv6 address in response to a DAD conflict.
-func TestAutoGenAddrWithOpaqueIIDDADRetries(t *testing.T) {
+func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
const nicID = 1
const nicName = "nic"
const dadTransmits = 1
@@ -2818,6 +3535,13 @@ func TestAutoGenAddrWithOpaqueIIDDADRetries(t *testing.T) {
const maxMaxRetries = 3
const lifetimeSeconds = 10
+ // Needed for the temporary address sub test.
+ savedMaxDesync := stack.MaxDesyncFactor
+ defer func() {
+ stack.MaxDesyncFactor = savedMaxDesync
+ }()
+ stack.MaxDesyncFactor = time.Nanosecond
+
var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
secretKey := secretKeyBuf[:]
n, err := rand.Read(secretKey)
@@ -2830,185 +3554,234 @@ func TestAutoGenAddrWithOpaqueIIDDADRetries(t *testing.T) {
prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1)
- for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ {
- for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ {
- addrTypes := []struct {
- name string
- ndpConfigs stack.NDPConfigurations
- autoGenLinkLocal bool
- subnet tcpip.Subnet
- triggerSLAACFn func(e *channel.Endpoint)
- }{
- {
- name: "Global address",
- ndpConfigs: stack.NDPConfigurations{
- DupAddrDetectTransmits: dadTransmits,
- RetransmitTimer: retransmitTimer,
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- AutoGenAddressConflictRetries: maxRetries,
- },
- subnet: subnet,
- triggerSLAACFn: func(e *channel.Endpoint) {
- // Receive an RA with prefix1 in a PI.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds))
+ addrForSubnet := func(subnet tcpip.Subnet, dadCounter uint8) tcpip.AddressWithPrefix {
+ addrBytes := []byte(subnet.ID())
+ return tcpip.AddressWithPrefix{
+ Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, dadCounter, secretKey)),
+ PrefixLen: 64,
+ }
+ }
- },
- },
- {
- name: "LinkLocal address",
- ndpConfigs: stack.NDPConfigurations{
- DupAddrDetectTransmits: dadTransmits,
- RetransmitTimer: retransmitTimer,
- AutoGenAddressConflictRetries: maxRetries,
- },
- autoGenLinkLocal: true,
- subnet: header.IPv6LinkLocalPrefix.Subnet(),
- triggerSLAACFn: func(e *channel.Endpoint) {},
- },
+ expectAutoGenAddrEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
- for _, addrType := range addrTypes {
- maxRetries := maxRetries
- numFailures := numFailures
- addrType := addrType
+ expectAutoGenAddrEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
- t.Run(fmt.Sprintf("%s with %d max retries and %d failures", addrType.name, maxRetries, numFailures), func(t *testing.T) {
- t.Parallel()
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ }
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 1),
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
- NDPConfigs: addrType.ndpConfigs,
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: func(_ tcpip.NICID, nicName string) string {
- return nicName
- },
- SecretKey: secretKey,
- },
- })
- opts := stack.NICOptions{Name: nicName}
- if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
- t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err)
- }
+ expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool) {
+ t.Helper()
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected DAD event")
+ }
+ }
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
+ expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool) {
+ t.Helper()
- addrType.triggerSLAACFn(e)
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ t.Fatal("timed out waiting for DAD event")
+ }
+ }
- // Simulate DAD conflicts so the address is regenerated.
- for i := uint8(0); i < numFailures; i++ {
- addrBytes := []byte(addrType.subnet.ID())
- addr := tcpip.AddressWithPrefix{
- Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], addrType.subnet, nicName, i, secretKey)),
- PrefixLen: 64,
- }
- expectAutoGenAddrEvent(addr, newAddr)
+ stableAddrForTempAddrTest := addrForSubnet(subnet, 0)
- // Should not have any addresses assigned to the NIC.
- mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err)
+ addrTypes := []struct {
+ name string
+ ndpConfigs stack.NDPConfigurations
+ autoGenLinkLocal bool
+ prepareFn func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix
+ addrGenFn func(dadCounter uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix
+ }{
+ {
+ name: "Global address",
+ ndpConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ prepareFn: func(_ *testing.T, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix {
+ // Receive an RA with prefix1 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds))
+ return nil
+
+ },
+ addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix {
+ return addrForSubnet(subnet, dadCounter)
+ },
+ },
+ {
+ name: "LinkLocal address",
+ ndpConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ },
+ autoGenLinkLocal: true,
+ prepareFn: func(*testing.T, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix {
+ return nil
+ },
+ addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix {
+ return addrForSubnet(header.IPv6LinkLocalPrefix.Subnet(), dadCounter)
+ },
+ },
+ {
+ name: "Temporary address",
+ ndpConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ },
+ prepareFn: func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix {
+ header.InitialTempIID(tempIIDHistory, nil, nicID)
+
+ // Generate a stable SLAAC address so temporary addresses will be
+ // generated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ expectAutoGenAddrEvent(t, ndpDisp, stableAddrForTempAddrTest, newAddr)
+ expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, true)
+
+ // The stable address will be assigned throughout the test.
+ return []tcpip.AddressWithPrefix{stableAddrForTempAddrTest}
+ },
+ addrGenFn: func(_ uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix {
+ return header.GenerateTempIPv6SLAACAddr(tempIIDHistory, stableAddrForTempAddrTest.Address)
+ },
+ },
+ }
+
+ for _, addrType := range addrTypes {
+ // This Run will not return until the parallel tests finish.
+ //
+ // We need this because we need to do some teardown work after the parallel
+ // tests complete and limit the number of parallel tests running at the same
+ // time to reduce flakes.
+ //
+ // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
+ // more details.
+ t.Run(addrType.name, func(t *testing.T) {
+ for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ {
+ for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ {
+ maxRetries := maxRetries
+ numFailures := numFailures
+ addrType := addrType
+
+ t.Run(fmt.Sprintf("%d max retries and %d failures", maxRetries, numFailures), func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
}
- if want := (tcpip.AddressWithPrefix{}); mainAddr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, want)
+ e := channel.New(0, 1280, linkAddr1)
+ ndpConfigs := addrType.ndpConfigs
+ ndpConfigs.AutoGenAddressConflictRetries = maxRetries
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ SecretKey: secretKey,
+ },
+ })
+ opts := stack.NICOptions{Name: nicName}
+ if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err)
}
- // Simulate a DAD conflict.
- if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
- t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
- }
- expectAutoGenAddrEvent(addr, invalidatedAddr)
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ var tempIIDHistory [header.IIDSize]byte
+ stableAddrs := addrType.prepareFn(t, &ndpDisp, e, tempIIDHistory[:])
+
+ // Simulate DAD conflicts so the address is regenerated.
+ for i := uint8(0); i < numFailures; i++ {
+ addr := addrType.addrGenFn(i, tempIIDHistory[:])
+ expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr)
+
+ // Should not have any new addresses assigned to the NIC.
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" {
+ t.Fatal(mismatch)
}
- default:
- t.Fatal("expected DAD event")
- }
- // Attempting to add the address manually should not fail if the
- // address's state was cleaned up when DAD failed.
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err)
- }
- if err := s.RemoveAddress(nicID, addr.Address); err != nil {
- t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err)
- }
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ // Simulate a DAD conflict.
+ if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
+ t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
}
- default:
- t.Fatal("expected DAD event")
- }
- }
+ expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr)
+ expectDADEvent(t, &ndpDisp, addr.Address, false)
- // Should not have any addresses assigned to the NIC.
- mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err)
- }
- if want := (tcpip.AddressWithPrefix{}); mainAddr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, want)
- }
+ // Attempting to add the address manually should not fail if the
+ // address's state was cleaned up when DAD failed.
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err)
+ }
+ if err := s.RemoveAddress(nicID, addr.Address); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err)
+ }
+ expectDADEvent(t, &ndpDisp, addr.Address, false)
+ }
- // If we had less failures than generation attempts, we should have an
- // address after DAD resolves.
- if maxRetries+1 > numFailures {
- addrBytes := []byte(addrType.subnet.ID())
- addr := tcpip.AddressWithPrefix{
- Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], addrType.subnet, nicName, numFailures, secretKey)),
- PrefixLen: 64,
+ // Should not have any new addresses assigned to the NIC.
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" {
+ t.Fatal(mismatch)
}
- expectAutoGenAddrEvent(addr, newAddr)
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ // If we had less failures than generation attempts, we should have
+ // an address after DAD resolves.
+ if maxRetries+1 > numFailures {
+ addr := addrType.addrGenFn(numFailures, tempIIDHistory[:])
+ expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr)
+ expectDADEventAsync(t, &ndpDisp, addr.Address, true)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, append(stableAddrs, addr), nil); mismatch != "" {
+ t.Fatal(mismatch)
}
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
- t.Fatal("timed out waiting for DAD event")
}
- mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err)
- }
- if mainAddr != addr {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, addr)
+ // Should not attempt address generation again.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly got an auto-generated address event = %+v", e)
+ case <-time.After(defaultAsyncEventTimeout):
}
- }
-
- // Should not attempt address regeneration again.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpectedly got an auto-generated address event = %+v", e)
- case <-time.After(defaultAsyncEventTimeout):
- }
- })
+ })
+ }
}
- }
+ })
}
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 0c2b1f36a..25188b4fb 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -131,6 +131,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState),
}
+ nic.mu.ndp.initializeTempAddrState()
// Register supported packet endpoint protocols.
for _, netProto := range header.Ethertypes {
@@ -1014,14 +1015,14 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
switch r.protocol {
case header.IPv6ProtocolNumber:
- return n.removePermanentIPv6EndpointLocked(r, true /* allowSLAAPrefixInvalidation */)
+ return n.removePermanentIPv6EndpointLocked(r, true /* allowSLAACInvalidation */)
default:
r.expireLocked()
return nil
}
}
-func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, allowSLAACPrefixInvalidation bool) *tcpip.Error {
+func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, allowSLAACInvalidation bool) *tcpip.Error {
addr := r.addrWithPrefix()
isIPv6Unicast := header.IsV6UnicastAddress(addr.Address)
@@ -1031,8 +1032,11 @@ func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, al
// If we are removing an address generated via SLAAC, cleanup
// its SLAAC resources and notify the integrator.
- if r.configType == slaac {
- n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACPrefixInvalidation)
+ switch r.configType {
+ case slaac:
+ n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation)
+ case slaacTemp:
+ n.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation)
}
}
@@ -1203,12 +1207,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
n.stack.stats.IP.PacketsReceived.Increment()
}
- netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize())
- if !ok {
+ if len(pkt.Data.First()) < netProto.MinimumPacketSize() {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
- src, dst := netProto.ParseAddresses(netHeader)
+
+ src, dst := netProto.ParseAddresses(pkt.Data.First())
if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil {
// The source address is one of our own, so we never should have gotten a
@@ -1289,8 +1293,22 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
// TODO(b/143425874) Decrease the TTL field in forwarded packets.
- if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 {
- pkt.Header = buffer.NewPrependable(linkHeaderLen)
+
+ firstData := pkt.Data.First()
+ pkt.Data.RemoveFirst()
+
+ if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 {
+ pkt.Header = buffer.NewPrependableFromView(firstData)
+ } else {
+ firstDataLen := len(firstData)
+
+ // pkt.Header should have enough capacity to hold n.linkEP's headers.
+ pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen)
+
+ // TODO(b/151227689): avoid copying the packet when forwarding
+ if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen {
+ panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen))
+ }
}
if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil {
@@ -1318,13 +1336,12 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// validly formed.
n.stack.demux.deliverRawPacket(r, protocol, pkt)
- transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize())
- if !ok {
+ if len(pkt.Data.First()) < transProto.MinimumPacketSize() {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
- srcPort, dstPort, err := transProto.ParsePorts(transHeader)
+ srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First())
if err != nil {
n.stack.stats.MalformedRcvdPackets.Increment()
return
@@ -1362,12 +1379,11 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp
// ICMPv4 only guarantees that 8 bytes of the transport protocol will
// be present in the payload. We know that the ports are within the
// first 8 bytes for all known transport protocols.
- transHeader, ok := pkt.Data.PullUp(8)
- if !ok {
+ if len(pkt.Data.First()) < 8 {
return
}
- srcPort, dstPort, err := transProto.ParsePorts(transHeader)
+ srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First())
if err != nil {
return
}
@@ -1436,12 +1452,19 @@ func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
// If the address is a SLAAC address, do not invalidate its SLAAC prefix as a
// new address will be generated for it.
- if err := n.removePermanentIPv6EndpointLocked(ref, false /* allowSLAACPrefixInvalidation */); err != nil {
+ if err := n.removePermanentIPv6EndpointLocked(ref, false /* allowSLAACInvalidation */); err != nil {
return err
}
- if ref.configType == slaac {
- n.mu.ndp.regenerateSLAACAddr(ref.addrWithPrefix().Subnet())
+ prefix := ref.addrWithPrefix().Subnet()
+
+ switch ref.configType {
+ case slaac:
+ n.mu.ndp.regenerateSLAACAddr(prefix)
+ case slaacTemp:
+ // Do not reset the generation attempts counter for the prefix as the
+ // temporary address is being regenerated in response to a DAD conflict.
+ n.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */)
}
return nil
@@ -1540,9 +1563,14 @@ const (
// multicast group).
static networkEndpointConfigType = iota
- // A slaac configured endpoint is an IPv6 endpoint that was
- // added by SLAAC as per RFC 4862 section 5.5.3.
+ // A SLAAC configured endpoint is an IPv6 endpoint that was added by
+ // SLAAC as per RFC 4862 section 5.5.3.
slaac
+
+ // A temporary SLAAC configured endpoint is an IPv6 endpoint that was added by
+ // SLAAC as per RFC 4941. Temporary SLAAC addresses are short-lived and are
+ // not expected to be valid (or preferred) forever; hence the term temporary.
+ slaacTemp
)
type referencedNetworkEndpoint struct {
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 7d36f8e84..dc125f25e 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -37,13 +37,7 @@ type PacketBuffer struct {
Data buffer.VectorisedView
// Header holds the headers of outbound packets. As a packet is passed
- // down the stack, each layer adds to Header. Note that forwarded
- // packets don't populate Headers on their way out -- their headers and
- // payload are never parsed out and remain in Data.
- //
- // TODO(gvisor.dev/issue/170): Forwarded packets don't currently
- // populate Header, but should. This will be doable once early parsing
- // (https://github.com/google/gvisor/pull/1995) is supported.
+ // down the stack, each layer adds to Header.
Header buffer.Prependable
// These fields are used by both inbound and outbound packets. They
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 41398a1b6..4a2dc3dc6 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -464,6 +464,10 @@ type Stack struct {
// (IIDs) as outlined by RFC 7217.
opaqueIIDOpts OpaqueInterfaceIdentifierOptions
+ // tempIIDSeed is used to seed the initial temporary interface identifier
+ // history value used to generate IIDs for temporary SLAAC addresses.
+ tempIIDSeed []byte
+
// forwarder holds the packets that wait for their link-address resolutions
// to complete, and forwards them when each resolution is done.
forwarder *forwardQueue
@@ -541,6 +545,21 @@ type Options struct {
//
// RandSource must be thread-safe.
RandSource mathrand.Source
+
+ // TempIIDSeed is used to seed the initial temporary interface identifier
+ // history value used to generate IIDs for temporary SLAAC addresses.
+ //
+ // Temporary SLAAC adresses are short-lived addresses which are unpredictable
+ // and random from the perspective of other nodes on the network. It is
+ // recommended that the seed be a random byte buffer of at least
+ // header.IIDSize bytes to make sure that temporary SLAAC addresses are
+ // sufficiently random. It should follow minimum randomness requirements for
+ // security as outlined by RFC 4086.
+ //
+ // Note: using a nil value, the same seed across netstack program runs, or a
+ // seed that is too small would reduce randomness and increase predictability,
+ // defeating the purpose of temporary SLAAC addresses.
+ TempIIDSeed []byte
}
// TransportEndpointInfo holds useful information about a transport endpoint
@@ -664,6 +683,7 @@ func New(opts Options) *Stack {
uniqueIDGenerator: opts.UniqueID,
ndpDisp: opts.NDPDisp,
opaqueIIDOpts: opts.OpaqueIIDOpts,
+ tempIIDSeed: opts.TempIIDSeed,
forwarder: newForwardQueue(),
randomGenerator: mathrand.New(randSrc),
}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index d45d2cc1f..c7634ceb1 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -95,18 +95,16 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffe
f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++
// Consume the network header.
- b, ok := pkt.Data.PullUp(fakeNetHeaderLen)
- if !ok {
- return
- }
+ b := pkt.Data.First()
pkt.Data.TrimFront(fakeNetHeaderLen)
// Handle control packets.
if b[2] == uint8(fakeControlProtocol) {
- nb, ok := pkt.Data.PullUp(fakeNetHeaderLen)
- if !ok {
+ nb := pkt.Data.First()
+ if len(nb) < fakeNetHeaderLen {
return
}
+
pkt.Data.TrimFront(fakeNetHeaderLen)
f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt)
return
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index a611e44ab..3084e6593 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -642,11 +642,10 @@ func TestTransportForwarding(t *testing.T) {
t.Fatal("Response packet not forwarded")
}
- hdrs := p.Pkt.Data.ToView()
- if dst := hdrs[0]; dst != 3 {
+ if dst := p.Pkt.Header.View()[0]; dst != 3 {
t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst)
}
- if src := hdrs[1]; src != 1 {
+ if src := p.Pkt.Header.View()[1]; src != 1 {
t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src)
}
}
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index b1d820372..feef8dca0 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -747,15 +747,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
// Only accept echo replies.
switch e.NetProto {
case header.IPv4ProtocolNumber:
- h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
- if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply {
+ h := header.ICMPv4(pkt.Data.First())
+ if h.Type() != header.ICMPv4EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
case header.IPv6ProtocolNumber:
- h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize)
- if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply {
+ h := header.ICMPv6(pkt.Data.First())
+ if h.Type() != header.ICMPv6EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 7712ce652..40461fd31 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -144,11 +144,7 @@ func (s *segment) logicalLen() seqnum.Size {
// TCP checksum and stores the checksum and result of checksum verification in
// the csum and csumValid fields of the segment.
func (s *segment) parse() bool {
- h, ok := s.data.PullUp(header.TCPMinimumSize)
- if !ok {
- return false
- }
- hdr := header.TCP(h)
+ h := header.TCP(s.data.First())
// h is the header followed by the payload. We check that the offset to
// the data respects the following constraints:
@@ -160,16 +156,12 @@ func (s *segment) parse() bool {
// N.B. The segment has already been validated as having at least the
// minimum TCP size before reaching here, so it's safe to read the
// fields.
- offset := int(hdr.DataOffset())
- if offset < header.TCPMinimumSize {
- return false
- }
- hdrWithOpts, ok := s.data.PullUp(offset)
- if !ok {
+ offset := int(h.DataOffset())
+ if offset < header.TCPMinimumSize || offset > len(h) {
return false
}
- s.options = []byte(hdrWithOpts[header.TCPMinimumSize:])
+ s.options = []byte(h[header.TCPMinimumSize:offset])
s.parsedOptions = header.ParseTCPOptions(s.options)
// Query the link capabilities to decide if checksum validation is
@@ -181,19 +173,18 @@ func (s *segment) parse() bool {
s.data.TrimFront(offset)
}
if verifyChecksum {
- hdr = header.TCP(hdrWithOpts)
- s.csum = hdr.Checksum()
+ s.csum = h.Checksum()
xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()))
- xsum = hdr.CalculateChecksum(xsum)
+ xsum = h.CalculateChecksum(xsum)
s.data.TrimFront(offset)
xsum = header.ChecksumVV(s.data, xsum)
s.csumValid = xsum == 0xffff
}
- s.sequenceNumber = seqnum.Value(hdr.SequenceNumber())
- s.ackNumber = seqnum.Value(hdr.AckNumber())
- s.flags = hdr.Flags()
- s.window = seqnum.Size(hdr.WindowSize())
+ s.sequenceNumber = seqnum.Value(h.SequenceNumber())
+ s.ackNumber = seqnum.Value(h.AckNumber())
+ s.flags = h.Flags()
+ s.window = seqnum.Size(h.WindowSize())
return true
}
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index 1dd63dd61..ace79b7b2 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -28,6 +28,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
+ "gvisor.dev/gvisor/pkg/test/testutil"
)
// createConnectedWithSACKPermittedOption creates and connects c.ep with the
@@ -439,21 +440,28 @@ func TestSACKRecovery(t *testing.T) {
// Receive the retransmitted packet.
c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize)
- tcpStats := c.Stack().Stats().TCP
- stats := []struct {
- stat *tcpip.StatCounter
- name string
- want uint64
- }{
- {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1},
- {tcpStats.Retransmits, "stats.TCP.Retransmits", 1},
- {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
- {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0},
- }
- for _, s := range stats {
- if got, want := s.stat.Value(), s.want; got != want {
- t.Errorf("got %s.Value() = %v, want = %v", s.name, got, want)
+ metricPollFn := func() error {
+ tcpStats := c.Stack().Stats().TCP
+ stats := []struct {
+ stat *tcpip.StatCounter
+ name string
+ want uint64
+ }{
+ {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1},
+ {tcpStats.Retransmits, "stats.TCP.Retransmits", 1},
+ {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
+ {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0},
+ }
+ for _, s := range stats {
+ if got, want := s.stat.Value(), s.want; got != want {
+ return fmt.Errorf("got %s.Value() = %v, want = %v", s.name, got, want)
+ }
}
+ return nil
+ }
+
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
}
// Now send 7 mode duplicate ACKs. In SACK TCP dupAcks do not cause
@@ -517,22 +525,28 @@ func TestSACKRecovery(t *testing.T) {
bytesRead += maxPayload
}
- // In SACK recovery only the first segment is fast retransmitted when
- // entering recovery.
- if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
- }
+ metricPollFn = func() error {
+ // In SACK recovery only the first segment is fast retransmitted when
+ // entering recovery.
+ if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
+ }
- if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want {
- t.Errorf("got EP stats SendErrors.FastRetransmit = %v, want = %v", got, want)
- }
+ if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want {
+ return fmt.Errorf("got EP stats SendErrors.FastRetransmit = %v, want = %v", got, want)
+ }
- if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want {
- t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
- }
+ if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want {
+ return fmt.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
+ }
- if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want {
- t.Errorf("got EP stats Stats.SendErrors.Retransmits = %v, want = %v", got, want)
+ if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want {
+ return fmt.Errorf("got EP stats Stats.SendErrors.Retransmits = %v, want = %v", got, want)
+ }
+ return nil
+ }
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
}
c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond)
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 286c66cf5..7e574859b 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -35,6 +35,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
+ "gvisor.dev/gvisor/pkg/test/testutil"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -209,8 +210,15 @@ func TestTCPResetsSentIncrement(t *testing.T) {
c.SendPacket(nil, ackHeaders)
c.GetPacket()
- if got := stats.TCP.ResetsSent.Value(); got != want {
- t.Errorf("got stats.TCP.ResetsSent.Value() = %v, want = %v", got, want)
+
+ metricPollFn := func() error {
+ if got := stats.TCP.ResetsSent.Value(); got != want {
+ return fmt.Errorf("got stats.TCP.ResetsSent.Value() = %v, want = %v", got, want)
+ }
+ return nil
+ }
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
}
}
@@ -3548,7 +3556,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
- tcpbuf := vv.ToView()[header.IPv4MinimumSize:]
+ tcpbuf := vv.First()[header.IPv4MinimumSize:]
tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4
c.SendSegment(vv)
@@ -3575,7 +3583,7 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) {
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
- tcpbuf := vv.ToView()[header.IPv4MinimumSize:]
+ tcpbuf := vv.First()[header.IPv4MinimumSize:]
// Overwrite a byte in the payload which should cause checksum
// verification to fail.
tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 756ab913a..edb54f0be 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -1250,8 +1250,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
// Get the header then trim it from the view.
- hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize)
- if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() {
+ hdr := header.UDP(pkt.Data.First())
+ if int(hdr.Length()) > pkt.Data.Size() {
// Malformed packet.
e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
@@ -1286,7 +1286,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
senderAddress: tcpip.FullAddress{
NIC: r.NICID(),
Addr: id.RemoteAddress,
- Port: header.UDP(hdr).SourcePort(),
+ Port: hdr.SourcePort(),
},
}
packet.data = pkt.Data
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 52af6de22..6e31a9bac 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -68,13 +68,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// that don't match any existing endpoint.
func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
// Get the header then trim it from the view.
- h, ok := pkt.Data.PullUp(header.UDPMinimumSize)
- if !ok {
- // Malformed packet.
- r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
- return true
- }
- if int(header.UDP(h).Length()) > pkt.Data.Size() {
+ hdr := header.UDP(pkt.Data.First())
+ if int(hdr.Length()) > pkt.Data.Size() {
// Malformed packet.
r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
return true