summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/fcntl.go41
-rw-r--r--pkg/abi/linux/file.go10
-rw-r--r--pkg/sentry/arch/arch_amd64.go4
-rw-r--r--pkg/sentry/control/pprof.go15
-rw-r--r--pkg/sentry/control/proc.go48
-rw-r--r--pkg/sentry/control/proc_test.go10
-rw-r--r--pkg/sentry/fs/gofer/session.go6
-rw-r--r--pkg/sentry/fs/proc/task.go34
-rw-r--r--pkg/sentry/fs/tty/terminal.go4
-rw-r--r--pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go6
-rw-r--r--pkg/sentry/fsimpl/ext/ext_test.go29
-rw-r--r--pkg/sentry/fsimpl/memfs/benchmark_test.go2
-rw-r--r--pkg/sentry/fsimpl/memfs/pipe_test.go6
-rw-r--r--pkg/sentry/fsimpl/proc/filesystems.go2
-rw-r--r--pkg/sentry/fsimpl/proc/mounts.go2
-rw-r--r--pkg/sentry/kernel/kernel.go20
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore.go6
-rw-r--r--pkg/sentry/kernel/syscalls.go8
-rw-r--r--pkg/sentry/kernel/task.go20
-rw-r--r--pkg/sentry/kernel/task_block.go8
-rw-r--r--pkg/sentry/kernel/task_clone.go1
-rw-r--r--pkg/sentry/kernel/task_exec.go3
-rw-r--r--pkg/sentry/kernel/task_exit.go1
-rw-r--r--pkg/sentry/kernel/task_log.go86
-rw-r--r--pkg/sentry/kernel/task_run.go14
-rw-r--r--pkg/sentry/kernel/task_start.go8
-rw-r--r--pkg/sentry/kernel/task_syscall.go8
-rw-r--r--pkg/sentry/kernel/tty.go11
-rw-r--r--pkg/sentry/loader/elf.go2
-rw-r--r--pkg/sentry/sighandling/sighandling.go75
-rw-r--r--pkg/sentry/sighandling/sighandling_unsafe.go26
-rw-r--r--pkg/sentry/socket/control/control.go63
-rw-r--r--pkg/sentry/socket/hostinet/socket.go26
-rw-r--r--pkg/sentry/socket/netstack/netstack.go2
-rw-r--r--pkg/sentry/socket/rpcinet/syscall_rpc.proto1
-rw-r--r--pkg/sentry/socket/socket.go5
-rw-r--r--pkg/sentry/strace/BUILD1
-rw-r--r--pkg/sentry/strace/linux64.go32
-rw-r--r--pkg/sentry/strace/select.go53
-rw-r--r--pkg/sentry/strace/socket.go9
-rw-r--r--pkg/sentry/strace/strace.go2
-rw-r--r--pkg/sentry/strace/syscalls.go4
-rw-r--r--pkg/sentry/syscalls/linux/linux64_amd64.go24
-rw-r--r--pkg/sentry/syscalls/linux/linux64_arm64.go21
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go70
-rw-r--r--pkg/sentry/syscalls/linux/sys_poll.go71
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go26
-rw-r--r--pkg/sentry/vfs/BUILD1
-rw-r--r--pkg/sentry/vfs/file_description.go93
-rw-r--r--pkg/sentry/vfs/file_description_impl_util_test.go10
-rw-r--r--pkg/sentry/vfs/filesystem.go22
-rw-r--r--pkg/sentry/vfs/mount.go69
-rw-r--r--pkg/sentry/vfs/options.go12
-rw-r--r--pkg/sentry/vfs/permissions.go62
-rw-r--r--pkg/sentry/vfs/syscalls.go237
-rw-r--r--pkg/sentry/vfs/vfs.go378
-rw-r--r--pkg/sentry/watchdog/watchdog.go4
-rw-r--r--pkg/syncutil/BUILD2
-rw-r--r--pkg/syncutil/downgradable_rwmutex_1_12_unsafe.go21
-rw-r--r--pkg/syncutil/downgradable_rwmutex_1_13_unsafe.go16
-rw-r--r--pkg/syncutil/downgradable_rwmutex_unsafe.go5
-rw-r--r--pkg/tcpip/header/BUILD5
-rw-r--r--pkg/tcpip/header/ipv6.go49
-rw-r--r--pkg/tcpip/header/ndp_options.go123
-rw-r--r--pkg/tcpip/header/ndp_test.go215
-rw-r--r--pkg/tcpip/ports/BUILD2
-rw-r--r--pkg/tcpip/ports/ports.go148
-rw-r--r--pkg/tcpip/ports/ports_test.go182
-rw-r--r--pkg/tcpip/stack/ndp.go529
-rw-r--r--pkg/tcpip/stack/ndp_test.go794
-rw-r--r--pkg/tcpip/stack/nic.go74
-rw-r--r--pkg/tcpip/tcpip.go8
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/accept.go4
-rw-r--r--pkg/tcpip/transport/tcp/connect.go89
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go25
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go2
-rw-r--r--pkg/tcpip/transport/tcp/snd.go1
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go227
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go3
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go26
82 files changed, 3488 insertions, 878 deletions
diff --git a/pkg/abi/linux/fcntl.go b/pkg/abi/linux/fcntl.go
index f78315ebf..6663a199c 100644
--- a/pkg/abi/linux/fcntl.go
+++ b/pkg/abi/linux/fcntl.go
@@ -16,15 +16,17 @@ package linux
// Commands from linux/fcntl.h.
const (
- F_DUPFD = 0x0
- F_GETFD = 0x1
- F_SETFD = 0x2
- F_GETFL = 0x3
- F_SETFL = 0x4
- F_SETLK = 0x6
- F_SETLKW = 0x7
- F_SETOWN = 0x8
- F_GETOWN = 0x9
+ F_DUPFD = 0
+ F_GETFD = 1
+ F_SETFD = 2
+ F_GETFL = 3
+ F_SETFL = 4
+ F_SETLK = 6
+ F_SETLKW = 7
+ F_SETOWN = 8
+ F_GETOWN = 9
+ F_SETOWN_EX = 15
+ F_GETOWN_EX = 16
F_DUPFD_CLOEXEC = 1024 + 6
F_SETPIPE_SZ = 1024 + 7
F_GETPIPE_SZ = 1024 + 8
@@ -32,9 +34,9 @@ const (
// Commands for F_SETLK.
const (
- F_RDLCK = 0x0
- F_WRLCK = 0x1
- F_UNLCK = 0x2
+ F_RDLCK = 0
+ F_WRLCK = 1
+ F_UNLCK = 2
)
// Flags for fcntl.
@@ -42,7 +44,7 @@ const (
FD_CLOEXEC = 00000001
)
-// Lock structure for F_SETLK.
+// Flock is the lock structure for F_SETLK.
type Flock struct {
Type int16
Whence int16
@@ -52,3 +54,16 @@ type Flock struct {
Pid int32
_ [4]byte
}
+
+// Flags for F_SETOWN_EX and F_GETOWN_EX.
+const (
+ F_OWNER_TID = 0
+ F_OWNER_PID = 1
+ F_OWNER_PGRP = 2
+)
+
+// FOwnerEx is the owner structure for F_SETOWN_EX and F_GETOWN_EX.
+type FOwnerEx struct {
+ Type int32
+ PID int32
+}
diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go
index c9ee098f4..0f014d27f 100644
--- a/pkg/abi/linux/file.go
+++ b/pkg/abi/linux/file.go
@@ -144,9 +144,13 @@ const (
ModeCharacterDevice = S_IFCHR
ModeNamedPipe = S_IFIFO
- ModeSetUID = 04000
- ModeSetGID = 02000
- ModeSticky = 01000
+ S_ISUID = 04000
+ S_ISGID = 02000
+ S_ISVTX = 01000
+
+ ModeSetUID = S_ISUID
+ ModeSetGID = S_ISGID
+ ModeSticky = S_ISVTX
ModeUserAll = 0700
ModeUserRead = 0400
diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go
index 9e7db8b30..67daa6c24 100644
--- a/pkg/sentry/arch/arch_amd64.go
+++ b/pkg/sentry/arch/arch_amd64.go
@@ -305,7 +305,7 @@ func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) {
buf := binary.Marshal(nil, usermem.ByteOrder, c.ptraceGetRegs())
return c.Native(uintptr(usermem.ByteOrder.Uint64(buf[addr:]))), nil
}
- // TODO(b/34088053): debug registers
+ // Note: x86 debug registers are missing.
return c.Native(0), nil
}
@@ -320,6 +320,6 @@ func (c *context64) PtracePokeUser(addr, data uintptr) error {
_, err := c.PtraceSetRegs(bytes.NewBuffer(buf))
return err
}
- // TODO(b/34088053): debug registers
+ // Note: x86 debug registers are missing.
return nil
}
diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go
index 1f78d54a2..e1f2fea60 100644
--- a/pkg/sentry/control/pprof.go
+++ b/pkg/sentry/control/pprof.go
@@ -22,6 +22,7 @@ import (
"sync"
"gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/urpc"
)
@@ -56,6 +57,9 @@ type Profile struct {
// traceFile is the current execution trace output file.
traceFile *fd.FD
+
+ // Kernel is the kernel under profile.
+ Kernel *kernel.Kernel
}
// StartCPUProfile is an RPC stub which starts recording the CPU profile in a
@@ -147,6 +151,9 @@ func (p *Profile) StartTrace(o *ProfileOpts, _ *struct{}) error {
return err
}
+ // Ensure all trace contexts are registered.
+ p.Kernel.RebuildTraceContexts()
+
p.traceFile = output
return nil
}
@@ -158,9 +165,15 @@ func (p *Profile) StopTrace(_, _ *struct{}) error {
defer p.mu.Unlock()
if p.traceFile == nil {
- return errors.New("Execution tracing not start")
+ return errors.New("Execution tracing not started")
}
+ // Similarly to the case above, if tasks have not ended traces, we will
+ // lose information. Thus we need to rebuild the tasks in order to have
+ // complete information. This will not lose information if multiple
+ // traces are overlapping.
+ p.Kernel.RebuildTraceContexts()
+
trace.Stop()
p.traceFile.Close()
p.traceFile = nil
diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go
index c35faeb4c..ced51c66c 100644
--- a/pkg/sentry/control/proc.go
+++ b/pkg/sentry/control/proc.go
@@ -268,14 +268,17 @@ func (proc *Proc) Ps(args *PsArgs, out *string) error {
}
// Process contains information about a single process in a Sandbox.
-// TODO(b/117881927): Implement TTY field.
type Process struct {
UID auth.KUID `json:"uid"`
PID kernel.ThreadID `json:"pid"`
// Parent PID
- PPID kernel.ThreadID `json:"ppid"`
+ PPID kernel.ThreadID `json:"ppid"`
+ Threads []kernel.ThreadID `json:"threads"`
// Processor utilization
C int32 `json:"c"`
+ // TTY name of the process. Will be of the form "pts/N" if there is a
+ // TTY, or "?" if there is not.
+ TTY string `json:"tty"`
// Start time
STime string `json:"stime"`
// CPU time
@@ -285,18 +288,19 @@ type Process struct {
}
// ProcessListToTable prints a table with the following format:
-// UID PID PPID C STIME TIME CMD
-// 0 1 0 0 14:04 505262ns tail
+// UID PID PPID C TTY STIME TIME CMD
+// 0 1 0 0 pty/4 14:04 505262ns tail
func ProcessListToTable(pl []*Process) string {
var buf bytes.Buffer
tw := tabwriter.NewWriter(&buf, 10, 1, 3, ' ', 0)
- fmt.Fprint(tw, "UID\tPID\tPPID\tC\tSTIME\tTIME\tCMD")
+ fmt.Fprint(tw, "UID\tPID\tPPID\tC\tTTY\tSTIME\tTIME\tCMD")
for _, d := range pl {
- fmt.Fprintf(tw, "\n%d\t%d\t%d\t%d\t%s\t%s\t%s",
+ fmt.Fprintf(tw, "\n%d\t%d\t%d\t%d\t%s\t%s\t%s\t%s",
d.UID,
d.PID,
d.PPID,
d.C,
+ d.TTY,
d.STime,
d.Time,
d.Cmd)
@@ -307,7 +311,7 @@ func ProcessListToTable(pl []*Process) string {
// ProcessListToJSON will return the JSON representation of ps.
func ProcessListToJSON(pl []*Process) (string, error) {
- b, err := json.Marshal(pl)
+ b, err := json.MarshalIndent(pl, "", " ")
if err != nil {
return "", fmt.Errorf("couldn't marshal process list %v: %v", pl, err)
}
@@ -334,7 +338,9 @@ func Processes(k *kernel.Kernel, containerID string, out *[]*Process) error {
ts := k.TaskSet()
now := k.RealtimeClock().Now()
for _, tg := range ts.Root.ThreadGroups() {
- pid := tg.PIDNamespace().IDOfThreadGroup(tg)
+ pidns := tg.PIDNamespace()
+ pid := pidns.IDOfThreadGroup(tg)
+
// If tg has already been reaped ignore it.
if pid == 0 {
continue
@@ -345,16 +351,19 @@ func Processes(k *kernel.Kernel, containerID string, out *[]*Process) error {
ppid := kernel.ThreadID(0)
if p := tg.Leader().Parent(); p != nil {
- ppid = p.PIDNamespace().IDOfThreadGroup(p.ThreadGroup())
+ ppid = pidns.IDOfThreadGroup(p.ThreadGroup())
}
+ threads := tg.MemberIDs(pidns)
*out = append(*out, &Process{
- UID: tg.Leader().Credentials().EffectiveKUID,
- PID: pid,
- PPID: ppid,
- STime: formatStartTime(now, tg.Leader().StartTime()),
- C: percentCPU(tg.CPUStats(), tg.Leader().StartTime(), now),
- Time: tg.CPUStats().SysTime.String(),
- Cmd: tg.Leader().Name(),
+ UID: tg.Leader().Credentials().EffectiveKUID,
+ PID: pid,
+ PPID: ppid,
+ Threads: threads,
+ STime: formatStartTime(now, tg.Leader().StartTime()),
+ C: percentCPU(tg.CPUStats(), tg.Leader().StartTime(), now),
+ Time: tg.CPUStats().SysTime.String(),
+ Cmd: tg.Leader().Name(),
+ TTY: ttyName(tg.TTY()),
})
}
sort.Slice(*out, func(i, j int) bool { return (*out)[i].PID < (*out)[j].PID })
@@ -395,3 +404,10 @@ func percentCPU(stats usage.CPUStats, startTime, now ktime.Time) int32 {
}
return int32(percentCPU)
}
+
+func ttyName(tty *kernel.TTY) string {
+ if tty == nil {
+ return "?"
+ }
+ return fmt.Sprintf("pts/%d", tty.Index)
+}
diff --git a/pkg/sentry/control/proc_test.go b/pkg/sentry/control/proc_test.go
index d8ada2694..0a88459b2 100644
--- a/pkg/sentry/control/proc_test.go
+++ b/pkg/sentry/control/proc_test.go
@@ -34,7 +34,7 @@ func TestProcessListTable(t *testing.T) {
}{
{
pl: []*Process{},
- expected: "UID PID PPID C STIME TIME CMD",
+ expected: "UID PID PPID C TTY STIME TIME CMD",
},
{
pl: []*Process{
@@ -43,6 +43,7 @@ func TestProcessListTable(t *testing.T) {
PID: 0,
PPID: 0,
C: 0,
+ TTY: "?",
STime: "0",
Time: "0",
Cmd: "zero",
@@ -52,14 +53,15 @@ func TestProcessListTable(t *testing.T) {
PID: 1,
PPID: 1,
C: 1,
+ TTY: "pts/4",
STime: "1",
Time: "1",
Cmd: "one",
},
},
- expected: `UID PID PPID C STIME TIME CMD
-0 0 0 0 0 0 zero
-1 1 1 1 1 1 one`,
+ expected: `UID PID PPID C TTY STIME TIME CMD
+0 0 0 0 ? 0 0 zero
+1 1 1 1 pts/4 1 1 one`,
},
}
diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go
index 0da608548..4e358a46a 100644
--- a/pkg/sentry/fs/gofer/session.go
+++ b/pkg/sentry/fs/gofer/session.go
@@ -143,9 +143,9 @@ type session struct {
// socket files. This allows unix domain sockets to be used with paths that
// belong to a gofer.
//
- // TODO(b/77154739): there are few possible races with someone stat'ing the
- // file and another deleting it concurrently, where the file will not be
- // reported as socket file.
+ // TODO(gvisor.dev/issue/1200): there are few possible races with someone
+ // stat'ing the file and another deleting it concurrently, where the file
+ // will not be reported as socket file.
endpoints *endpointMaps `state:"wait"`
}
diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go
index 87184ec67..0e46c5fb7 100644
--- a/pkg/sentry/fs/proc/task.go
+++ b/pkg/sentry/fs/proc/task.go
@@ -67,29 +67,28 @@ type taskDir struct {
var _ fs.InodeOperations = (*taskDir)(nil)
// newTaskDir creates a new proc task entry.
-func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, showSubtasks bool) *fs.Inode {
+func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode {
contents := map[string]*fs.Inode{
- "auxv": newAuxvec(t, msrc),
- "cmdline": newExecArgInode(t, msrc, cmdlineExecArg),
- "comm": newComm(t, msrc),
- "environ": newExecArgInode(t, msrc, environExecArg),
- "exe": newExe(t, msrc),
- "fd": newFdDir(t, msrc),
- "fdinfo": newFdInfoDir(t, msrc),
- "gid_map": newGIDMap(t, msrc),
- // FIXME(b/123511468): create the correct io file for threads.
- "io": newIO(t, msrc),
+ "auxv": newAuxvec(t, msrc),
+ "cmdline": newExecArgInode(t, msrc, cmdlineExecArg),
+ "comm": newComm(t, msrc),
+ "environ": newExecArgInode(t, msrc, environExecArg),
+ "exe": newExe(t, msrc),
+ "fd": newFdDir(t, msrc),
+ "fdinfo": newFdInfoDir(t, msrc),
+ "gid_map": newGIDMap(t, msrc),
+ "io": newIO(t, msrc, isThreadGroup),
"maps": newMaps(t, msrc),
"mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc),
"mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc),
"ns": newNamespaceDir(t, msrc),
"smaps": newSmaps(t, msrc),
- "stat": newTaskStat(t, msrc, showSubtasks, p.pidns),
+ "stat": newTaskStat(t, msrc, isThreadGroup, p.pidns),
"statm": newStatm(t, msrc),
"status": newStatus(t, msrc, p.pidns),
"uid_map": newUIDMap(t, msrc),
}
- if showSubtasks {
+ if isThreadGroup {
contents["task"] = p.newSubtasks(t, msrc)
}
if len(p.cgroupControllers) > 0 {
@@ -619,8 +618,11 @@ type ioData struct {
ioUsage
}
-func newIO(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
- return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t.ThreadGroup()}), msrc, fs.SpecialFile, t)
+func newIO(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode {
+ if isThreadGroup {
+ return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t.ThreadGroup()}), msrc, fs.SpecialFile, t)
+ }
+ return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t}), msrc, fs.SpecialFile, t)
}
// NeedsUpdate returns whether the generation is old or not.
@@ -639,7 +641,7 @@ func (i *ioData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
io.Accumulate(i.IOUsage())
var buf bytes.Buffer
- fmt.Fprintf(&buf, "char: %d\n", io.CharsRead)
+ fmt.Fprintf(&buf, "rchar: %d\n", io.CharsRead)
fmt.Fprintf(&buf, "wchar: %d\n", io.CharsWritten)
fmt.Fprintf(&buf, "syscr: %d\n", io.ReadSyscalls)
fmt.Fprintf(&buf, "syscw: %d\n", io.WriteSyscalls)
diff --git a/pkg/sentry/fs/tty/terminal.go b/pkg/sentry/fs/tty/terminal.go
index ff8138820..917f90cc0 100644
--- a/pkg/sentry/fs/tty/terminal.go
+++ b/pkg/sentry/fs/tty/terminal.go
@@ -53,8 +53,8 @@ func newTerminal(ctx context.Context, d *dirInodeOperations, n uint32) *Terminal
d: d,
n: n,
ld: newLineDiscipline(termios),
- masterKTTY: &kernel.TTY{},
- slaveKTTY: &kernel.TTY{},
+ masterKTTY: &kernel.TTY{Index: n},
+ slaveKTTY: &kernel.TTY{Index: n},
}
t.EnableLeakCheck("tty.Terminal")
return &t
diff --git a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
index 94cd74095..177ce2cb9 100644
--- a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
+++ b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
@@ -81,7 +81,11 @@ func mount(b *testing.B, imagePath string, vfsfs *vfs.VirtualFilesystem, pop *vf
ctx := contexttest.Context(b)
creds := auth.CredentialsFromContext(ctx)
- if err := vfsfs.NewMount(ctx, creds, imagePath, pop, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())}); err != nil {
+ if err := vfsfs.MountAt(ctx, creds, imagePath, pop, "extfs", &vfs.MountOptions{
+ GetFilesystemOptions: vfs.GetFilesystemOptions{
+ InternalData: int(f.Fd()),
+ },
+ }); err != nil {
b.Fatalf("failed to mount tmpfs submount: %v", err)
}
return func() {
diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go
index 307e4d68c..e9f756732 100644
--- a/pkg/sentry/fsimpl/ext/ext_test.go
+++ b/pkg/sentry/fsimpl/ext/ext_test.go
@@ -147,55 +147,54 @@ func TestSeek(t *testing.T) {
t.Fatalf("vfsfs.OpenAt failed: %v", err)
}
- if n, err := fd.Impl().Seek(ctx, 0, linux.SEEK_SET); n != 0 || err != nil {
+ if n, err := fd.Seek(ctx, 0, linux.SEEK_SET); n != 0 || err != nil {
t.Errorf("expected seek position 0, got %d and error %v", n, err)
}
- stat, err := fd.Impl().Stat(ctx, vfs.StatOptions{})
+ stat, err := fd.Stat(ctx, vfs.StatOptions{})
if err != nil {
t.Errorf("fd.stat failed for file %s in image %s: %v", test.path, test.image, err)
}
// We should be able to seek beyond the end of file.
size := int64(stat.Size)
- if n, err := fd.Impl().Seek(ctx, size, linux.SEEK_SET); n != size || err != nil {
+ if n, err := fd.Seek(ctx, size, linux.SEEK_SET); n != size || err != nil {
t.Errorf("expected seek position %d, got %d and error %v", size, n, err)
}
// EINVAL should be returned if the resulting offset is negative.
- if _, err := fd.Impl().Seek(ctx, -1, linux.SEEK_SET); err != syserror.EINVAL {
+ if _, err := fd.Seek(ctx, -1, linux.SEEK_SET); err != syserror.EINVAL {
t.Errorf("expected error EINVAL but got %v", err)
}
- if n, err := fd.Impl().Seek(ctx, 3, linux.SEEK_CUR); n != size+3 || err != nil {
+ if n, err := fd.Seek(ctx, 3, linux.SEEK_CUR); n != size+3 || err != nil {
t.Errorf("expected seek position %d, got %d and error %v", size+3, n, err)
}
// Make sure negative offsets work with SEEK_CUR.
- if n, err := fd.Impl().Seek(ctx, -2, linux.SEEK_CUR); n != size+1 || err != nil {
+ if n, err := fd.Seek(ctx, -2, linux.SEEK_CUR); n != size+1 || err != nil {
t.Errorf("expected seek position %d, got %d and error %v", size+1, n, err)
}
// EINVAL should be returned if the resulting offset is negative.
- if _, err := fd.Impl().Seek(ctx, -(size + 2), linux.SEEK_CUR); err != syserror.EINVAL {
+ if _, err := fd.Seek(ctx, -(size + 2), linux.SEEK_CUR); err != syserror.EINVAL {
t.Errorf("expected error EINVAL but got %v", err)
}
// Make sure SEEK_END works with regular files.
- switch fd.Impl().(type) {
- case *regularFileFD:
+ if _, ok := fd.Impl().(*regularFileFD); ok {
// Seek back to 0.
- if n, err := fd.Impl().Seek(ctx, -size, linux.SEEK_END); n != 0 || err != nil {
+ if n, err := fd.Seek(ctx, -size, linux.SEEK_END); n != 0 || err != nil {
t.Errorf("expected seek position %d, got %d and error %v", 0, n, err)
}
// Seek forward beyond EOF.
- if n, err := fd.Impl().Seek(ctx, 1, linux.SEEK_END); n != size+1 || err != nil {
+ if n, err := fd.Seek(ctx, 1, linux.SEEK_END); n != size+1 || err != nil {
t.Errorf("expected seek position %d, got %d and error %v", size+1, n, err)
}
// EINVAL should be returned if the resulting offset is negative.
- if _, err := fd.Impl().Seek(ctx, -(size + 1), linux.SEEK_END); err != syserror.EINVAL {
+ if _, err := fd.Seek(ctx, -(size + 1), linux.SEEK_END); err != syserror.EINVAL {
t.Errorf("expected error EINVAL but got %v", err)
}
}
@@ -456,7 +455,7 @@ func TestRead(t *testing.T) {
want := make([]byte, 1)
for {
n, err := f.Read(want)
- fd.Impl().Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{})
+ fd.Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{})
if diff := cmp.Diff(got, want); diff != "" {
t.Errorf("file data mismatch (-want +got):\n%s", diff)
@@ -464,7 +463,7 @@ func TestRead(t *testing.T) {
// Make sure there is no more file data left after getting EOF.
if n == 0 || err == io.EOF {
- if n, _ := fd.Impl().Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}); n != 0 {
+ if n, _ := fd.Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}); n != 0 {
t.Errorf("extra unexpected file data in file %s in image %s", test.absPath, test.image)
}
@@ -574,7 +573,7 @@ func TestIterDirents(t *testing.T) {
}
cb := &iterDirentsCb{}
- if err = fd.Impl().IterDirents(ctx, cb); err != nil {
+ if err = fd.IterDirents(ctx, cb); err != nil {
t.Fatalf("dir fd.IterDirents() failed: %v", err)
}
diff --git a/pkg/sentry/fsimpl/memfs/benchmark_test.go b/pkg/sentry/fsimpl/memfs/benchmark_test.go
index ea6417ce7..4a7a94a52 100644
--- a/pkg/sentry/fsimpl/memfs/benchmark_test.go
+++ b/pkg/sentry/fsimpl/memfs/benchmark_test.go
@@ -394,7 +394,7 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) {
}
defer mountPoint.DecRef()
// Create and mount the submount.
- if err := vfsObj.NewMount(ctx, creds, "", &pop, "memfs", &vfs.GetFilesystemOptions{}); err != nil {
+ if err := vfsObj.MountAt(ctx, creds, "", &pop, "memfs", &vfs.MountOptions{}); err != nil {
b.Fatalf("failed to mount tmpfs submount: %v", err)
}
filePathBuilder.WriteString(mountPointName)
diff --git a/pkg/sentry/fsimpl/memfs/pipe_test.go b/pkg/sentry/fsimpl/memfs/pipe_test.go
index a3a870571..5bf527c80 100644
--- a/pkg/sentry/fsimpl/memfs/pipe_test.go
+++ b/pkg/sentry/fsimpl/memfs/pipe_test.go
@@ -194,7 +194,7 @@ func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesy
func checkEmpty(ctx context.Context, t *testing.T, fd *vfs.FileDescription) {
readData := make([]byte, 1)
dst := usermem.BytesIOSequence(readData)
- bytesRead, err := fd.Impl().Read(ctx, dst, vfs.ReadOptions{})
+ bytesRead, err := fd.Read(ctx, dst, vfs.ReadOptions{})
if err != syserror.ErrWouldBlock {
t.Fatalf("expected ErrWouldBlock reading from empty pipe %q, but got: %v", fileName, err)
}
@@ -207,7 +207,7 @@ func checkEmpty(ctx context.Context, t *testing.T, fd *vfs.FileDescription) {
func checkWrite(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg string) {
writeData := []byte(msg)
src := usermem.BytesIOSequence(writeData)
- bytesWritten, err := fd.Impl().Write(ctx, src, vfs.WriteOptions{})
+ bytesWritten, err := fd.Write(ctx, src, vfs.WriteOptions{})
if err != nil {
t.Fatalf("error writing to pipe %q: %v", fileName, err)
}
@@ -220,7 +220,7 @@ func checkWrite(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg
func checkRead(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg string) {
readData := make([]byte, len(msg))
dst := usermem.BytesIOSequence(readData)
- bytesRead, err := fd.Impl().Read(ctx, dst, vfs.ReadOptions{})
+ bytesRead, err := fd.Read(ctx, dst, vfs.ReadOptions{})
if err != nil {
t.Fatalf("error reading from pipe %q: %v", fileName, err)
}
diff --git a/pkg/sentry/fsimpl/proc/filesystems.go b/pkg/sentry/fsimpl/proc/filesystems.go
index c36c4aff5..0e016bca5 100644
--- a/pkg/sentry/fsimpl/proc/filesystems.go
+++ b/pkg/sentry/fsimpl/proc/filesystems.go
@@ -19,7 +19,7 @@ package proc
// +stateify savable
type filesystemsData struct{}
-// TODO(b/138862512): Implement vfs.DynamicBytesSource.Generate for
+// TODO(gvisor.dev/issue/1195): Implement vfs.DynamicBytesSource.Generate for
// filesystemsData. We would need to retrive filesystem names from
// vfs.VirtualFilesystem. Also needs vfs replacement for
// fs.Filesystem.AllowUserList() and fs.FilesystemRequiresDev.
diff --git a/pkg/sentry/fsimpl/proc/mounts.go b/pkg/sentry/fsimpl/proc/mounts.go
index e81b1e910..8683cf677 100644
--- a/pkg/sentry/fsimpl/proc/mounts.go
+++ b/pkg/sentry/fsimpl/proc/mounts.go
@@ -16,7 +16,7 @@ package proc
import "gvisor.dev/gvisor/pkg/sentry/kernel"
-// TODO(b/138862512): Implement mountInfoFile and mountsFile.
+// TODO(gvisor.dev/issue/1195): Implement mountInfoFile and mountsFile.
// mountInfoFile implements vfs.DynamicBytesSource for /proc/[pid]/mountinfo.
//
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 28ba950bd..bd3fb4c03 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -841,9 +841,11 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
AbstractSocketNamespace: args.AbstractSocketNamespace,
ContainerID: args.ContainerID,
}
- if _, err := k.tasks.NewTask(config); err != nil {
+ t, err := k.tasks.NewTask(config)
+ if err != nil {
return nil, 0, err
}
+ t.traceExecEvent(tc) // Simulate exec for tracing.
// Success.
tgid := k.tasks.Root.IDOfThreadGroup(tg)
@@ -1118,6 +1120,22 @@ func (k *Kernel) SendContainerSignal(cid string, info *arch.SignalInfo) error {
return lastErr
}
+// RebuildTraceContexts rebuilds the trace context for all tasks.
+//
+// Unfortunately, if these are built while tracing is not enabled, then we will
+// not have meaningful trace data. Rebuilding here ensures that we can do so
+// after tracing has been enabled.
+func (k *Kernel) RebuildTraceContexts() {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ k.tasks.mu.RLock()
+ defer k.tasks.mu.RUnlock()
+
+ for t, tid := range k.tasks.Root.tids {
+ t.rebuildTraceContext(tid)
+ }
+}
+
// FeatureSet returns the FeatureSet.
func (k *Kernel) FeatureSet() *cpuid.FeatureSet {
return k.featureSet
diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go
index 93fe68a3e..de9617e9d 100644
--- a/pkg/sentry/kernel/semaphore/semaphore.go
+++ b/pkg/sentry/kernel/semaphore/semaphore.go
@@ -302,7 +302,7 @@ func (s *Set) SetVal(ctx context.Context, num int32, val int16, creds *auth.Cred
return syserror.ERANGE
}
- // TODO(b/29354920): Clear undo entries in all processes
+ // TODO(gvisor.dev/issue/137): Clear undo entries in all processes.
sem.value = val
sem.pid = pid
s.changeTime = ktime.NowFromContext(ctx)
@@ -336,7 +336,7 @@ func (s *Set) SetValAll(ctx context.Context, vals []uint16, creds *auth.Credenti
for i, val := range vals {
sem := &s.sems[i]
- // TODO(b/29354920): Clear undo entries in all processes
+ // TODO(gvisor.dev/issue/137): Clear undo entries in all processes.
sem.value = int16(val)
sem.pid = pid
sem.wakeWaiters()
@@ -481,7 +481,7 @@ func (s *Set) executeOps(ctx context.Context, ops []linux.Sembuf, pid int32) (ch
}
// All operations succeeded, apply them.
- // TODO(b/29354920): handle undo operations.
+ // TODO(gvisor.dev/issue/137): handle undo operations.
for i, v := range tmpVals {
s.sems[i].value = v
s.sems[i].wakeWaiters()
diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go
index 220fa73a2..2fdee0282 100644
--- a/pkg/sentry/kernel/syscalls.go
+++ b/pkg/sentry/kernel/syscalls.go
@@ -339,6 +339,14 @@ func (s *SyscallTable) Lookup(sysno uintptr) SyscallFn {
return nil
}
+// LookupName looks up a syscall name.
+func (s *SyscallTable) LookupName(sysno uintptr) string {
+ if sc, ok := s.Table[sysno]; ok {
+ return sc.Name
+ }
+ return fmt.Sprintf("sys_%d", sysno) // Unlikely.
+}
+
// LookupEmulate looks up an emulation syscall number.
func (s *SyscallTable) LookupEmulate(addr usermem.Addr) (uintptr, bool) {
sysno, ok := s.Emulate[addr]
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index 80c8e5464..ab0c6c4aa 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -15,6 +15,8 @@
package kernel
import (
+ gocontext "context"
+ "runtime/trace"
"sync"
"sync/atomic"
@@ -390,7 +392,14 @@ type Task struct {
// logPrefix is a string containing the task's thread ID in the root PID
// namespace, and is prepended to log messages emitted by Task.Infof etc.
- logPrefix atomic.Value `state:".(string)"`
+ logPrefix atomic.Value `state:"nosave"`
+
+ // traceContext and traceTask are both used for tracing, and are
+ // updated along with the logPrefix in updateInfoLocked.
+ //
+ // These are exclusive to the task goroutine.
+ traceContext gocontext.Context `state:"nosave"`
+ traceTask *trace.Task `state:"nosave"`
// creds is the task's credentials.
//
@@ -528,14 +537,6 @@ func (t *Task) loadPtraceTracer(tracer *Task) {
t.ptraceTracer.Store(tracer)
}
-func (t *Task) saveLogPrefix() string {
- return t.logPrefix.Load().(string)
-}
-
-func (t *Task) loadLogPrefix(prefix string) {
- t.logPrefix.Store(prefix)
-}
-
func (t *Task) saveSyscallFilters() []bpf.Program {
if f := t.syscallFilters.Load(); f != nil {
return f.([]bpf.Program)
@@ -549,6 +550,7 @@ func (t *Task) loadSyscallFilters(filters []bpf.Program) {
// afterLoad is invoked by stateify.
func (t *Task) afterLoad() {
+ t.updateInfoLocked()
t.interruptChan = make(chan struct{}, 1)
t.gosched.State = TaskGoroutineNonexistent
if t.stop != nil {
diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go
index dd69939f9..4a4a69ee2 100644
--- a/pkg/sentry/kernel/task_block.go
+++ b/pkg/sentry/kernel/task_block.go
@@ -16,6 +16,7 @@ package kernel
import (
"runtime"
+ "runtime/trace"
"time"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
@@ -133,19 +134,24 @@ func (t *Task) block(C <-chan struct{}, timerChan <-chan struct{}) error {
runtime.Gosched()
}
+ region := trace.StartRegion(t.traceContext, blockRegion)
select {
case <-C:
+ region.End()
t.SleepFinish(true)
+ // Woken by event.
return nil
case <-interrupt:
+ region.End()
t.SleepFinish(false)
// Return the indicated error on interrupt.
return syserror.ErrInterrupted
case <-timerChan:
- // We've timed out.
+ region.End()
t.SleepFinish(true)
+ // We've timed out.
return syserror.ETIMEDOUT
}
}
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
index 0916fd658..3eadfedb4 100644
--- a/pkg/sentry/kernel/task_clone.go
+++ b/pkg/sentry/kernel/task_clone.go
@@ -299,6 +299,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
// nt that it must receive before its task goroutine starts running.
tid := nt.k.tasks.Root.IDOfTask(nt)
defer nt.Start(tid)
+ t.traceCloneEvent(tid)
// "If fork/clone and execve are allowed by @prog, any child processes will
// be constrained to the same filters and system call ABI as the parent." -
diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go
index 17a089b90..90a6190f1 100644
--- a/pkg/sentry/kernel/task_exec.go
+++ b/pkg/sentry/kernel/task_exec.go
@@ -129,6 +129,7 @@ type runSyscallAfterExecStop struct {
}
func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState {
+ t.traceExecEvent(r.tc)
t.tg.pidns.owner.mu.Lock()
t.tg.execing = nil
if t.killed() {
@@ -253,7 +254,7 @@ func (t *Task) promoteLocked() {
t.tg.leader = t
t.Infof("Becoming TID %d (in root PID namespace)", t.tg.pidns.owner.Root.tids[t])
- t.updateLogPrefixLocked()
+ t.updateInfoLocked()
// Reap the original leader. If it has a tracer, detach it instead of
// waiting for it to acknowledge the original leader's death.
oldLeader.exitParentNotified = true
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
index 535f03e50..435761e5a 100644
--- a/pkg/sentry/kernel/task_exit.go
+++ b/pkg/sentry/kernel/task_exit.go
@@ -236,6 +236,7 @@ func (*runExit) execute(t *Task) taskRunState {
type runExitMain struct{}
func (*runExitMain) execute(t *Task) taskRunState {
+ t.traceExitEvent()
lastExiter := t.exitThreadGroup()
// If the task has a cleartid, and the thread group wasn't killed by a
diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go
index a29e9b9eb..0fb3661de 100644
--- a/pkg/sentry/kernel/task_log.go
+++ b/pkg/sentry/kernel/task_log.go
@@ -16,6 +16,7 @@ package kernel
import (
"fmt"
+ "runtime/trace"
"sort"
"gvisor.dev/gvisor/pkg/log"
@@ -127,11 +128,88 @@ func (t *Task) debugDumpStack() {
}
}
-// updateLogPrefix updates the task's cached log prefix to reflect its
-// current thread ID.
+// trace definitions.
+//
+// Note that all region names are prefixed by ':' in order to ensure that they
+// are lexically ordered before all system calls, which use the naked system
+// call name (e.g. "read") for maximum clarity.
+const (
+ traceCategory = "task"
+ runRegion = ":run"
+ blockRegion = ":block"
+ cpuidRegion = ":cpuid"
+ faultRegion = ":fault"
+)
+
+// updateInfoLocked updates the task's cached log prefix and tracing
+// information to reflect its current thread ID.
//
// Preconditions: The task's owning TaskSet.mu must be locked.
-func (t *Task) updateLogPrefixLocked() {
+func (t *Task) updateInfoLocked() {
// Use the task's TID in the root PID namespace for logging.
- t.logPrefix.Store(fmt.Sprintf("[% 4d] ", t.tg.pidns.owner.Root.tids[t]))
+ tid := t.tg.pidns.owner.Root.tids[t]
+ t.logPrefix.Store(fmt.Sprintf("[% 4d] ", tid))
+ t.rebuildTraceContext(tid)
+}
+
+// rebuildTraceContext rebuilds the trace context.
+//
+// Precondition: the passed tid must be the tid in the root namespace.
+func (t *Task) rebuildTraceContext(tid ThreadID) {
+ // Re-initialize the trace context.
+ if t.traceTask != nil {
+ t.traceTask.End()
+ }
+
+ // Note that we define the "task type" to be the dynamic TID. This does
+ // not align perfectly with the documentation for "tasks" in the
+ // tracing package. Tasks may be assumed to be bounded by analysis
+ // tools. However, if we just use a generic "task" type here, then the
+ // "user-defined tasks" page on the tracing dashboard becomes nearly
+ // unusable, as it loads all traces from all tasks.
+ //
+ // We can assume that the number of tasks in the system is not
+ // arbitrarily large (in general it won't be, especially for cases
+ // where we're collecting a brief profile), so using the TID is a
+ // reasonable compromise in this case.
+ t.traceContext, t.traceTask = trace.NewTask(t, fmt.Sprintf("tid:%d", tid))
+}
+
+// traceCloneEvent is called when a new task is spawned.
+//
+// ntid must be the new task's ThreadID in the root namespace.
+func (t *Task) traceCloneEvent(ntid ThreadID) {
+ if !trace.IsEnabled() {
+ return
+ }
+ trace.Logf(t.traceContext, traceCategory, "spawn: %d", ntid)
+}
+
+// traceExitEvent is called when a task exits.
+func (t *Task) traceExitEvent() {
+ if !trace.IsEnabled() {
+ return
+ }
+ trace.Logf(t.traceContext, traceCategory, "exit status: 0x%x", t.exitStatus.Status())
+}
+
+// traceExecEvent is called when a task calls exec.
+func (t *Task) traceExecEvent(tc *TaskContext) {
+ if !trace.IsEnabled() {
+ return
+ }
+ d := tc.MemoryManager.Executable()
+ if d == nil {
+ trace.Logf(t.traceContext, traceCategory, "exec: << unknown >>")
+ return
+ }
+ defer d.DecRef()
+ root := t.fsContext.RootDirectory()
+ if root == nil {
+ trace.Logf(t.traceContext, traceCategory, "exec: << no root directory >>")
+ return
+ }
+ defer root.DecRef()
+ n, _ := d.FullName(root)
+ trace.Logf(t.traceContext, traceCategory, "exec: %s", n)
}
diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go
index c92266c59..d97f8c189 100644
--- a/pkg/sentry/kernel/task_run.go
+++ b/pkg/sentry/kernel/task_run.go
@@ -17,6 +17,7 @@ package kernel
import (
"bytes"
"runtime"
+ "runtime/trace"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -205,9 +206,11 @@ func (*runApp) execute(t *Task) taskRunState {
t.tg.pidns.owner.mu.RUnlock()
}
+ region := trace.StartRegion(t.traceContext, runRegion)
t.accountTaskGoroutineEnter(TaskGoroutineRunningApp)
info, at, err := t.p.Switch(t.MemoryManager().AddressSpace(), t.Arch(), t.rseqCPU)
t.accountTaskGoroutineLeave(TaskGoroutineRunningApp)
+ region.End()
if clearSinglestep {
t.Arch().ClearSingleStep()
@@ -225,6 +228,7 @@ func (*runApp) execute(t *Task) taskRunState {
case platform.ErrContextSignalCPUID:
// Is this a CPUID instruction?
+ region := trace.StartRegion(t.traceContext, cpuidRegion)
expected := arch.CPUIDInstruction[:]
found := make([]byte, len(expected))
_, err := t.CopyIn(usermem.Addr(t.Arch().IP()), &found)
@@ -232,10 +236,12 @@ func (*runApp) execute(t *Task) taskRunState {
// Skip the cpuid instruction.
t.Arch().CPUIDEmulate(t)
t.Arch().SetIP(t.Arch().IP() + uintptr(len(expected)))
+ region.End()
// Resume execution.
return (*runApp)(nil)
}
+ region.End() // Not an actual CPUID, but required copy-in.
// The instruction at the given RIP was not a CPUID, and we
// fallthrough to the default signal deliver behavior below.
@@ -251,8 +257,10 @@ func (*runApp) execute(t *Task) taskRunState {
// an application-generated signal and we should continue execution
// normally.
if at.Any() {
+ region := trace.StartRegion(t.traceContext, faultRegion)
addr := usermem.Addr(info.Addr())
err := t.MemoryManager().HandleUserFault(t, addr, at, usermem.Addr(t.Arch().Stack()))
+ region.End()
if err == nil {
// The fault was handled appropriately.
// We can resume running the application.
@@ -260,6 +268,12 @@ func (*runApp) execute(t *Task) taskRunState {
}
// Is this a vsyscall that we need emulate?
+ //
+ // Note that we don't track vsyscalls as part of a
+ // specific trace region. This is because regions don't
+ // stack, and the actual system call will count as a
+ // region. We should be able to easily identify
+ // vsyscalls by having a <fault><syscall> pair.
if at.Execute {
if sysno, ok := t.tc.st.LookupEmulate(addr); ok {
return t.doVsyscall(addr, sysno)
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
index ae6fc4025..3522a4ae5 100644
--- a/pkg/sentry/kernel/task_start.go
+++ b/pkg/sentry/kernel/task_start.go
@@ -154,10 +154,10 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
// Below this point, newTask is expected not to fail (there is no rollback
// of assignTIDsLocked or any of the following).
- // Logging on t's behalf will panic if t.logPrefix hasn't been initialized.
- // This is the earliest point at which we can do so (since t now has thread
- // IDs).
- t.updateLogPrefixLocked()
+ // Logging on t's behalf will panic if t.logPrefix hasn't been
+ // initialized. This is the earliest point at which we can do so
+ // (since t now has thread IDs).
+ t.updateInfoLocked()
if cfg.InheritParent != nil {
t.parent = cfg.InheritParent.parent
diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go
index b543d536a..3180f5560 100644
--- a/pkg/sentry/kernel/task_syscall.go
+++ b/pkg/sentry/kernel/task_syscall.go
@@ -17,6 +17,7 @@ package kernel
import (
"fmt"
"os"
+ "runtime/trace"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -160,6 +161,10 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u
ctrl = ctrlStopAndReinvokeSyscall
} else {
fn := s.Lookup(sysno)
+ var region *trace.Region // Only non-nil if tracing == true.
+ if trace.IsEnabled() {
+ region = trace.StartRegion(t.traceContext, s.LookupName(sysno))
+ }
if fn != nil {
// Call our syscall implementation.
rval, ctrl, err = fn(t, args)
@@ -167,6 +172,9 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u
// Use the missing function if not found.
rval, err = t.SyscallTable().Missing(t, sysno, args)
}
+ if region != nil {
+ region.End()
+ }
}
if bits.IsOn32(fe, ExternalAfterEnable) && (s.ExternalFilterAfter == nil || s.ExternalFilterAfter(t, sysno, args)) {
diff --git a/pkg/sentry/kernel/tty.go b/pkg/sentry/kernel/tty.go
index 34f84487a..048de26dc 100644
--- a/pkg/sentry/kernel/tty.go
+++ b/pkg/sentry/kernel/tty.go
@@ -21,8 +21,19 @@ import "sync"
//
// +stateify savable
type TTY struct {
+ // Index is the terminal index. It is immutable.
+ Index uint32
+
mu sync.Mutex `state:"nosave"`
// tg is protected by mu.
tg *ThreadGroup
}
+
+// TTY returns the thread group's controlling terminal. If nil, there is no
+// controlling terminal.
+func (tg *ThreadGroup) TTY() *TTY {
+ tg.signalHandlers.mu.Lock()
+ defer tg.signalHandlers.mu.Unlock()
+ return tg.tty
+}
diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go
index c2c3ec06e..6299a3e2f 100644
--- a/pkg/sentry/loader/elf.go
+++ b/pkg/sentry/loader/elf.go
@@ -408,6 +408,8 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, info el
start = vaddr
}
if vaddr < end {
+ // NOTE(b/37474556): Linux allows out-of-order
+ // segments, in violation of the spec.
ctx.Infof("PT_LOAD headers out-of-order. %#x < %#x", vaddr, end)
return loadedELF{}, syserror.ENOEXEC
}
diff --git a/pkg/sentry/sighandling/sighandling.go b/pkg/sentry/sighandling/sighandling.go
index 2f65db70b..ba1f9043d 100644
--- a/pkg/sentry/sighandling/sighandling.go
+++ b/pkg/sentry/sighandling/sighandling.go
@@ -16,7 +16,6 @@
package sighandling
import (
- "fmt"
"os"
"os/signal"
"reflect"
@@ -31,37 +30,25 @@ const numSignals = 32
// handleSignals listens for incoming signals and calls the given handler
// function.
//
-// It starts when the start channel is closed, stops when the stop channel
-// is closed, and closes done once it will no longer deliver signals to k.
-func handleSignals(sigchans []chan os.Signal, handler func(linux.Signal), start, stop, done chan struct{}) {
+// It stops when the stop channel is closed. The done channel is closed once it
+// will no longer deliver signals to k.
+func handleSignals(sigchans []chan os.Signal, handler func(linux.Signal), stop, done chan struct{}) {
// Build a select case.
- sc := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(start)}}
+ sc := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(stop)}}
for _, sigchan := range sigchans {
sc = append(sc, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sigchan)})
}
- started := false
for {
// Wait for a notification.
index, _, ok := reflect.Select(sc)
- // Was it the start / stop channel?
+ // Was it the stop channel?
if index == 0 {
if !ok {
- if !started {
- // start channel; start forwarding and
- // swap this case for the stop channel
- // to select stop requests.
- started = true
- sc[0] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(stop)}
- } else {
- // stop channel; stop forwarding and
- // clear this case so it is never
- // selected again.
- started = false
- close(done)
- sc[0].Chan = reflect.Value{}
- }
+ // Stop forwarding and notify that it's done.
+ close(done)
+ return
}
continue
}
@@ -73,44 +60,17 @@ func handleSignals(sigchans []chan os.Signal, handler func(linux.Signal), start,
// Otherwise, it was a signal on channel N. Index 0 represents the stop
// channel, so index N represents the channel for signal N.
- signal := linux.Signal(index)
-
- if !started {
- // Kernel cannot receive signals, either because it is
- // not ready yet or is shutting down.
- //
- // Kill ourselves if this signal would have killed the
- // process before PrepareForwarding was called. i.e., all
- // _SigKill signals; see Go
- // src/runtime/sigtab_linux_generic.go.
- //
- // Otherwise ignore the signal.
- //
- // TODO(b/114489875): Drop in Go 1.12, which uses tgkill
- // in runtime.raise.
- switch signal {
- case linux.SIGHUP, linux.SIGINT, linux.SIGTERM:
- dieFromSignal(signal)
- panic(fmt.Sprintf("Failed to die from signal %d", signal))
- default:
- continue
- }
- }
-
- // Pass the signal to the handler.
- handler(signal)
+ handler(linux.Signal(index))
}
}
-// PrepareHandler ensures that synchronous signals are passed to the given
-// handler function and returns a callback that starts signal delivery, which
-// itself returns a callback that stops signal handling.
+// StartSignalForwarding ensures that synchronous signals are passed to the
+// given handler function and returns a callback that stops signal delivery.
//
// Note that this function permanently takes over signal handling. After the
// stop callback, signals revert to the default Go runtime behavior, which
// cannot be overridden with external calls to signal.Notify.
-func PrepareHandler(handler func(linux.Signal)) func() func() {
- start := make(chan struct{})
+func StartSignalForwarding(handler func(linux.Signal)) func() {
stop := make(chan struct{})
done := make(chan struct{})
@@ -128,13 +88,10 @@ func PrepareHandler(handler func(linux.Signal)) func() func() {
signal.Notify(sigchan, syscall.Signal(sig))
}
// Start up our listener.
- go handleSignals(sigchans, handler, start, stop, done) // S/R-SAFE: synchronized by Kernel.extMu.
+ go handleSignals(sigchans, handler, stop, done) // S/R-SAFE: synchronized by Kernel.extMu.
- return func() func() {
- close(start)
- return func() {
- close(stop)
- <-done
- }
+ return func() {
+ close(stop)
+ <-done
}
}
diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sentry/sighandling/sighandling_unsafe.go
index c303435d5..1ebe22d34 100644
--- a/pkg/sentry/sighandling/sighandling_unsafe.go
+++ b/pkg/sentry/sighandling/sighandling_unsafe.go
@@ -15,8 +15,6 @@
package sighandling
import (
- "fmt"
- "runtime"
"syscall"
"unsafe"
@@ -48,27 +46,3 @@ func IgnoreChildStop() error {
return nil
}
-
-// dieFromSignal kills the current process with sig.
-//
-// Preconditions: The default action of sig is termination.
-func dieFromSignal(sig linux.Signal) {
- runtime.LockOSThread()
- defer runtime.UnlockOSThread()
-
- sa := sigaction{handler: linux.SIG_DFL}
- if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, linux.SignalSetSize, 0, 0); e != 0 {
- panic(fmt.Sprintf("rt_sigaction failed: %v", e))
- }
-
- set := linux.MakeSignalSet(sig)
- if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGPROCMASK, linux.SIG_UNBLOCK, uintptr(unsafe.Pointer(&set)), 0, linux.SignalSetSize, 0, 0); e != 0 {
- panic(fmt.Sprintf("rt_sigprocmask failed: %v", e))
- }
-
- if err := syscall.Tgkill(syscall.Getpid(), syscall.Gettid(), syscall.Signal(sig)); err != nil {
- panic(fmt.Sprintf("tgkill failed: %v", err))
- }
-
- panic("failed to die")
-}
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 782a3cb92..af1a4e95f 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -195,15 +195,15 @@ func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([
// the available space, we must align down.
//
// align must be >= 4 and each data int32 is 4 bytes. The length of the
- // header is already aligned, so if we align to the with of the data there
+ // header is already aligned, so if we align to the width of the data there
// are two cases:
// 1. The aligned length is less than the length of the header. The
// unaligned length was also less than the length of the header, so we
// can't write anything.
// 2. The aligned length is greater than or equal to the length of the
- // header. We can write the header plus zero or more datas. We can't write
- // a partial int32, so the length of the message will be
- // min(aligned length, header + datas).
+ // header. We can write the header plus zero or more bytes of data. We can't
+ // write a partial int32, so the length of the message will be
+ // min(aligned length, header + data).
if space < linux.SizeOfControlMessageHeader {
flags |= linux.MSG_CTRUNC
return buf, flags
@@ -240,12 +240,12 @@ func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interf
buf = binary.Marshal(buf, usermem.ByteOrder, data)
- // Check if we went over.
+ // If the control message data brought us over capacity, omit it.
if cap(buf) != cap(ob) {
return hdrBuf
}
- // Fix up length.
+ // Update control message length to include data.
putUint64(ob, uint64(len(buf)-len(ob)))
return alignSlice(buf, align)
@@ -348,43 +348,62 @@ func PackTClass(t *kernel.Task, tClass int32, buf []byte) []byte {
)
}
-func addSpaceForCmsg(cmsgDataLen int, buf []byte) []byte {
- newBuf := make([]byte, 0, len(buf)+linux.SizeOfControlMessageHeader+cmsgDataLen)
- return append(newBuf, buf...)
-}
-
-// PackControlMessages converts the given ControlMessages struct into a buffer.
+// PackControlMessages packs control messages into the given buffer.
+//
// We skip control messages specific to Unix domain sockets.
-func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages) []byte {
- var buf []byte
- // The use of t.Arch().Width() is analogous to Linux's use of sizeof(long) in
- // CMSG_ALIGN.
- width := t.Arch().Width()
-
+//
+// Note that some control messages may be truncated if they do not fit under
+// the capacity of buf.
+func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byte) []byte {
if cmsgs.IP.HasTimestamp {
- buf = addSpaceForCmsg(int(width), buf)
buf = PackTimestamp(t, cmsgs.IP.Timestamp, buf)
}
if cmsgs.IP.HasInq {
// In Linux, TCP_CM_INQ is added after SO_TIMESTAMP.
- buf = addSpaceForCmsg(AlignUp(linux.SizeOfControlMessageInq, width), buf)
buf = PackInq(t, cmsgs.IP.Inq, buf)
}
if cmsgs.IP.HasTOS {
- buf = addSpaceForCmsg(AlignUp(linux.SizeOfControlMessageTOS, width), buf)
buf = PackTOS(t, cmsgs.IP.TOS, buf)
}
if cmsgs.IP.HasTClass {
- buf = addSpaceForCmsg(AlignUp(linux.SizeOfControlMessageTClass, width), buf)
buf = PackTClass(t, cmsgs.IP.TClass, buf)
}
return buf
}
+// cmsgSpace is equivalent to CMSG_SPACE in Linux.
+func cmsgSpace(t *kernel.Task, dataLen int) int {
+ return linux.SizeOfControlMessageHeader + AlignUp(dataLen, t.Arch().Width())
+}
+
+// CmsgsSpace returns the number of bytes needed to fit the control messages
+// represented in cmsgs.
+func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
+ space := 0
+
+ if cmsgs.IP.HasTimestamp {
+ space += cmsgSpace(t, linux.SizeOfTimeval)
+ }
+
+ if cmsgs.IP.HasInq {
+ space += cmsgSpace(t, linux.SizeOfControlMessageInq)
+ }
+
+ if cmsgs.IP.HasTOS {
+ space += cmsgSpace(t, linux.SizeOfControlMessageTOS)
+ }
+
+ if cmsgs.IP.HasTClass {
+ space += cmsgSpace(t, linux.SizeOfControlMessageTClass)
+ }
+
+ return space
+}
+
// Parse parses a raw socket control message into portable objects.
func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.ControlMessages, error) {
var (
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 8d9363aac..c957b0f1d 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -45,7 +45,7 @@ const (
sizeofSockaddr = syscall.SizeofSockaddrInet6 // sizeof(sockaddr_in6) > sizeof(sockaddr_in)
// maxControlLen is the maximum size of a control message buffer used in a
- // recvmsg syscall.
+ // recvmsg or sendmsg syscall.
maxControlLen = 1024
)
@@ -289,12 +289,12 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt
switch level {
case linux.SOL_IP:
switch name {
- case linux.IP_RECVTOS:
+ case linux.IP_TOS, linux.IP_RECVTOS:
optlen = sizeofInt32
}
case linux.SOL_IPV6:
switch name {
- case linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY:
+ case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY:
optlen = sizeofInt32
}
case linux.SOL_SOCKET:
@@ -334,12 +334,12 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [
switch level {
case linux.SOL_IP:
switch name {
- case linux.IP_RECVTOS:
+ case linux.IP_TOS, linux.IP_RECVTOS:
optlen = sizeofInt32
}
case linux.SOL_IPV6:
switch name {
- case linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY:
+ case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY:
optlen = sizeofInt32
}
case linux.SOL_SOCKET:
@@ -412,9 +412,12 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
msg.Namelen = uint32(len(senderAddrBuf))
}
if controlLen > 0 {
- controlBuf = make([]byte, maxControlLen)
+ if controlLen > maxControlLen {
+ controlLen = maxControlLen
+ }
+ controlBuf = make([]byte, controlLen)
msg.Control = &controlBuf[0]
- msg.Controllen = maxControlLen
+ msg.Controllen = controlLen
}
n, err := recvmsg(s.fd, &msg, sysflags)
if err != nil {
@@ -489,7 +492,14 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
return 0, syserr.ErrInvalidArgument
}
- controlBuf := control.PackControlMessages(t, controlMessages)
+ space := uint64(control.CmsgsSpace(t, controlMessages))
+ if space > maxControlLen {
+ space = maxControlLen
+ }
+ controlBuf := make([]byte, 0, space)
+ // PackControlMessages will append up to space bytes to controlBuf.
+ controlBuf = control.PackControlMessages(t, controlMessages, controlBuf)
+
sendmsgFromBlocks := safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) {
// Refuse to do anything if any part of src.Addrs was unusable.
if uint64(src.NumBytes()) != srcs.NumBytes() {
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index d92399efd..fe5a46aa3 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -151,6 +151,8 @@ var Metrics = tcpip.Stats{
PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."),
CurrentEstablished: mustCreateMetric("/netstack/tcp/current_established", "Number of connections in either ESTABLISHED or CLOSE-WAIT state now."),
EstablishedResets: mustCreateMetric("/netstack/tcp/established_resets", "Number of times TCP connections have made a direct transition to the CLOSED state from either the ESTABLISHED state or the CLOSE-WAIT state"),
+ EstablishedClosed: mustCreateMetric("/netstack/tcp/established_closed", "number of times established TCP connections made a transition to CLOSED state."),
+ EstablishedTimedout: mustCreateMetric("/netstack/tcp/established_timedout", "Number of times an established connection was reset because of keep-alive time out."),
ListenOverflowSynDrop: mustCreateMetric("/netstack/tcp/listen_overflow_syn_drop", "Number of times the listen queue overflowed and a SYN was dropped."),
ListenOverflowAckDrop: mustCreateMetric("/netstack/tcp/listen_overflow_ack_drop", "Number of times the listen queue overflowed and the final ACK in the handshake was dropped."),
ListenOverflowSynCookieSent: mustCreateMetric("/netstack/tcp/listen_overflow_syn_cookie_sent", "Number of times a SYN cookie was sent."),
diff --git a/pkg/sentry/socket/rpcinet/syscall_rpc.proto b/pkg/sentry/socket/rpcinet/syscall_rpc.proto
index 9586f5923..b677e9eb3 100644
--- a/pkg/sentry/socket/rpcinet/syscall_rpc.proto
+++ b/pkg/sentry/socket/rpcinet/syscall_rpc.proto
@@ -3,7 +3,6 @@ syntax = "proto3";
// package syscall_rpc is a set of networking related system calls that can be
// forwarded to a socket gofer.
//
-// TODO(b/77963526): Document individual RPCs.
package syscall_rpc;
message SendmsgRequest {
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 8c250c325..2389a9cdb 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -43,6 +43,11 @@ type ControlMessages struct {
IP tcpip.ControlMessages
}
+// Release releases Unix domain socket credentials and rights.
+func (c *ControlMessages) Release() {
+ c.Unix.Release()
+}
+
// Socket is the interface containing socket syscalls used by the syscall layer
// to redirect them to the appropriate implementation.
type Socket interface {
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index 72ebf766d..d46421199 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -14,6 +14,7 @@ go_library(
"open.go",
"poll.go",
"ptrace.go",
+ "select.go",
"signal.go",
"socket.go",
"strace.go",
diff --git a/pkg/sentry/strace/linux64.go b/pkg/sentry/strace/linux64.go
index 5d57b75af..e603f858f 100644
--- a/pkg/sentry/strace/linux64.go
+++ b/pkg/sentry/strace/linux64.go
@@ -40,7 +40,7 @@ var linuxAMD64 = SyscallMap{
20: makeSyscallInfo("writev", FD, WriteIOVec, Hex),
21: makeSyscallInfo("access", Path, Oct),
22: makeSyscallInfo("pipe", PipeFDs),
- 23: makeSyscallInfo("select", Hex, Hex, Hex, Hex, Timeval),
+ 23: makeSyscallInfo("select", Hex, SelectFDSet, SelectFDSet, SelectFDSet, Timeval),
24: makeSyscallInfo("sched_yield"),
25: makeSyscallInfo("mremap", Hex, Hex, Hex, Hex, Hex),
26: makeSyscallInfo("msync", Hex, Hex, Hex),
@@ -287,7 +287,7 @@ var linuxAMD64 = SyscallMap{
267: makeSyscallInfo("readlinkat", FD, Path, ReadBuffer, Hex),
268: makeSyscallInfo("fchmodat", FD, Path, Mode),
269: makeSyscallInfo("faccessat", FD, Path, Oct, Hex),
- 270: makeSyscallInfo("pselect6", Hex, Hex, Hex, Hex, Hex, Hex),
+ 270: makeSyscallInfo("pselect6", Hex, SelectFDSet, SelectFDSet, SelectFDSet, Timespec, SigSet),
271: makeSyscallInfo("ppoll", PollFDs, Hex, Timespec, SigSet, Hex),
272: makeSyscallInfo("unshare", CloneFlags),
273: makeSyscallInfo("set_robust_list", Hex, Hex),
@@ -335,5 +335,33 @@ var linuxAMD64 = SyscallMap{
315: makeSyscallInfo("sched_getattr", Hex, Hex, Hex),
316: makeSyscallInfo("renameat2", FD, Path, Hex, Path, Hex),
317: makeSyscallInfo("seccomp", Hex, Hex, Hex),
+ 318: makeSyscallInfo("getrandom", Hex, Hex, Hex),
+ 319: makeSyscallInfo("memfd_create", Path, Hex), // Not quite a path, but close.
+ 320: makeSyscallInfo("kexec_file_load", FD, FD, Hex, Hex, Hex),
+ 321: makeSyscallInfo("bpf", Hex, Hex, Hex),
+ 322: makeSyscallInfo("execveat", FD, Path, ExecveStringVector, ExecveStringVector, Hex),
+ 323: makeSyscallInfo("userfaultfd", Hex),
+ 324: makeSyscallInfo("membarrier", Hex, Hex),
+ 325: makeSyscallInfo("mlock2", Hex, Hex, Hex),
+ 326: makeSyscallInfo("copy_file_range", FD, Hex, FD, Hex, Hex, Hex),
+ 327: makeSyscallInfo("preadv2", FD, ReadIOVec, Hex, Hex, Hex),
+ 328: makeSyscallInfo("pwritev2", FD, WriteIOVec, Hex, Hex, Hex),
+ 329: makeSyscallInfo("pkey_mprotect", Hex, Hex, Hex, Hex),
+ 330: makeSyscallInfo("pkey_alloc", Hex, Hex),
+ 331: makeSyscallInfo("pkey_free", Hex),
332: makeSyscallInfo("statx", FD, Path, Hex, Hex, Hex),
+ 333: makeSyscallInfo("io_pgetevents", Hex, Hex, Hex, Hex, Timespec, SigSet),
+ 334: makeSyscallInfo("rseq", Hex, Hex, Hex, Hex),
+ 424: makeSyscallInfo("pidfd_send_signal", FD, Signal, Hex, Hex),
+ 425: makeSyscallInfo("io_uring_setup", Hex, Hex),
+ 426: makeSyscallInfo("io_uring_enter", FD, Hex, Hex, Hex, SigSet, Hex),
+ 427: makeSyscallInfo("io_uring_register", FD, Hex, Hex, Hex),
+ 428: makeSyscallInfo("open_tree", FD, Path, Hex),
+ 429: makeSyscallInfo("move_mount", FD, Path, FD, Path, Hex),
+ 430: makeSyscallInfo("fsopen", Path, Hex), // Not quite a path, but close.
+ 431: makeSyscallInfo("fsconfig", FD, Hex, Hex, Hex, Hex),
+ 432: makeSyscallInfo("fsmount", FD, Hex, Hex),
+ 433: makeSyscallInfo("fspick", FD, Path, Hex),
+ 434: makeSyscallInfo("pidfd_open", Hex, Hex),
+ 435: makeSyscallInfo("clone3", Hex, Hex),
}
diff --git a/pkg/sentry/strace/select.go b/pkg/sentry/strace/select.go
new file mode 100644
index 000000000..92c18083d
--- /dev/null
+++ b/pkg/sentry/strace/select.go
@@ -0,0 +1,53 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+)
+
+func fdsFromSet(t *kernel.Task, set []byte) []int {
+ var fds []int
+ // Append n if the n-th bit is 1.
+ for i, v := range set {
+ for j := 0; j < 8; j++ {
+ if (v>>uint(j))&1 == 1 {
+ fds = append(fds, i*8+j)
+ }
+ }
+ }
+ return fds
+}
+
+func fdSet(t *kernel.Task, nfds int, addr usermem.Addr) string {
+ if addr == 0 {
+ return "null"
+ }
+
+ // Calculate the size of the fd set (one bit per fd).
+ nBytes := (nfds + 7) / 8
+ nBitsInLastPartialByte := uint(nfds % 8)
+
+ set, err := linux.CopyInFDSet(t, addr, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return fmt.Sprintf("%#x (error decoding fdset: %s)", addr, err)
+ }
+
+ return fmt.Sprintf("%#x %v", addr, fdsFromSet(t, set))
+}
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index 94334f6d2..51f2efb39 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -208,6 +208,15 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64)
i += linux.SizeOfControlMessageHeader
width := t.Arch().Width()
length := int(h.Length) - linux.SizeOfControlMessageHeader
+ if length < 0 {
+ strs = append(strs, fmt.Sprintf(
+ "{level=%s, type=%s, length=%d, content too short}",
+ level,
+ typ,
+ h.Length,
+ ))
+ break
+ }
if skipData {
strs = append(strs, fmt.Sprintf("{level=%s, type=%s, length=%d}", level, typ, h.Length))
diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go
index 311389547..629c1f308 100644
--- a/pkg/sentry/strace/strace.go
+++ b/pkg/sentry/strace/strace.go
@@ -439,6 +439,8 @@ func (i *SyscallInfo) pre(t *kernel.Task, args arch.SyscallArguments, maximumBlo
output = append(output, capData(t, args[arg-1].Pointer(), args[arg].Pointer()))
case PollFDs:
output = append(output, pollFDs(t, args[arg].Pointer(), uint(args[arg+1].Uint()), false))
+ case SelectFDSet:
+ output = append(output, fdSet(t, int(args[0].Int()), args[arg].Pointer()))
case Oct:
output = append(output, "0o"+strconv.FormatUint(args[arg].Uint64(), 8))
case Hex:
diff --git a/pkg/sentry/strace/syscalls.go b/pkg/sentry/strace/syscalls.go
index 3c389d375..e5d486c4e 100644
--- a/pkg/sentry/strace/syscalls.go
+++ b/pkg/sentry/strace/syscalls.go
@@ -206,6 +206,10 @@ const (
// PollFDs is an array of struct pollfd. The number of entries in the
// array is in the next argument.
PollFDs
+
+ // SelectFDSet is an fd_set argument in select(2)/pselect(2). The number of
+ // fds represented must be the first argument.
+ SelectFDSet
)
// defaultFormat is the syscall argument format to use if the actual format is
diff --git a/pkg/sentry/syscalls/linux/linux64_amd64.go b/pkg/sentry/syscalls/linux/linux64_amd64.go
index 81e4f93a6..797542d28 100644
--- a/pkg/sentry/syscalls/linux/linux64_amd64.go
+++ b/pkg/sentry/syscalls/linux/linux64_amd64.go
@@ -260,7 +260,7 @@ var AMD64 = &kernel.SyscallTable{
217: syscalls.Supported("getdents64", Getdents64),
218: syscalls.Supported("set_tid_address", SetTidAddress),
219: syscalls.Supported("restart_syscall", RestartSyscall),
- 220: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), // TODO(b/29354920)
+ 220: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}),
221: syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil),
222: syscalls.Supported("timer_create", TimerCreate),
223: syscalls.Supported("timer_settime", TimerSettime),
@@ -367,11 +367,31 @@ var AMD64 = &kernel.SyscallTable{
324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(gvisor.dev/issue/267)
325: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
- // Syscalls after 325 are "backports" from versions of Linux after 4.4.
+ // Syscalls implemented after 325 are "backports" from versions
+ // of Linux after 4.4.
326: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil),
327: syscalls.Supported("preadv2", Preadv2),
328: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil),
+ 329: syscalls.ErrorWithEvent("pkey_mprotect", syserror.ENOSYS, "", nil),
+ 330: syscalls.ErrorWithEvent("pkey_alloc", syserror.ENOSYS, "", nil),
+ 331: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil),
332: syscalls.Supported("statx", Statx),
+ 333: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil),
+ 334: syscalls.ErrorWithEvent("rseq", syserror.ENOSYS, "", nil),
+
+ // Linux skips ahead to syscall 424 to sync numbers between arches.
+ 424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil),
+ 425: syscalls.ErrorWithEvent("io_uring_setup", syserror.ENOSYS, "", nil),
+ 426: syscalls.ErrorWithEvent("io_uring_enter", syserror.ENOSYS, "", nil),
+ 427: syscalls.ErrorWithEvent("io_uring_register", syserror.ENOSYS, "", nil),
+ 428: syscalls.ErrorWithEvent("open_tree", syserror.ENOSYS, "", nil),
+ 429: syscalls.ErrorWithEvent("move_mount", syserror.ENOSYS, "", nil),
+ 430: syscalls.ErrorWithEvent("fsopen", syserror.ENOSYS, "", nil),
+ 431: syscalls.ErrorWithEvent("fsconfig", syserror.ENOSYS, "", nil),
+ 432: syscalls.ErrorWithEvent("fsmount", syserror.ENOSYS, "", nil),
+ 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil),
+ 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil),
+ 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil),
},
Emulate: map[usermem.Addr]uintptr{
diff --git a/pkg/sentry/syscalls/linux/linux64_arm64.go b/pkg/sentry/syscalls/linux/linux64_arm64.go
index f1dd4b0c0..2bc7faff5 100644
--- a/pkg/sentry/syscalls/linux/linux64_arm64.go
+++ b/pkg/sentry/syscalls/linux/linux64_arm64.go
@@ -224,7 +224,7 @@ var ARM64 = &kernel.SyscallTable{
189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
190: syscalls.Supported("semget", Semget),
191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil),
- 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), // TODO(b/29354920)
+ 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}),
193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil),
195: syscalls.PartiallySupported("shmctl", Shmctl, "Options SHM_LOCK, SHM_UNLOCK are not supported.", nil),
@@ -302,7 +302,26 @@ var ARM64 = &kernel.SyscallTable{
285: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil),
286: syscalls.Supported("preadv2", Preadv2),
287: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil),
+ 288: syscalls.ErrorWithEvent("pkey_mprotect", syserror.ENOSYS, "", nil),
+ 289: syscalls.ErrorWithEvent("pkey_alloc", syserror.ENOSYS, "", nil),
+ 290: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil),
291: syscalls.Supported("statx", Statx),
+ 292: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil),
+ 293: syscalls.ErrorWithEvent("rseq", syserror.ENOSYS, "", nil),
+
+ // Linux skips ahead to syscall 424 to sync numbers between arches.
+ 424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil),
+ 425: syscalls.ErrorWithEvent("io_uring_setup", syserror.ENOSYS, "", nil),
+ 426: syscalls.ErrorWithEvent("io_uring_enter", syserror.ENOSYS, "", nil),
+ 427: syscalls.ErrorWithEvent("io_uring_register", syserror.ENOSYS, "", nil),
+ 428: syscalls.ErrorWithEvent("open_tree", syserror.ENOSYS, "", nil),
+ 429: syscalls.ErrorWithEvent("move_mount", syserror.ENOSYS, "", nil),
+ 430: syscalls.ErrorWithEvent("fsopen", syserror.ENOSYS, "", nil),
+ 431: syscalls.ErrorWithEvent("fsconfig", syserror.ENOSYS, "", nil),
+ 432: syscalls.ErrorWithEvent("fsmount", syserror.ENOSYS, "", nil),
+ 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil),
+ 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil),
+ 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil),
},
Emulate: map[usermem.Addr]uintptr{},
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index 3b9181002..9bc2445a5 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -840,25 +840,42 @@ func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
return uintptr(newfd), nil, nil
}
-func fGetOwn(t *kernel.Task, file *fs.File) int32 {
+func fGetOwnEx(t *kernel.Task, file *fs.File) linux.FOwnerEx {
ma := file.Async(nil)
if ma == nil {
- return 0
+ return linux.FOwnerEx{}
}
a := ma.(*fasync.FileAsync)
ot, otg, opg := a.Owner()
switch {
case ot != nil:
- return int32(t.PIDNamespace().IDOfTask(ot))
+ return linux.FOwnerEx{
+ Type: linux.F_OWNER_TID,
+ PID: int32(t.PIDNamespace().IDOfTask(ot)),
+ }
case otg != nil:
- return int32(t.PIDNamespace().IDOfThreadGroup(otg))
+ return linux.FOwnerEx{
+ Type: linux.F_OWNER_PID,
+ PID: int32(t.PIDNamespace().IDOfThreadGroup(otg)),
+ }
case opg != nil:
- return int32(-t.PIDNamespace().IDOfProcessGroup(opg))
+ return linux.FOwnerEx{
+ Type: linux.F_OWNER_PGRP,
+ PID: int32(t.PIDNamespace().IDOfProcessGroup(opg)),
+ }
default:
- return 0
+ return linux.FOwnerEx{}
}
}
+func fGetOwn(t *kernel.Task, file *fs.File) int32 {
+ owner := fGetOwnEx(t, file)
+ if owner.Type == linux.F_OWNER_PGRP {
+ return -owner.PID
+ }
+ return owner.PID
+}
+
// fSetOwn sets the file's owner with the semantics of F_SETOWN in Linux.
//
// If who is positive, it represents a PID. If negative, it represents a PGID.
@@ -901,11 +918,13 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
t.FDTable().SetFlags(fd, kernel.FDFlags{
CloseOnExec: flags&linux.FD_CLOEXEC != 0,
})
+ return 0, nil, nil
case linux.F_GETFL:
return uintptr(file.Flags().ToLinux()), nil, nil
case linux.F_SETFL:
flags := uint(args[2].Uint())
file.SetFlags(linuxToFlags(flags).Settable())
+ return 0, nil, nil
case linux.F_SETLK, linux.F_SETLKW:
// In Linux the file system can choose to provide lock operations for an inode.
// Normally pipe and socket types lack lock operations. We diverge and use a heavy
@@ -1008,6 +1027,44 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.F_SETOWN:
fSetOwn(t, file, args[2].Int())
return 0, nil, nil
+ case linux.F_GETOWN_EX:
+ addr := args[2].Pointer()
+ owner := fGetOwnEx(t, file)
+ _, err := t.CopyOut(addr, &owner)
+ return 0, nil, err
+ case linux.F_SETOWN_EX:
+ addr := args[2].Pointer()
+ var owner linux.FOwnerEx
+ n, err := t.CopyIn(addr, &owner)
+ if err != nil {
+ return 0, nil, err
+ }
+ a := file.Async(fasync.New).(*fasync.FileAsync)
+ switch owner.Type {
+ case linux.F_OWNER_TID:
+ task := t.PIDNamespace().TaskWithID(kernel.ThreadID(owner.PID))
+ if task == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ a.SetOwnerTask(t, task)
+ return uintptr(n), nil, nil
+ case linux.F_OWNER_PID:
+ tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(owner.PID))
+ if tg == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ a.SetOwnerThreadGroup(t, tg)
+ return uintptr(n), nil, nil
+ case linux.F_OWNER_PGRP:
+ pg := t.PIDNamespace().ProcessGroupWithID(kernel.ProcessGroupID(owner.PID))
+ if pg == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ a.SetOwnerProcessGroup(t, pg)
+ return uintptr(n), nil, nil
+ default:
+ return 0, nil, syserror.EINVAL
+ }
case linux.F_GET_SEALS:
val, err := tmpfs.GetSeals(file.Dirent.Inode)
return uintptr(val), nil, err
@@ -1035,7 +1092,6 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// Everything else is not yet supported.
return 0, nil, syserror.EINVAL
}
- return 0, nil, nil
}
const (
diff --git a/pkg/sentry/syscalls/linux/sys_poll.go b/pkg/sentry/syscalls/linux/sys_poll.go
index 7a13beac2..631dffec6 100644
--- a/pkg/sentry/syscalls/linux/sys_poll.go
+++ b/pkg/sentry/syscalls/linux/sys_poll.go
@@ -197,53 +197,51 @@ func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration)
return remainingTimeout, n, err
}
+// CopyInFDSet copies an fd set from select(2)/pselect(2).
+func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes int, nBitsInLastPartialByte uint) ([]byte, error) {
+ set := make([]byte, nBytes)
+
+ if addr != 0 {
+ if _, err := t.CopyIn(addr, &set); err != nil {
+ return nil, err
+ }
+ // If we only use part of the last byte, mask out the extraneous bits.
+ //
+ // N.B. This only works on little-endian architectures.
+ if nBitsInLastPartialByte != 0 {
+ set[nBytes-1] &^= byte(0xff) << nBitsInLastPartialByte
+ }
+ }
+ return set, nil
+}
+
func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Addr, timeout time.Duration) (uintptr, error) {
if nfds < 0 || nfds > fileCap {
return 0, syserror.EINVAL
}
- // Capture all the provided input vectors.
- //
- // N.B. This only works on little-endian architectures.
- byteCount := (nfds + 7) / 8
-
- bitsInLastPartialByte := uint(nfds % 8)
- r := make([]byte, byteCount)
- w := make([]byte, byteCount)
- e := make([]byte, byteCount)
+ // Calculate the size of the fd sets (one bit per fd).
+ nBytes := (nfds + 7) / 8
+ nBitsInLastPartialByte := uint(nfds % 8)
- if readFDs != 0 {
- if _, err := t.CopyIn(readFDs, &r); err != nil {
- return 0, err
- }
- // Mask out bits above nfds.
- if bitsInLastPartialByte != 0 {
- r[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte
- }
+ // Capture all the provided input vectors.
+ r, err := CopyInFDSet(t, readFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
}
-
- if writeFDs != 0 {
- if _, err := t.CopyIn(writeFDs, &w); err != nil {
- return 0, err
- }
- if bitsInLastPartialByte != 0 {
- w[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte
- }
+ w, err := CopyInFDSet(t, writeFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
}
-
- if exceptFDs != 0 {
- if _, err := t.CopyIn(exceptFDs, &e); err != nil {
- return 0, err
- }
- if bitsInLastPartialByte != 0 {
- e[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte
- }
+ e, err := CopyInFDSet(t, exceptFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
}
// Count how many FDs are actually being requested so that we can build
// a PollFD array.
fdCount := 0
- for i := 0; i < byteCount; i++ {
+ for i := 0; i < nBytes; i++ {
v := r[i] | w[i] | e[i]
for v != 0 {
v &= (v - 1)
@@ -254,7 +252,7 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add
// Build the PollFD array.
pfd := make([]linux.PollFD, 0, fdCount)
var fd int32
- for i := 0; i < byteCount; i++ {
+ for i := 0; i < nBytes; i++ {
rV, wV, eV := r[i], w[i], e[i]
v := rV | wV | eV
m := byte(1)
@@ -295,8 +293,7 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add
}
// Do the syscall, then count the number of bits set.
- _, _, err := pollBlock(t, pfd, timeout)
- if err != nil {
+ if _, _, err = pollBlock(t, pfd, timeout); err != nil {
return 0, syserror.ConvertIntr(err, syserror.EINTR)
}
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index d8acae063..4b5aafcc0 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -764,7 +764,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i
}
if !cms.Unix.Empty() {
mflags |= linux.MSG_CTRUNC
- cms.Unix.Release()
+ cms.Release()
}
if int(msg.Flags) != mflags {
@@ -784,32 +784,16 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i
if e != nil {
return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
}
- defer cms.Unix.Release()
+ defer cms.Release()
controlData := make([]byte, 0, msg.ControlLen)
+ controlData = control.PackControlMessages(t, cms, controlData)
if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() {
creds, _ := cms.Unix.Credentials.(control.SCMCredentials)
controlData, mflags = control.PackCredentials(t, creds, controlData, mflags)
}
- if cms.IP.HasTimestamp {
- controlData = control.PackTimestamp(t, cms.IP.Timestamp, controlData)
- }
-
- if cms.IP.HasInq {
- // In Linux, TCP_CM_INQ is added after SO_TIMESTAMP.
- controlData = control.PackInq(t, cms.IP.Inq, controlData)
- }
-
- if cms.IP.HasTOS {
- controlData = control.PackTOS(t, cms.IP.TOS, controlData)
- }
-
- if cms.IP.HasTClass {
- controlData = control.PackTClass(t, cms.IP.TClass, controlData)
- }
-
if cms.Unix.Rights != nil {
controlData, mflags = control.PackRights(t, cms.Unix.Rights.(control.SCMRights), flags&linux.MSG_CMSG_CLOEXEC != 0, controlData, mflags)
}
@@ -885,7 +869,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flag
}
n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0)
- cm.Unix.Release()
+ cm.Release()
if e != nil {
return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
}
@@ -1071,7 +1055,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme
n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages)
err = handleIOError(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendmsg", file)
if err != nil {
- controlMessages.Unix.Release()
+ controlMessages.Release()
}
return uintptr(n), err
}
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index 74a325309..59237c3b9 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -19,7 +19,6 @@ go_library(
"options.go",
"permissions.go",
"resolving_path.go",
- "syscalls.go",
"testutil.go",
"vfs.go",
],
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 34007eb57..4473dfce8 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -241,3 +241,96 @@ type IterDirentsCallback interface {
// called.
Handle(dirent Dirent) bool
}
+
+// OnClose is called when a file descriptor representing the FileDescription is
+// closed. Returning a non-nil error should not prevent the file descriptor
+// from being closed.
+func (fd *FileDescription) OnClose(ctx context.Context) error {
+ return fd.impl.OnClose(ctx)
+}
+
+// StatusFlags returns file description status flags, as for fcntl(F_GETFL).
+func (fd *FileDescription) StatusFlags(ctx context.Context) (uint32, error) {
+ flags, err := fd.impl.StatusFlags(ctx)
+ flags |= linux.O_LARGEFILE
+ return flags, err
+}
+
+// SetStatusFlags sets file description status flags, as for fcntl(F_SETFL).
+func (fd *FileDescription) SetStatusFlags(ctx context.Context, flags uint32) error {
+ return fd.impl.SetStatusFlags(ctx, flags)
+}
+
+// Stat returns metadata for the file represented by fd.
+func (fd *FileDescription) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) {
+ return fd.impl.Stat(ctx, opts)
+}
+
+// SetStat updates metadata for the file represented by fd.
+func (fd *FileDescription) SetStat(ctx context.Context, opts SetStatOptions) error {
+ return fd.impl.SetStat(ctx, opts)
+}
+
+// StatFS returns metadata for the filesystem containing the file represented
+// by fd.
+func (fd *FileDescription) StatFS(ctx context.Context) (linux.Statfs, error) {
+ return fd.impl.StatFS(ctx)
+}
+
+// PRead reads from the file represented by fd into dst, starting at the given
+// offset, and returns the number of bytes read. PRead is permitted to return
+// partial reads with a nil error.
+func (fd *FileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) {
+ return fd.impl.PRead(ctx, dst, offset, opts)
+}
+
+// Read is similar to PRead, but does not specify an offset.
+func (fd *FileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) {
+ return fd.impl.Read(ctx, dst, opts)
+}
+
+// PWrite writes src to the file represented by fd, starting at the given
+// offset, and returns the number of bytes written. PWrite is permitted to
+// return partial writes with a nil error.
+func (fd *FileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) {
+ return fd.impl.PWrite(ctx, src, offset, opts)
+}
+
+// Write is similar to PWrite, but does not specify an offset.
+func (fd *FileDescription) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) {
+ return fd.impl.Write(ctx, src, opts)
+}
+
+// IterDirents invokes cb on each entry in the directory represented by fd. If
+// IterDirents has been called since the last call to Seek, it continues
+// iteration from the end of the last call.
+func (fd *FileDescription) IterDirents(ctx context.Context, cb IterDirentsCallback) error {
+ return fd.impl.IterDirents(ctx, cb)
+}
+
+// Seek changes fd's offset (assuming one exists) and returns its new value.
+func (fd *FileDescription) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ return fd.impl.Seek(ctx, offset, whence)
+}
+
+// Sync has the semantics of fsync(2).
+func (fd *FileDescription) Sync(ctx context.Context) error {
+ return fd.impl.Sync(ctx)
+}
+
+// ConfigureMMap mutates opts to implement mmap(2) for the file represented by
+// fd.
+func (fd *FileDescription) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ return fd.impl.ConfigureMMap(ctx, opts)
+}
+
+// Ioctl implements the ioctl(2) syscall.
+func (fd *FileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return fd.impl.Ioctl(ctx, uio, args)
+}
+
+// SyncFS instructs the filesystem containing fd to execute the semantics of
+// syncfs(2).
+func (fd *FileDescription) SyncFS(ctx context.Context) error {
+ return fd.vd.mount.fs.impl.Sync(ctx)
+}
diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go
index a5561dcbe..ac7799296 100644
--- a/pkg/sentry/vfs/file_description_impl_util_test.go
+++ b/pkg/sentry/vfs/file_description_impl_util_test.go
@@ -103,7 +103,7 @@ func TestGenCountFD(t *testing.T) {
// The first read causes Generate to be called to fill the FD's buffer.
buf := make([]byte, 2)
ioseq := usermem.BytesIOSequence(buf)
- n, err := fd.Impl().Read(ctx, ioseq, ReadOptions{})
+ n, err := fd.Read(ctx, ioseq, ReadOptions{})
if n != 1 || (err != nil && err != io.EOF) {
t.Fatalf("first Read: got (%d, %v), wanted (1, nil or EOF)", n, err)
}
@@ -112,17 +112,17 @@ func TestGenCountFD(t *testing.T) {
}
// A second read without seeking is still at EOF.
- n, err = fd.Impl().Read(ctx, ioseq, ReadOptions{})
+ n, err = fd.Read(ctx, ioseq, ReadOptions{})
if n != 0 || err != io.EOF {
t.Fatalf("second Read: got (%d, %v), wanted (0, EOF)", n, err)
}
// Seeking to the beginning of the file causes it to be regenerated.
- n, err = fd.Impl().Seek(ctx, 0, linux.SEEK_SET)
+ n, err = fd.Seek(ctx, 0, linux.SEEK_SET)
if n != 0 || err != nil {
t.Fatalf("Seek: got (%d, %v), wanted (0, nil)", n, err)
}
- n, err = fd.Impl().Read(ctx, ioseq, ReadOptions{})
+ n, err = fd.Read(ctx, ioseq, ReadOptions{})
if n != 1 || (err != nil && err != io.EOF) {
t.Fatalf("Read after Seek: got (%d, %v), wanted (1, nil or EOF)", n, err)
}
@@ -131,7 +131,7 @@ func TestGenCountFD(t *testing.T) {
}
// PRead at the beginning of the file also causes it to be regenerated.
- n, err = fd.Impl().PRead(ctx, ioseq, 0, ReadOptions{})
+ n, err = fd.PRead(ctx, ioseq, 0, ReadOptions{})
if n != 1 || (err != nil && err != io.EOF) {
t.Fatalf("PRead: got (%d, %v), wanted (1, nil or EOF)", n, err)
}
diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go
index 76ff8cf51..dfbd2372a 100644
--- a/pkg/sentry/vfs/filesystem.go
+++ b/pkg/sentry/vfs/filesystem.go
@@ -47,6 +47,9 @@ func (fs *Filesystem) Init(vfsObj *VirtualFilesystem, impl FilesystemImpl) {
fs.refs = 1
fs.vfs = vfsObj
fs.impl = impl
+ vfsObj.filesystemsMu.Lock()
+ vfsObj.filesystems[fs] = struct{}{}
+ vfsObj.filesystemsMu.Unlock()
}
// VirtualFilesystem returns the containing VirtualFilesystem.
@@ -66,9 +69,28 @@ func (fs *Filesystem) IncRef() {
}
}
+// TryIncRef increments fs' reference count and returns true. If fs' reference
+// count is zero, TryIncRef does nothing and returns false.
+//
+// TryIncRef does not require that a reference is held on fs.
+func (fs *Filesystem) TryIncRef() bool {
+ for {
+ refs := atomic.LoadInt64(&fs.refs)
+ if refs <= 0 {
+ return false
+ }
+ if atomic.CompareAndSwapInt64(&fs.refs, refs, refs+1) {
+ return true
+ }
+ }
+}
+
// DecRef decrements fs' reference count.
func (fs *Filesystem) DecRef() {
if refs := atomic.AddInt64(&fs.refs, -1); refs == 0 {
+ fs.vfs.filesystemsMu.Lock()
+ delete(fs.vfs.filesystems, fs)
+ fs.vfs.filesystemsMu.Unlock()
fs.impl.Release()
} else if refs < 0 {
panic("Filesystem.decRef() called without holding a reference")
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index 1c3b2e987..ec23ab0dd 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -18,6 +18,7 @@ import (
"math"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/syserror"
@@ -133,13 +134,13 @@ func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth
return mntns, nil
}
-// NewMount creates and mounts a Filesystem configured by the given arguments.
-func (vfs *VirtualFilesystem) NewMount(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *GetFilesystemOptions) error {
+// MountAt creates and mounts a Filesystem configured by the given arguments.
+func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *MountOptions) error {
fsType := vfs.getFilesystemType(fsTypeName)
if fsType == nil {
return syserror.ENODEV
}
- fs, root, err := fsType.GetFilesystem(ctx, vfs, creds, source, *opts)
+ fs, root, err := fsType.GetFilesystem(ctx, vfs, creds, source, opts.GetFilesystemOptions)
if err != nil {
return err
}
@@ -207,6 +208,68 @@ func (vfs *VirtualFilesystem) NewMount(ctx context.Context, creds *auth.Credenti
return nil
}
+// UmountAt removes the Mount at the given path.
+func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *UmountOptions) error {
+ if opts.Flags&^(linux.MNT_FORCE|linux.MNT_DETACH) != 0 {
+ return syserror.EINVAL
+ }
+
+ // MNT_FORCE is currently unimplemented except for the permission check.
+ if opts.Flags&linux.MNT_FORCE != 0 && creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, creds.UserNamespace.Root()) {
+ return syserror.EPERM
+ }
+
+ vd, err := vfs.GetDentryAt(ctx, creds, pop, &GetDentryOptions{})
+ if err != nil {
+ return err
+ }
+ defer vd.DecRef()
+ if vd.dentry != vd.mount.root {
+ return syserror.EINVAL
+ }
+ vfs.mountMu.Lock()
+ if mntns := MountNamespaceFromContext(ctx); mntns != nil && mntns != vd.mount.ns {
+ vfs.mountMu.Unlock()
+ return syserror.EINVAL
+ }
+
+ // TODO(jamieliu): Linux special-cases umount of the caller's root, which
+ // we don't implement yet (we'll just fail it since the caller holds a
+ // reference on it).
+
+ vfs.mounts.seq.BeginWrite()
+ if opts.Flags&linux.MNT_DETACH == 0 {
+ if len(vd.mount.children) != 0 {
+ vfs.mounts.seq.EndWrite()
+ vfs.mountMu.Unlock()
+ return syserror.EBUSY
+ }
+ // We are holding a reference on vd.mount.
+ expectedRefs := int64(1)
+ if !vd.mount.umounted {
+ expectedRefs = 2
+ }
+ if atomic.LoadInt64(&vd.mount.refs)&^math.MinInt64 != expectedRefs { // mask out MSB
+ vfs.mounts.seq.EndWrite()
+ vfs.mountMu.Unlock()
+ return syserror.EBUSY
+ }
+ }
+ vdsToDecRef, mountsToDecRef := vfs.umountRecursiveLocked(vd.mount, &umountRecursiveOptions{
+ eager: opts.Flags&linux.MNT_DETACH == 0,
+ disconnectHierarchy: true,
+ }, nil, nil)
+ vfs.mounts.seq.EndWrite()
+ vfs.mountMu.Unlock()
+ for _, vd := range vdsToDecRef {
+ vd.DecRef()
+ }
+ for _, mnt := range mountsToDecRef {
+ mnt.DecRef()
+ }
+ return nil
+}
+
type umountRecursiveOptions struct {
// If eager is true, ensure that future calls to Mount.tryIncMountedRef()
// on umounted mounts fail.
diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go
index 3aa73d911..3ecbc8fc1 100644
--- a/pkg/sentry/vfs/options.go
+++ b/pkg/sentry/vfs/options.go
@@ -46,6 +46,12 @@ type MknodOptions struct {
DevMinor uint32
}
+// MountOptions contains options to VirtualFilesystem.MountAt().
+type MountOptions struct {
+ // GetFilesystemOptions contains options to FilesystemType.GetFilesystem().
+ GetFilesystemOptions GetFilesystemOptions
+}
+
// OpenOptions contains options to VirtualFilesystem.OpenAt() and
// FilesystemImpl.OpenAt().
type OpenOptions struct {
@@ -114,6 +120,12 @@ type StatOptions struct {
Sync uint32
}
+// UmountOptions contains options to VirtualFilesystem.UmountAt().
+type UmountOptions struct {
+ // Flags contains flags as specified for umount2(2).
+ Flags uint32
+}
+
// WriteOptions contains options to FileDescription.PWrite(),
// FileDescriptionImpl.PWrite(), FileDescription.Write(), and
// FileDescriptionImpl.Write().
diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go
index f8e74355c..f1edb0680 100644
--- a/pkg/sentry/vfs/permissions.go
+++ b/pkg/sentry/vfs/permissions.go
@@ -119,3 +119,65 @@ func MayWriteFileWithOpenFlags(flags uint32) bool {
return false
}
}
+
+// CheckSetStat checks that creds has permission to change the metadata of a
+// file with the given permissions, UID, and GID as specified by stat, subject
+// to the rules of Linux's fs/attr.c:setattr_prepare().
+func CheckSetStat(creds *auth.Credentials, stat *linux.Statx, mode uint16, kuid auth.KUID, kgid auth.KGID) error {
+ if stat.Mask&linux.STATX_MODE != 0 {
+ if !CanActAsOwner(creds, kuid) {
+ return syserror.EPERM
+ }
+ // TODO(b/30815691): "If the calling process is not privileged (Linux:
+ // does not have the CAP_FSETID capability), and the group of the file
+ // does not match the effective group ID of the process or one of its
+ // supplementary group IDs, the S_ISGID bit will be turned off, but
+ // this will not cause an error to be returned." - chmod(2)
+ }
+ if stat.Mask&linux.STATX_UID != 0 {
+ if !((creds.EffectiveKUID == kuid && auth.KUID(stat.UID) == kuid) ||
+ HasCapabilityOnFile(creds, linux.CAP_CHOWN, kuid, kgid)) {
+ return syserror.EPERM
+ }
+ }
+ if stat.Mask&linux.STATX_GID != 0 {
+ if !((creds.EffectiveKUID == kuid && creds.InGroup(auth.KGID(stat.GID))) ||
+ HasCapabilityOnFile(creds, linux.CAP_CHOWN, kuid, kgid)) {
+ return syserror.EPERM
+ }
+ }
+ if stat.Mask&(linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME) != 0 {
+ if !CanActAsOwner(creds, kuid) {
+ if (stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW) ||
+ (stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec != linux.UTIME_NOW) ||
+ (stat.Mask&linux.STATX_CTIME != 0 && stat.Ctime.Nsec != linux.UTIME_NOW) {
+ return syserror.EPERM
+ }
+ // isDir is irrelevant in the following call to
+ // GenericCheckPermissions since ats == MayWrite means that
+ // CAP_DAC_READ_SEARCH does not apply, and CAP_DAC_OVERRIDE
+ // applies, regardless of isDir.
+ if err := GenericCheckPermissions(creds, MayWrite, false /* isDir */, mode, kuid, kgid); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+// CanActAsOwner returns true if creds can act as the owner of a file with the
+// given owning UID, consistent with Linux's
+// fs/inode.c:inode_owner_or_capable().
+func CanActAsOwner(creds *auth.Credentials, kuid auth.KUID) bool {
+ if creds.EffectiveKUID == kuid {
+ return true
+ }
+ return creds.HasCapability(linux.CAP_FOWNER) && creds.UserNamespace.MapFromKUID(kuid).Ok()
+}
+
+// HasCapabilityOnFile returns true if creds has the given capability with
+// respect to a file with the given owning UID and GID, consistent with Linux's
+// kernel/capability.c:capable_wrt_inode_uidgid().
+func HasCapabilityOnFile(creds *auth.Credentials, cp linux.Capability, kuid auth.KUID, kgid auth.KGID) bool {
+ return creds.HasCapability(cp) && creds.UserNamespace.MapFromKUID(kuid).Ok() && creds.UserNamespace.MapFromKGID(kgid).Ok()
+}
diff --git a/pkg/sentry/vfs/syscalls.go b/pkg/sentry/vfs/syscalls.go
deleted file mode 100644
index 436151afa..000000000
--- a/pkg/sentry/vfs/syscalls.go
+++ /dev/null
@@ -1,237 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package vfs
-
-import (
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/syserror"
-)
-
-// PathOperation specifies the path operated on by a VFS method.
-//
-// PathOperation is passed to VFS methods by pointer to reduce memory copying:
-// it's somewhat large and should never escape. (Options structs are passed by
-// pointer to VFS and FileDescription methods for the same reason.)
-type PathOperation struct {
- // Root is the VFS root. References on Root are borrowed from the provider
- // of the PathOperation.
- //
- // Invariants: Root.Ok().
- Root VirtualDentry
-
- // Start is the starting point for the path traversal. References on Start
- // are borrowed from the provider of the PathOperation (i.e. the caller of
- // the VFS method to which the PathOperation was passed).
- //
- // Invariants: Start.Ok(). If Pathname.Absolute, then Start == Root.
- Start VirtualDentry
-
- // Path is the pathname traversed by this operation.
- Pathname string
-
- // If FollowFinalSymlink is true, and the Dentry traversed by the final
- // path component represents a symbolic link, the symbolic link should be
- // followed.
- FollowFinalSymlink bool
-}
-
-// GetDentryAt returns a VirtualDentry representing the given path, at which a
-// file must exist. A reference is taken on the returned VirtualDentry.
-func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetDentryOptions) (VirtualDentry, error) {
- rp, err := vfs.getResolvingPath(creds, pop)
- if err != nil {
- return VirtualDentry{}, err
- }
- for {
- d, err := rp.mount.fs.impl.GetDentryAt(ctx, rp, *opts)
- if err == nil {
- vd := VirtualDentry{
- mount: rp.mount,
- dentry: d,
- }
- rp.mount.IncRef()
- vfs.putResolvingPath(rp)
- return vd, nil
- }
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
- return VirtualDentry{}, err
- }
- }
-}
-
-// MkdirAt creates a directory at the given path.
-func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MkdirOptions) error {
- // "Under Linux, apart from the permission bits, the S_ISVTX mode bit is
- // also honored." - mkdir(2)
- opts.Mode &= 01777
- rp, err := vfs.getResolvingPath(creds, pop)
- if err != nil {
- return err
- }
- for {
- err := rp.mount.fs.impl.MkdirAt(ctx, rp, *opts)
- if err == nil {
- vfs.putResolvingPath(rp)
- return nil
- }
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
- return err
- }
- }
-}
-
-// MknodAt creates a file of the given mode at the given path. It returns an
-// error from the syserror package.
-func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MknodOptions) error {
- rp, err := vfs.getResolvingPath(creds, pop)
- if err != nil {
- return nil
- }
- for {
- if err = rp.mount.fs.impl.MknodAt(ctx, rp, *opts); err == nil {
- vfs.putResolvingPath(rp)
- return nil
- }
- // Handle mount traversals.
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
- return err
- }
- }
-}
-
-// OpenAt returns a FileDescription providing access to the file at the given
-// path. A reference is taken on the returned FileDescription.
-func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) {
- // Remove:
- //
- // - O_LARGEFILE, which we always report in FileDescription status flags
- // since only 64-bit architectures are supported at this time.
- //
- // - O_CLOEXEC, which affects file descriptors and therefore must be
- // handled outside of VFS.
- //
- // - Unknown flags.
- opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC | linux.O_APPEND | linux.O_NONBLOCK | linux.O_DSYNC | linux.O_ASYNC | linux.O_DIRECT | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NOATIME | linux.O_SYNC | linux.O_PATH | linux.O_TMPFILE
- // Linux's __O_SYNC (which we call linux.O_SYNC) implies O_DSYNC.
- if opts.Flags&linux.O_SYNC != 0 {
- opts.Flags |= linux.O_DSYNC
- }
- // Linux's __O_TMPFILE (which we call linux.O_TMPFILE) must be specified
- // with O_DIRECTORY and a writable access mode (to ensure that it fails on
- // filesystem implementations that do not support it).
- if opts.Flags&linux.O_TMPFILE != 0 {
- if opts.Flags&linux.O_DIRECTORY == 0 {
- return nil, syserror.EINVAL
- }
- if opts.Flags&linux.O_CREAT != 0 {
- return nil, syserror.EINVAL
- }
- if opts.Flags&linux.O_ACCMODE == linux.O_RDONLY {
- return nil, syserror.EINVAL
- }
- }
- // O_PATH causes most other flags to be ignored.
- if opts.Flags&linux.O_PATH != 0 {
- opts.Flags &= linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_PATH
- }
- // "On Linux, the following bits are also honored in mode: [S_ISUID,
- // S_ISGID, S_ISVTX]" - open(2)
- opts.Mode &= 07777
-
- if opts.Flags&linux.O_NOFOLLOW != 0 {
- pop.FollowFinalSymlink = false
- }
- rp, err := vfs.getResolvingPath(creds, pop)
- if err != nil {
- return nil, err
- }
- if opts.Flags&linux.O_DIRECTORY != 0 {
- rp.mustBeDir = true
- rp.mustBeDirOrig = true
- }
- for {
- fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts)
- if err == nil {
- vfs.putResolvingPath(rp)
- return fd, nil
- }
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
- return nil, err
- }
- }
-}
-
-// StatAt returns metadata for the file at the given path.
-func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *StatOptions) (linux.Statx, error) {
- rp, err := vfs.getResolvingPath(creds, pop)
- if err != nil {
- return linux.Statx{}, err
- }
- for {
- stat, err := rp.mount.fs.impl.StatAt(ctx, rp, *opts)
- if err == nil {
- vfs.putResolvingPath(rp)
- return stat, nil
- }
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
- return linux.Statx{}, err
- }
- }
-}
-
-// StatusFlags returns file description status flags.
-func (fd *FileDescription) StatusFlags(ctx context.Context) (uint32, error) {
- flags, err := fd.impl.StatusFlags(ctx)
- flags |= linux.O_LARGEFILE
- return flags, err
-}
-
-// SetStatusFlags sets file description status flags.
-func (fd *FileDescription) SetStatusFlags(ctx context.Context, flags uint32) error {
- return fd.impl.SetStatusFlags(ctx, flags)
-}
-
-// TODO:
-//
-// - VFS.SyncAllFilesystems() for sync(2)
-//
-// - Something for syncfs(2)
-//
-// - VFS.LinkAt()
-//
-// - VFS.ReadlinkAt()
-//
-// - VFS.RenameAt()
-//
-// - VFS.RmdirAt()
-//
-// - VFS.SetStatAt()
-//
-// - VFS.StatFSAt()
-//
-// - VFS.SymlinkAt()
-//
-// - VFS.UmountAt()
-//
-// - VFS.UnlinkAt()
-//
-// - FileDescription.(almost everything)
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index f0cd3ffe5..7262b0d0a 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -20,6 +20,7 @@
// VirtualFilesystem.mountMu
// Dentry.mu
// Locks acquired by FilesystemImpls between Prepare{Delete,Rename}Dentry and Commit{Delete,Rename*}Dentry
+// VirtualFilesystem.filesystemsMu
// VirtualFilesystem.fsTypesMu
//
// Locking Dentry.mu in multiple Dentries requires holding
@@ -28,6 +29,11 @@ package vfs
import (
"sync"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/syserror"
)
// A VirtualFilesystem (VFS for short) combines Filesystems in trees of Mounts.
@@ -67,6 +73,11 @@ type VirtualFilesystem struct {
// mountpoints is analogous to Linux's mountpoint_hashtable.
mountpoints map[*Dentry]map[*Mount]struct{}
+ // filesystems contains all Filesystems. filesystems is protected by
+ // filesystemsMu.
+ filesystemsMu sync.Mutex
+ filesystems map[*Filesystem]struct{}
+
// fsTypes contains all FilesystemTypes that are usable in the
// VirtualFilesystem. fsTypes is protected by fsTypesMu.
fsTypesMu sync.RWMutex
@@ -77,12 +88,379 @@ type VirtualFilesystem struct {
func New() *VirtualFilesystem {
vfs := &VirtualFilesystem{
mountpoints: make(map[*Dentry]map[*Mount]struct{}),
+ filesystems: make(map[*Filesystem]struct{}),
fsTypes: make(map[string]FilesystemType),
}
vfs.mounts.Init()
return vfs
}
+// PathOperation specifies the path operated on by a VFS method.
+//
+// PathOperation is passed to VFS methods by pointer to reduce memory copying:
+// it's somewhat large and should never escape. (Options structs are passed by
+// pointer to VFS and FileDescription methods for the same reason.)
+type PathOperation struct {
+ // Root is the VFS root. References on Root are borrowed from the provider
+ // of the PathOperation.
+ //
+ // Invariants: Root.Ok().
+ Root VirtualDentry
+
+ // Start is the starting point for the path traversal. References on Start
+ // are borrowed from the provider of the PathOperation (i.e. the caller of
+ // the VFS method to which the PathOperation was passed).
+ //
+ // Invariants: Start.Ok(). If Pathname.Absolute, then Start == Root.
+ Start VirtualDentry
+
+ // Path is the pathname traversed by this operation.
+ Pathname string
+
+ // If FollowFinalSymlink is true, and the Dentry traversed by the final
+ // path component represents a symbolic link, the symbolic link should be
+ // followed.
+ FollowFinalSymlink bool
+}
+
+// GetDentryAt returns a VirtualDentry representing the given path, at which a
+// file must exist. A reference is taken on the returned VirtualDentry.
+func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetDentryOptions) (VirtualDentry, error) {
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return VirtualDentry{}, err
+ }
+ for {
+ d, err := rp.mount.fs.impl.GetDentryAt(ctx, rp, *opts)
+ if err == nil {
+ vd := VirtualDentry{
+ mount: rp.mount,
+ dentry: d,
+ }
+ rp.mount.IncRef()
+ vfs.putResolvingPath(rp)
+ return vd, nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return VirtualDentry{}, err
+ }
+ }
+}
+
+// LinkAt creates a hard link at newpop representing the existing file at
+// oldpop.
+func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credentials, oldpop, newpop *PathOperation) error {
+ oldVD, err := vfs.GetDentryAt(ctx, creds, oldpop, &GetDentryOptions{})
+ if err != nil {
+ return err
+ }
+ rp, err := vfs.getResolvingPath(creds, newpop)
+ if err != nil {
+ oldVD.DecRef()
+ return err
+ }
+ for {
+ err := rp.mount.fs.impl.LinkAt(ctx, rp, oldVD)
+ if err == nil {
+ oldVD.DecRef()
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if !rp.handleError(err) {
+ oldVD.DecRef()
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// MkdirAt creates a directory at the given path.
+func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MkdirOptions) error {
+ // "Under Linux, apart from the permission bits, the S_ISVTX mode bit is
+ // also honored." - mkdir(2)
+ opts.Mode &= 0777 | linux.S_ISVTX
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return err
+ }
+ for {
+ err := rp.mount.fs.impl.MkdirAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// MknodAt creates a file of the given mode at the given path. It returns an
+// error from the syserror package.
+func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MknodOptions) error {
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return nil
+ }
+ for {
+ if err = rp.mount.fs.impl.MknodAt(ctx, rp, *opts); err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ // Handle mount traversals.
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// OpenAt returns a FileDescription providing access to the file at the given
+// path. A reference is taken on the returned FileDescription.
+func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) {
+ // Remove:
+ //
+ // - O_LARGEFILE, which we always report in FileDescription status flags
+ // since only 64-bit architectures are supported at this time.
+ //
+ // - O_CLOEXEC, which affects file descriptors and therefore must be
+ // handled outside of VFS.
+ //
+ // - Unknown flags.
+ opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC | linux.O_APPEND | linux.O_NONBLOCK | linux.O_DSYNC | linux.O_ASYNC | linux.O_DIRECT | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NOATIME | linux.O_SYNC | linux.O_PATH | linux.O_TMPFILE
+ // Linux's __O_SYNC (which we call linux.O_SYNC) implies O_DSYNC.
+ if opts.Flags&linux.O_SYNC != 0 {
+ opts.Flags |= linux.O_DSYNC
+ }
+ // Linux's __O_TMPFILE (which we call linux.O_TMPFILE) must be specified
+ // with O_DIRECTORY and a writable access mode (to ensure that it fails on
+ // filesystem implementations that do not support it).
+ if opts.Flags&linux.O_TMPFILE != 0 {
+ if opts.Flags&linux.O_DIRECTORY == 0 {
+ return nil, syserror.EINVAL
+ }
+ if opts.Flags&linux.O_CREAT != 0 {
+ return nil, syserror.EINVAL
+ }
+ if opts.Flags&linux.O_ACCMODE == linux.O_RDONLY {
+ return nil, syserror.EINVAL
+ }
+ }
+ // O_PATH causes most other flags to be ignored.
+ if opts.Flags&linux.O_PATH != 0 {
+ opts.Flags &= linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_PATH
+ }
+ // "On Linux, the following bits are also honored in mode: [S_ISUID,
+ // S_ISGID, S_ISVTX]" - open(2)
+ opts.Mode &= 0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX
+
+ if opts.Flags&linux.O_NOFOLLOW != 0 {
+ pop.FollowFinalSymlink = false
+ }
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return nil, err
+ }
+ if opts.Flags&linux.O_DIRECTORY != 0 {
+ rp.mustBeDir = true
+ rp.mustBeDirOrig = true
+ }
+ for {
+ fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return fd, nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return nil, err
+ }
+ }
+}
+
+// ReadlinkAt returns the target of the symbolic link at the given path.
+func (vfs *VirtualFilesystem) ReadlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (string, error) {
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return "", err
+ }
+ for {
+ target, err := rp.mount.fs.impl.ReadlinkAt(ctx, rp)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return target, nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return "", err
+ }
+ }
+}
+
+// RenameAt renames the file at oldpop to newpop.
+func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credentials, oldpop, newpop *PathOperation, opts *RenameOptions) error {
+ oldVD, err := vfs.GetDentryAt(ctx, creds, oldpop, &GetDentryOptions{})
+ if err != nil {
+ return err
+ }
+ rp, err := vfs.getResolvingPath(creds, newpop)
+ if err != nil {
+ oldVD.DecRef()
+ return err
+ }
+ for {
+ err := rp.mount.fs.impl.RenameAt(ctx, rp, oldVD, *opts)
+ if err == nil {
+ oldVD.DecRef()
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if !rp.handleError(err) {
+ oldVD.DecRef()
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// RmdirAt removes the directory at the given path.
+func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error {
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return err
+ }
+ for {
+ err := rp.mount.fs.impl.RmdirAt(ctx, rp)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// SetStatAt changes metadata for the file at the given path.
+func (vfs *VirtualFilesystem) SetStatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *SetStatOptions) error {
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return err
+ }
+ for {
+ err := rp.mount.fs.impl.SetStatAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// StatAt returns metadata for the file at the given path.
+func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *StatOptions) (linux.Statx, error) {
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ for {
+ stat, err := rp.mount.fs.impl.StatAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return stat, nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return linux.Statx{}, err
+ }
+ }
+}
+
+// StatFSAt returns metadata for the filesystem containing the file at the
+// given path.
+func (vfs *VirtualFilesystem) StatFSAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (linux.Statfs, error) {
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return linux.Statfs{}, err
+ }
+ for {
+ statfs, err := rp.mount.fs.impl.StatFSAt(ctx, rp)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return statfs, nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return linux.Statfs{}, err
+ }
+ }
+}
+
+// SymlinkAt creates a symbolic link at the given path with the given target.
+func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, target string) error {
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return err
+ }
+ for {
+ err := rp.mount.fs.impl.SymlinkAt(ctx, rp, target)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// UnlinkAt deletes the non-directory file at the given path.
+func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error {
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return err
+ }
+ for {
+ err := rp.mount.fs.impl.UnlinkAt(ctx, rp)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// SyncAllFilesystems has the semantics of Linux's sync(2).
+func (vfs *VirtualFilesystem) SyncAllFilesystems(ctx context.Context) error {
+ fss := make(map[*Filesystem]struct{})
+ vfs.filesystemsMu.Lock()
+ for fs := range vfs.filesystems {
+ if !fs.TryIncRef() {
+ continue
+ }
+ fss[fs] = struct{}{}
+ }
+ vfs.filesystemsMu.Unlock()
+ var retErr error
+ for fs := range fss {
+ if err := fs.impl.Sync(ctx); err != nil && retErr == nil {
+ retErr = err
+ }
+ fs.DecRef()
+ }
+ return retErr
+}
+
// A VirtualDentry represents a node in a VFS tree, by combining a Dentry
// (which represents a node in a Filesystem's tree) and a Mount (which
// represents the Filesystem's position in a VFS mount tree).
diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go
index ecce6c69f..5e4611333 100644
--- a/pkg/sentry/watchdog/watchdog.go
+++ b/pkg/sentry/watchdog/watchdog.go
@@ -287,7 +287,9 @@ func (w *Watchdog) runTurn() {
if !ok {
// New stuck task detected.
//
- // TODO(b/65849403): Tasks blocked doing IO may be considered stuck in kernel.
+ // Note that tasks blocked doing IO may be considered stuck in kernel,
+ // unless they are surrounded b
+ // Task.UninterruptibleSleepStart/Finish.
tc = &offender{lastUpdateTime: lastUpdateTime}
stuckTasks.Increment()
newTaskFound = true
diff --git a/pkg/syncutil/BUILD b/pkg/syncutil/BUILD
index b06a90bef..cb1f41628 100644
--- a/pkg/syncutil/BUILD
+++ b/pkg/syncutil/BUILD
@@ -31,8 +31,6 @@ go_template(
go_library(
name = "syncutil",
srcs = [
- "downgradable_rwmutex_1_12_unsafe.go",
- "downgradable_rwmutex_1_13_unsafe.go",
"downgradable_rwmutex_unsafe.go",
"memmove_unsafe.go",
"norace_unsafe.go",
diff --git a/pkg/syncutil/downgradable_rwmutex_1_12_unsafe.go b/pkg/syncutil/downgradable_rwmutex_1_12_unsafe.go
deleted file mode 100644
index 7c6336e62..000000000
--- a/pkg/syncutil/downgradable_rwmutex_1_12_unsafe.go
+++ /dev/null
@@ -1,21 +0,0 @@
-// Copyright 2009 The Go Authors. All rights reserved.
-// Copyright 2019 The gVisor Authors.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// +build go1.12
-// +build !go1.13
-
-// TODO(b/133868570): Delete once Go 1.12 is no longer supported.
-
-package syncutil
-
-import _ "unsafe"
-
-//go:linkname runtimeSemrelease112 sync.runtime_Semrelease
-func runtimeSemrelease112(s *uint32, handoff bool)
-
-func runtimeSemrelease(s *uint32, handoff bool, skipframes int) {
- // 'skipframes' is only available starting from 1.13.
- runtimeSemrelease112(s, handoff)
-}
diff --git a/pkg/syncutil/downgradable_rwmutex_1_13_unsafe.go b/pkg/syncutil/downgradable_rwmutex_1_13_unsafe.go
deleted file mode 100644
index 3c3673119..000000000
--- a/pkg/syncutil/downgradable_rwmutex_1_13_unsafe.go
+++ /dev/null
@@ -1,16 +0,0 @@
-// Copyright 2009 The Go Authors. All rights reserved.
-// Copyright 2019 The gVisor Authors.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// +build go1.13
-// +build !go1.15
-
-// Check go:linkname function signatures when updating Go version.
-
-package syncutil
-
-import _ "unsafe"
-
-//go:linkname runtimeSemrelease sync.runtime_Semrelease
-func runtimeSemrelease(s *uint32, handoff bool, skipframes int)
diff --git a/pkg/syncutil/downgradable_rwmutex_unsafe.go b/pkg/syncutil/downgradable_rwmutex_unsafe.go
index 07feca402..51e11555d 100644
--- a/pkg/syncutil/downgradable_rwmutex_unsafe.go
+++ b/pkg/syncutil/downgradable_rwmutex_unsafe.go
@@ -3,7 +3,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build go1.12
+// +build go1.13
// +build !go1.15
// Check go:linkname function signatures when updating Go version.
@@ -27,6 +27,9 @@ import (
//go:linkname runtimeSemacquire sync.runtime_Semacquire
func runtimeSemacquire(s *uint32)
+//go:linkname runtimeSemrelease sync.runtime_Semrelease
+func runtimeSemrelease(s *uint32, handoff bool, skipframes int)
+
// DowngradableRWMutex is identical to sync.RWMutex, but adds the DowngradeLock
// method.
type DowngradableRWMutex struct {
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index a3485b35c..8392cb9e5 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -55,5 +55,8 @@ go_test(
"ndp_test.go",
],
embed = [":header"],
- deps = ["//pkg/tcpip"],
+ deps = [
+ "//pkg/tcpip",
+ "@com_github_google_go-cmp//cmp:go_default_library",
+ ],
)
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 0caa51c1e..5275b34d4 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -90,6 +90,18 @@ const (
// IPv6Any is the non-routable IPv6 "any" meta address. It is also
// known as the unspecified address.
IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+
+ // IIDSize is the size of an interface identifier (IID), in bytes, as
+ // defined by RFC 4291 section 2.5.1.
+ IIDSize = 8
+
+ // IIDOffsetInIPv6Address is the offset, in bytes, from the start
+ // of an IPv6 address to the beginning of the interface identifier
+ // (IID) for auto-generated addresses. That is, all bytes before
+ // the IIDOffsetInIPv6Address-th byte are the prefix bytes, and all
+ // bytes including and after the IIDOffsetInIPv6Address-th byte are
+ // for the IID.
+ IIDOffsetInIPv6Address = 8
)
// IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the
@@ -266,6 +278,28 @@ func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address {
return solicitedNodeMulticastPrefix + addr[len(addr)-3:]
}
+// EthernetAdddressToEUI64IntoBuf populates buf with a EUI-64 from a 48-bit
+// Ethernet/MAC address.
+//
+// buf MUST be at least 8 bytes.
+func EthernetAdddressToEUI64IntoBuf(linkAddr tcpip.LinkAddress, buf []byte) {
+ buf[0] = linkAddr[0] ^ 2
+ buf[1] = linkAddr[1]
+ buf[2] = linkAddr[2]
+ buf[3] = 0xFE
+ buf[4] = 0xFE
+ buf[5] = linkAddr[3]
+ buf[6] = linkAddr[4]
+ buf[7] = linkAddr[5]
+}
+
+// EthernetAddressToEUI64 computes an EUI-64 from a 48-bit Ethernet/MAC address.
+func EthernetAddressToEUI64(linkAddr tcpip.LinkAddress) [IIDSize]byte {
+ var buf [IIDSize]byte
+ EthernetAdddressToEUI64IntoBuf(linkAddr, buf[:])
+ return buf
+}
+
// LinkLocalAddr computes the default IPv6 link-local address from a link-layer
// (MAC) address.
func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address {
@@ -275,18 +309,11 @@ func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address {
// The conversion is very nearly:
// aa:bb:cc:dd:ee:ff => FE80::Aabb:ccFF:FEdd:eeff
// Note the capital A. The conversion aa->Aa involves a bit flip.
- lladdrb := [16]byte{
- 0: 0xFE,
- 1: 0x80,
- 8: linkAddr[0] ^ 2,
- 9: linkAddr[1],
- 10: linkAddr[2],
- 11: 0xFF,
- 12: 0xFE,
- 13: linkAddr[3],
- 14: linkAddr[4],
- 15: linkAddr[5],
+ lladdrb := [IPv6AddressSize]byte{
+ 0: 0xFE,
+ 1: 0x80,
}
+ EthernetAdddressToEUI64IntoBuf(linkAddr, lladdrb[IIDOffsetInIPv6Address:])
return tcpip.Address(lladdrb[:])
}
diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go
index 1ca6199ef..06e0bace2 100644
--- a/pkg/tcpip/header/ndp_options.go
+++ b/pkg/tcpip/header/ndp_options.go
@@ -17,6 +17,7 @@ package header
import (
"encoding/binary"
"errors"
+ "math"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -85,6 +86,23 @@ const (
// within an NDPPrefixInformation.
ndpPrefixInformationPrefixOffset = 14
+ // NDPRecursiveDNSServerOptionType is the type of the Recursive DNS
+ // Server option, as per RFC 8106 section 5.1.
+ NDPRecursiveDNSServerOptionType = 25
+
+ // ndpRecursiveDNSServerLifetimeOffset is the start of the 4-byte
+ // Lifetime field within an NDPRecursiveDNSServer.
+ ndpRecursiveDNSServerLifetimeOffset = 2
+
+ // ndpRecursiveDNSServerAddressesOffset is the start of the addresses
+ // for IPv6 Recursive DNS Servers within an NDPRecursiveDNSServer.
+ ndpRecursiveDNSServerAddressesOffset = 6
+
+ // minNDPRecursiveDNSServerLength is the minimum NDP Recursive DNS
+ // Server option's length field value when it contains at least one
+ // IPv6 address.
+ minNDPRecursiveDNSServerLength = 3
+
// lengthByteUnits is the multiplier factor for the Length field of an
// NDP option. That is, the length field for NDP options is in units of
// 8 octets, as per RFC 4861 section 4.6.
@@ -92,13 +110,13 @@ const (
)
var (
- // NDPPrefixInformationInfiniteLifetime is a value that represents
- // infinity for the Valid and Preferred Lifetime fields in a NDP Prefix
- // Information option. Its value is (2^32 - 1)s = 4294967295s
+ // NDPInfiniteLifetime is a value that represents infinity for the
+ // 4-byte lifetime fields found in various NDP options. Its value is
+ // (2^32 - 1)s = 4294967295s.
//
// This is a variable instead of a constant so that tests can change
// this value to a smaller value. It should only be modified by tests.
- NDPPrefixInformationInfiniteLifetime = time.Second * 4294967295
+ NDPInfiniteLifetime = time.Second * math.MaxUint32
)
// NDPOptionIterator is an iterator of NDPOption.
@@ -118,6 +136,7 @@ var (
ErrNDPOptBufExhausted = errors.New("Buffer unexpectedly exhausted")
ErrNDPOptZeroLength = errors.New("NDP option has zero-valued Length field")
ErrNDPOptMalformedBody = errors.New("NDP option has a malformed body")
+ ErrNDPInvalidLength = errors.New("NDP option's Length value is invalid as per relevant RFC")
)
// Next returns the next element in the backing NDPOptions, or true if we are
@@ -182,6 +201,22 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) {
}
return NDPPrefixInformation(body), false, nil
+
+ case NDPRecursiveDNSServerOptionType:
+ // RFC 8106 section 5.3.1 outlines that the RDNSS option
+ // must have a minimum length of 3 so it contains at
+ // least one IPv6 address.
+ if l < minNDPRecursiveDNSServerLength {
+ return nil, true, ErrNDPInvalidLength
+ }
+
+ opt := NDPRecursiveDNSServer(body)
+ if len(opt.Addresses()) == 0 {
+ return nil, true, ErrNDPOptMalformedBody
+ }
+
+ return opt, false, nil
+
default:
// We do not yet recognize the option, just skip for
// now. This is okay because RFC 4861 allows us to
@@ -434,7 +469,7 @@ func (o NDPPrefixInformation) AutonomousAddressConfigurationFlag() bool {
//
// Note, a value of 0 implies the prefix should not be considered as on-link,
// and a value of infinity/forever is represented by
-// NDPPrefixInformationInfiniteLifetime.
+// NDPInfiniteLifetime.
func (o NDPPrefixInformation) ValidLifetime() time.Duration {
// The field is the time in seconds, as per RFC 4861 section 4.6.2.
return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpPrefixInformationValidLifetimeOffset:]))
@@ -447,7 +482,7 @@ func (o NDPPrefixInformation) ValidLifetime() time.Duration {
//
// Note, a value of 0 implies that addresses generated from the prefix should
// no longer remain preferred, and a value of infinity is represented by
-// NDPPrefixInformationInfiniteLifetime.
+// NDPInfiniteLifetime.
//
// Also note that the value of this field MUST NOT exceed the Valid Lifetime
// field to avoid preferring addresses that are no longer valid, for the
@@ -476,3 +511,79 @@ func (o NDPPrefixInformation) Subnet() tcpip.Subnet {
}
return addrWithPrefix.Subnet()
}
+
+// NDPRecursiveDNSServer is the NDP Recursive DNS Server option, as defined by
+// RFC 8106 section 5.1.
+//
+// To make sure that the option meets its minimum length and does not end in the
+// middle of a DNS server's IPv6 address, the length of a valid
+// NDPRecursiveDNSServer must meet the following constraint:
+// (Length - ndpRecursiveDNSServerAddressesOffset) % IPv6AddressSize == 0
+type NDPRecursiveDNSServer []byte
+
+// Type returns the type of an NDP Recursive DNS Server option.
+//
+// Type implements NDPOption.Type.
+func (NDPRecursiveDNSServer) Type() uint8 {
+ return NDPRecursiveDNSServerOptionType
+}
+
+// Length implements NDPOption.Length.
+func (o NDPRecursiveDNSServer) Length() int {
+ return len(o)
+}
+
+// serializeInto implements NDPOption.serializeInto.
+func (o NDPRecursiveDNSServer) serializeInto(b []byte) int {
+ used := copy(b, o)
+
+ // Zero out the reserved bytes that are before the Lifetime field.
+ for i := 0; i < ndpRecursiveDNSServerLifetimeOffset; i++ {
+ b[i] = 0
+ }
+
+ return used
+}
+
+// Lifetime returns the length of time that the DNS server addresses
+// in this option may be used for name resolution.
+//
+// Note, a value of 0 implies the addresses should no longer be used,
+// and a value of infinity/forever is represented by NDPInfiniteLifetime.
+//
+// Lifetime may panic if o does not have enough bytes to hold the Lifetime
+// field.
+func (o NDPRecursiveDNSServer) Lifetime() time.Duration {
+ // The field is the time in seconds, as per RFC 8106 section 5.1.
+ return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpRecursiveDNSServerLifetimeOffset:]))
+}
+
+// Addresses returns the recursive DNS server IPv6 addresses that may be
+// used for name resolution.
+//
+// Note, some of the addresses returned MAY be link-local addresses.
+//
+// Addresses may panic if o does not hold valid IPv6 addresses.
+func (o NDPRecursiveDNSServer) Addresses() []tcpip.Address {
+ l := len(o)
+ if l < ndpRecursiveDNSServerAddressesOffset {
+ return nil
+ }
+
+ l -= ndpRecursiveDNSServerAddressesOffset
+ if l%IPv6AddressSize != 0 {
+ return nil
+ }
+
+ buf := o[ndpRecursiveDNSServerAddressesOffset:]
+ var addrs []tcpip.Address
+ for len(buf) > 0 {
+ addr := tcpip.Address(buf[:IPv6AddressSize])
+ if !IsV6UnicastAddress(addr) {
+ return nil
+ }
+ addrs = append(addrs, addr)
+ buf = buf[IPv6AddressSize:]
+ }
+ return addrs
+}
diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go
index ad6daafcd..2c439d70c 100644
--- a/pkg/tcpip/header/ndp_test.go
+++ b/pkg/tcpip/header/ndp_test.go
@@ -19,6 +19,7 @@ import (
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -369,6 +370,175 @@ func TestNDPPrefixInformationOption(t *testing.T) {
}
}
+func TestNDPRecursiveDNSServerOptionSerialize(t *testing.T) {
+ b := []byte{
+ 9, 8,
+ 1, 2, 4, 8,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ }
+ targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
+ expected := []byte{
+ 25, 3, 0, 0,
+ 1, 2, 4, 8,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ }
+ opts := NDPOptions(targetBuf)
+ serializer := NDPOptionsSerializer{
+ NDPRecursiveDNSServer(b),
+ }
+ if got, want := opts.Serialize(serializer), len(expected); got != want {
+ t.Errorf("got Serialize = %d, want = %d", got, want)
+ }
+ if !bytes.Equal(targetBuf, expected) {
+ t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected)
+ }
+
+ it, err := opts.Iter(true)
+ if err != nil {
+ t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ }
+
+ next, done, err := it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got := next.Type(); got != NDPRecursiveDNSServerOptionType {
+ t.Errorf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType)
+ }
+
+ opt, ok := next.(NDPRecursiveDNSServer)
+ if !ok {
+ t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next)
+ }
+ if got := opt.Type(); got != 25 {
+ t.Errorf("got Type = %d, want = 31", got)
+ }
+ if got := opt.Length(); got != 22 {
+ t.Errorf("got Length = %d, want = 22", got)
+ }
+ if got, want := opt.Lifetime(), 16909320*time.Second; got != want {
+ t.Errorf("got Lifetime = %s, want = %s", got, want)
+ }
+ want := []tcpip.Address{
+ "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
+ }
+ if got := opt.Addresses(); !cmp.Equal(got, want) {
+ t.Errorf("got Addresses = %v, want = %v", got, want)
+ }
+
+ // Iterator should not return anything else.
+ next, done, err = it.Next()
+ if err != nil {
+ t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
+ }
+}
+
+func TestNDPRecursiveDNSServerOption(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ lifetime time.Duration
+ addrs []tcpip.Address
+ }{
+ {
+ "Valid1Addr",
+ []byte{
+ 25, 3, 0, 0,
+ 0, 0, 0, 0,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ },
+ 0,
+ []tcpip.Address{
+ "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
+ },
+ },
+ {
+ "Valid2Addr",
+ []byte{
+ 25, 5, 0, 0,
+ 0, 0, 0, 0,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16,
+ },
+ 0,
+ []tcpip.Address{
+ "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
+ "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10",
+ },
+ },
+ {
+ "Valid3Addr",
+ []byte{
+ 25, 7, 0, 0,
+ 0, 0, 0, 0,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16,
+ 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 17,
+ },
+ 0,
+ []tcpip.Address{
+ "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
+ "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10",
+ "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x11",
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := NDPOptions(test.buf)
+ it, err := opts.Iter(true)
+ if err != nil {
+ t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ }
+
+ // Iterator should get our option.
+ next, done, err := it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got := next.Type(); got != NDPRecursiveDNSServerOptionType {
+ t.Fatalf("got Type %= %d, want = %d", got, NDPRecursiveDNSServerOptionType)
+ }
+
+ opt, ok := next.(NDPRecursiveDNSServer)
+ if !ok {
+ t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next)
+ }
+ if got := opt.Lifetime(); got != test.lifetime {
+ t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime)
+ }
+ if got := opt.Addresses(); !cmp.Equal(got, test.addrs) {
+ t.Errorf("got Addresses = %v, want = %v", got, test.addrs)
+ }
+
+ // Iterator should not return anything else.
+ next, done, err = it.Next()
+ if err != nil {
+ t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
+ }
+ })
+ }
+}
+
// TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions
// the iterator was returned for is malformed.
func TestNDPOptionsIterCheck(t *testing.T) {
@@ -473,6 +643,51 @@ func TestNDPOptionsIterCheck(t *testing.T) {
},
nil,
},
+ {
+ "InvalidRecursiveDNSServerCutsOffAddress",
+ []byte{
+ 25, 4, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
+ 0, 1, 2, 3, 4, 5, 6, 7,
+ },
+ ErrNDPOptMalformedBody,
+ },
+ {
+ "InvalidRecursiveDNSServerInvalidLengthField",
+ []byte{
+ 25, 2, 0, 0,
+ 0, 0, 0, 0,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8,
+ },
+ ErrNDPInvalidLength,
+ },
+ {
+ "RecursiveDNSServerTooSmall",
+ []byte{
+ 25, 1, 0, 0,
+ 0, 0, 0,
+ },
+ ErrNDPOptBufExhausted,
+ },
+ {
+ "RecursiveDNSServerMulticast",
+ []byte{
+ 25, 3, 0, 0,
+ 0, 0, 0, 0,
+ 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
+ },
+ ErrNDPOptMalformedBody,
+ },
+ {
+ "RecursiveDNSServerUnspecified",
+ []byte{
+ 25, 3, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ },
+ ErrNDPOptMalformedBody,
+ },
}
for _, test := range tests {
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
index 4839f0a65..e156b01f6 100644
--- a/pkg/tcpip/ports/BUILD
+++ b/pkg/tcpip/ports/BUILD
@@ -1,5 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index 30cea8996..6c5e19e8f 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -41,6 +41,30 @@ type portDescriptor struct {
port uint16
}
+// Flags represents the type of port reservation.
+//
+// +stateify savable
+type Flags struct {
+ // MostRecent represents UDP SO_REUSEADDR.
+ MostRecent bool
+
+ // LoadBalanced indicates SO_REUSEPORT.
+ //
+ // LoadBalanced takes precidence over MostRecent.
+ LoadBalanced bool
+}
+
+func (f Flags) bits() reuseFlag {
+ var rf reuseFlag
+ if f.MostRecent {
+ rf |= mostRecentFlag
+ }
+ if f.LoadBalanced {
+ rf |= loadBalancedFlag
+ }
+ return rf
+}
+
// PortManager manages allocating, reserving and releasing ports.
type PortManager struct {
mu sync.RWMutex
@@ -54,9 +78,59 @@ type PortManager struct {
hint uint32
}
+type reuseFlag int
+
+const (
+ mostRecentFlag reuseFlag = 1 << iota
+ loadBalancedFlag
+ nextFlag
+
+ flagMask = nextFlag - 1
+)
+
type portNode struct {
- reuse bool
- refs int
+ // refs stores the count for each possible flag combination.
+ refs [nextFlag]int
+}
+
+func (p portNode) totalRefs() int {
+ var total int
+ for _, r := range p.refs {
+ total += r
+ }
+ return total
+}
+
+// flagRefs returns the number of references with all specified flags.
+func (p portNode) flagRefs(flags reuseFlag) int {
+ var total int
+ for i, r := range p.refs {
+ if reuseFlag(i)&flags == flags {
+ total += r
+ }
+ }
+ return total
+}
+
+// allRefsHave returns if all references have all specified flags.
+func (p portNode) allRefsHave(flags reuseFlag) bool {
+ for i, r := range p.refs {
+ if reuseFlag(i)&flags == flags && r > 0 {
+ return false
+ }
+ }
+ return true
+}
+
+// intersectionRefs returns the set of flags shared by all references.
+func (p portNode) intersectionRefs() reuseFlag {
+ intersection := flagMask
+ for i, r := range p.refs {
+ if r > 0 {
+ intersection &= reuseFlag(i)
+ }
+ }
+ return intersection
}
// deviceNode is never empty. When it has no elements, it is removed from the
@@ -66,30 +140,44 @@ type deviceNode map[tcpip.NICID]portNode
// isAvailable checks whether binding is possible by device. If not binding to a
// device, check against all portNodes. If binding to a specific device, check
// against the unspecified device and the provided device.
-func (d deviceNode) isAvailable(reuse bool, bindToDevice tcpip.NICID) bool {
+//
+// If either of the port reuse flags is enabled on any of the nodes, all nodes
+// sharing a port must share at least one reuse flag. This matches Linux's
+// behavior.
+func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool {
+ flagBits := flags.bits()
if bindToDevice == 0 {
// Trying to binding all devices.
- if !reuse {
+ if flagBits == 0 {
// Can't bind because the (addr,port) is already bound.
return false
}
+ intersection := flagMask
for _, p := range d {
- if !p.reuse {
- // Can't bind because the (addr,port) was previously bound without reuse.
+ i := p.intersectionRefs()
+ intersection &= i
+ if intersection&flagBits == 0 {
+ // Can't bind because the (addr,port) was
+ // previously bound without reuse.
return false
}
}
return true
}
+ intersection := flagMask
+
if p, ok := d[0]; ok {
- if !reuse || !p.reuse {
+ intersection = p.intersectionRefs()
+ if intersection&flagBits == 0 {
return false
}
}
if p, ok := d[bindToDevice]; ok {
- if !reuse || !p.reuse {
+ i := p.intersectionRefs()
+ intersection &= i
+ if intersection&flagBits == 0 {
return false
}
}
@@ -103,12 +191,12 @@ type bindAddresses map[tcpip.Address]deviceNode
// isAvailable checks whether an IP address is available to bind to. If the
// address is the "any" address, check all other addresses. Otherwise, just
// check against the "any" address and the provided address.
-func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice tcpip.NICID) bool {
+func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice tcpip.NICID) bool {
if addr == anyIPAddress {
// If binding to the "any" address then check that there are no conflicts
// with all addresses.
for _, d := range b {
- if !d.isAvailable(reuse, bindToDevice) {
+ if !d.isAvailable(flags, bindToDevice) {
return false
}
}
@@ -117,14 +205,14 @@ func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice
// Check that there is no conflict with the "any" address.
if d, ok := b[anyIPAddress]; ok {
- if !d.isAvailable(reuse, bindToDevice) {
+ if !d.isAvailable(flags, bindToDevice) {
return false
}
}
// Check that this is no conflict with the provided address.
if d, ok := b[addr]; ok {
- if !d.isAvailable(reuse, bindToDevice) {
+ if !d.isAvailable(flags, bindToDevice) {
return false
}
}
@@ -190,17 +278,17 @@ func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p ui
}
// IsPortAvailable tests if the given port is available on all given protocols.
-func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
+func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool {
s.mu.Lock()
defer s.mu.Unlock()
- return s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice)
+ return s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice)
}
-func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
+func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool {
for _, network := range networks {
desc := portDescriptor{network, transport, port}
if addrs, ok := s.allocatedPorts[desc]; ok {
- if !addrs.isAvailable(addr, reuse, bindToDevice) {
+ if !addrs.isAvailable(addr, flags, bindToDevice) {
return false
}
}
@@ -212,14 +300,14 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb
// reserved by another endpoint. If port is zero, ReservePort will search for
// an unreserved ephemeral port and reserve it, returning its value in the
// "port" return value.
-func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) {
+func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) {
s.mu.Lock()
defer s.mu.Unlock()
// If a port is specified, just try to reserve it for all network
// protocols.
if port != 0 {
- if !s.reserveSpecificPort(networks, transport, addr, port, reuse, bindToDevice) {
+ if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice) {
return 0, tcpip.ErrPortInUse
}
return port, nil
@@ -227,15 +315,16 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp
// A port wasn't specified, so try to find one.
return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
- return s.reserveSpecificPort(networks, transport, addr, p, reuse, bindToDevice), nil
+ return s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice), nil
})
}
// reserveSpecificPort tries to reserve the given port on all given protocols.
-func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
- if !s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) {
+func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool {
+ if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) {
return false
}
+ flagBits := flags.bits()
// Reserve port on all network protocols.
for _, network := range networks {
@@ -250,12 +339,9 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
d = make(deviceNode)
m[addr] = d
}
- if n, ok := d[bindToDevice]; ok {
- n.refs++
- d[bindToDevice] = n
- } else {
- d[bindToDevice] = portNode{reuse: reuse, refs: 1}
- }
+ n := d[bindToDevice]
+ n.refs[flagBits]++
+ d[bindToDevice] = n
}
return true
@@ -263,10 +349,12 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
// ReleasePort releases the reservation on a port/IP combination so that it can
// be reserved by other endpoints.
-func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, bindToDevice tcpip.NICID) {
+func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) {
s.mu.Lock()
defer s.mu.Unlock()
+ flagBits := flags.bits()
+
for _, network := range networks {
desc := portDescriptor{network, transport, port}
if m, ok := s.allocatedPorts[desc]; ok {
@@ -278,9 +366,9 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp
if !ok {
continue
}
- n.refs--
+ n.refs[flagBits]--
d[bindToDevice] = n
- if n.refs == 0 {
+ if n.refs == [nextFlag]int{} {
delete(d, bindToDevice)
}
if len(d) == 0 {
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
index 19f4833fc..d6969d050 100644
--- a/pkg/tcpip/ports/ports_test.go
+++ b/pkg/tcpip/ports/ports_test.go
@@ -33,7 +33,7 @@ type portReserveTestAction struct {
port uint16
ip tcpip.Address
want *tcpip.Error
- reuse bool
+ flags Flags
release bool
device tcpip.NICID
}
@@ -50,7 +50,7 @@ func TestPortReservation(t *testing.T) {
{port: 80, ip: fakeIPAddress1, want: nil},
/* N.B. Order of tests matters! */
{port: 80, ip: anyIPAddress, want: tcpip.ErrPortInUse},
- {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, reuse: true},
+ {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}},
},
},
{
@@ -61,7 +61,7 @@ func TestPortReservation(t *testing.T) {
/* release fakeIPAddress, but anyIPAddress is still inuse */
{port: 22, ip: fakeIPAddress, release: true},
{port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
- {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, reuse: true},
+ {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}},
/* Release port 22 from any IP address, then try to reserve fake IP address on 22 */
{port: 22, ip: anyIPAddress, want: nil, release: true},
{port: 22, ip: fakeIPAddress, want: nil},
@@ -71,36 +71,36 @@ func TestPortReservation(t *testing.T) {
actions: []portReserveTestAction{
{port: 00, ip: fakeIPAddress, want: nil},
{port: 00, ip: fakeIPAddress, want: nil},
- {port: 00, ip: fakeIPAddress, reuse: true, want: nil},
+ {port: 00, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
},
}, {
tname: "bind to ip with reuseport",
actions: []portReserveTestAction{
- {port: 25, ip: fakeIPAddress, reuse: true, want: nil},
- {port: 25, ip: fakeIPAddress, reuse: true, want: nil},
+ {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 25, ip: fakeIPAddress, reuse: false, want: tcpip.ErrPortInUse},
- {port: 25, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse},
+ {port: 25, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
+ {port: 25, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
- {port: 25, ip: anyIPAddress, reuse: true, want: nil},
+ {port: 25, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
},
}, {
tname: "bind to inaddr any with reuseport",
actions: []portReserveTestAction{
- {port: 24, ip: anyIPAddress, reuse: true, want: nil},
- {port: 24, ip: anyIPAddress, reuse: true, want: nil},
+ {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, reuse: false, want: tcpip.ErrPortInUse},
+ {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, reuse: true, want: nil},
- {port: 24, ip: fakeIPAddress, release: true, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, release: true, want: nil},
- {port: 24, ip: anyIPAddress, release: true},
- {port: 24, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse},
+ {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true},
+ {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
- {port: 24, ip: anyIPAddress, release: true},
- {port: 24, ip: anyIPAddress, reuse: false, want: nil},
+ {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true},
+ {port: 24, ip: anyIPAddress, flags: Flags{}, want: nil},
},
}, {
tname: "bind twice with device fails",
@@ -125,88 +125,152 @@ func TestPortReservation(t *testing.T) {
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, want: nil},
{port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
{port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
},
}, {
tname: "bind with device",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, device: 123, want: nil},
{port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
{port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 789, want: nil},
{port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
},
}, {
- tname: "bind with reuse",
+ tname: "bind with reuseport",
actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil},
},
}, {
- tname: "binding with reuse and device",
+ tname: "binding with reuseport and device",
actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
- {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse},
},
}, {
- tname: "mixing reuse and not reuse by binding to device",
+ tname: "mixing reuseport and not reuseport by binding to device",
actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 456, want: nil},
- {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 999, want: nil},
},
}, {
- tname: "can't bind to 0 after mixing reuse and not reuse",
+ tname: "can't bind to 0 after mixing reuseport and not reuseport",
actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 456, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
},
}, {
tname: "bind and release",
actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
- {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil},
// Release the bind to device 0 and try again.
- {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil, release: true},
- {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil, release: true},
+ {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil},
},
}, {
- tname: "bind twice with reuse once",
+ tname: "bind twice with reuseport once",
actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
},
}, {
tname: "release an unreserved device",
actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil},
- {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil},
// The below don't exist.
- {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil, release: true},
- {port: 9999, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true},
+ {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil, release: true},
+ {port: 9999, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true},
// Release all.
- {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true},
- {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil, release: true},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true},
+ {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil, release: true},
+ },
+ }, {
+ tname: "bind with reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{MostRecent: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: nil},
+ },
+ }, {
+ tname: "bind twice with reuseaddr once",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with reuseaddr and reuseport",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ },
+ }, {
+ tname: "bind with reuseaddr and reuseport, and then reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with reuseaddr and reuseport, and then reuseport",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with reuseaddr and reuseport twice, and then reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
+ },
+ }, {
+ tname: "bind with reuseaddr and reuseport twice, and then reuseport",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ },
+ }, {
+ tname: "bind with reuseaddr, and then reuseaddr and reuseport",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with reuseport, and then reuseaddr and reuseport",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse},
},
},
} {
@@ -216,12 +280,12 @@ func TestPortReservation(t *testing.T) {
for _, test := range test.actions {
if test.release {
- pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.device)
+ pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device)
continue
}
- gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse, test.device)
+ gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device)
if err != test.want {
- t.Fatalf("ReservePort(.., .., %s, %d, %t, %d) = %v, want %v", test.ip, test.port, test.reuse, test.device, err, test.want)
+ t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d) = %v, want %v", test.ip, test.port, test.flags, test.device, err, test.want)
}
if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) {
t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral)
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index cfdd0496e..060a2e7c6 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -58,6 +58,14 @@ const (
// Default = true.
defaultDiscoverOnLinkPrefixes = true
+ // defaultAutoGenGlobalAddresses is the default configuration for
+ // whether or not to generate global IPv6 addresses in response to
+ // receiving a new Prefix Information option with its Autonomous
+ // Address AutoConfiguration flag set, as a host.
+ //
+ // Default = true.
+ defaultAutoGenGlobalAddresses = true
+
// minimumRetransmitTimer is the minimum amount of time to wait between
// sending NDP Neighbor solicitation messages. Note, RFC 4861 does
// not impose a minimum Retransmit Timer, but we do here to make sure
@@ -87,6 +95,24 @@ const (
//
// Max = 10.
MaxDiscoveredOnLinkPrefixes = 10
+
+ // validPrefixLenForAutoGen is the expected prefix length that an
+ // address can be generated for. Must be 64 bits as the interface
+ // identifier (IID) is 64 bits and an IPv6 address is 128 bits, so
+ // 128 - 64 = 64.
+ validPrefixLenForAutoGen = 64
+)
+
+var (
+ // MinPrefixInformationValidLifetimeForUpdate is the minimum Valid
+ // Lifetime to update the valid lifetime of a generated address by
+ // SLAAC.
+ //
+ // This is exported as a variable (instead of a constant) so tests
+ // can update it to a smaller value.
+ //
+ // Min = 2hrs.
+ MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour
)
// NDPDispatcher is the interface integrators of netstack must implement to
@@ -139,6 +165,33 @@ type NDPDispatcher interface {
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) []tcpip.Route
+
+ // OnAutoGenAddress will be called when a new prefix with its
+ // autonomous address-configuration flag set has been received and SLAAC
+ // has been performed. Implementations may prevent the stack from
+ // assigning the address to the NIC by returning false.
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool
+
+ // OnAutoGenAddressInvalidated will be called when an auto-generated
+ // address (as part of SLAAC) has been invalidated.
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix)
+
+ // OnRecursiveDNSServerOption will be called when an NDP option with
+ // recursive DNS servers has been received. Note, addrs may contain
+ // link-local addresses.
+ //
+ // It is up to the caller to use the DNS Servers only for their valid
+ // lifetime. OnRecursiveDNSServerOption may be called for new or
+ // already known DNS servers. If called with known DNS servers, their
+ // valid lifetimes must be refreshed to lifetime (it may be increased,
+ // decreased, or completely invalidated when lifetime = 0).
+ OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration)
}
// NDPConfigurations is the NDP configurations for the netstack.
@@ -168,6 +221,17 @@ type NDPConfigurations struct {
// will be discovered from Router Advertisements' Prefix Information
// option. This configuration is ignored if HandleRAs is false.
DiscoverOnLinkPrefixes bool
+
+ // AutoGenGlobalAddresses determines whether or not global IPv6
+ // addresses will be generated for a NIC in response to receiving a new
+ // Prefix Information option with its Autonomous Address
+ // AutoConfiguration flag set, as a host, as per RFC 4862 (SLAAC).
+ //
+ // Note, if an address was already generated for some unique prefix, as
+ // part of SLAAC, this option does not affect whether or not the
+ // lifetime(s) of the generated address changes; this option only
+ // affects the generation of new addresses as part of SLAAC.
+ AutoGenGlobalAddresses bool
}
// DefaultNDPConfigurations returns an NDPConfigurations populated with
@@ -179,6 +243,7 @@ func DefaultNDPConfigurations() NDPConfigurations {
HandleRAs: defaultHandleRAs,
DiscoverDefaultRouters: defaultDiscoverDefaultRouters,
DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes,
+ AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses,
}
}
@@ -210,6 +275,9 @@ type ndpState struct {
// The on-link prefixes discovered through Router Advertisements' Prefix
// Information option.
onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState
+
+ // The addresses generated by SLAAC.
+ autoGenAddresses map[tcpip.Address]autoGenAddressState
}
// dadState holds the Duplicate Address Detection timer and channel to signal
@@ -270,6 +338,32 @@ type onLinkPrefixState struct {
doNotInvalidate *bool
}
+// autoGenAddressState holds data associated with an address generated via
+// SLAAC.
+type autoGenAddressState struct {
+ invalidationTimer *time.Timer
+
+ // Used to signal the timer not to invalidate the SLAAC address (A) in
+ // a race condition (T1 is a goroutine that handles a PI for A and T2
+ // is the goroutine that handles A's invalidation timer firing):
+ // T1: Receive a new PI for A
+ // T1: Obtain the NIC's lock before processing the PI
+ // T2: A's invalidation timer fires, and gets blocked on obtaining the
+ // NIC's lock
+ // T1: Refreshes/extends A's lifetime & releases NIC's lock
+ // T2: Obtains NIC's lock & invalidates A immediately
+ //
+ // To resolve this, T1 will check to see if the timer already fired, and
+ // inform the timer using doNotInvalidate to not invalidate A, so that
+ // once T2 obtains the lock, it will see that it is set to true and do
+ // nothing further.
+ doNotInvalidate *bool
+
+ // Nonzero only when the address is not valid forever (invalidationTimer
+ // is not nil).
+ validUntil time.Time
+}
+
// startDuplicateAddressDetection performs Duplicate Address Detection.
//
// This function must only be called by IPv6 addresses that are currently
@@ -534,19 +628,21 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
// we do not check the iterator for errors on calls to Next.
it, _ := ra.Options().Iter(false)
for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() {
- switch opt.Type() {
- case header.NDPPrefixInformationType:
- if !ndp.configs.DiscoverOnLinkPrefixes {
+ switch opt := opt.(type) {
+ case header.NDPRecursiveDNSServer:
+ if ndp.nic.stack.ndpDisp == nil {
continue
}
- pi := opt.(header.NDPPrefixInformation)
+ ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), opt.Addresses(), opt.Lifetime())
- prefix := pi.Subnet()
+ case header.NDPPrefixInformation:
+ prefix := opt.Subnet()
// Is the prefix a link-local?
if header.IsV6LinkLocalAddress(prefix.ID()) {
- // ...Yes, skip as per RFC 4861 section 6.3.4.
+ // ...Yes, skip as per RFC 4861 section 6.3.4,
+ // and RFC 4862 section 5.5.3.b (for SLAAC).
continue
}
@@ -557,82 +653,13 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
continue
}
- if !pi.OnLinkFlag() {
- // Not on-link so don't "discover" it as an
- // on-link prefix.
- continue
- }
-
- prefixState, ok := ndp.onLinkPrefixes[prefix]
- vl := pi.ValidLifetime()
- switch {
- case !ok && vl == 0:
- // Don't know about this prefix but has a zero
- // valid lifetime, so just ignore.
- continue
-
- case !ok && vl != 0:
- // This is a new on-link prefix we are
- // discovering.
- //
- // Only remember it if we currently know about
- // less than MaxDiscoveredOnLinkPrefixes on-link
- // prefixes.
- if len(ndp.onLinkPrefixes) < MaxDiscoveredOnLinkPrefixes {
- ndp.rememberOnLinkPrefix(prefix, vl)
- }
- continue
-
- case ok && vl == 0:
- // We know about the on-link prefix, but it is
- // no longer to be considered on-link, so
- // invalidate it.
- ndp.invalidateOnLinkPrefix(prefix)
- continue
- }
-
- // This is an already discovered on-link prefix with a
- // new non-zero valid lifetime.
- // Update the invalidation timer.
- timer := prefixState.invalidationTimer
-
- if timer == nil && vl >= header.NDPPrefixInformationInfiniteLifetime {
- // Had infinite valid lifetime before and
- // continues to have an invalid lifetime. Do
- // nothing further.
- continue
- }
-
- if timer != nil && !timer.Stop() {
- // If we reach this point, then we know the
- // timer already fired after we took the NIC
- // lock. Inform the timer to not invalidate
- // the prefix once it obtains the lock as we
- // just got a new PI that refeshes its lifetime
- // to a non-zero value. See
- // onLinkPrefixState.doNotInvalidate for more
- // details.
- *prefixState.doNotInvalidate = true
+ if opt.OnLinkFlag() {
+ ndp.handleOnLinkPrefixInformation(opt)
}
- if vl >= header.NDPPrefixInformationInfiniteLifetime {
- // Prefix is now valid forever so we don't need
- // an invalidation timer.
- prefixState.invalidationTimer = nil
- ndp.onLinkPrefixes[prefix] = prefixState
- continue
+ if opt.AutonomousAddressConfigurationFlag() {
+ ndp.handleAutonomousPrefixInformation(opt)
}
-
- if timer != nil {
- // We already have a timer so just reset it to
- // expire after the new valid lifetime.
- timer.Reset(vl)
- continue
- }
-
- // We do not have a timer so just create a new one.
- prefixState.invalidationTimer = ndp.prefixInvalidationCallback(prefix, vl, prefixState.doNotInvalidate)
- ndp.onLinkPrefixes[prefix] = prefixState
}
// TODO(b/141556115): Do (MTU) Parameter Discovery.
@@ -734,7 +761,7 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration)
var timer *time.Timer
// Only create a timer if the lifetime is not infinite.
- if l < header.NDPPrefixInformationInfiniteLifetime {
+ if l < header.NDPInfiniteLifetime {
timer = ndp.prefixInvalidationCallback(prefix, l, &doNotInvalidate)
}
@@ -795,3 +822,345 @@ func (ndp *ndpState) prefixInvalidationCallback(prefix tcpip.Subnet, vl time.Dur
ndp.invalidateOnLinkPrefix(prefix)
})
}
+
+// handleOnLinkPrefixInformation handles a Prefix Information option with
+// its on-link flag set, as per RFC 4861 section 6.3.4.
+//
+// handleOnLinkPrefixInformation assumes that the prefix this pi is for is
+// not the link-local prefix and the on-link flag is set.
+//
+// The NIC that ndp belongs to and its associated stack MUST be locked.
+func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformation) {
+ prefix := pi.Subnet()
+ prefixState, ok := ndp.onLinkPrefixes[prefix]
+ vl := pi.ValidLifetime()
+
+ if !ok && vl == 0 {
+ // Don't know about this prefix but it has a zero valid
+ // lifetime, so just ignore.
+ return
+ }
+
+ if !ok && vl != 0 {
+ // This is a new on-link prefix we are discovering
+ //
+ // Only remember it if we currently know about less than
+ // MaxDiscoveredOnLinkPrefixes on-link prefixes.
+ if ndp.configs.DiscoverOnLinkPrefixes && len(ndp.onLinkPrefixes) < MaxDiscoveredOnLinkPrefixes {
+ ndp.rememberOnLinkPrefix(prefix, vl)
+ }
+ return
+ }
+
+ if ok && vl == 0 {
+ // We know about the on-link prefix, but it is
+ // no longer to be considered on-link, so
+ // invalidate it.
+ ndp.invalidateOnLinkPrefix(prefix)
+ return
+ }
+
+ // This is an already discovered on-link prefix with a
+ // new non-zero valid lifetime.
+ // Update the invalidation timer.
+ timer := prefixState.invalidationTimer
+
+ if timer == nil && vl >= header.NDPInfiniteLifetime {
+ // Had infinite valid lifetime before and
+ // continues to have an invalid lifetime. Do
+ // nothing further.
+ return
+ }
+
+ if timer != nil && !timer.Stop() {
+ // If we reach this point, then we know the timer alread fired
+ // after we took the NIC lock. Inform the timer to not
+ // invalidate the prefix once it obtains the lock as we just
+ // got a new PI that refreshes its lifetime to a non-zero value.
+ // See onLinkPrefixState.doNotInvalidate for more details.
+ *prefixState.doNotInvalidate = true
+ }
+
+ if vl >= header.NDPInfiniteLifetime {
+ // Prefix is now valid forever so we don't need
+ // an invalidation timer.
+ prefixState.invalidationTimer = nil
+ ndp.onLinkPrefixes[prefix] = prefixState
+ return
+ }
+
+ if timer != nil {
+ // We already have a timer so just reset it to
+ // expire after the new valid lifetime.
+ timer.Reset(vl)
+ return
+ }
+
+ // We do not have a timer so just create a new one.
+ prefixState.invalidationTimer = ndp.prefixInvalidationCallback(prefix, vl, prefixState.doNotInvalidate)
+ ndp.onLinkPrefixes[prefix] = prefixState
+}
+
+// handleAutonomousPrefixInformation handles a Prefix Information option with
+// its autonomous flag set, as per RFC 4862 section 5.5.3.
+//
+// handleAutonomousPrefixInformation assumes that the prefix this pi is for is
+// not the link-local prefix and the autonomous flag is set.
+//
+// The NIC that ndp belongs to and its associated stack MUST be locked.
+func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) {
+ vl := pi.ValidLifetime()
+ pl := pi.PreferredLifetime()
+
+ // If the preferred lifetime is greater than the valid lifetime,
+ // silently ignore the Prefix Information option, as per RFC 4862
+ // section 5.5.3.c.
+ if pl > vl {
+ return
+ }
+
+ prefix := pi.Subnet()
+
+ // Check if we already have an auto-generated address for prefix.
+ for _, ref := range ndp.nic.endpoints {
+ if ref.protocol != header.IPv6ProtocolNumber {
+ continue
+ }
+
+ if ref.configType != slaac {
+ continue
+ }
+
+ addr := ref.ep.ID().LocalAddress
+ refAddrWithPrefix := tcpip.AddressWithPrefix{Address: addr, PrefixLen: ref.ep.PrefixLen()}
+ if refAddrWithPrefix.Subnet() != prefix {
+ continue
+ }
+
+ //
+ // At this point, we know we are refreshing a SLAAC generated
+ // IPv6 address with the prefix, prefix. Do the work as outlined
+ // by RFC 4862 section 5.5.3.e.
+ //
+
+ addrState, ok := ndp.autoGenAddresses[addr]
+ if !ok {
+ panic(fmt.Sprintf("must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr))
+ }
+
+ // TODO(b/143713887): Handle deprecating auto-generated address
+ // after the preferred lifetime.
+
+ // As per RFC 4862 section 5.5.3.e, the valid lifetime of the
+ // address generated by SLAAC is as follows:
+ //
+ // 1) If the received Valid Lifetime is greater than 2 hours or
+ // greater than RemainingLifetime, set the valid lifetime of
+ // the address to the advertised Valid Lifetime.
+ //
+ // 2) If RemainingLifetime is less than or equal to 2 hours,
+ // ignore the advertised Valid Lifetime.
+ //
+ // 3) Otherwise, reset the valid lifetime of the address to 2
+ // hours.
+
+ // Handle the infinite valid lifetime separately as we do not
+ // keep a timer in this case.
+ if vl >= header.NDPInfiniteLifetime {
+ if addrState.invalidationTimer != nil {
+ // Valid lifetime was finite before, but now it
+ // is valid forever.
+ if !addrState.invalidationTimer.Stop() {
+ *addrState.doNotInvalidate = true
+ }
+ addrState.invalidationTimer = nil
+ addrState.validUntil = time.Time{}
+ ndp.autoGenAddresses[addr] = addrState
+ }
+
+ return
+ }
+
+ var effectiveVl time.Duration
+ var rl time.Duration
+
+ // If the address was originally set to be valid forever,
+ // assume the remaining time to be the maximum possible value.
+ if addrState.invalidationTimer == nil {
+ rl = header.NDPInfiniteLifetime
+ } else {
+ rl = time.Until(addrState.validUntil)
+ }
+
+ if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl {
+ effectiveVl = vl
+ } else if rl <= MinPrefixInformationValidLifetimeForUpdate {
+ ndp.autoGenAddresses[addr] = addrState
+ return
+ } else {
+ effectiveVl = MinPrefixInformationValidLifetimeForUpdate
+ }
+
+ if addrState.invalidationTimer == nil {
+ addrState.invalidationTimer = ndp.autoGenAddrInvalidationTimer(addr, effectiveVl, addrState.doNotInvalidate)
+ } else {
+ if !addrState.invalidationTimer.Stop() {
+ *addrState.doNotInvalidate = true
+ }
+ addrState.invalidationTimer.Reset(effectiveVl)
+ }
+
+ addrState.validUntil = time.Now().Add(effectiveVl)
+ ndp.autoGenAddresses[addr] = addrState
+ return
+ }
+
+ // We do not already have an address within the prefix, prefix. Do the
+ // work as outlined by RFC 4862 section 5.5.3.d if n is configured
+ // to auto-generated global addresses by SLAAC.
+
+ // Are we configured to auto-generate new global addresses?
+ if !ndp.configs.AutoGenGlobalAddresses {
+ return
+ }
+
+ // If we do not already have an address for this prefix and the valid
+ // lifetime is 0, no need to do anything further, as per RFC 4862
+ // section 5.5.3.d.
+ if vl == 0 {
+ return
+ }
+
+ // Make sure the prefix is valid (as far as its length is concerned) to
+ // generate a valid IPv6 address from an interface identifier (IID), as
+ // per RFC 4862 sectiion 5.5.3.d.
+ if prefix.Prefix() != validPrefixLenForAutoGen {
+ return
+ }
+
+ // Only attempt to generate an interface-specific IID if we have a valid
+ // link address.
+ //
+ // TODO(b/141011931): Validate a LinkEndpoint's link address
+ // (provided by LinkEndpoint.LinkAddress) before reaching this
+ // point.
+ linkAddr := ndp.nic.linkEP.LinkAddress()
+ if !header.IsValidUnicastEthernetAddress(linkAddr) {
+ return
+ }
+
+ // Generate an address within prefix from the EUI-64 of ndp's NIC's
+ // Ethernet MAC address.
+ addrBytes := make([]byte, header.IPv6AddressSize)
+ copy(addrBytes[:header.IIDOffsetInIPv6Address], prefix.ID()[:header.IIDOffsetInIPv6Address])
+ header.EthernetAdddressToEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:])
+ addr := tcpip.Address(addrBytes)
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: addr,
+ PrefixLen: validPrefixLenForAutoGen,
+ }
+
+ // If the nic already has this address, do nothing further.
+ if ndp.nic.hasPermanentAddrLocked(addr) {
+ return
+ }
+
+ // Inform the integrator that we have a new SLAAC address.
+ if ndp.nic.stack.ndpDisp == nil {
+ return
+ }
+ if !ndp.nic.stack.ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addrWithPrefix) {
+ // Informed by the integrator not to add the address.
+ return
+ }
+
+ if _, err := ndp.nic.addAddressLocked(tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix,
+ }, FirstPrimaryEndpoint, permanent, slaac); err != nil {
+ panic(err)
+ }
+
+ // Setup the timers to deprecate and invalidate this newly generated
+ // address.
+
+ // TODO(b/143713887): Handle deprecating auto-generated addresses
+ // after the preferred lifetime.
+
+ var doNotInvalidate bool
+ var vTimer *time.Timer
+ if vl < header.NDPInfiniteLifetime {
+ vTimer = ndp.autoGenAddrInvalidationTimer(addr, vl, &doNotInvalidate)
+ }
+
+ ndp.autoGenAddresses[addr] = autoGenAddressState{
+ invalidationTimer: vTimer,
+ doNotInvalidate: &doNotInvalidate,
+ validUntil: time.Now().Add(vl),
+ }
+}
+
+// invalidateAutoGenAddress invalidates an auto-generated address.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) invalidateAutoGenAddress(addr tcpip.Address) {
+ if !ndp.cleanupAutoGenAddrResourcesAndNotify(addr) {
+ return
+ }
+
+ ndp.nic.removePermanentAddressLocked(addr)
+}
+
+// cleanupAutoGenAddrResourcesAndNotify cleans up an invalidated auto-generated
+// address's resources from ndp. If the stack has an NDP dispatcher, it will
+// be notified that addr has been invalidated.
+//
+// Returns true if ndp had resources for addr to cleanup.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bool {
+ state, ok := ndp.autoGenAddresses[addr]
+
+ if !ok {
+ return false
+ }
+
+ if state.invalidationTimer != nil {
+ state.invalidationTimer.Stop()
+ state.invalidationTimer = nil
+ *state.doNotInvalidate = true
+ }
+
+ state.doNotInvalidate = nil
+
+ delete(ndp.autoGenAddresses, addr)
+
+ if ndp.nic.stack.ndpDisp != nil {
+ ndp.nic.stack.ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), tcpip.AddressWithPrefix{
+ Address: addr,
+ PrefixLen: validPrefixLenForAutoGen,
+ })
+ }
+
+ return true
+}
+
+// autoGenAddrInvalidationTimer returns a new invalidation timer for an
+// auto-generated address that fires after vl.
+//
+// doNotInvalidate is used to inform the timer when it fires at the same time
+// that an auto-generated address's valid lifetime gets refreshed. See
+// autoGenAddrState.doNotInvalidate for more details.
+func (ndp *ndpState) autoGenAddrInvalidationTimer(addr tcpip.Address, vl time.Duration, doNotInvalidate *bool) *time.Timer {
+ return time.AfterFunc(vl, func() {
+ ndp.nic.mu.Lock()
+ defer ndp.nic.mu.Unlock()
+
+ if *doNotInvalidate {
+ *doNotInvalidate = false
+ return
+ }
+
+ ndp.invalidateAutoGenAddress(addr)
+ })
+}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 5b901f947..8d811eb8e 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -38,7 +38,7 @@ const (
linkAddr1 = "\x02\x02\x03\x04\x05\x06"
linkAddr2 = "\x02\x02\x03\x04\x05\x07"
linkAddr3 = "\x02\x02\x03\x04\x05\x08"
- defaultTimeout = 250 * time.Millisecond
+ defaultTimeout = 100 * time.Millisecond
)
var (
@@ -47,6 +47,31 @@ var (
llAddr3 = header.LinkLocalAddr(linkAddr3)
)
+// prefixSubnetAddr returns a prefix (Address + Length), the prefix's equivalent
+// tcpip.Subnet, and an address where the lower half of the address is composed
+// of the EUI-64 of linkAddr if it is a valid unicast ethernet address.
+func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWithPrefix, tcpip.Subnet, tcpip.AddressWithPrefix) {
+ prefixBytes := []byte{1, 2, 3, 4, 5, 6, 7, 8 + offset, 0, 0, 0, 0, 0, 0, 0, 0}
+ prefix := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(prefixBytes),
+ PrefixLen: 64,
+ }
+
+ subnet := prefix.Subnet()
+
+ var addr tcpip.AddressWithPrefix
+ if header.IsValidUnicastEthernetAddress(linkAddr) {
+ addrBytes := []byte(subnet.ID())
+ header.EthernetAdddressToEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:])
+ addr = tcpip.AddressWithPrefix{
+ Address: tcpip.Address(addrBytes),
+ PrefixLen: 64,
+ }
+ }
+
+ return prefix, subnet, addr
+}
+
// TestDADDisabled tests that an address successfully resolves immediately
// when DAD is not enabled (the default for an empty stack.Options).
func TestDADDisabled(t *testing.T) {
@@ -103,6 +128,29 @@ type ndpPrefixEvent struct {
discovered bool
}
+type ndpAutoGenAddrEventType int
+
+const (
+ newAddr ndpAutoGenAddrEventType = iota
+ invalidatedAddr
+)
+
+type ndpAutoGenAddrEvent struct {
+ nicID tcpip.NICID
+ addr tcpip.AddressWithPrefix
+ eventType ndpAutoGenAddrEventType
+}
+
+type ndpRDNSS struct {
+ addrs []tcpip.Address
+ lifetime time.Duration
+}
+
+type ndpRDNSSEvent struct {
+ nicID tcpip.NICID
+ rdnss ndpRDNSS
+}
+
var _ stack.NDPDispatcher = (*ndpDispatcher)(nil)
// ndpDispatcher implements NDPDispatcher so tests can know when various NDP
@@ -113,6 +161,8 @@ type ndpDispatcher struct {
rememberRouter bool
prefixC chan ndpPrefixEvent
rememberPrefix bool
+ autoGenAddrC chan ndpAutoGenAddrEvent
+ rdnssC chan ndpRDNSSEvent
routeTable []tcpip.Route
}
@@ -211,7 +261,7 @@ func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpi
}
}
- rt := make([]tcpip.Route, 0)
+ var rt []tcpip.Route
exclude := tcpip.Route{
Destination: prefix,
NIC: nicID,
@@ -226,6 +276,40 @@ func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpi
return rt
}
+func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool {
+ if n.autoGenAddrC != nil {
+ n.autoGenAddrC <- ndpAutoGenAddrEvent{
+ nicID,
+ addr,
+ newAddr,
+ }
+ }
+ return true
+}
+
+func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) {
+ if n.autoGenAddrC != nil {
+ n.autoGenAddrC <- ndpAutoGenAddrEvent{
+ nicID,
+ addr,
+ invalidatedAddr,
+ }
+ }
+}
+
+// Implements stack.NDPDispatcher.OnRecursiveDNSServerOption.
+func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) {
+ if n.rdnssC != nil {
+ n.rdnssC <- ndpRDNSSEvent{
+ nicID,
+ ndpRDNSS{
+ addrs,
+ lifetime,
+ },
+ }
+ }
+}
+
// TestDADResolve tests that an address successfully resolves after performing
// DAD for various values of DupAddrDetectTransmits and RetransmitTimer.
// Included in the subtests is a test to make sure that an invalid
@@ -247,6 +331,8 @@ func TestDADResolve(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent),
}
@@ -781,16 +867,33 @@ func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer {
//
// Note, raBufWithPI does not populate any of the RA fields other than the
// Router Lifetime.
-func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink bool, vl uint32) tcpip.PacketBuffer {
+func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) tcpip.PacketBuffer {
flags := uint8(0)
if onLink {
- flags |= 128
+ // The OnLink flag is the 7th bit in the flags byte.
+ flags |= 1 << 7
+ }
+ if auto {
+ // The Address Auto-Configuration flag is the 6th bit in the
+ // flags byte.
+ flags |= 1 << 6
}
+ // A valid header.NDPPrefixInformation must be 30 bytes.
buf := [30]byte{}
+ // The first byte in a header.NDPPrefixInformation is the Prefix Length
+ // field.
buf[0] = uint8(prefix.PrefixLen)
+ // The 2nd byte within a header.NDPPrefixInformation is the Flags field.
buf[1] = flags
+ // The Valid Lifetime field starts after the 2nd byte within a
+ // header.NDPPrefixInformation.
binary.BigEndian.PutUint32(buf[2:], vl)
+ // The Preferred Lifetime field starts after the 6th byte within a
+ // header.NDPPrefixInformation.
+ binary.BigEndian.PutUint32(buf[6:], pl)
+ // The Prefix Address field starts after the 14th byte within a
+ // header.NDPPrefixInformation.
copy(buf[14:], prefix.Address)
return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{
header.NDPPrefixInformation(buf[:]),
@@ -800,6 +903,8 @@ func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, on
// TestNoRouterDiscovery tests that router discovery will not be performed if
// configured not to.
func TestNoRouterDiscovery(t *testing.T) {
+ t.Parallel()
+
// Being configured to discover routers means handle and
// discover are set to true and forwarding is set to false.
// This tests all possible combinations of the configurations,
@@ -812,6 +917,8 @@ func TestNoRouterDiscovery(t *testing.T) {
forwarding := i&4 == 0
t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) {
+ t.Parallel()
+
ndpDisp := ndpDispatcher{
routerC: make(chan ndpRouterEvent, 10),
}
@@ -844,6 +951,8 @@ func TestNoRouterDiscovery(t *testing.T) {
// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not
// remember a discovered router when the dispatcher asks it not to.
func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
+ t.Parallel()
+
ndpDisp := ndpDispatcher{
routerC: make(chan ndpRouterEvent, 10),
}
@@ -909,6 +1018,8 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
}
func TestRouterDiscovery(t *testing.T) {
+ t.Parallel()
+
ndpDisp := ndpDispatcher{
routerC: make(chan ndpRouterEvent, 10),
rememberRouter: true,
@@ -1040,6 +1151,8 @@ func TestRouterDiscovery(t *testing.T) {
// TestRouterDiscoveryMaxRouters tests that only
// stack.MaxDiscoveredDefaultRouters discovered routers are remembered.
func TestRouterDiscoveryMaxRouters(t *testing.T) {
+ t.Parallel()
+
ndpDisp := ndpDispatcher{
routerC: make(chan ndpRouterEvent, 10),
rememberRouter: true,
@@ -1104,6 +1217,8 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
// TestNoPrefixDiscovery tests that prefix discovery will not be performed if
// configured not to.
func TestNoPrefixDiscovery(t *testing.T) {
+ t.Parallel()
+
prefix := tcpip.AddressWithPrefix{
Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"),
PrefixLen: 64,
@@ -1121,6 +1236,8 @@ func TestNoPrefixDiscovery(t *testing.T) {
forwarding := i&4 == 0
t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) {
+ t.Parallel()
+
ndpDisp := ndpDispatcher{
prefixC: make(chan ndpPrefixEvent, 10),
}
@@ -1140,7 +1257,7 @@ func TestNoPrefixDiscovery(t *testing.T) {
}
// Rx an RA with prefix with non-zero lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, 10))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0))
select {
case <-ndpDisp.prefixC:
@@ -1154,11 +1271,9 @@ func TestNoPrefixDiscovery(t *testing.T) {
// TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not
// remember a discovered on-link prefix when the dispatcher asks it not to.
func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
- prefix := tcpip.AddressWithPrefix{
- Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"),
- PrefixLen: 64,
- }
- subnet := prefix.Subnet()
+ t.Parallel()
+
+ prefix, subnet, _ := prefixSubnetAddr(0, "")
ndpDisp := ndpDispatcher{
prefixC: make(chan ndpPrefixEvent, 10),
@@ -1189,7 +1304,7 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
// Rx an RA with prefix with a short lifetime.
const lifetime = 1
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, lifetime))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetime, 0))
select {
case r := <-ndpDisp.prefixC:
if r.nicID != 1 {
@@ -1226,21 +1341,11 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
}
func TestPrefixDiscovery(t *testing.T) {
- prefix1 := tcpip.AddressWithPrefix{
- Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"),
- PrefixLen: 64,
- }
- prefix2 := tcpip.AddressWithPrefix{
- Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x09\x00\x00\x00\x00\x00\x00\x00\x00"),
- PrefixLen: 64,
- }
- prefix3 := tcpip.AddressWithPrefix{
- Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x09\x0a\x00\x00\x00\x00\x00\x00\x00"),
- PrefixLen: 72,
- }
- subnet1 := prefix1.Subnet()
- subnet2 := prefix2.Subnet()
- subnet3 := prefix3.Subnet()
+ t.Parallel()
+
+ prefix1, subnet1, _ := prefixSubnetAddr(0, "")
+ prefix2, subnet2, _ := prefixSubnetAddr(1, "")
+ prefix3, subnet3, _ := prefixSubnetAddr(2, "")
ndpDisp := ndpDispatcher{
prefixC: make(chan ndpPrefixEvent, 10),
@@ -1281,7 +1386,7 @@ func TestPrefixDiscovery(t *testing.T) {
// Receive an RA with prefix1 in an NDP Prefix Information option (PI)
// with zero valid lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, 0))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0))
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly discovered a prefix with 0 lifetime")
@@ -1290,7 +1395,7 @@ func TestPrefixDiscovery(t *testing.T) {
// Receive an RA with prefix1 in an NDP Prefix Information option (PI)
// with non-zero lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, 100))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0))
waitForEvent(subnet1, true, defaultTimeout)
// Should have added a device route for subnet1 through the nic.
@@ -1299,7 +1404,7 @@ func TestPrefixDiscovery(t *testing.T) {
}
// Receive an RA with prefix2 in a PI.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, 100))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0))
waitForEvent(subnet2, true, defaultTimeout)
// Should have added a device route for subnet2 through the nic.
@@ -1308,7 +1413,7 @@ func TestPrefixDiscovery(t *testing.T) {
}
// Receive an RA with prefix3 in a PI.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, 100))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0))
waitForEvent(subnet3, true, defaultTimeout)
// Should have added a device route for subnet3 through the nic.
@@ -1317,7 +1422,7 @@ func TestPrefixDiscovery(t *testing.T) {
}
// Receive an RA with prefix1 in a PI with lifetime = 0.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, 0))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0))
waitForEvent(subnet1, false, defaultTimeout)
// Should have removed the device route for subnet1 through the nic.
@@ -1327,7 +1432,7 @@ func TestPrefixDiscovery(t *testing.T) {
// Receive an RA with prefix2 in a PI with lesser lifetime.
lifetime := uint32(2)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, lifetime))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0))
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly received prefix event when updating lifetime")
@@ -1349,7 +1454,7 @@ func TestPrefixDiscovery(t *testing.T) {
}
// Receive RA to invalidate prefix3.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, 0))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0))
waitForEvent(subnet3, false, defaultTimeout)
// Should not have any routes.
@@ -1364,10 +1469,10 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
// invalidate the prefix.
const testInfiniteLifetimeSeconds = 2
const testInfiniteLifetime = testInfiniteLifetimeSeconds * time.Second
- saved := header.NDPPrefixInformationInfiniteLifetime
- header.NDPPrefixInformationInfiniteLifetime = testInfiniteLifetime
+ saved := header.NDPInfiniteLifetime
+ header.NDPInfiniteLifetime = testInfiniteLifetime
defer func() {
- header.NDPPrefixInformationInfiniteLifetime = saved
+ header.NDPInfiniteLifetime = saved
}()
prefix := tcpip.AddressWithPrefix{
@@ -1415,7 +1520,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
// Receive an RA with prefix in an NDP Prefix Information option (PI)
// with infinite valid lifetime which should not get invalidated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, testInfiniteLifetimeSeconds))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0))
waitForEvent(true, defaultTimeout)
select {
case <-ndpDisp.prefixC:
@@ -1425,16 +1530,16 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
// Receive an RA with finite lifetime.
// The prefix should get invalidated after 1s.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, testInfiniteLifetimeSeconds-1))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0))
waitForEvent(false, testInfiniteLifetime)
// Receive an RA with finite lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, testInfiniteLifetimeSeconds-1))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0))
waitForEvent(true, defaultTimeout)
// Receive an RA with prefix with an infinite lifetime.
// The prefix should not be invalidated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, testInfiniteLifetimeSeconds))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0))
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
@@ -1443,7 +1548,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
// Receive an RA with a prefix with a lifetime value greater than the
// set infinite lifetime value.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, testInfiniteLifetimeSeconds+1))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0))
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
@@ -1452,13 +1557,15 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
// Receive an RA with 0 lifetime.
// The prefix should get invalidated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, 0))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 0, 0))
waitForEvent(false, defaultTimeout)
}
// TestPrefixDiscoveryMaxRouters tests that only
// stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered.
func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
+ t.Parallel()
+
ndpDisp := ndpDispatcher{
prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3),
rememberPrefix: true,
@@ -1537,3 +1644,606 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
t.Fatalf("got GetRouteTable = %v, want = %v", got, expectedRt)
}
}
+
+// Checks to see if list contains an IPv6 address, item.
+func contains(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool {
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: item,
+ }
+
+ for _, i := range list {
+ if i == protocolAddress {
+ return true
+ }
+ }
+
+ return false
+}
+
+// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to.
+func TestNoAutoGenAddr(t *testing.T) {
+ t.Parallel()
+
+ prefix, _, _ := prefixSubnetAddr(0, "")
+
+ // Being configured to auto-generate addresses means handle and
+ // autogen are set to true and forwarding is set to false.
+ // This tests all possible combinations of the configurations,
+ // except for the configuration where handle = true, autogen =
+ // true and forwarding = false (the required configuration to do
+ // SLAAC) - that will done in other tests.
+ for i := 0; i < 7; i++ {
+ handle := i&1 != 0
+ autogen := i&2 != 0
+ forwarding := i&4 == 0
+
+ t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10),
+ }
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: handle,
+ AutoGenGlobalAddresses: autogen,
+ },
+ NDPDisp: &ndpDisp,
+ })
+ s.SetForwarding(forwarding)
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Rx an RA with prefix with non-zero lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0))
+
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address when configured not to")
+ case <-time.After(defaultTimeout):
+ }
+ })
+ }
+}
+
+// TestAutoGenAddr tests that an address is properly generated and invalidated
+// when configured to do so.
+func TestAutoGenAddr(t *testing.T) {
+ const newMinVL = 2
+ newMinVLDuration := newMinVL * time.Second
+ saved := stack.MinPrefixInformationValidLifetimeForUpdate
+ defer func() {
+ stack.MinPrefixInformationValidLifetimeForUpdate = saved
+ }()
+ stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
+
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10),
+ }
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ waitForEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
+ t.Helper()
+
+ select {
+ case r := <-ndpDisp.autoGenAddrC:
+ if r.nicID != 1 {
+ t.Fatalf("got r.nicID = %d, want = 1", r.nicID)
+ }
+ if r.addr != addr {
+ t.Fatalf("got r.addr = %s, want = %s", r.addr, addr)
+ }
+ if r.eventType != eventType {
+ t.Fatalf("got r.eventType = %v, want = %v", r.eventType, eventType)
+ }
+ case <-time.After(timeout):
+ t.Fatal("timeout waiting for addr auto gen event")
+ }
+ }
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with zero valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address with 0 lifetime")
+ case <-time.After(defaultTimeout):
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ waitForEvent(addr1, newAddr, defaultTimeout)
+ if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ t.Fatalf("Should have %s in the list of addresses", addr1)
+ }
+
+ // Receive an RA with prefix2 in an NDP Prefix Information option (PI)
+ // with preferred lifetime > valid lifetime
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime")
+ case <-time.After(defaultTimeout):
+ }
+
+ // Receive an RA with prefix2 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ waitForEvent(addr2, newAddr, defaultTimeout)
+ if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ t.Fatalf("Should have %s in the list of addresses", addr1)
+ }
+ if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) {
+ t.Fatalf("Should have %s in the list of addresses", addr2)
+ }
+
+ // Refresh valid lifetime for addr of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix")
+ case <-time.After(defaultTimeout):
+ }
+
+ // Wait for addr of prefix1 to be invalidated.
+ waitForEvent(addr1, invalidatedAddr, newMinVLDuration+defaultTimeout)
+ if contains(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ t.Fatalf("Should not have %s in the list of addresses", addr1)
+ }
+ if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) {
+ t.Fatalf("Should have %s in the list of addresses", addr2)
+ }
+}
+
+// TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an
+// auto-generated address only gets updated when required to, as specified in
+// RFC 4862 section 5.5.3.e.
+func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
+ const infiniteVL = 4294967295
+ const newMinVL = 5
+ saved := stack.MinPrefixInformationValidLifetimeForUpdate
+ defer func() {
+ stack.MinPrefixInformationValidLifetimeForUpdate = saved
+ }()
+ stack.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second
+
+ prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
+
+ tests := []struct {
+ name string
+ ovl uint32
+ nvl uint32
+ evl uint32
+ }{
+ // Should update the VL to the minimum VL for updating if the
+ // new VL is less than newMinVL but was originally greater than
+ // it.
+ {
+ "LargeVLToVLLessThanMinVLForUpdate",
+ 9999,
+ 1,
+ newMinVL,
+ },
+ {
+ "LargeVLTo0",
+ 9999,
+ 0,
+ newMinVL,
+ },
+ {
+ "InfiniteVLToVLLessThanMinVLForUpdate",
+ infiniteVL,
+ 1,
+ newMinVL,
+ },
+ {
+ "InfiniteVLTo0",
+ infiniteVL,
+ 0,
+ newMinVL,
+ },
+
+ // Should not update VL if original VL was less than newMinVL
+ // and the new VL is also less than newMinVL.
+ {
+ "ShouldNotUpdateWhenBothOldAndNewAreLessThanMinVLForUpdate",
+ newMinVL - 1,
+ newMinVL - 3,
+ newMinVL - 1,
+ },
+
+ // Should take the new VL if the new VL is greater than the
+ // remaining time or is greater than newMinVL.
+ {
+ "MorethanMinVLToLesserButStillMoreThanMinVLForUpdate",
+ newMinVL + 5,
+ newMinVL + 3,
+ newMinVL + 3,
+ },
+ {
+ "SmallVLToGreaterVLButStillLessThanMinVLForUpdate",
+ newMinVL - 3,
+ newMinVL - 1,
+ newMinVL - 1,
+ },
+ {
+ "SmallVLToGreaterVLThatIsMoreThaMinVLForUpdate",
+ newMinVL - 3,
+ newMinVL + 1,
+ newMinVL + 1,
+ },
+ }
+
+ const delta = 500 * time.Millisecond
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10),
+ }
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Receive an RA with prefix with initial VL, test.ovl.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0))
+ select {
+ case r := <-ndpDisp.autoGenAddrC:
+ if r.nicID != 1 {
+ t.Fatalf("got r.nicID = %d, want = 1", r.nicID)
+ }
+ if r.addr != addr {
+ t.Fatalf("got r.addr = %s, want = %s", r.addr, addr)
+ }
+ if r.eventType != newAddr {
+ t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr)
+ }
+ case <-time.After(defaultTimeout):
+ t.Fatal("timeout waiting for addr auto gen event")
+ }
+
+ // Receive an new RA with prefix with new VL, test.nvl.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0))
+
+ //
+ // Validate that the VL for the address got set to
+ // test.evl.
+ //
+
+ // Make sure we do not get any invalidation events
+ // until atleast 500ms (delta) before test.evl.
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly received an auto gen addr event")
+ case <-time.After(time.Duration(test.evl)*time.Second - delta):
+ }
+
+ // Wait for another second (2x delta), but now we expect
+ // the invalidation event.
+ select {
+ case r := <-ndpDisp.autoGenAddrC:
+ if r.nicID != 1 {
+ t.Fatalf("got r.nicID = %d, want = 1", r.nicID)
+ }
+ if r.addr != addr {
+ t.Fatalf("got r.addr = %s, want = %s", r.addr, addr)
+ }
+ if r.eventType != invalidatedAddr {
+ t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr)
+ }
+ case <-time.After(2 * delta):
+ t.Fatal("timeout waiting for addr auto gen event")
+ }
+ })
+ }
+}
+
+// TestAutoGenAddrRemoval tests that when auto-generated addresses are removed
+// by the user, its resources will be cleaned up and an invalidation event will
+// be sent to the integrator.
+func TestAutoGenAddrRemoval(t *testing.T) {
+ t.Parallel()
+
+ prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10),
+ }
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Receive an RA with prefix with its valid lifetime = lifetime.
+ const lifetime = 5
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetime, 0))
+ select {
+ case r := <-ndpDisp.autoGenAddrC:
+ if r.nicID != 1 {
+ t.Fatalf("got r.nicID = %d, want = 1", r.nicID)
+ }
+ if r.addr != addr {
+ t.Fatalf("got r.addr = %s, want = %s", r.addr, addr)
+ }
+ if r.eventType != newAddr {
+ t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr)
+ }
+ case <-time.After(defaultTimeout):
+ t.Fatal("timeout waiting for addr auto gen event")
+ }
+
+ // Remove the address.
+ if err := s.RemoveAddress(1, addr.Address); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr.Address, err)
+ }
+
+ // Should get the invalidation event immediately.
+ select {
+ case r := <-ndpDisp.autoGenAddrC:
+ if r.nicID != 1 {
+ t.Fatalf("got r.nicID = %d, want = 1", r.nicID)
+ }
+ if r.addr != addr {
+ t.Fatalf("got r.addr = %s, want = %s", r.addr, addr)
+ }
+ if r.eventType != invalidatedAddr {
+ t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr)
+ }
+ case <-time.After(defaultTimeout):
+ t.Fatal("timeout waiting for addr auto gen event")
+ }
+
+ // Wait for the original valid lifetime to make sure the original timer
+ // got stopped/cleaned up.
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly received an auto gen addr event")
+ case <-time.After(lifetime*time.Second + defaultTimeout):
+ }
+}
+
+// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that
+// is already assigned to the NIC, the static address remains.
+func TestAutoGenAddrStaticConflict(t *testing.T) {
+ t.Parallel()
+
+ prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10),
+ }
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Add the address as a static address before SLAAC tries to add it.
+ if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err)
+ }
+ if !contains(s.NICInfo()[1].ProtocolAddresses, addr) {
+ t.Fatalf("Should have %s in the list of addresses", addr1)
+ }
+
+ // Receive a PI where the generated address will be the same as the one
+ // that we already have assigned statically.
+ const lifetime = 5
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetime, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly received an auto gen addr event for an address we already have statically")
+ case <-time.After(defaultTimeout):
+ }
+ if !contains(s.NICInfo()[1].ProtocolAddresses, addr) {
+ t.Fatalf("Should have %s in the list of addresses", addr1)
+ }
+
+ // Should not get an invalidation event after the PI's invalidation
+ // time.
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly received an auto gen addr event")
+ case <-time.After(lifetime*time.Second + defaultTimeout):
+ }
+ if !contains(s.NICInfo()[1].ProtocolAddresses, addr) {
+ t.Fatalf("Should have %s in the list of addresses", addr1)
+ }
+}
+
+// TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event
+// to the integrator when an RA is received with the NDP Recursive DNS Server
+// option with at least one valid address.
+func TestNDPRecursiveDNSServerDispatch(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ opt header.NDPRecursiveDNSServer
+ expected *ndpRDNSS
+ }{
+ {
+ "Unspecified",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 2,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ }),
+ nil,
+ },
+ {
+ "Multicast",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 2,
+ 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
+ }),
+ nil,
+ },
+ {
+ "OptionTooSmall",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 2,
+ 1, 2, 3, 4, 5, 6, 7, 8,
+ }),
+ nil,
+ },
+ {
+ "0Addresses",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 2,
+ }),
+ nil,
+ },
+ {
+ "Valid1Address",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 2,
+ 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1,
+ }),
+ &ndpRDNSS{
+ []tcpip.Address{
+ "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01",
+ },
+ 2 * time.Second,
+ },
+ },
+ {
+ "Valid2Addresses",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1,
+ 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2,
+ }),
+ &ndpRDNSS{
+ []tcpip.Address{
+ "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01",
+ "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02",
+ },
+ time.Second,
+ },
+ },
+ {
+ "Valid3Addresses",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1,
+ 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2,
+ 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 3,
+ }),
+ &ndpRDNSS{
+ []tcpip.Address{
+ "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01",
+ "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02",
+ "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x03",
+ },
+ 0,
+ },
+ },
+ }
+
+ for _, test := range tests {
+ test := test
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ // We do not expect more than a single RDNSS
+ // event at any time for this test.
+ rdnssC: make(chan ndpRDNSSEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, header.NDPOptionsSerializer{test.opt}))
+
+ if test.expected != nil {
+ select {
+ case e := <-ndpDisp.rdnssC:
+ if e.nicID != 1 {
+ t.Errorf("got rdnss nicID = %d, want = 1", e.nicID)
+ }
+ if diff := cmp.Diff(e.rdnss.addrs, test.expected.addrs); diff != "" {
+ t.Errorf("rdnss addrs mismatch (-want +got):\n%s", diff)
+ }
+ if e.rdnss.lifetime != test.expected.lifetime {
+ t.Errorf("got rdnss lifetime = %s, want = %s", e.rdnss.lifetime, test.expected.lifetime)
+ }
+ default:
+ t.Fatal("expected an RDNSS option event")
+ }
+ }
+
+ // Should have no more RDNSS options.
+ select {
+ case e := <-ndpDisp.rdnssC:
+ t.Fatalf("unexpectedly got a new RDNSS option event: %+v", e)
+ default:
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 3f8d7312c..e8401c673 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -115,10 +115,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback
},
},
ndp: ndpState{
- configs: stack.ndpConfigs,
- dad: make(map[tcpip.Address]dadState),
- defaultRouters: make(map[tcpip.Address]defaultRouterState),
- onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
+ configs: stack.ndpConfigs,
+ dad: make(map[tcpip.Address]dadState),
+ defaultRouters: make(map[tcpip.Address]defaultRouterState),
+ onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
+ autoGenAddresses: make(map[tcpip.Address]autoGenAddressState),
},
}
nic.ndp.nic = nic
@@ -244,6 +245,20 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN
return nil
}
+// hasPermanentAddrLocked returns true if n has a permanent (including currently
+// tentative) address, addr.
+func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool {
+ ref, ok := n.endpoints[NetworkEndpointID{addr}]
+
+ if !ok {
+ return false
+ }
+
+ kind := ref.getKind()
+
+ return kind == permanent || kind == permanentTentative
+}
+
func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, n.promiscuous)
}
@@ -335,7 +350,7 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
Address: address,
PrefixLen: netProto.DefaultPrefixLen(),
},
- }, peb, temporary)
+ }, peb, temporary, static)
n.mu.Unlock()
return ref
@@ -384,10 +399,10 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p
}
}
- return n.addAddressLocked(protocolAddress, peb, permanent)
+ return n.addAddressLocked(protocolAddress, peb, permanent, static)
}
-func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) {
+func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType) (*referencedNetworkEndpoint, *tcpip.Error) {
// TODO(b/141022673): Validate IP address before adding them.
// Sanity check.
@@ -417,11 +432,12 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
}
ref := &referencedNetworkEndpoint{
- refs: 1,
- ep: ep,
- nic: n,
- protocol: protocolAddress.Protocol,
- kind: kind,
+ refs: 1,
+ ep: ep,
+ nic: n,
+ protocol: protocolAddress.Protocol,
+ kind: kind,
+ configType: configType,
}
// Set up cache if link address resolution exists for this protocol.
@@ -624,9 +640,18 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
isIPv6Unicast := r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr)
- // If we are removing a tentative IPv6 unicast address, stop DAD.
- if isIPv6Unicast && kind == permanentTentative {
- n.ndp.stopDuplicateAddressDetection(addr)
+ if isIPv6Unicast {
+ // If we are removing a tentative IPv6 unicast address, stop
+ // DAD.
+ if kind == permanentTentative {
+ n.ndp.stopDuplicateAddressDetection(addr)
+ }
+
+ // If we are removing an address generated via SLAAC, cleanup
+ // its SLAAC resources and notify the integrator.
+ if r.configType == slaac {
+ n.ndp.cleanupAutoGenAddrResourcesAndNotify(addr)
+ }
}
r.setKind(permanentExpired)
@@ -989,7 +1014,7 @@ const (
// removing the permanent address from the NIC.
permanent
- // An expired permanent endoint is a permanent endoint that had its address
+ // An expired permanent endpoint is a permanent endpoint that had its address
// removed from the NIC, and it is waiting to be removed once no more routes
// hold a reference to it. This is achieved by decreasing its reference count
// by 1. If its address is re-added before the endpoint is removed, its type
@@ -1035,6 +1060,19 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep
}
}
+type networkEndpointConfigType int32
+
+const (
+ // A statically configured endpoint is an address that was added by
+ // some user-specified action (adding an explicit address, joining a
+ // multicast group).
+ static networkEndpointConfigType = iota
+
+ // A slaac configured endpoint is an IPv6 endpoint that was
+ // added by SLAAC as per RFC 4862 section 5.5.3.
+ slaac
+)
+
type referencedNetworkEndpoint struct {
ep NetworkEndpoint
nic *NIC
@@ -1050,6 +1088,10 @@ type referencedNetworkEndpoint struct {
// networkEndpointKind must only be accessed using {get,set}Kind().
kind networkEndpointKind
+
+ // configType is the method that was used to configure this endpoint.
+ // This must never change after the endpoint is added to a NIC.
+ configType networkEndpointConfigType
}
func (r *referencedNetworkEndpoint) getKind() networkEndpointKind {
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 5746043cc..d5bb5b6ed 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -924,6 +924,14 @@ type TCPStats struct {
// ESTABLISHED state or the CLOSE-WAIT state.
EstablishedResets *StatCounter
+ // EstablishedClosed is the number of times established TCP connections
+ // made a transition to CLOSED state.
+ EstablishedClosed *StatCounter
+
+ // EstablishedTimedout is the number of times an established connection
+ // was reset because of keep-alive time out.
+ EstablishedTimedout *StatCounter
+
// ListenOverflowSynDrop is the number of times the listen queue overflowed
// and a SYN was dropped.
ListenOverflowSynDrop *StatCounter
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index dd1728f9c..455a1c098 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -52,6 +52,7 @@ go_library(
"//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
"//pkg/tcpip/iptables",
+ "//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index f543a6105..74df3edfb 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -298,8 +298,6 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
return nil, err
}
ep.mu.Lock()
- ep.stack.Stats().TCP.CurrentEstablished.Increment()
- ep.state = StateEstablished
ep.isConnectNotified = true
ep.mu.Unlock()
@@ -546,6 +544,8 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
n.tsOffset = 0
// Switch state to connected.
+ // We do not use transitionToStateEstablishedLocked here as there is
+ // no handshake state available when doing a SYN cookie based accept.
n.stack.Stats().TCP.CurrentEstablished.Increment()
n.state = StateEstablished
n.isConnectNotified = true
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 4206db8b6..3d059c302 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -252,6 +252,11 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// and the handshake is completed.
if s.flagIsSet(header.TCPFlagAck) {
h.state = handshakeCompleted
+
+ h.ep.mu.Lock()
+ h.ep.transitionToStateEstablishedLocked(h)
+ h.ep.mu.Unlock()
+
h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale())
return nil
}
@@ -352,6 +357,10 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber)
}
h.state = handshakeCompleted
+ h.ep.mu.Lock()
+ h.ep.transitionToStateEstablishedLocked(h)
+ h.ep.mu.Unlock()
+
return nil
}
@@ -880,6 +889,30 @@ func (e *endpoint) completeWorkerLocked() {
}
}
+// transitionToStateEstablisedLocked transitions a given endpoint
+// to an established state using the handshake parameters provided.
+// It also initializes sender/receiver if required.
+func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
+ if e.snd == nil {
+ // Transfer handshake state to TCP connection. We disable
+ // receive window scaling if the peer doesn't support it
+ // (indicated by a negative send window scale).
+ e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
+ }
+ if e.rcv == nil {
+ rcvBufSize := seqnum.Size(e.receiveBufferSize())
+ e.rcvListMu.Lock()
+ e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize)
+ // Bootstrap the auto tuning algorithm. Starting at zero will
+ // result in a really large receive window after the first auto
+ // tuning adjustment.
+ e.rcvAutoParams.prevCopied = int(h.rcvWnd)
+ e.rcvListMu.Unlock()
+ }
+ h.ep.stack.Stats().TCP.CurrentEstablished.Increment()
+ e.state = StateEstablished
+}
+
// transitionToStateCloseLocked ensures that the endpoint is
// cleaned up from the transport demuxer, "before" moving to
// StateClose. This will ensure that no packet will be
@@ -891,6 +924,7 @@ func (e *endpoint) transitionToStateCloseLocked() {
}
e.cleanupLocked()
e.state = StateClose
+ e.stack.Stats().TCP.EstablishedClosed.Increment()
}
// tryDeliverSegmentFromClosedEndpoint attempts to deliver the parsed
@@ -953,20 +987,6 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
func (e *endpoint) handleSegments() *tcpip.Error {
checkRequeue := true
for i := 0; i < maxSegmentsPerWake; i++ {
- e.mu.RLock()
- state := e.state
- e.mu.RUnlock()
- if state == StateClose {
- // When we get into StateClose while processing from the queue,
- // return immediately and let the protocolMainloop handle it.
- //
- // We can reach StateClose only while processing a previous segment
- // or a notification from the protocolMainLoop (caller goroutine).
- // This means that with this return, the segment dequeue below can
- // never occur on a closed endpoint.
- return nil
- }
-
s := e.segmentQueue.dequeue()
if s == nil {
checkRequeue = false
@@ -1024,6 +1044,24 @@ func (e *endpoint) handleSegments() *tcpip.Error {
s.decRef()
continue
}
+
+ // Now check if the received segment has caused us to transition
+ // to a CLOSED state, if yes then terminate processing and do
+ // not invoke the sender.
+ e.mu.RLock()
+ state := e.state
+ e.mu.RUnlock()
+ if state == StateClose {
+ // When we get into StateClose while processing from the queue,
+ // return immediately and let the protocolMainloop handle it.
+ //
+ // We can reach StateClose only while processing a previous segment
+ // or a notification from the protocolMainLoop (caller goroutine).
+ // This means that with this return, the segment dequeue below can
+ // never occur on a closed endpoint.
+ s.decRef()
+ return nil
+ }
e.snd.handleRcvdSegment(s)
}
s.decRef()
@@ -1057,6 +1095,7 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
if e.keepalive.unacked >= e.keepalive.count {
e.keepalive.Unlock()
+ e.stack.Stats().TCP.EstablishedTimedout.Increment()
return tcpip.ErrTimeout
}
@@ -1142,8 +1181,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
e.lastErrorMu.Unlock()
e.mu.Lock()
- e.stack.Stats().TCP.EstablishedResets.Increment()
- e.stack.Stats().TCP.CurrentEstablished.Decrement()
e.state = StateError
e.HardError = err
@@ -1152,25 +1189,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
return err
}
-
- // Transfer handshake state to TCP connection. We disable
- // receive window scaling if the peer doesn't support it
- // (indicated by a negative send window scale).
- e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
-
- rcvBufSize := seqnum.Size(e.receiveBufferSize())
- e.rcvListMu.Lock()
- e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize)
- // boot strap the auto tuning algorithm. Starting at zero will
- // result in a large step function on the first proper causing
- // the window to just go to a really large value after the first
- // RTT itself.
- e.rcvAutoParams.prevCopied = initialRcvWnd
- e.rcvListMu.Unlock()
- e.stack.Stats().TCP.CurrentEstablished.Increment()
- e.mu.Lock()
- e.state = StateEstablished
- e.mu.Unlock()
}
e.keepalive.timer.init(&e.keepalive.waker)
@@ -1371,7 +1389,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
// Mark endpoint as closed.
e.mu.Lock()
if e.state != StateError {
- e.stack.Stats().TCP.EstablishedResets.Increment()
e.stack.Stats().TCP.CurrentEstablished.Decrement()
e.transitionToStateCloseLocked()
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 9d4a87e30..4861ab513 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -30,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tmutex"
@@ -343,6 +344,7 @@ type endpoint struct {
// Values used to reserve a port or register a transport endpoint
// (which ever happens first).
boundBindToDevice tcpip.NICID
+ boundPortFlags ports.Flags
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
@@ -737,9 +739,10 @@ func (e *endpoint) Close() {
e.isRegistered = false
}
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundBindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
e.isPortReserved = false
e.boundBindToDevice = 0
+ e.boundPortFlags = ports.Flags{}
}
// Mark endpoint as closed.
@@ -800,10 +803,11 @@ func (e *endpoint) cleanupLocked() {
}
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundBindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
e.isPortReserved = false
}
e.boundBindToDevice = 0
+ e.boundPortFlags = ports.Flags{}
e.route.Release()
e.stack.CompleteTransportEndpointCleanup(e)
@@ -1775,7 +1779,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
// reusePort is false below because connect cannot reuse a port even if
// reusePort was set.
- if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, false /* reusePort */, e.bindToDevice) {
+ if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, ports.Flags{LoadBalanced: false}, e.bindToDevice) {
return false, nil
}
@@ -1802,7 +1806,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
// before Connect: in such a case we don't want to hold on to
// reservations anymore.
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.boundBindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
e.isPortReserved = false
}
@@ -2034,28 +2038,33 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
}
}
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort, e.bindToDevice)
+ flags := ports.Flags{
+ LoadBalanced: e.reusePort,
+ }
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, flags, e.bindToDevice)
if err != nil {
return err
}
e.boundBindToDevice = e.bindToDevice
+ e.boundPortFlags = flags
e.isPortReserved = true
e.effectiveNetProtos = netProtos
e.ID.LocalPort = port
// Any failures beyond this point must remove the port registration.
- defer func(bindToDevice tcpip.NICID) {
+ defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) {
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, bindToDevice)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice)
e.isPortReserved = false
e.effectiveNetProtos = nil
e.ID.LocalPort = 0
e.ID.LocalAddress = ""
e.boundNICID = 0
e.boundBindToDevice = 0
+ e.boundPortFlags = ports.Flags{}
}
- }(e.boundBindToDevice)
+ }(e.boundPortFlags, e.boundBindToDevice)
// If an address is specified, we must ensure that it's one of our
// local addresses.
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 857dc445f..5ee499c36 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -205,7 +205,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// Handle ACK (not FIN-ACK, which we handled above) during one of the
// shutdown states.
- if s.flagIsSet(header.TCPFlagAck) {
+ if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt {
r.ep.mu.Lock()
switch r.ep.state {
case StateFinWait1:
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index d3f7c9125..8332a0179 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -674,7 +674,6 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
default:
s.ep.state = StateFinWait1
}
- s.ep.stack.Stats().TCP.CurrentEstablished.Decrement()
s.ep.mu.Unlock()
} else {
// We're sending a non-FIN segment.
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 50829ae27..bc5cfcf0e 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -75,6 +75,20 @@ func TestGiveUpConnect(t *testing.T) {
if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted {
t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %v, want = %v", err, tcpip.ErrAborted)
}
+
+ // Call Connect again to retreive the handshake failure status
+ // and stats updates.
+ if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted {
+ t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrAborted)
+ }
+
+ if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 {
+ t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = 1", got)
+ }
+
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ }
}
func TestConnectIncrementActiveConnection(t *testing.T) {
@@ -541,13 +555,21 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
ep.(interface{ ResumeWork() }).ResumeWork()
// Wait for the protocolMainLoop to resume and update state.
- time.Sleep(1 * time.Millisecond)
+ time.Sleep(10 * time.Millisecond)
// Expect the endpoint to be closed.
if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want {
t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
}
+ if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = 1", got)
+ }
+
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ }
+
// Check if the endpoint was moved to CLOSED and netstack a reset in
// response to the ACK packet that we sent after last-ACK.
checker.IPv4(t, c.GetPacket(),
@@ -2694,6 +2716,13 @@ loop:
if tcp.EndpointState(c.EP.State()) != tcp.StateError {
t.Fatalf("got EP state is not StateError")
}
+
+ if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
+ t.Errorf("got stats.TCP.EstablishedResets.Value() = %v, want = 1", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ }
}
func TestSendOnResetConnection(t *testing.T) {
@@ -4363,9 +4392,17 @@ func TestKeepalive(t *testing.T) {
),
)
+ if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %v, want = 1", got)
+ }
+
if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
}
+
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ }
}
func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
@@ -5632,6 +5669,7 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) {
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: iss,
+ RcvWnd: 30000,
})
// Receive the SYN-ACK reply.
@@ -5750,6 +5788,7 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) {
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: iss,
+ RcvWnd: 30000,
})
// Receive the SYN-ACK reply.
@@ -5856,6 +5895,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: iss,
+ RcvWnd: 30000,
})
// Receive the SYN-ACK reply.
@@ -5929,6 +5969,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: iss,
+ RcvWnd: 30000,
})
c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second)
@@ -5941,6 +5982,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: iss,
+ RcvWnd: 30000,
})
// Receive the SYN-ACK reply.
@@ -5987,6 +6029,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err)
}
+ want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1
+
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
@@ -6007,6 +6051,7 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: iss,
+ RcvWnd: 30000,
})
// Receive the SYN-ACK reply.
@@ -6114,4 +6159,184 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
checker.SeqNum(uint32(ackHeaders.AckNum)),
checker.AckNum(uint32(ackHeaders.SeqNum)),
checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck)))
+
+ if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = %v", got, want)
+ }
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ }
+}
+
+func TestTCPCloseWithData(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
+ // after 5 seconds in TIME_WAIT state.
+ tcpTimeWaitTimeout := 5 * time.Second
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err)
+ }
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ RcvWnd: 30000,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ // Now trigger a passive close by sending a FIN.
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ RcvWnd: 30000,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ // Now write a few bytes and then close the endpoint.
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Check that data is received.
+ b = c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
+ t.Errorf("got data = %x, want = %x", p, data)
+ }
+
+ c.EP.Close()
+ // Check the FIN.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))),
+ checker.AckNum(uint32(iss+2)),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ // First send a partial ACK.
+ ackHeaders = &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 2,
+ AckNum: c.IRS + 1 + seqnum.Value(len(data)-1),
+ RcvWnd: 30000,
+ }
+ c.SendPacket(nil, ackHeaders)
+
+ // Now send a full ACK.
+ ackHeaders = &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 2,
+ AckNum: c.IRS + 1 + seqnum.Value(len(data)),
+ RcvWnd: 30000,
+ }
+ c.SendPacket(nil, ackHeaders)
+
+ // Now ACK the FIN.
+ ackHeaders.AckNum++
+ c.SendPacket(nil, ackHeaders)
+
+ // Now send an ACK and we should get a RST back as the endpoint should
+ // be in CLOSED state.
+ ackHeaders = &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 2,
+ AckNum: c.IRS + 1 + seqnum.Value(len(data)),
+ RcvWnd: 30000,
+ }
+ c.SendPacket(nil, ackHeaders)
+
+ // Check the RST.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(ackHeaders.AckNum)),
+ checker.AckNum(uint32(ackHeaders.SeqNum)),
+ checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck)))
+
}
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 6cb66c1af..b0a376eba 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -231,6 +231,7 @@ func (c *Context) CheckNoPacket(errMsg string) {
// addresses. It will fail with an error if no packet is received for
// 2 seconds.
func (c *Context) GetPacket() []byte {
+ c.t.Helper()
select {
case p := <-c.linkEP.C:
if p.Proto != ipv4.ProtocolNumber {
@@ -259,6 +260,7 @@ func (c *Context) GetPacket() []byte {
// and destination address. If no packet is available it will return
// nil immediately.
func (c *Context) GetPacketNonBlocking() []byte {
+ c.t.Helper()
select {
case p := <-c.linkEP.C:
if p.Proto != ipv4.ProtocolNumber {
@@ -483,6 +485,7 @@ func (c *Context) CreateV6Endpoint(v6only bool) {
// GetV6Packet reads a single packet from the link layer endpoint of the context
// and asserts that it is an IPv6 Packet with the expected src/dest addresses.
func (c *Context) GetV6Packet() []byte {
+ c.t.Helper()
select {
case p := <-c.linkEP.C:
if p.Proto != ipv6.ProtocolNumber {
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index 8d4c3808f..97e4d5825 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -34,6 +34,7 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/iptables",
+ "//pkg/tcpip/ports",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
"//pkg/waiter",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 24cb88c13..1ac4705af 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -107,6 +108,7 @@ type endpoint struct {
// Values used to reserve a port or register a transport endpoint.
// (which ever happens first).
boundBindToDevice tcpip.NICID
+ boundPortFlags ports.Flags
// sendTOS represents IPv4 TOS or IPv6 TrafficClass,
// applied while sending packets. Defaults to 0 as on Linux.
@@ -180,8 +182,9 @@ func (e *endpoint) Close() {
switch e.state {
case StateBound, StateConnected:
e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundBindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
e.boundBindToDevice = 0
+ e.boundPortFlags = ports.Flags{}
}
for _, mem := range e.multicastMemberships {
@@ -895,7 +898,8 @@ func (e *endpoint) Disconnect() *tcpip.Error {
} else {
if e.ID.LocalPort != 0 {
// Release the ephemeral port.
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundBindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
+ e.boundPortFlags = ports.Flags{}
}
e.state = StateInitial
}
@@ -1042,16 +1046,23 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) {
if e.ID.LocalPort == 0 {
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice)
+ flags := ports.Flags{
+ LoadBalanced: e.reusePort,
+ // FIXME(b/129164367): Support SO_REUSEADDR.
+ MostRecent: false,
+ }
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, flags, e.bindToDevice)
if err != nil {
return id, e.bindToDevice, err
}
+ e.boundPortFlags = flags
id.LocalPort = port
}
err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice)
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.bindToDevice)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice)
+ e.boundPortFlags = ports.Flags{}
}
return id, e.bindToDevice, err
}
@@ -1134,9 +1145,14 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
+ addr := e.ID.LocalAddress
+ if e.state == StateConnected {
+ addr = e.route.LocalAddress
+ }
+
return tcpip.FullAddress{
NIC: e.RegisterNICID,
- Addr: e.ID.LocalAddress,
+ Addr: addr,
Port: e.ID.LocalPort,
}, nil
}