diff options
Diffstat (limited to 'pkg')
140 files changed, 7212 insertions, 941 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index a89f34d4b..322d1ccc4 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -30,6 +30,7 @@ go_library( "futex.go", "inotify.go", "ioctl.go", + "ioctl_tun.go", "ip.go", "ipc.go", "limits.go", diff --git a/pkg/abi/linux/epoll.go b/pkg/abi/linux/epoll.go index 6e4de69da..1121a1a92 100644 --- a/pkg/abi/linux/epoll.go +++ b/pkg/abi/linux/epoll.go @@ -14,6 +14,10 @@ package linux +import ( + "gvisor.dev/gvisor/pkg/binary" +) + // Event masks. const ( EPOLLIN = 0x1 @@ -53,3 +57,6 @@ const ( EPOLL_CTL_DEL = 0x2 EPOLL_CTL_MOD = 0x3 ) + +// SizeOfEpollEvent is the size of EpollEvent struct. +var SizeOfEpollEvent = int(binary.Size(EpollEvent{})) diff --git a/pkg/abi/linux/epoll_amd64.go b/pkg/abi/linux/epoll_amd64.go index 57041491c..34ff18009 100644 --- a/pkg/abi/linux/epoll_amd64.go +++ b/pkg/abi/linux/epoll_amd64.go @@ -15,6 +15,8 @@ package linux // EpollEvent is equivalent to struct epoll_event from epoll(2). +// +// +marshal type EpollEvent struct { Events uint32 // Linux makes struct epoll_event::data a __u64. We represent it as diff --git a/pkg/abi/linux/epoll_arm64.go b/pkg/abi/linux/epoll_arm64.go index 62ef5821e..f86c35329 100644 --- a/pkg/abi/linux/epoll_arm64.go +++ b/pkg/abi/linux/epoll_arm64.go @@ -15,6 +15,8 @@ package linux // EpollEvent is equivalent to struct epoll_event from epoll(2). +// +// +marshal type EpollEvent struct { Events uint32 // Linux makes struct epoll_event a __u64, necessitating 4 bytes of padding diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go index c3ab15a4f..e229ac21c 100644 --- a/pkg/abi/linux/file.go +++ b/pkg/abi/linux/file.go @@ -241,6 +241,8 @@ const ( ) // Statx represents struct statx. +// +// +marshal type Statx struct { Mask uint32 Blksize uint32 diff --git a/pkg/abi/linux/fs.go b/pkg/abi/linux/fs.go index 2c652baa2..158d2db5b 100644 --- a/pkg/abi/linux/fs.go +++ b/pkg/abi/linux/fs.go @@ -38,6 +38,8 @@ const ( ) // Statfs is struct statfs, from uapi/asm-generic/statfs.h. +// +// +marshal type Statfs struct { // Type is one of the filesystem magic values, defined above. Type uint64 diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go index 0e18db9ef..2062e6a4b 100644 --- a/pkg/abi/linux/ioctl.go +++ b/pkg/abi/linux/ioctl.go @@ -72,3 +72,29 @@ const ( SIOCGMIIPHY = 0x8947 SIOCGMIIREG = 0x8948 ) + +// ioctl(2) directions. Used to calculate requests number. +// Constants from asm-generic/ioctl.h. +const ( + _IOC_NONE = 0 + _IOC_WRITE = 1 + _IOC_READ = 2 +) + +// Constants from asm-generic/ioctl.h. +const ( + _IOC_NRBITS = 8 + _IOC_TYPEBITS = 8 + _IOC_SIZEBITS = 14 + _IOC_DIRBITS = 2 + + _IOC_NRSHIFT = 0 + _IOC_TYPESHIFT = _IOC_NRSHIFT + _IOC_NRBITS + _IOC_SIZESHIFT = _IOC_TYPESHIFT + _IOC_TYPEBITS + _IOC_DIRSHIFT = _IOC_SIZESHIFT + _IOC_SIZEBITS +) + +// IOC outputs the result of _IOC macro in asm-generic/ioctl.h. +func IOC(dir, typ, nr, size uint32) uint32 { + return uint32(dir)<<_IOC_DIRSHIFT | typ<<_IOC_TYPESHIFT | nr<<_IOC_NRSHIFT | size<<_IOC_SIZESHIFT +} diff --git a/pkg/fspath/builder_unsafe.go b/pkg/abi/linux/ioctl_tun.go index 75606808d..c59c9c136 100644 --- a/pkg/fspath/builder_unsafe.go +++ b/pkg/abi/linux/ioctl_tun.go @@ -1,4 +1,4 @@ -// Copyright 2019 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,16 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -package fspath +package linux -import ( - "unsafe" +// ioctl(2) request numbers from linux/if_tun.h +var ( + TUNSETIFF = IOC(_IOC_WRITE, 'T', 202, 4) + TUNGETIFF = IOC(_IOC_READ, 'T', 210, 4) ) -// String returns the accumulated string. No other methods should be called -// after String. -func (b *Builder) String() string { - bs := b.buf[b.start:] - // Compare strings.Builder.String(). - return *(*string)(unsafe.Pointer(&bs)) -} +// Flags from net/if_tun.h +const ( + IFF_TUN = 0x0001 + IFF_TAP = 0x0002 + IFF_NO_PI = 0x1000 + IFF_NOFILTER = 0x1000 +) diff --git a/pkg/abi/linux/signal.go b/pkg/abi/linux/signal.go index c69b04ea9..1c330e763 100644 --- a/pkg/abi/linux/signal.go +++ b/pkg/abi/linux/signal.go @@ -115,6 +115,8 @@ const ( ) // SignalSet is a signal mask with a bit corresponding to each signal. +// +// +marshal type SignalSet uint64 // SignalSetSize is the size in bytes of a SignalSet. diff --git a/pkg/abi/linux/time.go b/pkg/abi/linux/time.go index e562b46d9..e6860ed49 100644 --- a/pkg/abi/linux/time.go +++ b/pkg/abi/linux/time.go @@ -157,6 +157,8 @@ func DurationToTimespec(dur time.Duration) Timespec { const SizeOfTimeval = 16 // Timeval represents struct timeval in <time.h>. +// +// +marshal type Timeval struct { Sec int64 Usec int64 @@ -230,6 +232,8 @@ type Tms struct { type TimerID int32 // StatxTimestamp represents struct statx_timestamp. +// +// +marshal type StatxTimestamp struct { Sec int64 Nsec uint32 @@ -258,6 +262,8 @@ func NsecToStatxTimestamp(nsec int64) (ts StatxTimestamp) { } // Utime represents struct utimbuf used by utimes(2). +// +// +marshal type Utime struct { Actime int64 Modtime int64 diff --git a/pkg/abi/linux/xattr.go b/pkg/abi/linux/xattr.go index a3b6406fa..99180b208 100644 --- a/pkg/abi/linux/xattr.go +++ b/pkg/abi/linux/xattr.go @@ -18,6 +18,7 @@ package linux const ( XATTR_NAME_MAX = 255 XATTR_SIZE_MAX = 65536 + XATTR_LIST_MAX = 65536 XATTR_CREATE = 1 XATTR_REPLACE = 2 diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD index ba8b06071..1a30f6967 100644 --- a/pkg/atomicbitops/BUILD +++ b/pkg/atomicbitops/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "atomicbitops.go", "atomicbitops_amd64.s", + "atomicbitops_arm64.s", "atomicbitops_noasm.go", ], visibility = ["//:sandbox"], diff --git a/pkg/fspath/BUILD b/pkg/fspath/BUILD index ee84471b2..67dd1e225 100644 --- a/pkg/fspath/BUILD +++ b/pkg/fspath/BUILD @@ -8,9 +8,11 @@ go_library( name = "fspath", srcs = [ "builder.go", - "builder_unsafe.go", "fspath.go", ], + deps = [ + "//pkg/gohacks", + ], ) go_test( diff --git a/pkg/fspath/builder.go b/pkg/fspath/builder.go index 7ddb36826..6318d3874 100644 --- a/pkg/fspath/builder.go +++ b/pkg/fspath/builder.go @@ -16,6 +16,8 @@ package fspath import ( "fmt" + + "gvisor.dev/gvisor/pkg/gohacks" ) // Builder is similar to strings.Builder, but is used to produce pathnames @@ -102,3 +104,9 @@ func (b *Builder) AppendString(str string) { copy(b.buf[b.start:], b.buf[oldStart:]) copy(b.buf[len(b.buf)-len(str):], str) } + +// String returns the accumulated string. No other methods should be called +// after String. +func (b *Builder) String() string { + return gohacks.StringFromImmutableBytes(b.buf[b.start:]) +} diff --git a/pkg/fspath/fspath.go b/pkg/fspath/fspath.go index 9fb3fee24..4c983d5fd 100644 --- a/pkg/fspath/fspath.go +++ b/pkg/fspath/fspath.go @@ -67,7 +67,8 @@ func Parse(pathname string) Path { // Path contains the information contained in a pathname string. // -// Path is copyable by value. +// Path is copyable by value. The zero value for Path is equivalent to +// fspath.Parse(""), i.e. the empty path. type Path struct { // Begin is an iterator to the first path component in the relative part of // the path. diff --git a/pkg/gohacks/BUILD b/pkg/gohacks/BUILD new file mode 100644 index 000000000..798a65eca --- /dev/null +++ b/pkg/gohacks/BUILD @@ -0,0 +1,11 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "gohacks", + srcs = [ + "gohacks_unsafe.go", + ], + visibility = ["//:sandbox"], +) diff --git a/pkg/gohacks/gohacks_unsafe.go b/pkg/gohacks/gohacks_unsafe.go new file mode 100644 index 000000000..aad675172 --- /dev/null +++ b/pkg/gohacks/gohacks_unsafe.go @@ -0,0 +1,57 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package gohacks contains utilities for subverting the Go compiler. +package gohacks + +import ( + "reflect" + "unsafe" +) + +// Noescape hides a pointer from escape analysis. Noescape is the identity +// function but escape analysis doesn't think the output depends on the input. +// Noescape is inlined and currently compiles down to zero instructions. +// USE CAREFULLY! +// +// (Noescape is copy/pasted from Go's runtime/stubs.go:noescape().) +// +//go:nosplit +func Noescape(p unsafe.Pointer) unsafe.Pointer { + x := uintptr(p) + return unsafe.Pointer(x ^ 0) +} + +// ImmutableBytesFromString is equivalent to []byte(s), except that it uses the +// same memory backing s instead of making a heap-allocated copy. This is only +// valid if the returned slice is never mutated. +func ImmutableBytesFromString(s string) []byte { + shdr := (*reflect.StringHeader)(unsafe.Pointer(&s)) + var bs []byte + bshdr := (*reflect.SliceHeader)(unsafe.Pointer(&bs)) + bshdr.Data = shdr.Data + bshdr.Len = shdr.Len + bshdr.Cap = shdr.Len + return bs +} + +// StringFromImmutableBytes is equivalent to string(bs), except that it uses +// the same memory backing bs instead of making a heap-allocated copy. This is +// only valid if bs is never mutated after StringFromImmutableBytes returns. +func StringFromImmutableBytes(bs []byte) string { + // This is cheaper than messing with reflect.StringHeader and + // reflect.SliceHeader, which as of this writing produces many dead stores + // of zeroes. Compare strings.Builder.String(). + return *(*string)(unsafe.Pointer(&bs)) +} diff --git a/pkg/log/glog.go b/pkg/log/glog.go index cab5fae55..b4f7bb5a4 100644 --- a/pkg/log/glog.go +++ b/pkg/log/glog.go @@ -46,7 +46,7 @@ var pid = os.Getpid() // line The line number // msg The user-supplied message // -func (g *GoogleEmitter) Emit(level Level, timestamp time.Time, format string, args ...interface{}) { +func (g *GoogleEmitter) Emit(depth int, level Level, timestamp time.Time, format string, args ...interface{}) { // Log level. prefix := byte('?') switch level { @@ -64,9 +64,7 @@ func (g *GoogleEmitter) Emit(level Level, timestamp time.Time, format string, ar microsecond := int(timestamp.Nanosecond() / 1000) // 0 = this frame. - // 1 = Debugf, etc. - // 2 = Caller. - _, file, line, ok := runtime.Caller(2) + _, file, line, ok := runtime.Caller(depth + 1) if ok { // Trim any directory path from the file. slash := strings.LastIndexByte(file, byte('/')) diff --git a/pkg/log/json.go b/pkg/log/json.go index a278c8fc8..0943db1cc 100644 --- a/pkg/log/json.go +++ b/pkg/log/json.go @@ -62,7 +62,7 @@ type JSONEmitter struct { } // Emit implements Emitter.Emit. -func (e JSONEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) { +func (e JSONEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) { j := jsonLog{ Msg: fmt.Sprintf(format, v...), Level: level, diff --git a/pkg/log/json_k8s.go b/pkg/log/json_k8s.go index cee6eb514..6c6fc8b6f 100644 --- a/pkg/log/json_k8s.go +++ b/pkg/log/json_k8s.go @@ -33,7 +33,7 @@ type K8sJSONEmitter struct { } // Emit implements Emitter.Emit. -func (e *K8sJSONEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) { +func (e *K8sJSONEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) { j := k8sJSONLog{ Log: fmt.Sprintf(format, v...), Level: level, diff --git a/pkg/log/log.go b/pkg/log/log.go index 5056f17e6..a794da1aa 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -79,7 +79,7 @@ func (l Level) String() string { type Emitter interface { // Emit emits the given log statement. This allows for control over the // timestamp used for logging. - Emit(level Level, timestamp time.Time, format string, v ...interface{}) + Emit(depth int, level Level, timestamp time.Time, format string, v ...interface{}) } // Writer writes the output to the given writer. @@ -142,7 +142,7 @@ func (l *Writer) Write(data []byte) (int, error) { } // Emit emits the message. -func (l *Writer) Emit(level Level, timestamp time.Time, format string, args ...interface{}) { +func (l *Writer) Emit(_ int, _ Level, _ time.Time, format string, args ...interface{}) { fmt.Fprintf(l, format, args...) } @@ -150,9 +150,9 @@ func (l *Writer) Emit(level Level, timestamp time.Time, format string, args ...i type MultiEmitter []Emitter // Emit emits to all emitters. -func (m *MultiEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) { +func (m *MultiEmitter) Emit(depth int, level Level, timestamp time.Time, format string, v ...interface{}) { for _, e := range *m { - e.Emit(level, timestamp, format, v...) + e.Emit(1+depth, level, timestamp, format, v...) } } @@ -167,7 +167,7 @@ type TestEmitter struct { } // Emit emits to the TestLogger. -func (t *TestEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) { +func (t *TestEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) { t.Logf(format, v...) } @@ -198,22 +198,37 @@ type BasicLogger struct { // Debugf implements logger.Debugf. func (l *BasicLogger) Debugf(format string, v ...interface{}) { - if l.IsLogging(Debug) { - l.Emit(Debug, time.Now(), format, v...) - } + l.DebugfAtDepth(1, format, v...) } // Infof implements logger.Infof. func (l *BasicLogger) Infof(format string, v ...interface{}) { - if l.IsLogging(Info) { - l.Emit(Info, time.Now(), format, v...) - } + l.InfofAtDepth(1, format, v...) } // Warningf implements logger.Warningf. func (l *BasicLogger) Warningf(format string, v ...interface{}) { + l.WarningfAtDepth(1, format, v...) +} + +// DebugfAtDepth logs at a specific depth. +func (l *BasicLogger) DebugfAtDepth(depth int, format string, v ...interface{}) { + if l.IsLogging(Debug) { + l.Emit(1+depth, Debug, time.Now(), format, v...) + } +} + +// InfofAtDepth logs at a specific depth. +func (l *BasicLogger) InfofAtDepth(depth int, format string, v ...interface{}) { + if l.IsLogging(Info) { + l.Emit(1+depth, Info, time.Now(), format, v...) + } +} + +// WarningfAtDepth logs at a specific depth. +func (l *BasicLogger) WarningfAtDepth(depth int, format string, v ...interface{}) { if l.IsLogging(Warning) { - l.Emit(Warning, time.Now(), format, v...) + l.Emit(1+depth, Warning, time.Now(), format, v...) } } @@ -257,17 +272,32 @@ func SetLevel(newLevel Level) { // Debugf logs to the global logger. func Debugf(format string, v ...interface{}) { - Log().Debugf(format, v...) + Log().DebugfAtDepth(1, format, v...) } // Infof logs to the global logger. func Infof(format string, v ...interface{}) { - Log().Infof(format, v...) + Log().InfofAtDepth(1, format, v...) } // Warningf logs to the global logger. func Warningf(format string, v ...interface{}) { - Log().Warningf(format, v...) + Log().WarningfAtDepth(1, format, v...) +} + +// DebugfAtDepth logs to the global logger. +func DebugfAtDepth(depth int, format string, v ...interface{}) { + Log().DebugfAtDepth(1+depth, format, v...) +} + +// InfofAtDepth logs to the global logger. +func InfofAtDepth(depth int, format string, v ...interface{}) { + Log().InfofAtDepth(1+depth, format, v...) +} + +// WarningfAtDepth logs to the global logger. +func WarningfAtDepth(depth int, format string, v ...interface{}) { + Log().WarningfAtDepth(1+depth, format, v...) } // defaultStackSize is the default buffer size to allocate for stack traces. diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go index 3b6987665..d7794db63 100644 --- a/pkg/sentry/arch/arch_aarch64.go +++ b/pkg/sentry/arch/arch_aarch64.go @@ -34,6 +34,9 @@ const ( SyscallWidth = 4 ) +// ARMTrapFlag is the mask for the trap flag. +const ARMTrapFlag = uint64(1) << 21 + // aarch64FPState is aarch64 floating point state. type aarch64FPState []byte diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD index 4c4b7d5cc..9b6bb26d0 100644 --- a/pkg/sentry/fs/dev/BUILD +++ b/pkg/sentry/fs/dev/BUILD @@ -9,6 +9,7 @@ go_library( "device.go", "fs.go", "full.go", + "net_tun.go", "null.go", "random.go", "tty.go", @@ -19,15 +20,19 @@ go_library( "//pkg/context", "//pkg/rand", "//pkg/safemem", + "//pkg/sentry/arch", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/ramfs", "//pkg/sentry/fs/tmpfs", + "//pkg/sentry/kernel", "//pkg/sentry/memmap", "//pkg/sentry/mm", "//pkg/sentry/pgalloc", + "//pkg/sentry/socket/netstack", "//pkg/syserror", + "//pkg/tcpip/link/tun", "//pkg/usermem", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/dev/dev.go b/pkg/sentry/fs/dev/dev.go index 35bd23991..7e66c29b0 100644 --- a/pkg/sentry/fs/dev/dev.go +++ b/pkg/sentry/fs/dev/dev.go @@ -66,8 +66,8 @@ func newMemDevice(ctx context.Context, iops fs.InodeOperations, msrc *fs.MountSo }) } -func newDirectory(ctx context.Context, msrc *fs.MountSource) *fs.Inode { - iops := ramfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0555)) +func newDirectory(ctx context.Context, contents map[string]*fs.Inode, msrc *fs.MountSource) *fs.Inode { + iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555)) return fs.NewInode(ctx, iops, msrc, fs.StableAttr{ DeviceID: devDevice.DeviceID(), InodeID: devDevice.NextIno(), @@ -111,7 +111,7 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode { // A devpts is typically mounted at /dev/pts to provide // pseudoterminal support. Place an empty directory there for // the devpts to be mounted over. - "pts": newDirectory(ctx, msrc), + "pts": newDirectory(ctx, nil, msrc), // Similarly, applications expect a ptmx device at /dev/ptmx // connected to the terminals provided by /dev/pts/. Rather // than creating a device directly (which requires a hairy @@ -124,6 +124,10 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode { "ptmx": newSymlink(ctx, "pts/ptmx", msrc), "tty": newCharacterDevice(ctx, newTTYDevice(ctx, fs.RootOwner, 0666), msrc, ttyDevMajor, ttyDevMinor), + + "net": newDirectory(ctx, map[string]*fs.Inode{ + "tun": newCharacterDevice(ctx, newNetTunDevice(ctx, fs.RootOwner, 0666), msrc, netTunDevMajor, netTunDevMinor), + }, msrc), } iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555)) diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go new file mode 100644 index 000000000..755644488 --- /dev/null +++ b/pkg/sentry/fs/dev/net_tun.go @@ -0,0 +1,170 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket/netstack" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip/link/tun" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + netTunDevMajor = 10 + netTunDevMinor = 200 +) + +// +stateify savable +type netTunInodeOperations struct { + fsutil.InodeGenericChecker `state:"nosave"` + fsutil.InodeNoExtendedAttributes `state:"nosave"` + fsutil.InodeNoopAllocate `state:"nosave"` + fsutil.InodeNoopRelease `state:"nosave"` + fsutil.InodeNoopTruncate `state:"nosave"` + fsutil.InodeNoopWriteOut `state:"nosave"` + fsutil.InodeNotDirectory `state:"nosave"` + fsutil.InodeNotMappable `state:"nosave"` + fsutil.InodeNotSocket `state:"nosave"` + fsutil.InodeNotSymlink `state:"nosave"` + fsutil.InodeVirtual `state:"nosave"` + + fsutil.InodeSimpleAttributes +} + +var _ fs.InodeOperations = (*netTunInodeOperations)(nil) + +func newNetTunDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMode) *netTunInodeOperations { + return &netTunInodeOperations{ + InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, fs.FilePermsFromMode(mode), linux.TMPFS_MAGIC), + } +} + +// GetFile implements fs.InodeOperations.GetFile. +func (iops *netTunInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + return fs.NewFile(ctx, d, flags, &netTunFileOperations{}), nil +} + +// +stateify savable +type netTunFileOperations struct { + fsutil.FileNoSeek `state:"nosave"` + fsutil.FileNoMMap `state:"nosave"` + fsutil.FileNoSplice `state:"nosave"` + fsutil.FileNoopFlush `state:"nosave"` + fsutil.FileNoopFsync `state:"nosave"` + fsutil.FileNotDirReaddir `state:"nosave"` + fsutil.FileUseInodeUnstableAttr `state:"nosave"` + + device tun.Device +} + +var _ fs.FileOperations = (*netTunFileOperations)(nil) + +// Release implements fs.FileOperations.Release. +func (fops *netTunFileOperations) Release() { + fops.device.Release() +} + +// Ioctl implements fs.FileOperations.Ioctl. +func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + request := args[1].Uint() + data := args[2].Pointer() + + switch request { + case linux.TUNSETIFF: + t := kernel.TaskFromContext(ctx) + if t == nil { + panic("Ioctl should be called from a task context") + } + if !t.HasCapability(linux.CAP_NET_ADMIN) { + return 0, syserror.EPERM + } + stack, ok := t.NetworkContext().(*netstack.Stack) + if !ok { + return 0, syserror.EINVAL + } + + var req linux.IFReq + if _, err := usermem.CopyObjectIn(ctx, io, data, &req, usermem.IOOpts{ + AddressSpaceActive: true, + }); err != nil { + return 0, err + } + flags := usermem.ByteOrder.Uint16(req.Data[:]) + return 0, fops.device.SetIff(stack.Stack, req.Name(), flags) + + case linux.TUNGETIFF: + var req linux.IFReq + + copy(req.IFName[:], fops.device.Name()) + + // Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately) when + // there is no sk_filter. See __tun_chr_ioctl() in net/drivers/tun.c. + flags := fops.device.Flags() | linux.IFF_NOFILTER + usermem.ByteOrder.PutUint16(req.Data[:], flags) + + _, err := usermem.CopyObjectOut(ctx, io, data, &req, usermem.IOOpts{ + AddressSpaceActive: true, + }) + return 0, err + + default: + return 0, syserror.ENOTTY + } +} + +// Write implements fs.FileOperations.Write. +func (fops *netTunFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) { + data := make([]byte, src.NumBytes()) + if _, err := src.CopyIn(ctx, data); err != nil { + return 0, err + } + return fops.device.Write(data) +} + +// Read implements fs.FileOperations.Read. +func (fops *netTunFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { + data, err := fops.device.Read() + if err != nil { + return 0, err + } + n, err := dst.CopyOut(ctx, data) + if n > 0 && n < len(data) { + // Not an error for partial copying. Packet truncated. + err = nil + } + return int64(n), err +} + +// Readiness implements watier.Waitable.Readiness. +func (fops *netTunFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask { + return fops.device.Readiness(mask) +} + +// EventRegister implements watier.Waitable.EventRegister. +func (fops *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + fops.device.EventRegister(e, mask) +} + +// EventUnregister implements watier.Waitable.EventUnregister. +func (fops *netTunFileOperations) EventUnregister(e *waiter.Entry) { + fops.device.EventUnregister(e) +} diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go index acab0411a..e0b32e1c1 100644 --- a/pkg/sentry/fs/dirent.go +++ b/pkg/sentry/fs/dirent.go @@ -1438,8 +1438,8 @@ func lockForRename(oldParent *Dirent, oldName string, newParent *Dirent, newName }, nil } -func checkSticky(ctx context.Context, dir *Dirent, victim *Dirent) error { - uattr, err := dir.Inode.UnstableAttr(ctx) +func (d *Dirent) checkSticky(ctx context.Context, victim *Dirent) error { + uattr, err := d.Inode.UnstableAttr(ctx) if err != nil { return syserror.EPERM } @@ -1465,30 +1465,33 @@ func checkSticky(ctx context.Context, dir *Dirent, victim *Dirent) error { return syserror.EPERM } -// MayDelete determines whether `name`, a child of `dir`, can be deleted or +// MayDelete determines whether `name`, a child of `d`, can be deleted or // renamed by `ctx`. // // Compare Linux kernel fs/namei.c:may_delete. -func MayDelete(ctx context.Context, root, dir *Dirent, name string) error { - if err := dir.Inode.CheckPermission(ctx, PermMask{Write: true, Execute: true}); err != nil { +func (d *Dirent) MayDelete(ctx context.Context, root *Dirent, name string) error { + if err := d.Inode.CheckPermission(ctx, PermMask{Write: true, Execute: true}); err != nil { return err } - victim, err := dir.Walk(ctx, root, name) + unlock := d.lockDirectory() + defer unlock() + + victim, err := d.walk(ctx, root, name, true /* may unlock */) if err != nil { return err } defer victim.DecRef() - return mayDelete(ctx, dir, victim) + return d.mayDelete(ctx, victim) } // mayDelete determines whether `victim`, a child of `dir`, can be deleted or // renamed by `ctx`. // // Preconditions: `dir` is writable and executable by `ctx`. -func mayDelete(ctx context.Context, dir, victim *Dirent) error { - if err := checkSticky(ctx, dir, victim); err != nil { +func (d *Dirent) mayDelete(ctx context.Context, victim *Dirent) error { + if err := d.checkSticky(ctx, victim); err != nil { return err } @@ -1542,7 +1545,7 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string defer renamed.DecRef() // Check that the renamed dirent is deletable. - if err := mayDelete(ctx, oldParent, renamed); err != nil { + if err := oldParent.mayDelete(ctx, renamed); err != nil { return err } @@ -1580,7 +1583,7 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string // across the Rename, so must call DecRef manually (no defer). // Check that we can delete replaced. - if err := mayDelete(ctx, newParent, replaced); err != nil { + if err := newParent.mayDelete(ctx, replaced); err != nil { replaced.DecRef() return err } diff --git a/pkg/sentry/fs/mount_test.go b/pkg/sentry/fs/mount_test.go index e672a438c..a3d10770b 100644 --- a/pkg/sentry/fs/mount_test.go +++ b/pkg/sentry/fs/mount_test.go @@ -36,11 +36,12 @@ func mountPathsAre(root *Dirent, got []*Mount, want ...string) error { gotPaths := make(map[string]struct{}, len(got)) gotStr := make([]string, len(got)) for i, g := range got { - groot := g.Root() - name, _ := groot.FullName(root) - groot.DecRef() - gotStr[i] = name - gotPaths[name] = struct{}{} + if groot := g.Root(); groot != nil { + name, _ := groot.FullName(root) + groot.DecRef() + gotStr[i] = name + gotPaths[name] = struct{}{} + } } if len(got) != len(want) { return fmt.Errorf("mount paths are different, got: %q, want: %q", gotStr, want) diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go index 574a2cc91..c7981f66e 100644 --- a/pkg/sentry/fs/mounts.go +++ b/pkg/sentry/fs/mounts.go @@ -100,10 +100,14 @@ func newUndoMount(d *Dirent) *Mount { } } -// Root returns the root dirent of this mount. Callers must call DecRef on the -// returned dirent. +// Root returns the root dirent of this mount. +// +// This may return nil if the mount has already been free. Callers must handle this +// case appropriately. If non-nil, callers must call DecRef on the returned *Dirent. func (m *Mount) Root() *Dirent { - m.root.IncRef() + if !m.root.TryIncRef() { + return nil + } return m.root } diff --git a/pkg/sentry/fs/proc/mounts.go b/pkg/sentry/fs/proc/mounts.go index c10888100..94deb553b 100644 --- a/pkg/sentry/fs/proc/mounts.go +++ b/pkg/sentry/fs/proc/mounts.go @@ -60,13 +60,15 @@ func forEachMount(t *kernel.Task, fn func(string, *fs.Mount)) { }) for _, m := range ms { mroot := m.Root() + if mroot == nil { + continue // No longer valid. + } mountPath, desc := mroot.FullName(rootDir) mroot.DecRef() if !desc { // MountSources that are not descendants of the chroot jail are ignored. continue } - fn(mountPath, m) } } @@ -91,6 +93,12 @@ func (mif *mountInfoFile) ReadSeqFileData(ctx context.Context, handle seqfile.Se var buf bytes.Buffer forEachMount(mif.t, func(mountPath string, m *fs.Mount) { + mroot := m.Root() + if mroot == nil { + return // No longer valid. + } + defer mroot.DecRef() + // Format: // 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue // (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11) @@ -107,9 +115,6 @@ func (mif *mountInfoFile) ReadSeqFileData(ctx context.Context, handle seqfile.Se // (3) Major:Minor device ID. We don't have a superblock, so we // just use the root inode device number. - mroot := m.Root() - defer mroot.DecRef() - sa := mroot.Inode.StableAttr fmt.Fprintf(&buf, "%d:%d ", sa.DeviceFileMajor, sa.DeviceFileMinor) @@ -207,6 +212,9 @@ func (mf *mountsFile) ReadSeqFileData(ctx context.Context, handle seqfile.SeqHan // // The "needs dump"and fsck flags are always 0, which is allowed. root := m.Root() + if root == nil { + return // No longer valid. + } defer root.DecRef() flags := root.Inode.MountSource.Flags diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 6f2775344..95d5817ff 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -43,7 +43,10 @@ import ( // newNet creates a new proc net entry. func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSource) *fs.Inode { var contents map[string]*fs.Inode - if s := p.k.NetworkStack(); s != nil { + // TODO(gvisor.dev/issue/1833): Support for using the network stack in the + // network namespace of the calling process. We should make this per-process, + // a.k.a. /proc/PID/net, and make /proc/net a symlink to /proc/self/net. + if s := p.k.RootNetworkNamespace().Stack(); s != nil { contents = map[string]*fs.Inode{ "dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc), "snmp": seqfile.NewSeqFileInode(ctx, &netSnmp{s: s}, msrc), diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index 0772d4ae4..d4c4b533d 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -357,7 +357,9 @@ func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s ine func (p *proc) newSysNetDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode { var contents map[string]*fs.Inode - if s := p.k.NetworkStack(); s != nil { + // TODO(gvisor.dev/issue/1833): Support for using the network stack in the + // network namespace of the calling process. + if s := p.k.RootNetworkNamespace().Stack(); s != nil { contents = map[string]*fs.Inode{ "ipv4": p.newSysNetIPv4Dir(ctx, msrc, s), "core": p.newSysNetCore(ctx, msrc, s), diff --git a/pkg/sentry/fsbridge/vfs.go b/pkg/sentry/fsbridge/vfs.go index e657c39bc..6aa17bfc1 100644 --- a/pkg/sentry/fsbridge/vfs.go +++ b/pkg/sentry/fsbridge/vfs.go @@ -117,15 +117,19 @@ func NewVFSLookup(mntns *vfs.MountNamespace, root, workingDir vfs.VirtualDentry) // default anyways. // // TODO(gvisor.dev/issue/1623): Check mount has read and exec permission. -func (l *vfsLookup) OpenPath(ctx context.Context, path string, opts vfs.OpenOptions, _ *uint, resolveFinal bool) (File, error) { +func (l *vfsLookup) OpenPath(ctx context.Context, pathname string, opts vfs.OpenOptions, _ *uint, resolveFinal bool) (File, error) { vfsObj := l.mntns.Root().Mount().Filesystem().VirtualFilesystem() creds := auth.CredentialsFromContext(ctx) + path := fspath.Parse(pathname) pop := &vfs.PathOperation{ Root: l.root, - Start: l.root, - Path: fspath.Parse(path), + Start: l.workingDir, + Path: path, FollowFinalSymlink: resolveFinal, } + if path.Absolute { + pop.Start = l.root + } fd, err := vfsObj.OpenAt(ctx, creds, pop, &opts) if err != nil { return nil, err diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index ce08a7d53..10c08fa90 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -73,9 +73,9 @@ func newTasksInode(inoGen InoGenerator, k *kernel.Kernel, pidns *kernel.PIDNames "meminfo": newDentry(root, inoGen.NextIno(), 0444, &meminfoData{}), "mounts": kernfs.NewStaticSymlink(root, inoGen.NextIno(), "self/mounts"), "net": newNetDir(root, inoGen, k), - "stat": newDentry(root, inoGen.NextIno(), 0444, &statData{}), + "stat": newDentry(root, inoGen.NextIno(), 0444, &statData{k: k}), "uptime": newDentry(root, inoGen.NextIno(), 0444, &uptimeData{}), - "version": newDentry(root, inoGen.NextIno(), 0444, &versionData{}), + "version": newDentry(root, inoGen.NextIno(), 0444, &versionData{k: k}), } inode := &tasksInode{ diff --git a/pkg/sentry/fsimpl/proc/tasks_net.go b/pkg/sentry/fsimpl/proc/tasks_net.go index 608fec017..d4e1812d8 100644 --- a/pkg/sentry/fsimpl/proc/tasks_net.go +++ b/pkg/sentry/fsimpl/proc/tasks_net.go @@ -39,7 +39,10 @@ import ( func newNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *kernfs.Dentry { var contents map[string]*kernfs.Dentry - if stack := k.NetworkStack(); stack != nil { + // TODO(gvisor.dev/issue/1833): Support for using the network stack in the + // network namespace of the calling process. We should make this per-process, + // a.k.a. /proc/PID/net, and make /proc/net a symlink to /proc/self/net. + if stack := k.RootNetworkNamespace().Stack(); stack != nil { const ( arp = "IP address HW type Flags HW address Mask Device\n" netlink = "sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode\n" diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index c7ce74883..3d5dc463c 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -50,7 +50,9 @@ func newSysDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *k func newSysNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *kernfs.Dentry { var contents map[string]*kernfs.Dentry - if stack := k.NetworkStack(); stack != nil { + // TODO(gvisor.dev/issue/1833): Support for using the network stack in the + // network namespace of the calling process. + if stack := k.RootNetworkNamespace().Stack(); stack != nil { contents = map[string]*kernfs.Dentry{ "ipv4": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{ "tcp_sack": newDentry(root, inoGen.NextIno(), 0644, &tcpSackData{stack: stack}), diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go index d0be32e72..488478e29 100644 --- a/pkg/sentry/fsimpl/testutil/kernel.go +++ b/pkg/sentry/fsimpl/testutil/kernel.go @@ -128,6 +128,7 @@ func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup, mntns ThreadGroup: tc, TaskContext: &kernel.TaskContext{Name: name}, Credentials: auth.CredentialsFromContext(ctx), + NetworkNamespace: k.RootNetworkNamespace(), AllowedCPUMask: sched.NewFullCPUSet(k.ApplicationCores()), UTSNamespace: kernel.UTSNamespaceFromContext(ctx), IPCNamespace: kernel.IPCNamespaceFromContext(ctx), diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index 7f7b791c4..e1b551422 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -16,7 +16,6 @@ package tmpfs import ( "fmt" - "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -347,10 +346,9 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open return nil, err } if opts.Flags&linux.O_TRUNC != 0 { - impl.mu.Lock() - impl.data.Truncate(0, impl.memFile) - atomic.StoreUint64(&impl.size, 0) - impl.mu.Unlock() + if _, err := impl.truncate(0); err != nil { + return nil, err + } } return &fd.vfsfd, nil case *directory: diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index dab346a41..711442424 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -15,6 +15,7 @@ package tmpfs import ( + "fmt" "io" "math" "sync/atomic" @@ -22,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -34,25 +36,53 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) +// regularFile is a regular (=S_IFREG) tmpfs file. type regularFile struct { inode inode // memFile is a platform.File used to allocate pages to this regularFile. memFile *pgalloc.MemoryFile - // mu protects the fields below. - mu sync.RWMutex + // mapsMu protects mappings. + mapsMu sync.Mutex `state:"nosave"` + + // mappings tracks mappings of the file into memmap.MappingSpaces. + // + // Protected by mapsMu. + mappings memmap.MappingSet + + // writableMappingPages tracks how many pages of virtual memory are mapped + // as potentially writable from this file. If a page has multiple mappings, + // each mapping is counted separately. + // + // This counter is susceptible to overflow as we can potentially count + // mappings from many VMAs. We count pages rather than bytes to slightly + // mitigate this. + // + // Protected by mapsMu. + writableMappingPages uint64 + + // dataMu protects the fields below. + dataMu sync.RWMutex // data maps offsets into the file to offsets into memFile that store // the file's data. + // + // Protected by dataMu. data fsutil.FileRangeSet - // size is the size of data, but accessed using atomic memory - // operations to avoid locking in inode.stat(). - size uint64 - // seals represents file seals on this inode. + // + // Protected by dataMu. seals uint32 + + // size is the size of data. + // + // Protected by both dataMu and inode.mu; reading it requires holding + // either mutex, while writing requires holding both AND using atomics. + // Readers that do not require consistency (like Stat) may read the + // value atomically without holding either lock. + size uint64 } func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode { @@ -66,39 +96,170 @@ func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMod // truncate grows or shrinks the file to the given size. It returns true if the // file size was updated. -func (rf *regularFile) truncate(size uint64) (bool, error) { - rf.mu.Lock() - defer rf.mu.Unlock() +func (rf *regularFile) truncate(newSize uint64) (bool, error) { + rf.inode.mu.Lock() + defer rf.inode.mu.Unlock() + return rf.truncateLocked(newSize) +} - if size == rf.size { +// Preconditions: rf.inode.mu must be held. +func (rf *regularFile) truncateLocked(newSize uint64) (bool, error) { + oldSize := rf.size + if newSize == oldSize { // Nothing to do. return false, nil } - if size > rf.size { - // Growing the file. + // Need to hold inode.mu and dataMu while modifying size. + rf.dataMu.Lock() + if newSize > oldSize { + // Can we grow the file? if rf.seals&linux.F_SEAL_GROW != 0 { - // Seal does not allow growth. + rf.dataMu.Unlock() return false, syserror.EPERM } - rf.size = size + // We only need to update the file size. + atomic.StoreUint64(&rf.size, newSize) + rf.dataMu.Unlock() return true, nil } - // Shrinking the file + // We are shrinking the file. First check if this is allowed. if rf.seals&linux.F_SEAL_SHRINK != 0 { - // Seal does not allow shrink. + rf.dataMu.Unlock() return false, syserror.EPERM } - // TODO(gvisor.dev/issues/1197): Invalidate mappings once we have - // mappings. + // Update the file size. + atomic.StoreUint64(&rf.size, newSize) + rf.dataMu.Unlock() + + // Invalidate past translations of truncated pages. + oldpgend := fs.OffsetPageEnd(int64(oldSize)) + newpgend := fs.OffsetPageEnd(int64(newSize)) + if newpgend < oldpgend { + rf.mapsMu.Lock() + rf.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{ + // Compare Linux's mm/shmem.c:shmem_setattr() => + // mm/memory.c:unmap_mapping_range(evencows=1). + InvalidatePrivate: true, + }) + rf.mapsMu.Unlock() + } - rf.data.Truncate(size, rf.memFile) - rf.size = size + // We are now guaranteed that there are no translations of truncated pages, + // and can remove them. + rf.dataMu.Lock() + rf.data.Truncate(newSize, rf.memFile) + rf.dataMu.Unlock() return true, nil } +// AddMapping implements memmap.Mappable.AddMapping. +func (rf *regularFile) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error { + rf.mapsMu.Lock() + defer rf.mapsMu.Unlock() + rf.dataMu.RLock() + defer rf.dataMu.RUnlock() + + // Reject writable mapping if F_SEAL_WRITE is set. + if rf.seals&linux.F_SEAL_WRITE != 0 && writable { + return syserror.EPERM + } + + rf.mappings.AddMapping(ms, ar, offset, writable) + if writable { + pagesBefore := rf.writableMappingPages + + // ar is guaranteed to be page aligned per memmap.Mappable. + rf.writableMappingPages += uint64(ar.Length() / usermem.PageSize) + + if rf.writableMappingPages < pagesBefore { + panic(fmt.Sprintf("Overflow while mapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, rf.writableMappingPages)) + } + } + + return nil +} + +// RemoveMapping implements memmap.Mappable.RemoveMapping. +func (rf *regularFile) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) { + rf.mapsMu.Lock() + defer rf.mapsMu.Unlock() + + rf.mappings.RemoveMapping(ms, ar, offset, writable) + + if writable { + pagesBefore := rf.writableMappingPages + + // ar is guaranteed to be page aligned per memmap.Mappable. + rf.writableMappingPages -= uint64(ar.Length() / usermem.PageSize) + + if rf.writableMappingPages > pagesBefore { + panic(fmt.Sprintf("Underflow while unmapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, rf.writableMappingPages)) + } + } +} + +// CopyMapping implements memmap.Mappable.CopyMapping. +func (rf *regularFile) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error { + return rf.AddMapping(ctx, ms, dstAR, offset, writable) +} + +// Translate implements memmap.Mappable.Translate. +func (rf *regularFile) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { + rf.dataMu.Lock() + defer rf.dataMu.Unlock() + + // Constrain translations to f.attr.Size (rounded up) to prevent + // translation to pages that may be concurrently truncated. + pgend := fs.OffsetPageEnd(int64(rf.size)) + var beyondEOF bool + if required.End > pgend { + if required.Start >= pgend { + return nil, &memmap.BusError{io.EOF} + } + beyondEOF = true + required.End = pgend + } + if optional.End > pgend { + optional.End = pgend + } + + cerr := rf.data.Fill(ctx, required, optional, rf.memFile, usage.Tmpfs, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) { + // Newly-allocated pages are zeroed, so we don't need to do anything. + return dsts.NumBytes(), nil + }) + + var ts []memmap.Translation + var translatedEnd uint64 + for seg := rf.data.FindSegment(required.Start); seg.Ok() && seg.Start() < required.End; seg, _ = seg.NextNonEmpty() { + segMR := seg.Range().Intersect(optional) + ts = append(ts, memmap.Translation{ + Source: segMR, + File: rf.memFile, + Offset: seg.FileRangeOf(segMR).Start, + Perms: usermem.AnyAccess, + }) + translatedEnd = segMR.End + } + + // Don't return the error returned by f.data.Fill if it occurred outside of + // required. + if translatedEnd < required.End && cerr != nil { + return ts, &memmap.BusError{cerr} + } + if beyondEOF { + return ts, &memmap.BusError{io.EOF} + } + return ts, nil +} + +// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. +func (*regularFile) InvalidateUnsavable(context.Context) error { + return nil +} + type regularFileFD struct { fileDescription @@ -152,8 +313,10 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off // Overflow. return 0, syserror.EFBIG } + f.inode.mu.Lock() rw := getRegularFileReadWriter(f, offset) n, err := src.CopyInTo(ctx, rw) + f.inode.mu.Unlock() putRegularFileReadWriter(rw) return n, err } @@ -215,6 +378,12 @@ func (fd *regularFileFD) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, rng return nil } +// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. +func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { + file := fd.inode().impl.(*regularFile) + return vfs.GenericConfigureMMap(&fd.vfsfd, file, opts) +} + // regularFileReadWriter implements safemem.Reader and Safemem.Writer. type regularFileReadWriter struct { file *regularFile @@ -244,14 +413,15 @@ func putRegularFileReadWriter(rw *regularFileReadWriter) { // ReadToBlocks implements safemem.Reader.ReadToBlocks. func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { - rw.file.mu.RLock() + rw.file.dataMu.RLock() + defer rw.file.dataMu.RUnlock() + size := rw.file.size // Compute the range to read (limited by file size and overflow-checked). - if rw.off >= rw.file.size { - rw.file.mu.RUnlock() + if rw.off >= size { return 0, io.EOF } - end := rw.file.size + end := size if rend := rw.off + dsts.NumBytes(); rend > rw.off && rend < end { end = rend } @@ -265,7 +435,6 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er // Get internal mappings. ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read) if err != nil { - rw.file.mu.RUnlock() return done, err } @@ -275,7 +444,6 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er rw.off += uint64(n) dsts = dsts.DropFirst64(n) if err != nil { - rw.file.mu.RUnlock() return done, err } @@ -291,7 +459,6 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er rw.off += uint64(n) dsts = dsts.DropFirst64(n) if err != nil { - rw.file.mu.RUnlock() return done, err } @@ -299,13 +466,16 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{} } } - rw.file.mu.RUnlock() return done, nil } // WriteFromBlocks implements safemem.Writer.WriteFromBlocks. +// +// Preconditions: inode.mu must be held. func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { - rw.file.mu.Lock() + // Hold dataMu so we can modify size. + rw.file.dataMu.Lock() + defer rw.file.dataMu.Unlock() // Compute the range to write (overflow-checked). end := rw.off + srcs.NumBytes() @@ -316,7 +486,6 @@ func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, // Check if seals prevent either file growth or all writes. switch { case rw.file.seals&linux.F_SEAL_WRITE != 0: // Write sealed - rw.file.mu.Unlock() return 0, syserror.EPERM case end > rw.file.size && rw.file.seals&linux.F_SEAL_GROW != 0: // Grow sealed // When growth is sealed, Linux effectively allows writes which would @@ -338,7 +507,6 @@ func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, } if end <= rw.off { // Truncation would result in no data being written. - rw.file.mu.Unlock() return 0, syserror.EPERM } } @@ -395,9 +563,8 @@ exitLoop: // If the write ends beyond the file's previous size, it causes the // file to grow. if rw.off > rw.file.size { - atomic.StoreUint64(&rw.file.size, rw.off) + rw.file.size = rw.off } - rw.file.mu.Unlock() return done, retErr } diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index c5bb17562..521206305 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -18,9 +18,10 @@ // Lock order: // // filesystem.mu -// regularFileFD.offMu -// regularFile.mu // inode.mu +// regularFileFD.offMu +// regularFile.mapsMu +// regularFile.dataMu package tmpfs import ( @@ -226,12 +227,15 @@ func (i *inode) tryIncRef() bool { func (i *inode) decRef() { if refs := atomic.AddInt64(&i.refs, -1); refs == 0 { - // This is unnecessary; it's mostly to simulate what tmpfs would do. if regFile, ok := i.impl.(*regularFile); ok { - regFile.mu.Lock() + // Hold inode.mu and regFile.dataMu while mutating + // size. + i.mu.Lock() + regFile.dataMu.Lock() regFile.data.DropAll(regFile.memFile) atomic.StoreUint64(®File.size, 0) - regFile.mu.Unlock() + regFile.dataMu.Unlock() + i.mu.Unlock() } } else if refs < 0 { panic("tmpfs.inode.decRef() called without holding a reference") @@ -320,7 +324,7 @@ func (i *inode) setStat(stat linux.Statx) error { if mask&linux.STATX_SIZE != 0 { switch impl := i.impl.(type) { case *regularFile: - updated, err := impl.truncate(stat.Size) + updated, err := impl.truncateLocked(stat.Size) if err != nil { return err } diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD index 334432abf..07bf39fed 100644 --- a/pkg/sentry/inet/BUILD +++ b/pkg/sentry/inet/BUILD @@ -10,6 +10,7 @@ go_library( srcs = [ "context.go", "inet.go", + "namespace.go", "test_stack.go", ], deps = [ diff --git a/pkg/sentry/inet/namespace.go b/pkg/sentry/inet/namespace.go new file mode 100644 index 000000000..c16667e7f --- /dev/null +++ b/pkg/sentry/inet/namespace.go @@ -0,0 +1,99 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inet + +// Namespace represents a network namespace. See network_namespaces(7). +// +// +stateify savable +type Namespace struct { + // stack is the network stack implementation of this network namespace. + stack Stack `state:"nosave"` + + // creator allows kernel to create new network stack for network namespaces. + // If nil, no networking will function if network is namespaced. + creator NetworkStackCreator + + // isRoot indicates whether this is the root network namespace. + isRoot bool +} + +// NewRootNamespace creates the root network namespace, with creator +// allowing new network namespaces to be created. If creator is nil, no +// networking will function if the network is namespaced. +func NewRootNamespace(stack Stack, creator NetworkStackCreator) *Namespace { + return &Namespace{ + stack: stack, + creator: creator, + isRoot: true, + } +} + +// NewNamespace creates a new network namespace from the root. +func NewNamespace(root *Namespace) *Namespace { + n := &Namespace{ + creator: root.creator, + } + n.init() + return n +} + +// Stack returns the network stack of n. Stack may return nil if no network +// stack is configured. +func (n *Namespace) Stack() Stack { + return n.stack +} + +// IsRoot returns whether n is the root network namespace. +func (n *Namespace) IsRoot() bool { + return n.isRoot +} + +// RestoreRootStack restores the root network namespace with stack. This should +// only be called when restoring kernel. +func (n *Namespace) RestoreRootStack(stack Stack) { + if !n.isRoot { + panic("RestoreRootStack can only be called on root network namespace") + } + if n.stack != nil { + panic("RestoreRootStack called after a stack has already been set") + } + n.stack = stack +} + +func (n *Namespace) init() { + // Root network namespace will have stack assigned later. + if n.isRoot { + return + } + if n.creator != nil { + var err error + n.stack, err = n.creator.CreateStack() + if err != nil { + panic(err) + } + } +} + +// afterLoad is invoked by stateify. +func (n *Namespace) afterLoad() { + n.init() +} + +// NetworkStackCreator allows new instances of a network stack to be created. It +// is used by the kernel to create new network namespaces when requested. +type NetworkStackCreator interface { + // CreateStack creates a new network stack for a network namespace. + CreateStack() (Stack, error) +} diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index 23b88f7a6..58001d56c 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -296,6 +296,50 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags return fds, nil } +// NewFDVFS2 allocates a file descriptor greater than or equal to minfd for +// the given file description. If it succeeds, it takes a reference on file. +func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) { + if minfd < 0 { + // Don't accept negative FDs. + return -1, syscall.EINVAL + } + + // Default limit. + end := int32(math.MaxInt32) + + // Ensure we don't get past the provided limit. + if limitSet := limits.FromContext(ctx); limitSet != nil { + lim := limitSet.Get(limits.NumberOfFiles) + if lim.Cur != limits.Infinity { + end = int32(lim.Cur) + } + if minfd >= end { + return -1, syscall.EMFILE + } + } + + f.mu.Lock() + defer f.mu.Unlock() + + // From f.next to find available fd. + fd := minfd + if fd < f.next { + fd = f.next + } + for fd < end { + if d, _, _ := f.get(fd); d == nil { + f.setVFS2(fd, file, flags) + if fd == f.next { + // Update next search start position. + f.next = fd + 1 + } + return fd, nil + } + fd++ + } + return -1, syscall.EMFILE +} + // NewFDAt sets the file reference for the given FD. If there is an active // reference for that FD, the ref count for that existing reference is // decremented. @@ -316,9 +360,6 @@ func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2 return syscall.EBADF } - f.mu.Lock() - defer f.mu.Unlock() - // Check the limit for the provided file. if limitSet := limits.FromContext(ctx); limitSet != nil { if lim := limitSet.Get(limits.NumberOfFiles); lim.Cur != limits.Infinity && uint64(fd) >= lim.Cur { @@ -327,6 +368,8 @@ func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2 } // Install the entry. + f.mu.Lock() + defer f.mu.Unlock() f.setAll(fd, file, fileVFS2, flags) return nil } diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go index 7218aa24e..47f78df9a 100644 --- a/pkg/sentry/kernel/fs_context.go +++ b/pkg/sentry/kernel/fs_context.go @@ -244,6 +244,28 @@ func (f *FSContext) SetRootDirectory(d *fs.Dirent) { old.DecRef() } +// SetRootDirectoryVFS2 sets the root directory. It takes a reference on vd. +// +// This is not a valid call after free. +func (f *FSContext) SetRootDirectoryVFS2(vd vfs.VirtualDentry) { + if !vd.Ok() { + panic("FSContext.SetRootDirectoryVFS2 called with zero-value VirtualDentry") + } + + f.mu.Lock() + + if !f.rootVFS2.Ok() { + f.mu.Unlock() + panic(fmt.Sprintf("FSContext.SetRootDirectoryVFS2(%v)) called after destroy", vd)) + } + + old := f.rootVFS2 + vd.IncRef() + f.rootVFS2 = vd + f.mu.Unlock() + old.DecRef() +} + // Umask returns the current umask. func (f *FSContext) Umask() uint { f.mu.Lock() diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 7da0368f1..8b76750e9 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -111,7 +111,7 @@ type Kernel struct { timekeeper *Timekeeper tasks *TaskSet rootUserNamespace *auth.UserNamespace - networkStack inet.Stack `state:"nosave"` + rootNetworkNamespace *inet.Namespace applicationCores uint useHostCores bool extraAuxv []arch.AuxEntry @@ -247,6 +247,10 @@ type Kernel struct { // VFS keeps the filesystem state used across the kernel. vfs vfs.VirtualFilesystem + + // If set to true, report address space activation waits as if the task is in + // external wait so that the watchdog doesn't report the task stuck. + SleepForAddressSpaceActivation bool } // InitKernelArgs holds arguments to Init. @@ -260,8 +264,9 @@ type InitKernelArgs struct { // RootUserNamespace is the root user namespace. RootUserNamespace *auth.UserNamespace - // NetworkStack is the TCP/IP network stack. NetworkStack may be nil. - NetworkStack inet.Stack + // RootNetworkNamespace is the root network namespace. If nil, no networking + // will be available. + RootNetworkNamespace *inet.Namespace // ApplicationCores is the number of logical CPUs visible to sandboxed // applications. The set of logical CPU IDs is [0, ApplicationCores); thus @@ -320,7 +325,10 @@ func (k *Kernel) Init(args InitKernelArgs) error { k.rootUTSNamespace = args.RootUTSNamespace k.rootIPCNamespace = args.RootIPCNamespace k.rootAbstractSocketNamespace = args.RootAbstractSocketNamespace - k.networkStack = args.NetworkStack + k.rootNetworkNamespace = args.RootNetworkNamespace + if k.rootNetworkNamespace == nil { + k.rootNetworkNamespace = inet.NewRootNamespace(nil, nil) + } k.applicationCores = args.ApplicationCores if args.UseHostCores { k.useHostCores = true @@ -543,8 +551,6 @@ func (ts *TaskSet) unregisterEpollWaiters() { func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) error { loadStart := time.Now() - k.networkStack = net - initAppCores := k.applicationCores // Load the pre-saved CPUID FeatureSet. @@ -575,6 +581,10 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) log.Infof("Kernel load stats: %s", &stats) log.Infof("Kernel load took [%s].", time.Since(kernelStart)) + // rootNetworkNamespace should be populated after loading the state file. + // Restore the root network stack. + k.rootNetworkNamespace.RestoreRootStack(net) + // Load the memory file's state. memoryStart := time.Now() if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil { @@ -905,6 +915,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, FSContext: fsContext, FDTable: args.FDTable, Credentials: args.Credentials, + NetworkNamespace: k.RootNetworkNamespace(), AllowedCPUMask: sched.NewFullCPUSet(k.applicationCores), UTSNamespace: args.UTSNamespace, IPCNamespace: args.IPCNamespace, @@ -1255,10 +1266,9 @@ func (k *Kernel) RootAbstractSocketNamespace() *AbstractSocketNamespace { return k.rootAbstractSocketNamespace } -// NetworkStack returns the network stack. NetworkStack may return nil if no -// network stack is available. -func (k *Kernel) NetworkStack() inet.Stack { - return k.networkStack +// RootNetworkNamespace returns the root network namespace, always non-nil. +func (k *Kernel) RootNetworkNamespace() *inet.Namespace { + return k.rootNetworkNamespace } // GlobalInit returns the thread group with ID 1 in the root PID namespace, or diff --git a/pkg/sentry/kernel/rseq.go b/pkg/sentry/kernel/rseq.go index 18416643b..ded95f532 100644 --- a/pkg/sentry/kernel/rseq.go +++ b/pkg/sentry/kernel/rseq.go @@ -304,7 +304,7 @@ func (t *Task) rseqAddrInterrupt() { } var cs linux.RSeqCriticalSection - if _, err := cs.CopyIn(t, critAddr); err != nil { + if err := cs.CopyIn(t, critAddr); err != nil { t.Debugf("Failed to copy critical section from %#x for rseq: %v", critAddr, err) t.forceSignal(linux.SIGSEGV, false /* unconditional */) t.SendSignal(SignalInfoPriv(linux.SIGSEGV)) diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index a3443ff21..2cee2e6ed 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -486,13 +486,10 @@ type Task struct { numaPolicy int32 numaNodeMask uint64 - // If netns is true, the task is in a non-root network namespace. Network - // namespaces aren't currently implemented in full; being in a network - // namespace simply prevents the task from observing any network devices - // (including loopback) or using abstract socket addresses (see unix(7)). + // netns is the task's network namespace. netns is never nil. // - // netns is protected by mu. netns is owned by the task goroutine. - netns bool + // netns is protected by mu. + netns *inet.Namespace // If rseqPreempted is true, before the next call to p.Switch(), // interrupt rseq critical regions as defined by rseqAddr and @@ -792,6 +789,15 @@ func (t *Task) NewFDFrom(fd int32, file *fs.File, flags FDFlags) (int32, error) return fds[0], nil } +// NewFDFromVFS2 is a convenience wrapper for t.FDTable().NewFDVFS2. +// +// This automatically passes the task as the context. +// +// Precondition: same as FDTable.Get. +func (t *Task) NewFDFromVFS2(fd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) { + return t.fdTable.NewFDVFS2(t, fd, file, flags) +} + // NewFDAt is a convenience wrapper for t.FDTable().NewFDAt. // // This automatically passes the task as the context. @@ -801,6 +807,15 @@ func (t *Task) NewFDAt(fd int32, file *fs.File, flags FDFlags) error { return t.fdTable.NewFDAt(t, fd, file, flags) } +// NewFDAtVFS2 is a convenience wrapper for t.FDTable().NewFDAtVFS2. +// +// This automatically passes the task as the context. +// +// Precondition: same as FDTable. +func (t *Task) NewFDAtVFS2(fd int32, file *vfs.FileDescription, flags FDFlags) error { + return t.fdTable.NewFDAtVFS2(t, fd, file, flags) +} + // WithMuLocked executes f with t.mu locked. func (t *Task) WithMuLocked(f func(*Task)) { t.mu.Lock() diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index ba74b4c1c..78866f280 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -17,6 +17,7 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bpf" + "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -54,8 +55,7 @@ type SharingOptions struct { NewUserNamespace bool // If NewNetworkNamespace is true, the task should have an independent - // network namespace. (Note that network namespaces are not really - // implemented; see comment on Task.netns for details.) + // network namespace. NewNetworkNamespace bool // If NewFiles is true, the task should use an independent file descriptor @@ -199,6 +199,11 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { ipcns = NewIPCNamespace(userns) } + netns := t.NetworkNamespace() + if opts.NewNetworkNamespace { + netns = inet.NewNamespace(netns) + } + // TODO(b/63601033): Implement CLONE_NEWNS. mntnsVFS2 := t.mountNamespaceVFS2 if mntnsVFS2 != nil { @@ -268,7 +273,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { FDTable: fdTable, Credentials: creds, Niceness: t.Niceness(), - NetworkNamespaced: t.netns, + NetworkNamespace: netns, AllowedCPUMask: t.CPUMask(), UTSNamespace: utsns, IPCNamespace: ipcns, @@ -283,9 +288,6 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { } else { cfg.InheritParent = t } - if opts.NewNetworkNamespace { - cfg.NetworkNamespaced = true - } nt, err := t.tg.pidns.owner.NewTask(cfg) if err != nil { if opts.NewThreadGroup { @@ -482,7 +484,7 @@ func (t *Task) Unshare(opts *SharingOptions) error { t.mu.Unlock() return syserror.EPERM } - t.netns = true + t.netns = inet.NewNamespace(t.netns) } if opts.NewUTSNamespace { if !haveCapSysAdmin { diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go index 2be982684..0158b1788 100644 --- a/pkg/sentry/kernel/task_context.go +++ b/pkg/sentry/kernel/task_context.go @@ -140,7 +140,7 @@ func (k *Kernel) LoadTaskImage(ctx context.Context, args loader.LoadArgs) (*Task } // Prepare a new user address space to load into. - m := mm.NewMemoryManager(k, k) + m := mm.NewMemoryManager(k, k, k.SleepForAddressSpaceActivation) defer m.DecUsers(ctx) args.MemoryManager = m diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go index 8f57a34a6..00c425cca 100644 --- a/pkg/sentry/kernel/task_exec.go +++ b/pkg/sentry/kernel/task_exec.go @@ -220,7 +220,7 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { t.mu.Unlock() t.unstopVforkParent() // NOTE(b/30316266): All locks must be dropped prior to calling Activate. - t.MemoryManager().Activate() + t.MemoryManager().Activate(t) t.ptraceExec(oldTID) return (*runSyscallExit)(nil) diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go index 6d737d3e5..eeccaa197 100644 --- a/pkg/sentry/kernel/task_log.go +++ b/pkg/sentry/kernel/task_log.go @@ -32,21 +32,21 @@ const ( // Infof logs an formatted info message by calling log.Infof. func (t *Task) Infof(fmt string, v ...interface{}) { if log.IsLogging(log.Info) { - log.Infof(t.logPrefix.Load().(string)+fmt, v...) + log.InfofAtDepth(1, t.logPrefix.Load().(string)+fmt, v...) } } // Warningf logs a warning string by calling log.Warningf. func (t *Task) Warningf(fmt string, v ...interface{}) { if log.IsLogging(log.Warning) { - log.Warningf(t.logPrefix.Load().(string)+fmt, v...) + log.WarningfAtDepth(1, t.logPrefix.Load().(string)+fmt, v...) } } // Debugf creates a debug string that includes the task ID. func (t *Task) Debugf(fmt string, v ...interface{}) { if log.IsLogging(log.Debug) { - log.Debugf(t.logPrefix.Load().(string)+fmt, v...) + log.DebugfAtDepth(1, t.logPrefix.Load().(string)+fmt, v...) } } diff --git a/pkg/sentry/kernel/task_net.go b/pkg/sentry/kernel/task_net.go index 172a31e1d..f7711232c 100644 --- a/pkg/sentry/kernel/task_net.go +++ b/pkg/sentry/kernel/task_net.go @@ -22,14 +22,23 @@ import ( func (t *Task) IsNetworkNamespaced() bool { t.mu.Lock() defer t.mu.Unlock() - return t.netns + return !t.netns.IsRoot() } // NetworkContext returns the network stack used by the task. NetworkContext // may return nil if no network stack is available. +// +// TODO(gvisor.dev/issue/1833): Migrate callers of this method to +// NetworkNamespace(). func (t *Task) NetworkContext() inet.Stack { - if t.IsNetworkNamespaced() { - return nil - } - return t.k.networkStack + t.mu.Lock() + defer t.mu.Unlock() + return t.netns.Stack() +} + +// NetworkNamespace returns the network namespace observed by the task. +func (t *Task) NetworkNamespace() *inet.Namespace { + t.mu.Lock() + defer t.mu.Unlock() + return t.netns } diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index f9236a842..a5035bb7f 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -17,6 +17,7 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/sentry/kernel/sched" @@ -65,9 +66,8 @@ type TaskConfig struct { // Niceness is the niceness of the new task. Niceness int - // If NetworkNamespaced is true, the new task should observe a non-root - // network namespace. - NetworkNamespaced bool + // NetworkNamespace is the network namespace to be used for the new task. + NetworkNamespace *inet.Namespace // AllowedCPUMask contains the cpus that this task can run on. AllowedCPUMask sched.CPUSet @@ -133,7 +133,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { allowedCPUMask: cfg.AllowedCPUMask.Copy(), ioUsage: &usage.IO{}, niceness: cfg.Niceness, - netns: cfg.NetworkNamespaced, + netns: cfg.NetworkNamespace, utsns: cfg.UTSNamespace, ipcns: cfg.IPCNamespace, abstractSockets: cfg.AbstractSocketNamespace, diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go index 2bf3ce8a8..b02044ad2 100644 --- a/pkg/sentry/kernel/task_usermem.go +++ b/pkg/sentry/kernel/task_usermem.go @@ -30,7 +30,7 @@ var MAX_RW_COUNT = int(usermem.Addr(math.MaxInt32).RoundDown()) // Activate ensures that the task has an active address space. func (t *Task) Activate() { if mm := t.MemoryManager(); mm != nil { - if err := mm.Activate(); err != nil { + if err := mm.Activate(t); err != nil { panic("unable to activate mm: " + err.Error()) } } diff --git a/pkg/sentry/mm/address_space.go b/pkg/sentry/mm/address_space.go index 94d39af60..0332fc71c 100644 --- a/pkg/sentry/mm/address_space.go +++ b/pkg/sentry/mm/address_space.go @@ -18,6 +18,7 @@ import ( "fmt" "sync/atomic" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/usermem" ) @@ -38,7 +39,7 @@ func (mm *MemoryManager) AddressSpace() platform.AddressSpace { // // When this MemoryManager is no longer needed by a task, it should call // Deactivate to release the reference. -func (mm *MemoryManager) Activate() error { +func (mm *MemoryManager) Activate(ctx context.Context) error { // Fast path: the MemoryManager already has an active // platform.AddressSpace, and we just need to indicate that we need it too. for { @@ -91,16 +92,20 @@ func (mm *MemoryManager) Activate() error { if as == nil { // AddressSpace is unavailable, we must wait. // - // activeMu must not be held while waiting, as the user - // of the address space we are waiting on may attempt - // to take activeMu. - // - // Don't call UninterruptibleSleepStart to register the - // wait to allow the watchdog stuck task to trigger in - // case a process is starved waiting for the address - // space. + // activeMu must not be held while waiting, as the user of the address + // space we are waiting on may attempt to take activeMu. mm.activeMu.Unlock() + + sleep := mm.p.CooperativelySchedulesAddressSpace() && mm.sleepForActivation + if sleep { + // Mark this task sleeping while waiting for the address space to + // prevent the watchdog from reporting it as a stuck task. + ctx.UninterruptibleSleepStart(false) + } <-c + if sleep { + ctx.UninterruptibleSleepFinish(false) + } continue } diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go index 3c263ebaa..d8a5b9d29 100644 --- a/pkg/sentry/mm/lifecycle.go +++ b/pkg/sentry/mm/lifecycle.go @@ -28,16 +28,17 @@ import ( ) // NewMemoryManager returns a new MemoryManager with no mappings and 1 user. -func NewMemoryManager(p platform.Platform, mfp pgalloc.MemoryFileProvider) *MemoryManager { +func NewMemoryManager(p platform.Platform, mfp pgalloc.MemoryFileProvider, sleepForActivation bool) *MemoryManager { return &MemoryManager{ - p: p, - mfp: mfp, - haveASIO: p.SupportsAddressSpaceIO(), - privateRefs: &privateRefs{}, - users: 1, - auxv: arch.Auxv{}, - dumpability: UserDumpable, - aioManager: aioManager{contexts: make(map[uint64]*AIOContext)}, + p: p, + mfp: mfp, + haveASIO: p.SupportsAddressSpaceIO(), + privateRefs: &privateRefs{}, + users: 1, + auxv: arch.Auxv{}, + dumpability: UserDumpable, + aioManager: aioManager{contexts: make(map[uint64]*AIOContext)}, + sleepForActivation: sleepForActivation, } } @@ -79,9 +80,10 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) { envv: mm.envv, auxv: append(arch.Auxv(nil), mm.auxv...), // IncRef'd below, once we know that there isn't an error. - executable: mm.executable, - dumpability: mm.dumpability, - aioManager: aioManager{contexts: make(map[uint64]*AIOContext)}, + executable: mm.executable, + dumpability: mm.dumpability, + aioManager: aioManager{contexts: make(map[uint64]*AIOContext)}, + sleepForActivation: mm.sleepForActivation, } // Copy vmas. diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index 637383c7a..c2195ae11 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -226,6 +226,11 @@ type MemoryManager struct { // aioManager keeps track of AIOContexts used for async IOs. AIOManager // must be cloned when CLONE_VM is used. aioManager aioManager + + // sleepForActivation indicates whether the task should report to be sleeping + // before trying to activate the address space. When set to true, delays in + // activation are not reported as stuck tasks by the watchdog. + sleepForActivation bool } // vma represents a virtual memory area. diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go index edacca741..fdc308542 100644 --- a/pkg/sentry/mm/mm_test.go +++ b/pkg/sentry/mm/mm_test.go @@ -31,7 +31,7 @@ import ( func testMemoryManager(ctx context.Context) *MemoryManager { p := platform.FromContext(ctx) mfp := pgalloc.MemoryFileProviderFromContext(ctx) - mm := NewMemoryManager(p, mfp) + mm := NewMemoryManager(p, mfp, false) mm.layout = arch.MmapLayout{ MinAddr: p.MinUserAddress(), MaxAddr: p.MaxUserAddress(), diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index 8076c7529..f1afc74dc 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -329,10 +329,12 @@ func (m *machine) Destroy() { } // Get gets an available vCPU. +// +// This will return with the OS thread locked. func (m *machine) Get() *vCPU { + m.mu.RLock() runtime.LockOSThread() tid := procid.Current() - m.mu.RLock() // Check for an exact match. if c := m.vCPUs[tid]; c != nil { @@ -343,8 +345,22 @@ func (m *machine) Get() *vCPU { // The happy path failed. We now proceed to acquire an exclusive lock // (because the vCPU map may change), and scan all available vCPUs. + // In this case, we first unlock the OS thread. Otherwise, if mu is + // not available, the current system thread will be parked and a new + // system thread spawned. We avoid this situation by simply refreshing + // tid after relocking the system thread. m.mu.RUnlock() + runtime.UnlockOSThread() m.mu.Lock() + runtime.LockOSThread() + tid = procid.Current() + + // Recheck for an exact match. + if c := m.vCPUs[tid]; c != nil { + c.lock() + m.mu.Unlock() + return c + } for { // Scan for an available vCPU. diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s index baa6c4910..d42eda37b 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -25,10 +25,14 @@ // not available for calls. // +// ERET returns using the ELR and SPSR for the current exception level. #define ERET() \ WORD $0xd69f03e0 +// RSV_REG is a register that holds el1 information temporarily. #define RSV_REG R18_PLATFORM + +// RSV_REG_APP is a register that holds el0 information temporarily. #define RSV_REG_APP R9 #define FPEN_NOTRAP 0x3 @@ -36,6 +40,12 @@ #define FPEN_ENABLE (FPEN_NOTRAP << FPEN_SHIFT) +// Saves a register set. +// +// This is a macro because it may need to executed in contents where a stack is +// not available for calls. +// +// The following registers are not saved: R9, R18. #define REGISTERS_SAVE(reg, offset) \ MOVD R0, offset+PTRACE_R0(reg); \ MOVD R1, offset+PTRACE_R1(reg); \ @@ -67,6 +77,12 @@ MOVD R29, offset+PTRACE_R29(reg); \ MOVD R30, offset+PTRACE_R30(reg); +// Loads a register set. +// +// This is a macro because it may need to executed in contents where a stack is +// not available for calls. +// +// The following registers are not loaded: R9, R18. #define REGISTERS_LOAD(reg, offset) \ MOVD offset+PTRACE_R0(reg), R0; \ MOVD offset+PTRACE_R1(reg), R1; \ @@ -98,7 +114,7 @@ MOVD offset+PTRACE_R29(reg), R29; \ MOVD offset+PTRACE_R30(reg), R30; -//NOP +// NOP-s #define nop31Instructions() \ WORD $0xd503201f; \ WORD $0xd503201f; \ @@ -254,6 +270,7 @@ #define ESR_ELx_WFx_ISS_WFE (UL(1) << 0) #define ESR_ELx_xVC_IMM_MASK ((1UL << 16) - 1) +// LOAD_KERNEL_ADDRESS loads a kernel address. #define LOAD_KERNEL_ADDRESS(from, to) \ MOVD from, to; \ ORR $0xffff000000000000, to, to; @@ -263,15 +280,18 @@ LOAD_KERNEL_ADDRESS(CPU_SELF(from), RSV_REG); \ MOVD $CPU_STACK_TOP(RSV_REG), RSV_REG; \ MOVD RSV_REG, RSP; \ + WORD $0xd538d092; \ //MRS TPIDR_EL1, R18 ISB $15; \ DSB $15; +// SWITCH_TO_APP_PAGETABLE sets a new pagetable for a container application. #define SWITCH_TO_APP_PAGETABLE(from) \ MOVD CPU_TTBR0_APP(from), RSV_REG; \ WORD $0xd5182012; \ // MSR R18, TTBR0_EL1 ISB $15; \ DSB $15; +// SWITCH_TO_KVM_PAGETABLE sets the kvm pagetable. #define SWITCH_TO_KVM_PAGETABLE(from) \ MOVD CPU_TTBR0_KVM(from), RSV_REG; \ WORD $0xd5182012; \ // MSR R18, TTBR0_EL1 @@ -294,6 +314,7 @@ WORD $0xd5181040; \ //MSR R0, CPACR_EL1 ISB $15; +// KERNEL_ENTRY_FROM_EL0 is the entry code of the vcpu from el0 to el1. #define KERNEL_ENTRY_FROM_EL0 \ SUB $16, RSP, RSP; \ // step1, save r18, r9 into kernel temporary stack. STP (RSV_REG, RSV_REG_APP), 16*0(RSP); \ @@ -315,19 +336,22 @@ WORD $0xd5384103; \ // MRS SP_EL0, R3 MOVD R3, PTRACE_SP(RSV_REG_APP); +// KERNEL_ENTRY_FROM_EL1 is the entry code of the vcpu from el1 to el1. #define KERNEL_ENTRY_FROM_EL1 \ WORD $0xd538d092; \ //MRS TPIDR_EL1, R18 - REGISTERS_SAVE(RSV_REG, CPU_REGISTERS); \ // save sentry context + REGISTERS_SAVE(RSV_REG, CPU_REGISTERS); \ // Save sentry context. MOVD RSV_REG_APP, CPU_REGISTERS+PTRACE_R9(RSV_REG); \ WORD $0xd5384004; \ // MRS SPSR_EL1, R4 MOVD R4, CPU_REGISTERS+PTRACE_PSTATE(RSV_REG); \ MRS ELR_EL1, R4; \ MOVD R4, CPU_REGISTERS+PTRACE_PC(RSV_REG); \ MOVD RSP, R4; \ - MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); + MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); \ + LOAD_KERNEL_STACK(RSV_REG); // Load the temporary stack. +// Halt halts execution. TEXT ·Halt(SB),NOSPLIT,$0 - // clear bluepill. + // Clear bluepill. WORD $0xd538d092 //MRS TPIDR_EL1, R18 CMP RSV_REG, R9 BNE mmio_exit @@ -341,8 +365,22 @@ mmio_exit: // MMIO_EXIT. MOVD $0, R9 MOVD R0, 0xffff000000001000(R9) - B ·kernelExitToEl1(SB) + RET + +// HaltAndResume halts execution and point the pointer to the resume function. +TEXT ·HaltAndResume(SB),NOSPLIT,$0 + BL ·Halt(SB) + B ·kernelExitToEl1(SB) // Resume. +// HaltEl1SvcAndResume calls Hooks.KernelSyscall and resume. +TEXT ·HaltEl1SvcAndResume(SB),NOSPLIT,$0 + WORD $0xd538d092 // MRS TPIDR_EL1, R18 + MOVD CPU_SELF(RSV_REG), R3 // Load vCPU. + MOVD R3, 8(RSP) // First argument (vCPU). + CALL ·kernelSyscall(SB) // Call the trampoline. + B ·kernelExitToEl1(SB) // Resume. + +// Shutdown stops the guest. TEXT ·Shutdown(SB),NOSPLIT,$0 // PSCI EVENT. MOVD $0x84000009, R0 @@ -429,6 +467,7 @@ TEXT ·kernelExitToEl0(SB),NOSPLIT,$0 TEXT ·kernelExitToEl1(SB),NOSPLIT,$0 ERET() +// Start is the CPU entrypoint. TEXT ·Start(SB),NOSPLIT,$0 IRQ_DISABLE MOVD R8, RSV_REG @@ -437,18 +476,23 @@ TEXT ·Start(SB),NOSPLIT,$0 B ·kernelExitToEl1(SB) +// El1_sync_invalid is the handler for an invalid EL1_sync. TEXT ·El1_sync_invalid(SB),NOSPLIT,$0 B ·Shutdown(SB) +// El1_irq_invalid is the handler for an invalid El1_irq. TEXT ·El1_irq_invalid(SB),NOSPLIT,$0 B ·Shutdown(SB) +// El1_fiq_invalid is the handler for an invalid El1_fiq. TEXT ·El1_fiq_invalid(SB),NOSPLIT,$0 B ·Shutdown(SB) +// El1_error_invalid is the handler for an invalid El1_error. TEXT ·El1_error_invalid(SB),NOSPLIT,$0 B ·Shutdown(SB) +// El1_sync is the handler for El1_sync. TEXT ·El1_sync(SB),NOSPLIT,$0 KERNEL_ENTRY_FROM_EL1 WORD $0xd5385219 // MRS ESR_EL1, R25 @@ -484,10 +528,10 @@ el1_da: MOVD $PageFault, R3 MOVD R3, CPU_VECTOR_CODE(RSV_REG) - B ·Halt(SB) + B ·HaltAndResume(SB) el1_ia: - B ·Halt(SB) + B ·HaltAndResume(SB) el1_sp_pc: B ·Shutdown(SB) @@ -496,7 +540,9 @@ el1_undef: B ·Shutdown(SB) el1_svc: - B ·Halt(SB) + MOVD $0, CPU_ERROR_CODE(RSV_REG) + MOVD $0, CPU_ERROR_TYPE(RSV_REG) + B ·HaltEl1SvcAndResume(SB) el1_dbg: B ·Shutdown(SB) @@ -508,15 +554,19 @@ el1_fpsimd_acc: el1_invalid: B ·Shutdown(SB) +// El1_irq is the handler for El1_irq. TEXT ·El1_irq(SB),NOSPLIT,$0 B ·Shutdown(SB) +// El1_fiq is the handler for El1_fiq. TEXT ·El1_fiq(SB),NOSPLIT,$0 B ·Shutdown(SB) +// El1_error is the handler for El1_error. TEXT ·El1_error(SB),NOSPLIT,$0 B ·Shutdown(SB) +// El0_sync is the handler for El0_sync. TEXT ·El0_sync(SB),NOSPLIT,$0 KERNEL_ENTRY_FROM_EL0 WORD $0xd5385219 // MRS ESR_EL1, R25 @@ -554,7 +604,7 @@ el0_svc: MOVD $Syscall, R3 MOVD R3, CPU_VECTOR_CODE(RSV_REG) - B ·Halt(SB) + B ·HaltAndResume(SB) el0_da: WORD $0xd538d092 //MRS TPIDR_EL1, R18 @@ -568,7 +618,7 @@ el0_da: MOVD $PageFault, R3 MOVD R3, CPU_VECTOR_CODE(RSV_REG) - B ·Halt(SB) + B ·HaltAndResume(SB) el0_ia: B ·Shutdown(SB) @@ -613,7 +663,7 @@ TEXT ·El0_error(SB),NOSPLIT,$0 MOVD $VirtualizationException, R3 MOVD R3, CPU_VECTOR_CODE(RSV_REG) - B ·Halt(SB) + B ·HaltAndResume(SB) TEXT ·El0_sync_invalid(SB),NOSPLIT,$0 B ·Shutdown(SB) @@ -627,6 +677,7 @@ TEXT ·El0_fiq_invalid(SB),NOSPLIT,$0 TEXT ·El0_error_invalid(SB),NOSPLIT,$0 B ·Shutdown(SB) +// Vectors implements exception vector table. TEXT ·Vectors(SB),NOSPLIT,$0 B ·El1_sync_invalid(SB) nop31Instructions() diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go index c3d341998..ccacaea6b 100644 --- a/pkg/sentry/platform/ring0/kernel_arm64.go +++ b/pkg/sentry/platform/ring0/kernel_arm64.go @@ -16,6 +16,14 @@ package ring0 +// HaltAndResume halts execution and point the pointer to the resume function. +//go:nosplit +func HaltAndResume() + +// HaltEl1SvcAndResume calls Hooks.KernelSyscall and resume. +//go:nosplit +func HaltEl1SvcAndResume() + // init initializes architecture-specific state. func (k *Kernel) init(opts KernelOpts) { // Save the root page tables. diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index 2f39a6f2b..88d5db9fc 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "capability.go", "clone.go", + "epoll.go", "futex.go", "linux64_amd64.go", "linux64_arm64.go", diff --git a/pkg/sentry/strace/epoll.go b/pkg/sentry/strace/epoll.go new file mode 100644 index 000000000..a6e48b836 --- /dev/null +++ b/pkg/sentry/strace/epoll.go @@ -0,0 +1,89 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package strace + +import ( + "fmt" + "strings" + + "gvisor.dev/gvisor/pkg/abi" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/usermem" +) + +func epollEvent(t *kernel.Task, eventAddr usermem.Addr) string { + var e linux.EpollEvent + if _, err := t.CopyIn(eventAddr, &e); err != nil { + return fmt.Sprintf("%#x {error reading event: %v}", eventAddr, err) + } + var sb strings.Builder + fmt.Fprintf(&sb, "%#x ", eventAddr) + writeEpollEvent(&sb, e) + return sb.String() +} + +func epollEvents(t *kernel.Task, eventsAddr usermem.Addr, numEvents, maxBytes uint64) string { + var sb strings.Builder + fmt.Fprintf(&sb, "%#x {", eventsAddr) + addr := eventsAddr + for i := uint64(0); i < numEvents; i++ { + var e linux.EpollEvent + if _, err := t.CopyIn(addr, &e); err != nil { + fmt.Fprintf(&sb, "{error reading event at %#x: %v}", addr, err) + continue + } + writeEpollEvent(&sb, e) + if uint64(sb.Len()) >= maxBytes { + sb.WriteString("...") + break + } + if _, ok := addr.AddLength(uint64(linux.SizeOfEpollEvent)); !ok { + fmt.Fprintf(&sb, "{error reading event at %#x: EFAULT}", addr) + continue + } + } + sb.WriteString("}") + return sb.String() +} + +func writeEpollEvent(sb *strings.Builder, e linux.EpollEvent) { + events := epollEventEvents.Parse(uint64(e.Events)) + fmt.Fprintf(sb, "{events=%s data=[%#x, %#x]}", events, e.Data[0], e.Data[1]) +} + +var epollCtlOps = abi.ValueSet{ + linux.EPOLL_CTL_ADD: "EPOLL_CTL_ADD", + linux.EPOLL_CTL_DEL: "EPOLL_CTL_DEL", + linux.EPOLL_CTL_MOD: "EPOLL_CTL_MOD", +} + +var epollEventEvents = abi.FlagSet{ + {Flag: linux.EPOLLIN, Name: "EPOLLIN"}, + {Flag: linux.EPOLLPRI, Name: "EPOLLPRI"}, + {Flag: linux.EPOLLOUT, Name: "EPOLLOUT"}, + {Flag: linux.EPOLLERR, Name: "EPOLLERR"}, + {Flag: linux.EPOLLHUP, Name: "EPULLHUP"}, + {Flag: linux.EPOLLRDNORM, Name: "EPOLLRDNORM"}, + {Flag: linux.EPOLLRDBAND, Name: "EPOLLRDBAND"}, + {Flag: linux.EPOLLWRNORM, Name: "EPOLLWRNORM"}, + {Flag: linux.EPOLLWRBAND, Name: "EPOLLWRBAND"}, + {Flag: linux.EPOLLMSG, Name: "EPOLLMSG"}, + {Flag: linux.EPOLLRDHUP, Name: "EPOLLRDHUP"}, + {Flag: linux.EPOLLEXCLUSIVE, Name: "EPOLLEXCLUSIVE"}, + {Flag: linux.EPOLLWAKEUP, Name: "EPOLLWAKEUP"}, + {Flag: linux.EPOLLONESHOT, Name: "EPOLLONESHOT"}, + {Flag: linux.EPOLLET, Name: "EPOLLET"}, +} diff --git a/pkg/sentry/strace/linux64_amd64.go b/pkg/sentry/strace/linux64_amd64.go index a4de545e9..71b92eaee 100644 --- a/pkg/sentry/strace/linux64_amd64.go +++ b/pkg/sentry/strace/linux64_amd64.go @@ -256,8 +256,8 @@ var linuxAMD64 = SyscallMap{ 229: makeSyscallInfo("clock_getres", Hex, PostTimespec), 230: makeSyscallInfo("clock_nanosleep", Hex, Hex, Timespec, PostTimespec), 231: makeSyscallInfo("exit_group", Hex), - 232: makeSyscallInfo("epoll_wait", Hex, Hex, Hex, Hex), - 233: makeSyscallInfo("epoll_ctl", Hex, Hex, FD, Hex), + 232: makeSyscallInfo("epoll_wait", FD, EpollEvents, Hex, Hex), + 233: makeSyscallInfo("epoll_ctl", FD, EpollCtlOp, FD, EpollEvent), 234: makeSyscallInfo("tgkill", Hex, Hex, Signal), 235: makeSyscallInfo("utimes", Path, Timeval), // 236: vserver (not implemented in the Linux kernel) @@ -305,7 +305,7 @@ var linuxAMD64 = SyscallMap{ 278: makeSyscallInfo("vmsplice", FD, Hex, Hex, Hex), 279: makeSyscallInfo("move_pages", Hex, Hex, Hex, Hex, Hex, Hex), 280: makeSyscallInfo("utimensat", FD, Path, UTimeTimespec, Hex), - 281: makeSyscallInfo("epoll_pwait", Hex, Hex, Hex, Hex, SigSet, Hex), + 281: makeSyscallInfo("epoll_pwait", FD, EpollEvents, Hex, Hex, SigSet, Hex), 282: makeSyscallInfo("signalfd", Hex, Hex, Hex), 283: makeSyscallInfo("timerfd_create", Hex, Hex), 284: makeSyscallInfo("eventfd", Hex), diff --git a/pkg/sentry/strace/linux64_arm64.go b/pkg/sentry/strace/linux64_arm64.go index 8bc38545f..bd7361a52 100644 --- a/pkg/sentry/strace/linux64_arm64.go +++ b/pkg/sentry/strace/linux64_arm64.go @@ -45,8 +45,8 @@ var linuxARM64 = SyscallMap{ 18: makeSyscallInfo("lookup_dcookie", Hex, Hex, Hex), 19: makeSyscallInfo("eventfd2", Hex, Hex), 20: makeSyscallInfo("epoll_create1", Hex), - 21: makeSyscallInfo("epoll_ctl", Hex, Hex, FD, Hex), - 22: makeSyscallInfo("epoll_pwait", Hex, Hex, Hex, Hex, SigSet, Hex), + 21: makeSyscallInfo("epoll_ctl", FD, EpollCtlOp, FD, EpollEvent), + 22: makeSyscallInfo("epoll_pwait", FD, EpollEvents, Hex, Hex, SigSet, Hex), 23: makeSyscallInfo("dup", FD), 24: makeSyscallInfo("dup3", FD, FD, Hex), 25: makeSyscallInfo("fcntl", FD, Hex, Hex), diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index 51e6d81b2..c0512de89 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -632,4 +632,13 @@ var sockOptNames = map[uint64]abi.ValueSet{ linux.MCAST_MSFILTER: "MCAST_MSFILTER", linux.IPV6_ADDRFORM: "IPV6_ADDRFORM", }, + linux.SOL_NETLINK: { + linux.NETLINK_BROADCAST_ERROR: "NETLINK_BROADCAST_ERROR", + linux.NETLINK_CAP_ACK: "NETLINK_CAP_ACK", + linux.NETLINK_DUMP_STRICT_CHK: "NETLINK_DUMP_STRICT_CHK", + linux.NETLINK_EXT_ACK: "NETLINK_EXT_ACK", + linux.NETLINK_LIST_MEMBERSHIPS: "NETLINK_LIST_MEMBERSHIPS", + linux.NETLINK_NO_ENOBUFS: "NETLINK_NO_ENOBUFS", + linux.NETLINK_PKTINFO: "NETLINK_PKTINFO", + }, } diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go index 46cb2a1cc..77655558e 100644 --- a/pkg/sentry/strace/strace.go +++ b/pkg/sentry/strace/strace.go @@ -481,6 +481,12 @@ func (i *SyscallInfo) pre(t *kernel.Task, args arch.SyscallArguments, maximumBlo output = append(output, capData(t, args[arg-1].Pointer(), args[arg].Pointer())) case PollFDs: output = append(output, pollFDs(t, args[arg].Pointer(), uint(args[arg+1].Uint()), false)) + case EpollCtlOp: + output = append(output, epollCtlOps.Parse(uint64(args[arg].Int()))) + case EpollEvent: + output = append(output, epollEvent(t, args[arg].Pointer())) + case EpollEvents: + output = append(output, epollEvents(t, args[arg].Pointer(), 0 /* numEvents */, uint64(maximumBlobSize))) case SelectFDSet: output = append(output, fdSet(t, int(args[0].Int()), args[arg].Pointer())) case Oct: @@ -549,6 +555,8 @@ func (i *SyscallInfo) post(t *kernel.Task, args arch.SyscallArguments, rval uint output[arg] = capData(t, args[arg-1].Pointer(), args[arg].Pointer()) case PollFDs: output[arg] = pollFDs(t, args[arg].Pointer(), uint(args[arg+1].Uint()), true) + case EpollEvents: + output[arg] = epollEvents(t, args[arg].Pointer(), uint64(rval), uint64(maximumBlobSize)) case GetSockOptVal: output[arg] = getSockOptVal(t, args[arg-2].Uint64() /* level */, args[arg-1].Uint64() /* optName */, args[arg].Pointer() /* optVal */, args[arg+1].Pointer() /* optLen */, maximumBlobSize, rval) case SetSockOptVal: diff --git a/pkg/sentry/strace/syscalls.go b/pkg/sentry/strace/syscalls.go index 446d1e0f6..7e69b9279 100644 --- a/pkg/sentry/strace/syscalls.go +++ b/pkg/sentry/strace/syscalls.go @@ -228,6 +228,16 @@ const ( // SockOptLevel is the optname argument in getsockopt(2) and // setsockopt(2). SockOptName + + // EpollCtlOp is the op argument to epoll_ctl(2). + EpollCtlOp + + // EpollEvent is the event argument in epoll_ctl(2). + EpollEvent + + // EpollEvents is an array of struct epoll_event. It is the events + // argument in epoll_wait(2)/epoll_pwait(2). + EpollEvents ) // defaultFormat is the syscall argument format to use if the actual format is diff --git a/pkg/sentry/syscalls/linux/sys_epoll.go b/pkg/sentry/syscalls/linux/sys_epoll.go index fbef5b376..3ab93fbde 100644 --- a/pkg/sentry/syscalls/linux/sys_epoll.go +++ b/pkg/sentry/syscalls/linux/sys_epoll.go @@ -25,6 +25,8 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// LINT.IfChange + // EpollCreate1 implements the epoll_create1(2) linux syscall. func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { flags := args[0].Int() @@ -164,3 +166,5 @@ func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return EpollWait(t, args) } + +// LINT.ThenChange(vfs2/epoll.go) diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 421845ebb..d10a9bed8 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -130,6 +130,8 @@ func copyInPath(t *kernel.Task, addr usermem.Addr, allowEmpty bool) (path string return path, dirPath, nil } +// LINT.IfChange + func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uintptr, err error) { path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */) if err != nil { @@ -575,6 +577,10 @@ func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, accessAt(t, dirFD, addr, flags&linux.AT_SYMLINK_NOFOLLOW == 0, mode) } +// LINT.ThenChange(vfs2/filesystem.go) + +// LINT.IfChange + // Ioctl implements linux syscall ioctl(2). func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { fd := args[0].Int() @@ -650,6 +656,10 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } } +// LINT.ThenChange(vfs2/ioctl.go) + +// LINT.IfChange + // Getcwd implements the linux syscall getcwd(2). func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { addr := args[0].Pointer() @@ -760,6 +770,10 @@ func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, nil } +// LINT.ThenChange(vfs2/fscontext.go) + +// LINT.IfChange + // Close implements linux syscall close(2). func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { fd := args[0].Int() @@ -1094,6 +1108,8 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } } +// LINT.ThenChange(vfs2/fd.go) + const ( _FADV_NORMAL = 0 _FADV_RANDOM = 1 @@ -1141,6 +1157,8 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, nil } +// LINT.IfChange + func mkdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode) error { path, _, err := copyInPath(t, addr, false /* allowEmpty */) if err != nil { @@ -1218,7 +1236,7 @@ func rmdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error { return syserror.ENOTEMPTY } - if err := fs.MayDelete(t, root, d, name); err != nil { + if err := d.MayDelete(t, root, name); err != nil { return err } @@ -1421,6 +1439,10 @@ func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, linkAt(t, oldDirFD, oldAddr, newDirFD, newAddr, resolve, allowEmpty) } +// LINT.ThenChange(vfs2/filesystem.go) + +// LINT.IfChange + func readlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr, bufAddr usermem.Addr, size uint) (copied uintptr, err error) { path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */) if err != nil { @@ -1480,6 +1502,10 @@ func Readlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return n, nil, err } +// LINT.ThenChange(vfs2/stat.go) + +// LINT.IfChange + func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error { path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */) if err != nil { @@ -1491,7 +1517,7 @@ func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error { return syserror.ENOTDIR } - if err := fs.MayDelete(t, root, d, name); err != nil { + if err := d.MayDelete(t, root, name); err != nil { return err } @@ -1516,6 +1542,10 @@ func Unlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc return 0, nil, unlinkAt(t, dirFD, addr) } +// LINT.ThenChange(vfs2/filesystem.go) + +// LINT.IfChange + // Truncate implements linux syscall truncate(2). func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { addr := args[0].Pointer() @@ -1614,6 +1644,8 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, nil } +// LINT.ThenChange(vfs2/setstat.go) + // Umask implements linux syscall umask(2). func Umask(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { mask := args[0].ModeT() @@ -1621,6 +1653,8 @@ func Umask(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return uintptr(mask), nil, nil } +// LINT.IfChange + // Change ownership of a file. // // uid and gid may be -1, in which case they will not be changed. @@ -1987,6 +2021,10 @@ func Futimesat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, utimes(t, dirFD, pathnameAddr, ts, true) } +// LINT.ThenChange(vfs2/setstat.go) + +// LINT.IfChange + func renameAt(t *kernel.Task, oldDirFD int32, oldAddr usermem.Addr, newDirFD int32, newAddr usermem.Addr) error { newPath, _, err := copyInPath(t, newAddr, false /* allowEmpty */) if err != nil { @@ -2042,6 +2080,8 @@ func Renameat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc return 0, nil, renameAt(t, oldDirFD, oldPathAddr, newDirFD, newPathAddr) } +// LINT.ThenChange(vfs2/filesystem.go) + // Fallocate implements linux system call fallocate(2). func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { fd := args[0].Int() diff --git a/pkg/sentry/syscalls/linux/sys_getdents.go b/pkg/sentry/syscalls/linux/sys_getdents.go index f66f4ffde..b126fecc0 100644 --- a/pkg/sentry/syscalls/linux/sys_getdents.go +++ b/pkg/sentry/syscalls/linux/sys_getdents.go @@ -27,6 +27,8 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) +// LINT.IfChange + // Getdents implements linux syscall getdents(2) for 64bit systems. func Getdents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { fd := args[0].Int() @@ -244,3 +246,5 @@ func (ds *direntSerializer) CopyOut(name string, attr fs.DentAttr) error { func (ds *direntSerializer) Written() int { return ds.written } + +// LINT.ThenChange(vfs2/getdents.go) diff --git a/pkg/sentry/syscalls/linux/sys_lseek.go b/pkg/sentry/syscalls/linux/sys_lseek.go index 297e920c4..3f7691eae 100644 --- a/pkg/sentry/syscalls/linux/sys_lseek.go +++ b/pkg/sentry/syscalls/linux/sys_lseek.go @@ -21,6 +21,8 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) +// LINT.IfChange + // Lseek implements linux syscall lseek(2). func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { fd := args[0].Int() @@ -52,3 +54,5 @@ func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } return uintptr(offset), nil, err } + +// LINT.ThenChange(vfs2/read_write.go) diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go index 9959f6e61..91694d374 100644 --- a/pkg/sentry/syscalls/linux/sys_mmap.go +++ b/pkg/sentry/syscalls/linux/sys_mmap.go @@ -35,6 +35,8 @@ func Brk(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo return uintptr(addr), nil, nil } +// LINT.IfChange + // Mmap implements linux syscall mmap(2). func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { prot := args[2].Int() @@ -104,6 +106,8 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC return uintptr(rv), nil, err } +// LINT.ThenChange(vfs2/mmap.go) + // Munmap implements linux syscall munmap(2). func Munmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { return 0, nil, t.MemoryManager().MUnmap(t, args[0].Pointer(), args[1].Uint64()) diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go index 227692f06..78a2cb750 100644 --- a/pkg/sentry/syscalls/linux/sys_read.go +++ b/pkg/sentry/syscalls/linux/sys_read.go @@ -28,6 +28,8 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// LINT.IfChange + const ( // EventMaskRead contains events that can be triggered on reads. EventMaskRead = waiter.EventIn | waiter.EventHUp | waiter.EventErr @@ -388,3 +390,5 @@ func preadv(t *kernel.Task, f *fs.File, dst usermem.IOSequence, offset int64) (i return total, err } + +// LINT.ThenChange(vfs2/read_write.go) diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go index 8b66a9006..701b27b4a 100644 --- a/pkg/sentry/syscalls/linux/sys_stat.go +++ b/pkg/sentry/syscalls/linux/sys_stat.go @@ -23,6 +23,8 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) +// LINT.IfChange + func statFromAttrs(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr) linux.Stat { return linux.Stat{ Dev: sattr.DeviceID, @@ -131,8 +133,7 @@ func stat(t *kernel.Task, d *fs.Dirent, dirPath bool, statAddr usermem.Addr) err return err } s := statFromAttrs(t, d.Inode.StableAttr, uattr) - _, err = s.CopyOut(t, statAddr) - return err + return s.CopyOut(t, statAddr) } // fstat implements fstat for the given *fs.File. @@ -142,8 +143,7 @@ func fstat(t *kernel.Task, f *fs.File, statAddr usermem.Addr) error { return err } s := statFromAttrs(t, f.Dirent.Inode.StableAttr, uattr) - _, err = s.CopyOut(t, statAddr) - return err + return s.CopyOut(t, statAddr) } // Statx implements linux syscall statx(2). @@ -299,3 +299,5 @@ func statfsImpl(t *kernel.Task, d *fs.Dirent, addr usermem.Addr) error { _, err = t.CopyOut(addr, &statfs) return err } + +// LINT.ThenChange(vfs2/stat.go) diff --git a/pkg/sentry/syscalls/linux/sys_sync.go b/pkg/sentry/syscalls/linux/sys_sync.go index 3e55235bd..5ad465ae3 100644 --- a/pkg/sentry/syscalls/linux/sys_sync.go +++ b/pkg/sentry/syscalls/linux/sys_sync.go @@ -22,6 +22,8 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) +// LINT.IfChange + // Sync implements linux system call sync(2). func Sync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { t.MountNamespace().SyncAll(t) @@ -135,3 +137,5 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS) } + +// LINT.ThenChange(vfs2/sync.go) diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go index aba892939..506ee54ce 100644 --- a/pkg/sentry/syscalls/linux/sys_write.go +++ b/pkg/sentry/syscalls/linux/sys_write.go @@ -28,6 +28,8 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// LINT.IfChange + const ( // EventMaskWrite contains events that can be triggered on writes. // @@ -358,3 +360,5 @@ func pwritev(t *kernel.Task, f *fs.File, src usermem.IOSequence, offset int64) ( return total, err } + +// LINT.ThenChange(vfs2/read_write.go) diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go index 9d8140b8a..2de5e3422 100644 --- a/pkg/sentry/syscalls/linux/sys_xattr.go +++ b/pkg/sentry/syscalls/linux/sys_xattr.go @@ -25,6 +25,8 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) +// LINT.IfChange + // GetXattr implements linux syscall getxattr(2). func GetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { return getXattrFromPath(t, args, true) @@ -418,3 +420,5 @@ func removeXattr(t *kernel.Task, d *fs.Dirent, nameAddr usermem.Addr) error { return d.Inode.RemoveXattr(t, d, name) } + +// LINT.ThenChange(vfs2/xattr.go) diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD index 6b8a00b6e..f51761e81 100644 --- a/pkg/sentry/syscalls/linux/vfs2/BUILD +++ b/pkg/sentry/syscalls/linux/vfs2/BUILD @@ -5,18 +5,44 @@ package(licenses = ["notice"]) go_library( name = "vfs2", srcs = [ + "epoll.go", + "epoll_unsafe.go", + "execve.go", + "fd.go", + "filesystem.go", + "fscontext.go", + "getdents.go", + "ioctl.go", "linux64.go", "linux64_override_amd64.go", "linux64_override_arm64.go", - "sys_read.go", + "mmap.go", + "path.go", + "poll.go", + "read_write.go", + "setstat.go", + "stat.go", + "sync.go", + "xattr.go", ], + marshal = True, visibility = ["//:sandbox"], deps = [ + "//pkg/abi/linux", + "//pkg/fspath", + "//pkg/gohacks", "//pkg/sentry/arch", + "//pkg/sentry/fsbridge", "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/kernel/time", + "//pkg/sentry/limits", + "//pkg/sentry/loader", + "//pkg/sentry/memmap", "//pkg/sentry/syscalls", "//pkg/sentry/syscalls/linux", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go new file mode 100644 index 000000000..d6cb0e79a --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go @@ -0,0 +1,225 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "math" + "time" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +// EpollCreate1 implements Linux syscall epoll_create1(2). +func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + flags := args[0].Int() + if flags&^linux.EPOLL_CLOEXEC != 0 { + return 0, nil, syserror.EINVAL + } + + file, err := t.Kernel().VFS().NewEpollInstanceFD() + if err != nil { + return 0, nil, err + } + defer file.DecRef() + + fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{ + CloseOnExec: flags&linux.EPOLL_CLOEXEC != 0, + }) + if err != nil { + return 0, nil, err + } + return uintptr(fd), nil, nil +} + +// EpollCreate implements Linux syscall epoll_create(2). +func EpollCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + size := args[0].Int() + + // "Since Linux 2.6.8, the size argument is ignored, but must be greater + // than zero" - epoll_create(2) + if size <= 0 { + return 0, nil, syserror.EINVAL + } + + file, err := t.Kernel().VFS().NewEpollInstanceFD() + if err != nil { + return 0, nil, err + } + defer file.DecRef() + + fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{}) + if err != nil { + return 0, nil, err + } + return uintptr(fd), nil, nil +} + +// EpollCtl implements Linux syscall epoll_ctl(2). +func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + epfd := args[0].Int() + op := args[1].Int() + fd := args[2].Int() + eventAddr := args[3].Pointer() + + epfile := t.GetFileVFS2(epfd) + if epfile == nil { + return 0, nil, syserror.EBADF + } + defer epfile.DecRef() + ep, ok := epfile.Impl().(*vfs.EpollInstance) + if !ok { + return 0, nil, syserror.EINVAL + } + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + if epfile == file { + return 0, nil, syserror.EINVAL + } + + var event linux.EpollEvent + switch op { + case linux.EPOLL_CTL_ADD: + if err := event.CopyIn(t, eventAddr); err != nil { + return 0, nil, err + } + return 0, nil, ep.AddInterest(file, fd, event) + case linux.EPOLL_CTL_DEL: + return 0, nil, ep.DeleteInterest(file, fd) + case linux.EPOLL_CTL_MOD: + if err := event.CopyIn(t, eventAddr); err != nil { + return 0, nil, err + } + return 0, nil, ep.ModifyInterest(file, fd, event) + default: + return 0, nil, syserror.EINVAL + } +} + +// EpollWait implements Linux syscall epoll_wait(2). +func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + epfd := args[0].Int() + eventsAddr := args[1].Pointer() + maxEvents := int(args[2].Int()) + timeout := int(args[3].Int()) + + const _EP_MAX_EVENTS = math.MaxInt32 / sizeofEpollEvent // Linux: fs/eventpoll.c:EP_MAX_EVENTS + if maxEvents <= 0 || maxEvents > _EP_MAX_EVENTS { + return 0, nil, syserror.EINVAL + } + + epfile := t.GetFileVFS2(epfd) + if epfile == nil { + return 0, nil, syserror.EBADF + } + defer epfile.DecRef() + ep, ok := epfile.Impl().(*vfs.EpollInstance) + if !ok { + return 0, nil, syserror.EINVAL + } + + // Use a fixed-size buffer in a loop, instead of make([]linux.EpollEvent, + // maxEvents), so that the buffer can be allocated on the stack. + var ( + events [16]linux.EpollEvent + total int + ch chan struct{} + haveDeadline bool + deadline ktime.Time + ) + for { + batchEvents := len(events) + if batchEvents > maxEvents { + batchEvents = maxEvents + } + n := ep.ReadEvents(events[:batchEvents]) + maxEvents -= n + if n != 0 { + // Copy what we read out. + copiedEvents, err := copyOutEvents(t, eventsAddr, events[:n]) + eventsAddr += usermem.Addr(copiedEvents * sizeofEpollEvent) + total += copiedEvents + if err != nil { + if total != 0 { + return uintptr(total), nil, nil + } + return 0, nil, err + } + // If we've filled the application's event buffer, we're done. + if maxEvents == 0 { + return uintptr(total), nil, nil + } + // Loop if we read a full batch, under the expectation that there + // may be more events to read. + if n == batchEvents { + continue + } + } + // We get here if n != batchEvents. If we read any number of events + // (just now, or in a previous iteration of this loop), or if timeout + // is 0 (such that epoll_wait should be non-blocking), return the + // events we've read so far to the application. + if total != 0 || timeout == 0 { + return uintptr(total), nil, nil + } + // In the first iteration of this loop, register with the epoll + // instance for readability events, but then immediately continue the + // loop since we need to retry ReadEvents() before blocking. In all + // subsequent iterations, block until events are available, the timeout + // expires, or an interrupt arrives. + if ch == nil { + var w waiter.Entry + w, ch = waiter.NewChannelEntry(nil) + epfile.EventRegister(&w, waiter.EventIn) + defer epfile.EventUnregister(&w) + } else { + // Set up the timer if a timeout was specified. + if timeout > 0 && !haveDeadline { + timeoutDur := time.Duration(timeout) * time.Millisecond + deadline = t.Kernel().MonotonicClock().Now().Add(timeoutDur) + haveDeadline = true + } + if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { + if err == syserror.ETIMEDOUT { + err = nil + } + // total must be 0 since otherwise we would have returned + // above. + return 0, nil, err + } + } + } +} + +// EpollPwait implements Linux syscall epoll_pwait(2). +func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + maskAddr := args[4].Pointer() + maskSize := uint(args[5].Uint()) + + if err := setTempSignalSet(t, maskAddr, maskSize); err != nil { + return 0, nil, err + } + + return EpollWait(t, args) +} diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go b/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go new file mode 100644 index 000000000..825f325bf --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go @@ -0,0 +1,44 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "reflect" + "runtime" + "unsafe" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/gohacks" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/usermem" +) + +const sizeofEpollEvent = int(unsafe.Sizeof(linux.EpollEvent{})) + +func copyOutEvents(t *kernel.Task, addr usermem.Addr, events []linux.EpollEvent) (int, error) { + if len(events) == 0 { + return 0, nil + } + // Cast events to a byte slice for copying. + var eventBytes []byte + eventBytesHdr := (*reflect.SliceHeader)(unsafe.Pointer(&eventBytes)) + eventBytesHdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(&events[0]))) + eventBytesHdr.Len = len(events) * sizeofEpollEvent + eventBytesHdr.Cap = len(events) * sizeofEpollEvent + copiedBytes, err := t.CopyOutBytes(addr, eventBytes) + runtime.KeepAlive(events) + copiedEvents := copiedBytes / sizeofEpollEvent // rounded down + return copiedEvents, err +} diff --git a/pkg/sentry/syscalls/linux/vfs2/execve.go b/pkg/sentry/syscalls/linux/vfs2/execve.go new file mode 100644 index 000000000..aef0078a8 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/execve.go @@ -0,0 +1,137 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsbridge" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/loader" + slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// Execve implements linux syscall execve(2). +func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathnameAddr := args[0].Pointer() + argvAddr := args[1].Pointer() + envvAddr := args[2].Pointer() + return execveat(t, linux.AT_FDCWD, pathnameAddr, argvAddr, envvAddr, 0 /* flags */) +} + +// Execveat implements linux syscall execveat(2). +func Execveat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirfd := args[0].Int() + pathnameAddr := args[1].Pointer() + argvAddr := args[2].Pointer() + envvAddr := args[3].Pointer() + flags := args[4].Int() + return execveat(t, dirfd, pathnameAddr, argvAddr, envvAddr, flags) +} + +func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr usermem.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) { + if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 { + return 0, nil, syserror.EINVAL + } + + pathname, err := t.CopyInString(pathnameAddr, linux.PATH_MAX) + if err != nil { + return 0, nil, err + } + var argv, envv []string + if argvAddr != 0 { + var err error + argv, err = t.CopyInVector(argvAddr, slinux.ExecMaxElemSize, slinux.ExecMaxTotalSize) + if err != nil { + return 0, nil, err + } + } + if envvAddr != 0 { + var err error + envv, err = t.CopyInVector(envvAddr, slinux.ExecMaxElemSize, slinux.ExecMaxTotalSize) + if err != nil { + return 0, nil, err + } + } + + root := t.FSContext().RootDirectoryVFS2() + defer root.DecRef() + var executable fsbridge.File + closeOnExec := false + if path := fspath.Parse(pathname); dirfd != linux.AT_FDCWD && !path.Absolute { + // We must open the executable ourselves since dirfd is used as the + // starting point while resolving path, but the task working directory + // is used as the starting point while resolving interpreters (Linux: + // fs/binfmt_script.c:load_script() => fs/exec.c:open_exec() => + // do_open_execat(fd=AT_FDCWD)), and the loader package is currently + // incapable of handling this correctly. + if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 { + return 0, nil, syserror.ENOENT + } + dirfile, dirfileFlags := t.FDTable().GetVFS2(dirfd) + if dirfile == nil { + return 0, nil, syserror.EBADF + } + start := dirfile.VirtualDentry() + start.IncRef() + dirfile.DecRef() + closeOnExec = dirfileFlags.CloseOnExec + file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &vfs.PathOperation{ + Root: root, + Start: start, + Path: path, + FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0, + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + FileExec: true, + }) + start.DecRef() + if err != nil { + return 0, nil, err + } + defer file.DecRef() + executable = fsbridge.NewVFSFile(file) + } + + // Load the new TaskContext. + mntns := t.MountNamespaceVFS2() // FIXME(jamieliu): useless refcount change + defer mntns.DecRef() + wd := t.FSContext().WorkingDirectoryVFS2() + defer wd.DecRef() + remainingTraversals := uint(linux.MaxSymlinkTraversals) + loadArgs := loader.LoadArgs{ + Opener: fsbridge.NewVFSLookup(mntns, root, wd), + RemainingTraversals: &remainingTraversals, + ResolveFinal: flags&linux.AT_SYMLINK_NOFOLLOW == 0, + Filename: pathname, + File: executable, + CloseOnExec: closeOnExec, + Argv: argv, + Envv: envv, + Features: t.Arch().FeatureSet(), + } + + tc, se := t.Kernel().LoadTaskImage(t, loadArgs) + if se != nil { + return 0, nil, se.ToError() + } + + ctrl, err := t.Execve(tc) + return 0, ctrl, err +} diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go new file mode 100644 index 000000000..3afcea665 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/fd.go @@ -0,0 +1,147 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" + "gvisor.dev/gvisor/pkg/syserror" +) + +// Close implements Linux syscall close(2). +func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + + // Note that Remove provides a reference on the file that we may use to + // flush. It is still active until we drop the final reference below + // (and other reference-holding operations complete). + _, file := t.FDTable().Remove(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + err := file.OnClose(t) + return 0, nil, slinux.HandleIOErrorVFS2(t, false /* partial */, err, syserror.EINTR, "close", file) +} + +// Dup implements Linux syscall dup(2). +func Dup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + newFD, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{}) + if err != nil { + return 0, nil, syserror.EMFILE + } + return uintptr(newFD), nil, nil +} + +// Dup2 implements Linux syscall dup2(2). +func Dup2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + oldfd := args[0].Int() + newfd := args[1].Int() + + if oldfd == newfd { + // As long as oldfd is valid, dup2() does nothing and returns newfd. + file := t.GetFileVFS2(oldfd) + if file == nil { + return 0, nil, syserror.EBADF + } + file.DecRef() + return uintptr(newfd), nil, nil + } + + return dup3(t, oldfd, newfd, 0) +} + +// Dup3 implements Linux syscall dup3(2). +func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + oldfd := args[0].Int() + newfd := args[1].Int() + flags := args[2].Uint() + + if oldfd == newfd { + return 0, nil, syserror.EINVAL + } + + return dup3(t, oldfd, newfd, flags) +} + +func dup3(t *kernel.Task, oldfd, newfd int32, flags uint32) (uintptr, *kernel.SyscallControl, error) { + if flags&^linux.O_CLOEXEC != 0 { + return 0, nil, syserror.EINVAL + } + + file := t.GetFileVFS2(oldfd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + err := t.NewFDAtVFS2(newfd, file, kernel.FDFlags{ + CloseOnExec: flags&linux.O_CLOEXEC != 0, + }) + if err != nil { + return 0, nil, err + } + return uintptr(newfd), nil, nil +} + +// Fcntl implements linux syscall fcntl(2). +func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + cmd := args[1].Int() + + file, flags := t.FDTable().GetVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + switch cmd { + case linux.F_DUPFD, linux.F_DUPFD_CLOEXEC: + minfd := args[2].Int() + fd, err := t.NewFDFromVFS2(minfd, file, kernel.FDFlags{ + CloseOnExec: cmd == linux.F_DUPFD_CLOEXEC, + }) + if err != nil { + return 0, nil, err + } + return uintptr(fd), nil, nil + case linux.F_GETFD: + return uintptr(flags.ToLinuxFDFlags()), nil, nil + case linux.F_SETFD: + flags := args[2].Uint() + t.FDTable().SetFlags(fd, kernel.FDFlags{ + CloseOnExec: flags&linux.FD_CLOEXEC != 0, + }) + return 0, nil, nil + case linux.F_GETFL: + return uintptr(file.StatusFlags()), nil, nil + case linux.F_SETFL: + return 0, nil, file.SetStatusFlags(t, t.Credentials(), args[2].Uint()) + default: + // TODO(gvisor.dev/issue/1623): Everything else is not yet supported. + return 0, nil, syserror.EINVAL + } +} diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go new file mode 100644 index 000000000..fc5ceea4c --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go @@ -0,0 +1,326 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// Link implements Linux syscall link(2). +func Link(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + oldpathAddr := args[0].Pointer() + newpathAddr := args[1].Pointer() + return 0, nil, linkat(t, linux.AT_FDCWD, oldpathAddr, linux.AT_FDCWD, newpathAddr, 0 /* flags */) +} + +// Linkat implements Linux syscall linkat(2). +func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + olddirfd := args[0].Int() + oldpathAddr := args[1].Pointer() + newdirfd := args[2].Int() + newpathAddr := args[3].Pointer() + flags := args[4].Int() + return 0, nil, linkat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, flags) +} + +func linkat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd int32, newpathAddr usermem.Addr, flags int32) error { + if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_FOLLOW) != 0 { + return syserror.EINVAL + } + if flags&linux.AT_EMPTY_PATH != 0 && !t.HasCapability(linux.CAP_DAC_READ_SEARCH) { + return syserror.ENOENT + } + + oldpath, err := copyInPath(t, oldpathAddr) + if err != nil { + return err + } + oldtpop, err := getTaskPathOperation(t, olddirfd, oldpath, shouldAllowEmptyPath(flags&linux.AT_EMPTY_PATH != 0), shouldFollowFinalSymlink(flags&linux.AT_SYMLINK_FOLLOW != 0)) + if err != nil { + return err + } + defer oldtpop.Release() + + newpath, err := copyInPath(t, newpathAddr) + if err != nil { + return err + } + newtpop, err := getTaskPathOperation(t, newdirfd, newpath, disallowEmptyPath, nofollowFinalSymlink) + if err != nil { + return err + } + defer newtpop.Release() + + return t.Kernel().VFS().LinkAt(t, t.Credentials(), &oldtpop.pop, &newtpop.pop) +} + +// Mkdir implements Linux syscall mkdir(2). +func Mkdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + addr := args[0].Pointer() + mode := args[1].ModeT() + return 0, nil, mkdirat(t, linux.AT_FDCWD, addr, mode) +} + +// Mkdirat implements Linux syscall mkdirat(2). +func Mkdirat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirfd := args[0].Int() + addr := args[1].Pointer() + mode := args[2].ModeT() + return 0, nil, mkdirat(t, dirfd, addr, mode) +} + +func mkdirat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint) error { + path, err := copyInPath(t, addr) + if err != nil { + return err + } + tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink) + if err != nil { + return err + } + defer tpop.Release() + return t.Kernel().VFS().MkdirAt(t, t.Credentials(), &tpop.pop, &vfs.MkdirOptions{ + Mode: linux.FileMode(mode & (0777 | linux.S_ISVTX) &^ t.FSContext().Umask()), + }) +} + +// Mknod implements Linux syscall mknod(2). +func Mknod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + addr := args[0].Pointer() + mode := args[1].ModeT() + dev := args[2].Uint() + return 0, nil, mknodat(t, linux.AT_FDCWD, addr, mode, dev) +} + +// Mknodat implements Linux syscall mknodat(2). +func Mknodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirfd := args[0].Int() + addr := args[1].Pointer() + mode := args[2].ModeT() + dev := args[3].Uint() + return 0, nil, mknodat(t, dirfd, addr, mode, dev) +} + +func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint, dev uint32) error { + path, err := copyInPath(t, addr) + if err != nil { + return err + } + tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink) + if err != nil { + return err + } + defer tpop.Release() + major, minor := linux.DecodeDeviceID(dev) + return t.Kernel().VFS().MknodAt(t, t.Credentials(), &tpop.pop, &vfs.MknodOptions{ + Mode: linux.FileMode(mode &^ t.FSContext().Umask()), + DevMajor: uint32(major), + DevMinor: minor, + }) +} + +// Open implements Linux syscall open(2). +func Open(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + addr := args[0].Pointer() + flags := args[1].Uint() + mode := args[2].ModeT() + return openat(t, linux.AT_FDCWD, addr, flags, mode) +} + +// Openat implements Linux syscall openat(2). +func Openat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirfd := args[0].Int() + addr := args[1].Pointer() + flags := args[2].Uint() + mode := args[3].ModeT() + return openat(t, dirfd, addr, flags, mode) +} + +// Creat implements Linux syscall creat(2). +func Creat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + addr := args[0].Pointer() + mode := args[1].ModeT() + return openat(t, linux.AT_FDCWD, addr, linux.O_WRONLY|linux.O_CREAT|linux.O_TRUNC, mode) +} + +func openat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, flags uint32, mode uint) (uintptr, *kernel.SyscallControl, error) { + path, err := copyInPath(t, pathAddr) + if err != nil { + return 0, nil, err + } + tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, shouldFollowFinalSymlink(flags&linux.O_NOFOLLOW == 0)) + if err != nil { + return 0, nil, err + } + defer tpop.Release() + + file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &tpop.pop, &vfs.OpenOptions{ + Flags: flags, + Mode: linux.FileMode(mode & (0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX) &^ t.FSContext().Umask()), + }) + if err != nil { + return 0, nil, err + } + defer file.DecRef() + + fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{ + CloseOnExec: flags&linux.O_CLOEXEC != 0, + }) + return uintptr(fd), nil, err +} + +// Rename implements Linux syscall rename(2). +func Rename(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + oldpathAddr := args[0].Pointer() + newpathAddr := args[1].Pointer() + return 0, nil, renameat(t, linux.AT_FDCWD, oldpathAddr, linux.AT_FDCWD, newpathAddr, 0 /* flags */) +} + +// Renameat implements Linux syscall renameat(2). +func Renameat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + olddirfd := args[0].Int() + oldpathAddr := args[1].Pointer() + newdirfd := args[2].Int() + newpathAddr := args[3].Pointer() + return 0, nil, renameat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, 0 /* flags */) +} + +// Renameat2 implements Linux syscall renameat2(2). +func Renameat2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + olddirfd := args[0].Int() + oldpathAddr := args[1].Pointer() + newdirfd := args[2].Int() + newpathAddr := args[3].Pointer() + flags := args[4].Uint() + return 0, nil, renameat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, flags) +} + +func renameat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd int32, newpathAddr usermem.Addr, flags uint32) error { + oldpath, err := copyInPath(t, oldpathAddr) + if err != nil { + return err + } + // "If oldpath refers to a symbolic link, the link is renamed" - rename(2) + oldtpop, err := getTaskPathOperation(t, olddirfd, oldpath, disallowEmptyPath, nofollowFinalSymlink) + if err != nil { + return err + } + defer oldtpop.Release() + + newpath, err := copyInPath(t, newpathAddr) + if err != nil { + return err + } + newtpop, err := getTaskPathOperation(t, newdirfd, newpath, disallowEmptyPath, nofollowFinalSymlink) + if err != nil { + return err + } + defer newtpop.Release() + + return t.Kernel().VFS().RenameAt(t, t.Credentials(), &oldtpop.pop, &newtpop.pop, &vfs.RenameOptions{ + Flags: flags, + }) +} + +// Rmdir implements Linux syscall rmdir(2). +func Rmdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + return 0, nil, rmdirat(t, linux.AT_FDCWD, pathAddr) +} + +func rmdirat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error { + path, err := copyInPath(t, pathAddr) + if err != nil { + return err + } + tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, followFinalSymlink) + if err != nil { + return err + } + defer tpop.Release() + return t.Kernel().VFS().RmdirAt(t, t.Credentials(), &tpop.pop) +} + +// Unlink implements Linux syscall unlink(2). +func Unlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + return 0, nil, unlinkat(t, linux.AT_FDCWD, pathAddr) +} + +func unlinkat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error { + path, err := copyInPath(t, pathAddr) + if err != nil { + return err + } + tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink) + if err != nil { + return err + } + defer tpop.Release() + return t.Kernel().VFS().UnlinkAt(t, t.Credentials(), &tpop.pop) +} + +// Unlinkat implements Linux syscall unlinkat(2). +func Unlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirfd := args[0].Int() + pathAddr := args[1].Pointer() + flags := args[2].Int() + + if flags&^linux.AT_REMOVEDIR != 0 { + return 0, nil, syserror.EINVAL + } + + if flags&linux.AT_REMOVEDIR != 0 { + return 0, nil, rmdirat(t, dirfd, pathAddr) + } + return 0, nil, unlinkat(t, dirfd, pathAddr) +} + +// Symlink implements Linux syscall symlink(2). +func Symlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + targetAddr := args[0].Pointer() + linkpathAddr := args[1].Pointer() + return 0, nil, symlinkat(t, targetAddr, linux.AT_FDCWD, linkpathAddr) +} + +// Symlinkat implements Linux syscall symlinkat(2). +func Symlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + targetAddr := args[0].Pointer() + newdirfd := args[1].Int() + linkpathAddr := args[2].Pointer() + return 0, nil, symlinkat(t, targetAddr, newdirfd, linkpathAddr) +} + +func symlinkat(t *kernel.Task, targetAddr usermem.Addr, newdirfd int32, linkpathAddr usermem.Addr) error { + target, err := t.CopyInString(targetAddr, linux.PATH_MAX) + if err != nil { + return err + } + linkpath, err := copyInPath(t, linkpathAddr) + if err != nil { + return err + } + tpop, err := getTaskPathOperation(t, newdirfd, linkpath, disallowEmptyPath, nofollowFinalSymlink) + if err != nil { + return err + } + defer tpop.Release() + return t.Kernel().VFS().SymlinkAt(t, t.Credentials(), &tpop.pop, target) +} diff --git a/pkg/sentry/syscalls/linux/vfs2/fscontext.go b/pkg/sentry/syscalls/linux/vfs2/fscontext.go new file mode 100644 index 000000000..317409a18 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/fscontext.go @@ -0,0 +1,131 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// Getcwd implements Linux syscall getcwd(2). +func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + addr := args[0].Pointer() + size := args[1].SizeT() + + root := t.FSContext().RootDirectoryVFS2() + wd := t.FSContext().WorkingDirectoryVFS2() + s, err := t.Kernel().VFS().PathnameForGetcwd(t, root, wd) + root.DecRef() + wd.DecRef() + if err != nil { + return 0, nil, err + } + + // Note this is >= because we need a terminator. + if uint(len(s)) >= size { + return 0, nil, syserror.ERANGE + } + + // Construct a byte slice containing a NUL terminator. + buf := t.CopyScratchBuffer(len(s) + 1) + copy(buf, s) + buf[len(buf)-1] = 0 + + // Write the pathname slice. + n, err := t.CopyOutBytes(addr, buf) + if err != nil { + return 0, nil, err + } + return uintptr(n), nil, nil +} + +// Chdir implements Linux syscall chdir(2). +func Chdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + addr := args[0].Pointer() + + path, err := copyInPath(t, addr) + if err != nil { + return 0, nil, err + } + tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink) + if err != nil { + return 0, nil, err + } + defer tpop.Release() + + vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{ + CheckSearchable: true, + }) + if err != nil { + return 0, nil, err + } + t.FSContext().SetWorkingDirectoryVFS2(vd) + vd.DecRef() + return 0, nil, nil +} + +// Fchdir implements Linux syscall fchdir(2). +func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + + tpop, err := getTaskPathOperation(t, fd, fspath.Path{}, allowEmptyPath, nofollowFinalSymlink) + if err != nil { + return 0, nil, err + } + defer tpop.Release() + + vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{ + CheckSearchable: true, + }) + if err != nil { + return 0, nil, err + } + t.FSContext().SetWorkingDirectoryVFS2(vd) + vd.DecRef() + return 0, nil, nil +} + +// Chroot implements Linux syscall chroot(2). +func Chroot(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + addr := args[0].Pointer() + + if !t.HasCapability(linux.CAP_SYS_CHROOT) { + return 0, nil, syserror.EPERM + } + + path, err := copyInPath(t, addr) + if err != nil { + return 0, nil, err + } + tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink) + if err != nil { + return 0, nil, err + } + defer tpop.Release() + + vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{ + CheckSearchable: true, + }) + if err != nil { + return 0, nil, err + } + t.FSContext().SetRootDirectoryVFS2(vd) + vd.DecRef() + return 0, nil, nil +} diff --git a/pkg/sentry/syscalls/linux/vfs2/getdents.go b/pkg/sentry/syscalls/linux/vfs2/getdents.go new file mode 100644 index 000000000..ddc140b65 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/getdents.go @@ -0,0 +1,149 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// Getdents implements Linux syscall getdents(2). +func Getdents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + return getdents(t, args, false /* isGetdents64 */) +} + +// Getdents64 implements Linux syscall getdents64(2). +func Getdents64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + return getdents(t, args, true /* isGetdents64 */) +} + +func getdents(t *kernel.Task, args arch.SyscallArguments, isGetdents64 bool) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + size := int(args[2].Uint()) + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + cb := getGetdentsCallback(t, addr, size, isGetdents64) + err := file.IterDirents(t, cb) + n := size - cb.remaining + putGetdentsCallback(cb) + if n == 0 { + return 0, nil, err + } + return uintptr(n), nil, nil +} + +type getdentsCallback struct { + t *kernel.Task + addr usermem.Addr + remaining int + isGetdents64 bool +} + +var getdentsCallbackPool = sync.Pool{ + New: func() interface{} { + return &getdentsCallback{} + }, +} + +func getGetdentsCallback(t *kernel.Task, addr usermem.Addr, size int, isGetdents64 bool) *getdentsCallback { + cb := getdentsCallbackPool.Get().(*getdentsCallback) + *cb = getdentsCallback{ + t: t, + addr: addr, + remaining: size, + isGetdents64: isGetdents64, + } + return cb +} + +func putGetdentsCallback(cb *getdentsCallback) { + cb.t = nil + getdentsCallbackPool.Put(cb) +} + +// Handle implements vfs.IterDirentsCallback.Handle. +func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error { + var buf []byte + if cb.isGetdents64 { + // struct linux_dirent64 { + // ino64_t d_ino; /* 64-bit inode number */ + // off64_t d_off; /* 64-bit offset to next structure */ + // unsigned short d_reclen; /* Size of this dirent */ + // unsigned char d_type; /* File type */ + // char d_name[]; /* Filename (null-terminated) */ + // }; + size := 8 + 8 + 2 + 1 + 1 + len(dirent.Name) + if size < cb.remaining { + return syserror.EINVAL + } + buf = cb.t.CopyScratchBuffer(size) + usermem.ByteOrder.PutUint64(buf[0:8], dirent.Ino) + usermem.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff)) + usermem.ByteOrder.PutUint16(buf[16:18], uint16(size)) + buf[18] = dirent.Type + copy(buf[19:], dirent.Name) + buf[size-1] = 0 // NUL terminator + } else { + // struct linux_dirent { + // unsigned long d_ino; /* Inode number */ + // unsigned long d_off; /* Offset to next linux_dirent */ + // unsigned short d_reclen; /* Length of this linux_dirent */ + // char d_name[]; /* Filename (null-terminated) */ + // /* length is actually (d_reclen - 2 - + // offsetof(struct linux_dirent, d_name)) */ + // /* + // char pad; // Zero padding byte + // char d_type; // File type (only since Linux + // // 2.6.4); offset is (d_reclen - 1) + // */ + // }; + if cb.t.Arch().Width() != 8 { + panic(fmt.Sprintf("unsupported sizeof(unsigned long): %d", cb.t.Arch().Width())) + } + size := 8 + 8 + 2 + 1 + 1 + 1 + len(dirent.Name) + if size < cb.remaining { + return syserror.EINVAL + } + buf = cb.t.CopyScratchBuffer(size) + usermem.ByteOrder.PutUint64(buf[0:8], dirent.Ino) + usermem.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff)) + usermem.ByteOrder.PutUint16(buf[16:18], uint16(size)) + copy(buf[18:], dirent.Name) + buf[size-3] = 0 // NUL terminator + buf[size-2] = 0 // zero padding byte + buf[size-1] = dirent.Type + } + n, err := cb.t.CopyOutBytes(cb.addr, buf) + if err != nil { + // Don't report partially-written dirents by advancing cb.addr or + // cb.remaining. + return err + } + cb.addr += usermem.Addr(n) + cb.remaining -= n + return nil +} diff --git a/pkg/usermem/usermem_unsafe.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go index 876783e78..5a2418da9 100644 --- a/pkg/usermem/usermem_unsafe.go +++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go @@ -1,4 +1,4 @@ -// Copyright 2019 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,16 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -package usermem +package vfs2 import ( - "unsafe" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/syserror" ) -// stringFromImmutableBytes is equivalent to string(bs), except that it never -// copies even if escape analysis can't prove that bs does not escape. This is -// only valid if bs is never mutated after stringFromImmutableBytes returns. -func stringFromImmutableBytes(bs []byte) string { - // Compare strings.Builder.String(). - return *(*string)(unsafe.Pointer(&bs)) +// Ioctl implements Linux syscall ioctl(2). +func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + ret, err := file.Ioctl(t, t.MemoryManager(), args) + return ret, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go index e0ac32b33..7d220bc20 100644 --- a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go +++ b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +// +build amd64 + package vfs2 import ( @@ -22,110 +24,142 @@ import ( // Override syscall table to add syscalls implementations from this package. func Override(table map[uintptr]kernel.Syscall) { table[0] = syscalls.Supported("read", Read) - - // Remove syscalls that haven't been converted yet. It's better to get ENOSYS - // rather than a SIGSEGV deep in the stack. - delete(table, 1) // write - delete(table, 2) // open - delete(table, 3) // close - delete(table, 4) // stat - delete(table, 5) // fstat - delete(table, 6) // lstat - delete(table, 7) // poll - delete(table, 8) // lseek - delete(table, 9) // mmap - delete(table, 16) // ioctl - delete(table, 17) // pread64 - delete(table, 18) // pwrite64 - delete(table, 19) // readv - delete(table, 20) // writev - delete(table, 21) // access - delete(table, 22) // pipe - delete(table, 32) // dup - delete(table, 33) // dup2 - delete(table, 40) // sendfile - delete(table, 59) // execve - delete(table, 72) // fcntl - delete(table, 73) // flock - delete(table, 74) // fsync - delete(table, 75) // fdatasync - delete(table, 76) // truncate - delete(table, 77) // ftruncate - delete(table, 78) // getdents - delete(table, 79) // getcwd - delete(table, 80) // chdir - delete(table, 81) // fchdir - delete(table, 82) // rename - delete(table, 83) // mkdir - delete(table, 84) // rmdir - delete(table, 85) // creat - delete(table, 86) // link - delete(table, 87) // unlink - delete(table, 88) // symlink - delete(table, 89) // readlink - delete(table, 90) // chmod - delete(table, 91) // fchmod - delete(table, 92) // chown - delete(table, 93) // fchown - delete(table, 94) // lchown - delete(table, 133) // mknod - delete(table, 137) // statfs - delete(table, 138) // fstatfs - delete(table, 161) // chroot - delete(table, 162) // sync + table[1] = syscalls.Supported("write", Write) + table[2] = syscalls.Supported("open", Open) + table[3] = syscalls.Supported("close", Close) + table[4] = syscalls.Supported("stat", Stat) + table[5] = syscalls.Supported("fstat", Fstat) + table[6] = syscalls.Supported("lstat", Lstat) + table[7] = syscalls.Supported("poll", Poll) + table[8] = syscalls.Supported("lseek", Lseek) + table[9] = syscalls.Supported("mmap", Mmap) + table[16] = syscalls.Supported("ioctl", Ioctl) + table[17] = syscalls.Supported("pread64", Pread64) + table[18] = syscalls.Supported("pwrite64", Pwrite64) + table[19] = syscalls.Supported("readv", Readv) + table[20] = syscalls.Supported("writev", Writev) + table[21] = syscalls.Supported("access", Access) + delete(table, 22) // pipe + table[23] = syscalls.Supported("select", Select) + table[32] = syscalls.Supported("dup", Dup) + table[33] = syscalls.Supported("dup2", Dup2) + delete(table, 40) // sendfile + delete(table, 41) // socket + delete(table, 42) // connect + delete(table, 43) // accept + delete(table, 44) // sendto + delete(table, 45) // recvfrom + delete(table, 46) // sendmsg + delete(table, 47) // recvmsg + delete(table, 48) // shutdown + delete(table, 49) // bind + delete(table, 50) // listen + delete(table, 51) // getsockname + delete(table, 52) // getpeername + delete(table, 53) // socketpair + delete(table, 54) // setsockopt + delete(table, 55) // getsockopt + table[59] = syscalls.Supported("execve", Execve) + table[72] = syscalls.Supported("fcntl", Fcntl) + delete(table, 73) // flock + table[74] = syscalls.Supported("fsync", Fsync) + table[75] = syscalls.Supported("fdatasync", Fdatasync) + table[76] = syscalls.Supported("truncate", Truncate) + table[77] = syscalls.Supported("ftruncate", Ftruncate) + table[78] = syscalls.Supported("getdents", Getdents) + table[79] = syscalls.Supported("getcwd", Getcwd) + table[80] = syscalls.Supported("chdir", Chdir) + table[81] = syscalls.Supported("fchdir", Fchdir) + table[82] = syscalls.Supported("rename", Rename) + table[83] = syscalls.Supported("mkdir", Mkdir) + table[84] = syscalls.Supported("rmdir", Rmdir) + table[85] = syscalls.Supported("creat", Creat) + table[86] = syscalls.Supported("link", Link) + table[87] = syscalls.Supported("unlink", Unlink) + table[88] = syscalls.Supported("symlink", Symlink) + table[89] = syscalls.Supported("readlink", Readlink) + table[90] = syscalls.Supported("chmod", Chmod) + table[91] = syscalls.Supported("fchmod", Fchmod) + table[92] = syscalls.Supported("chown", Chown) + table[93] = syscalls.Supported("fchown", Fchown) + table[94] = syscalls.Supported("lchown", Lchown) + table[132] = syscalls.Supported("utime", Utime) + table[133] = syscalls.Supported("mknod", Mknod) + table[137] = syscalls.Supported("statfs", Statfs) + table[138] = syscalls.Supported("fstatfs", Fstatfs) + table[161] = syscalls.Supported("chroot", Chroot) + table[162] = syscalls.Supported("sync", Sync) delete(table, 165) // mount delete(table, 166) // umount2 - delete(table, 172) // iopl - delete(table, 173) // ioperm delete(table, 187) // readahead - delete(table, 188) // setxattr - delete(table, 189) // lsetxattr - delete(table, 190) // fsetxattr - delete(table, 191) // getxattr - delete(table, 192) // lgetxattr - delete(table, 193) // fgetxattr + table[188] = syscalls.Supported("setxattr", Setxattr) + table[189] = syscalls.Supported("lsetxattr", Lsetxattr) + table[190] = syscalls.Supported("fsetxattr", Fsetxattr) + table[191] = syscalls.Supported("getxattr", Getxattr) + table[192] = syscalls.Supported("lgetxattr", Lgetxattr) + table[193] = syscalls.Supported("fgetxattr", Fgetxattr) + table[194] = syscalls.Supported("listxattr", Listxattr) + table[195] = syscalls.Supported("llistxattr", Llistxattr) + table[196] = syscalls.Supported("flistxattr", Flistxattr) + table[197] = syscalls.Supported("removexattr", Removexattr) + table[198] = syscalls.Supported("lremovexattr", Lremovexattr) + table[199] = syscalls.Supported("fremovexattr", Fremovexattr) delete(table, 206) // io_setup delete(table, 207) // io_destroy delete(table, 208) // io_getevents delete(table, 209) // io_submit delete(table, 210) // io_cancel - delete(table, 213) // epoll_create - delete(table, 214) // epoll_ctl_old - delete(table, 215) // epoll_wait_old - delete(table, 216) // remap_file_pages - delete(table, 217) // getdents64 - delete(table, 232) // epoll_wait - delete(table, 233) // epoll_ctl + table[213] = syscalls.Supported("epoll_create", EpollCreate) + table[217] = syscalls.Supported("getdents64", Getdents64) + delete(table, 221) // fdavise64 + table[232] = syscalls.Supported("epoll_wait", EpollWait) + table[233] = syscalls.Supported("epoll_ctl", EpollCtl) + table[235] = syscalls.Supported("utimes", Utimes) delete(table, 253) // inotify_init delete(table, 254) // inotify_add_watch delete(table, 255) // inotify_rm_watch - delete(table, 257) // openat - delete(table, 258) // mkdirat - delete(table, 259) // mknodat - delete(table, 260) // fchownat - delete(table, 261) // futimesat - delete(table, 262) // fstatat - delete(table, 263) // unlinkat - delete(table, 264) // renameat - delete(table, 265) // linkat - delete(table, 266) // symlinkat - delete(table, 267) // readlinkat - delete(table, 268) // fchmodat - delete(table, 269) // faccessat - delete(table, 270) // pselect - delete(table, 271) // ppoll + table[257] = syscalls.Supported("openat", Openat) + table[258] = syscalls.Supported("mkdirat", Mkdirat) + table[259] = syscalls.Supported("mknodat", Mknodat) + table[260] = syscalls.Supported("fchownat", Fchownat) + table[261] = syscalls.Supported("futimens", Futimens) + table[262] = syscalls.Supported("newfstatat", Newfstatat) + table[263] = syscalls.Supported("unlinkat", Unlinkat) + table[264] = syscalls.Supported("renameat", Renameat) + table[265] = syscalls.Supported("linkat", Linkat) + table[266] = syscalls.Supported("symlinkat", Symlinkat) + table[267] = syscalls.Supported("readlinkat", Readlinkat) + table[268] = syscalls.Supported("fchmodat", Fchmodat) + table[269] = syscalls.Supported("faccessat", Faccessat) + table[270] = syscalls.Supported("pselect", Pselect) + table[271] = syscalls.Supported("ppoll", Ppoll) + delete(table, 275) // splice + delete(table, 276) // tee + table[277] = syscalls.Supported("sync_file_range", SyncFileRange) + table[280] = syscalls.Supported("utimensat", Utimensat) + table[281] = syscalls.Supported("epoll_pwait", EpollPwait) + delete(table, 282) // signalfd + delete(table, 283) // timerfd_create + delete(table, 284) // eventfd delete(table, 285) // fallocate - delete(table, 291) // epoll_create1 - delete(table, 292) // dup3 + delete(table, 286) // timerfd_settime + delete(table, 287) // timerfd_gettime + delete(table, 288) // accept4 + delete(table, 289) // signalfd4 + delete(table, 290) // eventfd2 + table[291] = syscalls.Supported("epoll_create1", EpollCreate1) + table[292] = syscalls.Supported("dup3", Dup3) delete(table, 293) // pipe2 delete(table, 294) // inotify_init1 - delete(table, 295) // preadv - delete(table, 296) // pwritev - delete(table, 306) // syncfs - delete(table, 316) // renameat2 + table[295] = syscalls.Supported("preadv", Preadv) + table[296] = syscalls.Supported("pwritev", Pwritev) + delete(table, 299) // recvmmsg + table[306] = syscalls.Supported("syncfs", Syncfs) + delete(table, 307) // sendmmsg + table[316] = syscalls.Supported("renameat2", Renameat2) delete(table, 319) // memfd_create - delete(table, 322) // execveat - delete(table, 327) // preadv2 - delete(table, 328) // pwritev2 - delete(table, 332) // statx + table[322] = syscalls.Supported("execveat", Execveat) + table[327] = syscalls.Supported("preadv2", Preadv2) + table[328] = syscalls.Supported("pwritev2", Pwritev2) + table[332] = syscalls.Supported("statx", Statx) } diff --git a/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go b/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go index 6af5c400f..a6b367468 100644 --- a/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go +++ b/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +// +build arm64 + package vfs2 import ( diff --git a/pkg/sentry/syscalls/linux/vfs2/mmap.go b/pkg/sentry/syscalls/linux/vfs2/mmap.go new file mode 100644 index 000000000..60a43f0a0 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/mmap.go @@ -0,0 +1,92 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/memmap" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// Mmap implements Linux syscall mmap(2). +func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + prot := args[2].Int() + flags := args[3].Int() + fd := args[4].Int() + fixed := flags&linux.MAP_FIXED != 0 + private := flags&linux.MAP_PRIVATE != 0 + shared := flags&linux.MAP_SHARED != 0 + anon := flags&linux.MAP_ANONYMOUS != 0 + map32bit := flags&linux.MAP_32BIT != 0 + + // Require exactly one of MAP_PRIVATE and MAP_SHARED. + if private == shared { + return 0, nil, syserror.EINVAL + } + + opts := memmap.MMapOpts{ + Length: args[1].Uint64(), + Offset: args[5].Uint64(), + Addr: args[0].Pointer(), + Fixed: fixed, + Unmap: fixed, + Map32Bit: map32bit, + Private: private, + Perms: usermem.AccessType{ + Read: linux.PROT_READ&prot != 0, + Write: linux.PROT_WRITE&prot != 0, + Execute: linux.PROT_EXEC&prot != 0, + }, + MaxPerms: usermem.AnyAccess, + GrowsDown: linux.MAP_GROWSDOWN&flags != 0, + Precommit: linux.MAP_POPULATE&flags != 0, + } + if linux.MAP_LOCKED&flags != 0 { + opts.MLockMode = memmap.MLockEager + } + defer func() { + if opts.MappingIdentity != nil { + opts.MappingIdentity.DecRef() + } + }() + + if !anon { + // Convert the passed FD to a file reference. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // mmap unconditionally requires that the FD is readable. + if !file.IsReadable() { + return 0, nil, syserror.EACCES + } + // MAP_SHARED requires that the FD be writable for PROT_WRITE. + if shared && !file.IsWritable() { + opts.MaxPerms.Write = false + } + + if err := file.ConfigureMMap(t, &opts); err != nil { + return 0, nil, err + } + } + + rv, err := t.MemoryManager().MMap(t, opts) + return uintptr(rv), nil, err +} diff --git a/pkg/sentry/syscalls/linux/vfs2/path.go b/pkg/sentry/syscalls/linux/vfs2/path.go new file mode 100644 index 000000000..97da6c647 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/path.go @@ -0,0 +1,94 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +func copyInPath(t *kernel.Task, addr usermem.Addr) (fspath.Path, error) { + pathname, err := t.CopyInString(addr, linux.PATH_MAX) + if err != nil { + return fspath.Path{}, err + } + return fspath.Parse(pathname), nil +} + +type taskPathOperation struct { + pop vfs.PathOperation + haveStartRef bool +} + +func getTaskPathOperation(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPath shouldAllowEmptyPath, shouldFollowFinalSymlink shouldFollowFinalSymlink) (taskPathOperation, error) { + root := t.FSContext().RootDirectoryVFS2() + start := root + haveStartRef := false + if !path.Absolute { + if !path.HasComponents() && !bool(shouldAllowEmptyPath) { + root.DecRef() + return taskPathOperation{}, syserror.ENOENT + } + if dirfd == linux.AT_FDCWD { + start = t.FSContext().WorkingDirectoryVFS2() + haveStartRef = true + } else { + dirfile := t.GetFileVFS2(dirfd) + if dirfile == nil { + root.DecRef() + return taskPathOperation{}, syserror.EBADF + } + start = dirfile.VirtualDentry() + start.IncRef() + haveStartRef = true + dirfile.DecRef() + } + } + return taskPathOperation{ + pop: vfs.PathOperation{ + Root: root, + Start: start, + Path: path, + FollowFinalSymlink: bool(shouldFollowFinalSymlink), + }, + haveStartRef: haveStartRef, + }, nil +} + +func (tpop *taskPathOperation) Release() { + tpop.pop.Root.DecRef() + if tpop.haveStartRef { + tpop.pop.Start.DecRef() + tpop.haveStartRef = false + } +} + +type shouldAllowEmptyPath bool + +const ( + disallowEmptyPath shouldAllowEmptyPath = false + allowEmptyPath shouldAllowEmptyPath = true +) + +type shouldFollowFinalSymlink bool + +const ( + nofollowFinalSymlink shouldFollowFinalSymlink = false + followFinalSymlink shouldFollowFinalSymlink = true +) diff --git a/pkg/sentry/syscalls/linux/vfs2/poll.go b/pkg/sentry/syscalls/linux/vfs2/poll.go new file mode 100644 index 000000000..dbf4882da --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/poll.go @@ -0,0 +1,584 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "fmt" + "time" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sentry/limits" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +// fileCap is the maximum allowable files for poll & select. This has no +// equivalent in Linux; it exists in gVisor since allocation failure in Go is +// unrecoverable. +const fileCap = 1024 * 1024 + +// Masks for "readable", "writable", and "exceptional" events as defined by +// select(2). +const ( + // selectReadEvents is analogous to the Linux kernel's + // fs/select.c:POLLIN_SET. + selectReadEvents = linux.POLLIN | linux.POLLHUP | linux.POLLERR + + // selectWriteEvents is analogous to the Linux kernel's + // fs/select.c:POLLOUT_SET. + selectWriteEvents = linux.POLLOUT | linux.POLLERR + + // selectExceptEvents is analogous to the Linux kernel's + // fs/select.c:POLLEX_SET. + selectExceptEvents = linux.POLLPRI +) + +// pollState tracks the associated file description and waiter of a PollFD. +type pollState struct { + file *vfs.FileDescription + waiter waiter.Entry +} + +// initReadiness gets the current ready mask for the file represented by the FD +// stored in pfd.FD. If a channel is passed in, the waiter entry in "state" is +// used to register with the file for event notifications, and a reference to +// the file is stored in "state". +func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan struct{}) { + if pfd.FD < 0 { + pfd.REvents = 0 + return + } + + file := t.GetFileVFS2(pfd.FD) + if file == nil { + pfd.REvents = linux.POLLNVAL + return + } + + if ch == nil { + defer file.DecRef() + } else { + state.file = file + state.waiter, _ = waiter.NewChannelEntry(ch) + file.EventRegister(&state.waiter, waiter.EventMaskFromLinux(uint32(pfd.Events))) + } + + r := file.Readiness(waiter.EventMaskFromLinux(uint32(pfd.Events))) + pfd.REvents = int16(r.ToLinux()) & pfd.Events +} + +// releaseState releases all the pollState in "state". +func releaseState(state []pollState) { + for i := range state { + if state[i].file != nil { + state[i].file.EventUnregister(&state[i].waiter) + state[i].file.DecRef() + } + } +} + +// pollBlock polls the PollFDs in "pfd" with a bounded time specified in "timeout" +// when "timeout" is greater than zero. +// +// pollBlock returns the remaining timeout, which is always 0 on a timeout; and 0 or +// positive if interrupted by a signal. +func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time.Duration, uintptr, error) { + var ch chan struct{} + if timeout != 0 { + ch = make(chan struct{}, 1) + } + + // Register for event notification in the files involved if we may + // block (timeout not zero). Once we find a file that has a non-zero + // result, we stop registering for events but still go through all files + // to get their ready masks. + state := make([]pollState, len(pfd)) + defer releaseState(state) + n := uintptr(0) + for i := range pfd { + initReadiness(t, &pfd[i], &state[i], ch) + if pfd[i].REvents != 0 { + n++ + ch = nil + } + } + + if timeout == 0 { + return timeout, n, nil + } + + haveTimeout := timeout >= 0 + + for n == 0 { + var err error + // Wait for a notification. + timeout, err = t.BlockWithTimeout(ch, haveTimeout, timeout) + if err != nil { + if err == syserror.ETIMEDOUT { + err = nil + } + return timeout, 0, err + } + + // We got notified, count how many files are ready. If none, + // then this was a spurious notification, and we just go back + // to sleep with the remaining timeout. + for i := range state { + if state[i].file == nil { + continue + } + + r := state[i].file.Readiness(waiter.EventMaskFromLinux(uint32(pfd[i].Events))) + rl := int16(r.ToLinux()) & pfd[i].Events + if rl != 0 { + pfd[i].REvents = rl + n++ + } + } + } + + return timeout, n, nil +} + +// copyInPollFDs copies an array of struct pollfd unless nfds exceeds the max. +func copyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD, error) { + if uint64(nfds) > t.ThreadGroup().Limits().GetCapped(limits.NumberOfFiles, fileCap) { + return nil, syserror.EINVAL + } + + pfd := make([]linux.PollFD, nfds) + if nfds > 0 { + if _, err := t.CopyIn(addr, &pfd); err != nil { + return nil, err + } + } + + return pfd, nil +} + +func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) (time.Duration, uintptr, error) { + pfd, err := copyInPollFDs(t, addr, nfds) + if err != nil { + return timeout, 0, err + } + + // Compatibility warning: Linux adds POLLHUP and POLLERR just before + // polling, in fs/select.c:do_pollfd(). Since pfd is copied out after + // polling, changing event masks here is an application-visible difference. + // (Linux also doesn't copy out event masks at all, only revents.) + for i := range pfd { + pfd[i].Events |= linux.POLLHUP | linux.POLLERR + } + remainingTimeout, n, err := pollBlock(t, pfd, timeout) + err = syserror.ConvertIntr(err, syserror.EINTR) + + // The poll entries are copied out regardless of whether + // any are set or not. This aligns with the Linux behavior. + if nfds > 0 && err == nil { + if _, err := t.CopyOut(addr, pfd); err != nil { + return remainingTimeout, 0, err + } + } + + return remainingTimeout, n, err +} + +// CopyInFDSet copies an fd set from select(2)/pselect(2). +func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialByte int) ([]byte, error) { + set := make([]byte, nBytes) + + if addr != 0 { + if _, err := t.CopyIn(addr, &set); err != nil { + return nil, err + } + // If we only use part of the last byte, mask out the extraneous bits. + // + // N.B. This only works on little-endian architectures. + if nBitsInLastPartialByte != 0 { + set[nBytes-1] &^= byte(0xff) << nBitsInLastPartialByte + } + } + return set, nil +} + +func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Addr, timeout time.Duration) (uintptr, error) { + if nfds < 0 || nfds > fileCap { + return 0, syserror.EINVAL + } + + // Calculate the size of the fd sets (one bit per fd). + nBytes := (nfds + 7) / 8 + nBitsInLastPartialByte := nfds % 8 + + // Capture all the provided input vectors. + r, err := CopyInFDSet(t, readFDs, nBytes, nBitsInLastPartialByte) + if err != nil { + return 0, err + } + w, err := CopyInFDSet(t, writeFDs, nBytes, nBitsInLastPartialByte) + if err != nil { + return 0, err + } + e, err := CopyInFDSet(t, exceptFDs, nBytes, nBitsInLastPartialByte) + if err != nil { + return 0, err + } + + // Count how many FDs are actually being requested so that we can build + // a PollFD array. + fdCount := 0 + for i := 0; i < nBytes; i++ { + v := r[i] | w[i] | e[i] + for v != 0 { + v &= (v - 1) + fdCount++ + } + } + + // Build the PollFD array. + pfd := make([]linux.PollFD, 0, fdCount) + var fd int32 + for i := 0; i < nBytes; i++ { + rV, wV, eV := r[i], w[i], e[i] + v := rV | wV | eV + m := byte(1) + for j := 0; j < 8; j++ { + if (v & m) != 0 { + // Make sure the fd is valid and decrement the reference + // immediately to ensure we don't leak. Note, another thread + // might be about to close fd. This is racy, but that's + // OK. Linux is racy in the same way. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, syserror.EBADF + } + file.DecRef() + + var mask int16 + if (rV & m) != 0 { + mask |= selectReadEvents + } + + if (wV & m) != 0 { + mask |= selectWriteEvents + } + + if (eV & m) != 0 { + mask |= selectExceptEvents + } + + pfd = append(pfd, linux.PollFD{ + FD: fd, + Events: mask, + }) + } + + fd++ + m <<= 1 + } + } + + // Do the syscall, then count the number of bits set. + if _, _, err = pollBlock(t, pfd, timeout); err != nil { + return 0, syserror.ConvertIntr(err, syserror.EINTR) + } + + // r, w, and e are currently event mask bitsets; unset bits corresponding + // to events that *didn't* occur. + bitSetCount := uintptr(0) + for idx := range pfd { + events := pfd[idx].REvents + i, j := pfd[idx].FD/8, uint(pfd[idx].FD%8) + m := byte(1) << j + if r[i]&m != 0 { + if (events & selectReadEvents) != 0 { + bitSetCount++ + } else { + r[i] &^= m + } + } + if w[i]&m != 0 { + if (events & selectWriteEvents) != 0 { + bitSetCount++ + } else { + w[i] &^= m + } + } + if e[i]&m != 0 { + if (events & selectExceptEvents) != 0 { + bitSetCount++ + } else { + e[i] &^= m + } + } + } + + // Copy updated vectors back. + if readFDs != 0 { + if _, err := t.CopyOut(readFDs, r); err != nil { + return 0, err + } + } + + if writeFDs != 0 { + if _, err := t.CopyOut(writeFDs, w); err != nil { + return 0, err + } + } + + if exceptFDs != 0 { + if _, err := t.CopyOut(exceptFDs, e); err != nil { + return 0, err + } + } + + return bitSetCount, nil +} + +// timeoutRemaining returns the amount of time remaining for the specified +// timeout or 0 if it has elapsed. +// +// startNs must be from CLOCK_MONOTONIC. +func timeoutRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration) time.Duration { + now := t.Kernel().MonotonicClock().Now() + remaining := timeout - now.Sub(startNs) + if remaining < 0 { + remaining = 0 + } + return remaining +} + +// copyOutTimespecRemaining copies the time remaining in timeout to timespecAddr. +// +// startNs must be from CLOCK_MONOTONIC. +func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timespecAddr usermem.Addr) error { + if timeout <= 0 { + return nil + } + remaining := timeoutRemaining(t, startNs, timeout) + tsRemaining := linux.NsecToTimespec(remaining.Nanoseconds()) + return tsRemaining.CopyOut(t, timespecAddr) +} + +// copyOutTimevalRemaining copies the time remaining in timeout to timevalAddr. +// +// startNs must be from CLOCK_MONOTONIC. +func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timevalAddr usermem.Addr) error { + if timeout <= 0 { + return nil + } + remaining := timeoutRemaining(t, startNs, timeout) + tvRemaining := linux.NsecToTimeval(remaining.Nanoseconds()) + return tvRemaining.CopyOut(t, timevalAddr) +} + +// pollRestartBlock encapsulates the state required to restart poll(2) via +// restart_syscall(2). +// +// +stateify savable +type pollRestartBlock struct { + pfdAddr usermem.Addr + nfds uint + timeout time.Duration +} + +// Restart implements kernel.SyscallRestartBlock.Restart. +func (p *pollRestartBlock) Restart(t *kernel.Task) (uintptr, error) { + return poll(t, p.pfdAddr, p.nfds, p.timeout) +} + +func poll(t *kernel.Task, pfdAddr usermem.Addr, nfds uint, timeout time.Duration) (uintptr, error) { + remainingTimeout, n, err := doPoll(t, pfdAddr, nfds, timeout) + // On an interrupt poll(2) is restarted with the remaining timeout. + if err == syserror.EINTR { + t.SetSyscallRestartBlock(&pollRestartBlock{ + pfdAddr: pfdAddr, + nfds: nfds, + timeout: remainingTimeout, + }) + return 0, kernel.ERESTART_RESTARTBLOCK + } + return n, err +} + +// Poll implements linux syscall poll(2). +func Poll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pfdAddr := args[0].Pointer() + nfds := uint(args[1].Uint()) // poll(2) uses unsigned long. + timeout := time.Duration(args[2].Int()) * time.Millisecond + n, err := poll(t, pfdAddr, nfds, timeout) + return n, nil, err +} + +// Ppoll implements linux syscall ppoll(2). +func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pfdAddr := args[0].Pointer() + nfds := uint(args[1].Uint()) // poll(2) uses unsigned long. + timespecAddr := args[2].Pointer() + maskAddr := args[3].Pointer() + maskSize := uint(args[4].Uint()) + + timeout, err := copyTimespecInToDuration(t, timespecAddr) + if err != nil { + return 0, nil, err + } + + var startNs ktime.Time + if timeout > 0 { + startNs = t.Kernel().MonotonicClock().Now() + } + + if err := setTempSignalSet(t, maskAddr, maskSize); err != nil { + return 0, nil, err + } + + _, n, err := doPoll(t, pfdAddr, nfds, timeout) + copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr) + // doPoll returns EINTR if interrupted, but ppoll is normally restartable + // if interrupted by something other than a signal handled by the + // application (i.e. returns ERESTARTNOHAND). However, if + // copyOutTimespecRemaining failed, then the restarted ppoll would use the + // wrong timeout, so the error should be left as EINTR. + // + // Note that this means that if err is nil but copyErr is not, copyErr is + // ignored. This is consistent with Linux. + if err == syserror.EINTR && copyErr == nil { + err = kernel.ERESTARTNOHAND + } + return n, nil, err +} + +// Select implements linux syscall select(2). +func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + nfds := int(args[0].Int()) // select(2) uses an int. + readFDs := args[1].Pointer() + writeFDs := args[2].Pointer() + exceptFDs := args[3].Pointer() + timevalAddr := args[4].Pointer() + + // Use a negative Duration to indicate "no timeout". + timeout := time.Duration(-1) + if timevalAddr != 0 { + var timeval linux.Timeval + if err := timeval.CopyIn(t, timevalAddr); err != nil { + return 0, nil, err + } + if timeval.Sec < 0 || timeval.Usec < 0 { + return 0, nil, syserror.EINVAL + } + timeout = time.Duration(timeval.ToNsecCapped()) + } + startNs := t.Kernel().MonotonicClock().Now() + n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout) + copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr) + // See comment in Ppoll. + if err == syserror.EINTR && copyErr == nil { + err = kernel.ERESTARTNOHAND + } + return n, nil, err +} + +// Pselect implements linux syscall pselect(2). +func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + nfds := int(args[0].Int()) // select(2) uses an int. + readFDs := args[1].Pointer() + writeFDs := args[2].Pointer() + exceptFDs := args[3].Pointer() + timespecAddr := args[4].Pointer() + maskWithSizeAddr := args[5].Pointer() + + timeout, err := copyTimespecInToDuration(t, timespecAddr) + if err != nil { + return 0, nil, err + } + + var startNs ktime.Time + if timeout > 0 { + startNs = t.Kernel().MonotonicClock().Now() + } + + if maskWithSizeAddr != 0 { + if t.Arch().Width() != 8 { + panic(fmt.Sprintf("unsupported sizeof(void*): %d", t.Arch().Width())) + } + var maskStruct sigSetWithSize + if err := maskStruct.CopyIn(t, maskWithSizeAddr); err != nil { + return 0, nil, err + } + if err := setTempSignalSet(t, usermem.Addr(maskStruct.sigsetAddr), uint(maskStruct.sizeofSigset)); err != nil { + return 0, nil, err + } + } + + n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout) + copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr) + // See comment in Ppoll. + if err == syserror.EINTR && copyErr == nil { + err = kernel.ERESTARTNOHAND + } + return n, nil, err +} + +// +marshal +type sigSetWithSize struct { + sigsetAddr uint64 + sizeofSigset uint64 +} + +// copyTimespecInToDuration copies a Timespec from the untrusted app range, +// validates it and converts it to a Duration. +// +// If the Timespec is larger than what can be represented in a Duration, the +// returned value is the maximum that Duration will allow. +// +// If timespecAddr is NULL, the returned value is negative. +func copyTimespecInToDuration(t *kernel.Task, timespecAddr usermem.Addr) (time.Duration, error) { + // Use a negative Duration to indicate "no timeout". + timeout := time.Duration(-1) + if timespecAddr != 0 { + var timespec linux.Timespec + if err := timespec.CopyIn(t, timespecAddr); err != nil { + return 0, err + } + if !timespec.Valid() { + return 0, syserror.EINVAL + } + timeout = time.Duration(timespec.ToNsecCapped()) + } + return timeout, nil +} + +func setTempSignalSet(t *kernel.Task, maskAddr usermem.Addr, maskSize uint) error { + if maskAddr == 0 { + return nil + } + if maskSize != linux.SignalSetSize { + return syserror.EINVAL + } + var mask linux.SignalSet + if err := mask.CopyIn(t, maskAddr); err != nil { + return err + } + mask &^= kernel.UnblockableSignals + oldmask := t.SignalMask() + t.SetSignalMask(mask) + t.SetSavedSignalMask(oldmask) + return nil +} diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go new file mode 100644 index 000000000..35f6308d6 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go @@ -0,0 +1,511 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + eventMaskRead = waiter.EventIn | waiter.EventHUp | waiter.EventErr + eventMaskWrite = waiter.EventOut | waiter.EventHUp | waiter.EventErr +) + +// Read implements Linux syscall read(2). +func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + size := args[2].SizeT() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Check that the size is legitimate. + si := int(size) + if si < 0 { + return 0, nil, syserror.EINVAL + } + + // Get the destination of the read. + dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, nil, err + } + + n, err := read(t, file, dst, vfs.ReadOptions{}) + t.IOUsage().AccountReadSyscall(n) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "read", file) +} + +// Readv implements Linux syscall readv(2). +func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + iovcnt := int(args[2].Int()) + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Get the destination of the read. + dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, nil, err + } + + n, err := read(t, file, dst, vfs.ReadOptions{}) + t.IOUsage().AccountReadSyscall(n) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "readv", file) +} + +func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + n, err := file.Read(t, dst, opts) + if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 { + return n, err + } + + // Register for notifications. + w, ch := waiter.NewChannelEntry(nil) + file.EventRegister(&w, eventMaskRead) + + total := n + for { + // Shorten dst to reflect bytes previously read. + dst = dst.DropFirst(int(n)) + + // Issue the request and break out if it completes with anything other than + // "would block". + n, err := file.Read(t, dst, opts) + total += n + if err != syserror.ErrWouldBlock { + break + } + if err := t.Block(ch); err != nil { + break + } + } + file.EventUnregister(&w) + + return total, err +} + +// Pread64 implements Linux syscall pread64(2). +func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + size := args[2].SizeT() + offset := args[3].Int64() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Check that the offset is legitimate. + if offset < 0 { + return 0, nil, syserror.EINVAL + } + + // Check that the size is legitimate. + si := int(size) + if si < 0 { + return 0, nil, syserror.EINVAL + } + + // Get the destination of the read. + dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, nil, err + } + + n, err := pread(t, file, dst, offset, vfs.ReadOptions{}) + t.IOUsage().AccountReadSyscall(n) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pread64", file) +} + +// Preadv implements Linux syscall preadv(2). +func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + iovcnt := int(args[2].Int()) + offset := args[3].Int64() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Check that the offset is legitimate. + if offset < 0 { + return 0, nil, syserror.EINVAL + } + + // Get the destination of the read. + dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, nil, err + } + + n, err := pread(t, file, dst, offset, vfs.ReadOptions{}) + t.IOUsage().AccountReadSyscall(n) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "preadv", file) +} + +// Preadv2 implements Linux syscall preadv2(2). +func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + // While the glibc signature is + // preadv2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags) + // the actual syscall + // (https://elixir.bootlin.com/linux/v5.5/source/fs/read_write.c#L1142) + // splits the offset argument into a high/low value for compatibility with + // 32-bit architectures. The flags argument is the 6th argument (index 5). + fd := args[0].Int() + addr := args[1].Pointer() + iovcnt := int(args[2].Int()) + offset := args[3].Int64() + flags := args[5].Int() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Check that the offset is legitimate. + if offset < -1 { + return 0, nil, syserror.EINVAL + } + + // Get the destination of the read. + dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, nil, err + } + + opts := vfs.ReadOptions{ + Flags: uint32(flags), + } + var n int64 + if offset == -1 { + n, err = read(t, file, dst, opts) + } else { + n, err = pread(t, file, dst, offset, opts) + } + t.IOUsage().AccountReadSyscall(n) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "preadv2", file) +} + +func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + n, err := file.PRead(t, dst, offset, opts) + if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 { + return n, err + } + + // Register for notifications. + w, ch := waiter.NewChannelEntry(nil) + file.EventRegister(&w, eventMaskRead) + + total := n + for { + // Shorten dst to reflect bytes previously read. + dst = dst.DropFirst(int(n)) + + // Issue the request and break out if it completes with anything other than + // "would block". + n, err := file.PRead(t, dst, offset+total, opts) + total += n + if err != syserror.ErrWouldBlock { + break + } + if err := t.Block(ch); err != nil { + break + } + } + file.EventUnregister(&w) + + return total, err +} + +// Write implements Linux syscall write(2). +func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + size := args[2].SizeT() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Check that the size is legitimate. + si := int(size) + if si < 0 { + return 0, nil, syserror.EINVAL + } + + // Get the source of the write. + src, err := t.SingleIOSequence(addr, si, usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, nil, err + } + + n, err := write(t, file, src, vfs.WriteOptions{}) + t.IOUsage().AccountWriteSyscall(n) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "write", file) +} + +// Writev implements Linux syscall writev(2). +func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + iovcnt := int(args[2].Int()) + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Get the source of the write. + src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, nil, err + } + + n, err := write(t, file, src, vfs.WriteOptions{}) + t.IOUsage().AccountWriteSyscall(n) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "writev", file) +} + +func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + n, err := file.Write(t, src, opts) + if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 { + return n, err + } + + // Register for notifications. + w, ch := waiter.NewChannelEntry(nil) + file.EventRegister(&w, eventMaskWrite) + + total := n + for { + // Shorten src to reflect bytes previously written. + src = src.DropFirst(int(n)) + + // Issue the request and break out if it completes with anything other than + // "would block". + n, err := file.Write(t, src, opts) + total += n + if err != syserror.ErrWouldBlock { + break + } + if err := t.Block(ch); err != nil { + break + } + } + file.EventUnregister(&w) + + return total, err +} + +// Pwrite64 implements Linux syscall pwrite64(2). +func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + size := args[2].SizeT() + offset := args[3].Int64() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Check that the offset is legitimate. + if offset < 0 { + return 0, nil, syserror.EINVAL + } + + // Check that the size is legitimate. + si := int(size) + if si < 0 { + return 0, nil, syserror.EINVAL + } + + // Get the source of the write. + src, err := t.SingleIOSequence(addr, si, usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, nil, err + } + + n, err := pwrite(t, file, src, offset, vfs.WriteOptions{}) + t.IOUsage().AccountWriteSyscall(n) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwrite64", file) +} + +// Pwritev implements Linux syscall pwritev(2). +func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + iovcnt := int(args[2].Int()) + offset := args[3].Int64() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Check that the offset is legitimate. + if offset < 0 { + return 0, nil, syserror.EINVAL + } + + // Get the source of the write. + src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, nil, err + } + + n, err := pwrite(t, file, src, offset, vfs.WriteOptions{}) + t.IOUsage().AccountReadSyscall(n) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwritev", file) +} + +// Pwritev2 implements Linux syscall pwritev2(2). +func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + // While the glibc signature is + // pwritev2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags) + // the actual syscall + // (https://elixir.bootlin.com/linux/v5.5/source/fs/read_write.c#L1162) + // splits the offset argument into a high/low value for compatibility with + // 32-bit architectures. The flags argument is the 6th argument (index 5). + fd := args[0].Int() + addr := args[1].Pointer() + iovcnt := int(args[2].Int()) + offset := args[3].Int64() + flags := args[5].Int() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Check that the offset is legitimate. + if offset < -1 { + return 0, nil, syserror.EINVAL + } + + // Get the source of the write. + src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, nil, err + } + + opts := vfs.WriteOptions{ + Flags: uint32(flags), + } + var n int64 + if offset == -1 { + n, err = write(t, file, src, opts) + } else { + n, err = pwrite(t, file, src, offset, opts) + } + t.IOUsage().AccountWriteSyscall(n) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwritev2", file) +} + +func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + n, err := file.PWrite(t, src, offset, opts) + if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 { + return n, err + } + + // Register for notifications. + w, ch := waiter.NewChannelEntry(nil) + file.EventRegister(&w, eventMaskWrite) + + total := n + for { + // Shorten src to reflect bytes previously written. + src = src.DropFirst(int(n)) + + // Issue the request and break out if it completes with anything other than + // "would block". + n, err := file.PWrite(t, src, offset+total, opts) + total += n + if err != syserror.ErrWouldBlock { + break + } + if err := t.Block(ch); err != nil { + break + } + } + file.EventUnregister(&w) + + return total, err +} + +// Lseek implements Linux syscall lseek(2). +func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + offset := args[1].Int64() + whence := args[2].Int() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + newoff, err := file.Seek(t, offset, whence) + return uintptr(newoff), nil, err +} diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go new file mode 100644 index 000000000..9250659ff --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go @@ -0,0 +1,380 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +const chmodMask = 0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX + +// Chmod implements Linux syscall chmod(2). +func Chmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + mode := args[1].ModeT() + return 0, nil, fchmodat(t, linux.AT_FDCWD, pathAddr, mode) +} + +// Fchmodat implements Linux syscall fchmodat(2). +func Fchmodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirfd := args[0].Int() + pathAddr := args[1].Pointer() + mode := args[2].ModeT() + return 0, nil, fchmodat(t, dirfd, pathAddr, mode) +} + +func fchmodat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, mode uint) error { + path, err := copyInPath(t, pathAddr) + if err != nil { + return err + } + + return setstatat(t, dirfd, path, disallowEmptyPath, followFinalSymlink, &vfs.SetStatOptions{ + Stat: linux.Statx{ + Mask: linux.STATX_MODE, + Mode: uint16(mode & chmodMask), + }, + }) +} + +// Fchmod implements Linux syscall fchmod(2). +func Fchmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + mode := args[1].ModeT() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + return 0, nil, file.SetStat(t, vfs.SetStatOptions{ + Stat: linux.Statx{ + Mask: linux.STATX_MODE, + Mode: uint16(mode & chmodMask), + }, + }) +} + +// Chown implements Linux syscall chown(2). +func Chown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + owner := args[1].Int() + group := args[2].Int() + return 0, nil, fchownat(t, linux.AT_FDCWD, pathAddr, owner, group, 0 /* flags */) +} + +// Lchown implements Linux syscall lchown(2). +func Lchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + owner := args[1].Int() + group := args[2].Int() + return 0, nil, fchownat(t, linux.AT_FDCWD, pathAddr, owner, group, linux.AT_SYMLINK_NOFOLLOW) +} + +// Fchownat implements Linux syscall fchownat(2). +func Fchownat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirfd := args[0].Int() + pathAddr := args[1].Pointer() + owner := args[2].Int() + group := args[3].Int() + flags := args[4].Int() + return 0, nil, fchownat(t, dirfd, pathAddr, owner, group, flags) +} + +func fchownat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, owner, group, flags int32) error { + if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 { + return syserror.EINVAL + } + + path, err := copyInPath(t, pathAddr) + if err != nil { + return err + } + + var opts vfs.SetStatOptions + if err := populateSetStatOptionsForChown(t, owner, group, &opts); err != nil { + return err + } + + return setstatat(t, dirfd, path, shouldAllowEmptyPath(flags&linux.AT_EMPTY_PATH != 0), shouldFollowFinalSymlink(flags&linux.AT_SYMLINK_NOFOLLOW == 0), &opts) +} + +func populateSetStatOptionsForChown(t *kernel.Task, owner, group int32, opts *vfs.SetStatOptions) error { + userns := t.UserNamespace() + if owner != -1 { + kuid := userns.MapToKUID(auth.UID(owner)) + if !kuid.Ok() { + return syserror.EINVAL + } + opts.Stat.Mask |= linux.STATX_UID + opts.Stat.UID = uint32(kuid) + } + if group != -1 { + kgid := userns.MapToKGID(auth.GID(group)) + if !kgid.Ok() { + return syserror.EINVAL + } + opts.Stat.Mask |= linux.STATX_GID + opts.Stat.GID = uint32(kgid) + } + return nil +} + +// Fchown implements Linux syscall fchown(2). +func Fchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + owner := args[1].Int() + group := args[2].Int() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + var opts vfs.SetStatOptions + if err := populateSetStatOptionsForChown(t, owner, group, &opts); err != nil { + return 0, nil, err + } + return 0, nil, file.SetStat(t, opts) +} + +// Truncate implements Linux syscall truncate(2). +func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + addr := args[0].Pointer() + length := args[1].Int64() + + if length < 0 { + return 0, nil, syserror.EINVAL + } + + path, err := copyInPath(t, addr) + if err != nil { + return 0, nil, err + } + + return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &vfs.SetStatOptions{ + Stat: linux.Statx{ + Mask: linux.STATX_SIZE, + Size: uint64(length), + }, + }) +} + +// Ftruncate implements Linux syscall ftruncate(2). +func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + length := args[1].Int64() + + if length < 0 { + return 0, nil, syserror.EINVAL + } + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + return 0, nil, file.SetStat(t, vfs.SetStatOptions{ + Stat: linux.Statx{ + Mask: linux.STATX_SIZE, + Size: uint64(length), + }, + }) +} + +// Utime implements Linux syscall utime(2). +func Utime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + timesAddr := args[1].Pointer() + + path, err := copyInPath(t, pathAddr) + if err != nil { + return 0, nil, err + } + + opts := vfs.SetStatOptions{ + Stat: linux.Statx{ + Mask: linux.STATX_ATIME | linux.STATX_MTIME, + }, + } + if timesAddr == 0 { + opts.Stat.Atime.Nsec = linux.UTIME_NOW + opts.Stat.Mtime.Nsec = linux.UTIME_NOW + } else { + var times linux.Utime + if err := times.CopyIn(t, timesAddr); err != nil { + return 0, nil, err + } + opts.Stat.Atime.Sec = times.Actime + opts.Stat.Mtime.Sec = times.Modtime + } + + return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &opts) +} + +// Utimes implements Linux syscall utimes(2). +func Utimes(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + timesAddr := args[1].Pointer() + + path, err := copyInPath(t, pathAddr) + if err != nil { + return 0, nil, err + } + + opts := vfs.SetStatOptions{ + Stat: linux.Statx{ + Mask: linux.STATX_ATIME | linux.STATX_MTIME, + }, + } + if timesAddr == 0 { + opts.Stat.Atime.Nsec = linux.UTIME_NOW + opts.Stat.Mtime.Nsec = linux.UTIME_NOW + } else { + var times [2]linux.Timeval + if _, err := t.CopyIn(timesAddr, ×); err != nil { + return 0, nil, err + } + opts.Stat.Atime = linux.StatxTimestamp{ + Sec: times[0].Sec, + Nsec: uint32(times[0].Usec * 1000), + } + opts.Stat.Mtime = linux.StatxTimestamp{ + Sec: times[1].Sec, + Nsec: uint32(times[1].Usec * 1000), + } + } + + return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &opts) +} + +// Utimensat implements Linux syscall utimensat(2). +func Utimensat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirfd := args[0].Int() + pathAddr := args[1].Pointer() + timesAddr := args[2].Pointer() + flags := args[3].Int() + + if flags&^linux.AT_SYMLINK_NOFOLLOW != 0 { + return 0, nil, syserror.EINVAL + } + + path, err := copyInPath(t, pathAddr) + if err != nil { + return 0, nil, err + } + + var opts vfs.SetStatOptions + if err := populateSetStatOptionsForUtimens(t, timesAddr, &opts); err != nil { + return 0, nil, err + } + + return 0, nil, setstatat(t, dirfd, path, disallowEmptyPath, followFinalSymlink, &opts) +} + +// Futimens implements Linux syscall futimens(2). +func Futimens(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + timesAddr := args[1].Pointer() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + var opts vfs.SetStatOptions + if err := populateSetStatOptionsForUtimens(t, timesAddr, &opts); err != nil { + return 0, nil, err + } + + return 0, nil, file.SetStat(t, opts) +} + +func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr usermem.Addr, opts *vfs.SetStatOptions) error { + if timesAddr == 0 { + opts.Stat.Mask = linux.STATX_ATIME | linux.STATX_MTIME + opts.Stat.Atime.Nsec = linux.UTIME_NOW + opts.Stat.Mtime.Nsec = linux.UTIME_NOW + return nil + } + var times [2]linux.Timespec + if _, err := t.CopyIn(timesAddr, ×); err != nil { + return err + } + if times[0].Nsec != linux.UTIME_OMIT { + opts.Stat.Mask |= linux.STATX_ATIME + opts.Stat.Atime = linux.StatxTimestamp{ + Sec: times[0].Sec, + Nsec: uint32(times[0].Nsec), + } + } + if times[1].Nsec != linux.UTIME_OMIT { + opts.Stat.Mask |= linux.STATX_MTIME + opts.Stat.Mtime = linux.StatxTimestamp{ + Sec: times[1].Sec, + Nsec: uint32(times[1].Nsec), + } + } + return nil +} + +func setstatat(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPath shouldAllowEmptyPath, shouldFollowFinalSymlink shouldFollowFinalSymlink, opts *vfs.SetStatOptions) error { + root := t.FSContext().RootDirectoryVFS2() + defer root.DecRef() + start := root + if !path.Absolute { + if !path.HasComponents() && !bool(shouldAllowEmptyPath) { + return syserror.ENOENT + } + if dirfd == linux.AT_FDCWD { + start = t.FSContext().WorkingDirectoryVFS2() + defer start.DecRef() + } else { + dirfile := t.GetFileVFS2(dirfd) + if dirfile == nil { + return syserror.EBADF + } + if !path.HasComponents() { + // Use FileDescription.SetStat() instead of + // VirtualFilesystem.SetStatAt(), since the former may be able + // to use opened file state to expedite the SetStat. + err := dirfile.SetStat(t, *opts) + dirfile.DecRef() + return err + } + start = dirfile.VirtualDentry() + start.IncRef() + defer start.DecRef() + dirfile.DecRef() + } + } + return t.Kernel().VFS().SetStatAt(t, t.Credentials(), &vfs.PathOperation{ + Root: root, + Start: start, + Path: path, + FollowFinalSymlink: bool(shouldFollowFinalSymlink), + }, opts) +} diff --git a/pkg/sentry/syscalls/linux/vfs2/stat.go b/pkg/sentry/syscalls/linux/vfs2/stat.go new file mode 100644 index 000000000..dca8d7011 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/stat.go @@ -0,0 +1,346 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/gohacks" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// Stat implements Linux syscall stat(2). +func Stat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + statAddr := args[1].Pointer() + return 0, nil, fstatat(t, linux.AT_FDCWD, pathAddr, statAddr, 0 /* flags */) +} + +// Lstat implements Linux syscall lstat(2). +func Lstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + statAddr := args[1].Pointer() + return 0, nil, fstatat(t, linux.AT_FDCWD, pathAddr, statAddr, linux.AT_SYMLINK_NOFOLLOW) +} + +// Newfstatat implements Linux syscall newfstatat, which backs fstatat(2). +func Newfstatat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirfd := args[0].Int() + pathAddr := args[1].Pointer() + statAddr := args[2].Pointer() + flags := args[3].Int() + return 0, nil, fstatat(t, dirfd, pathAddr, statAddr, flags) +} + +func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags int32) error { + if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 { + return syserror.EINVAL + } + + opts := vfs.StatOptions{ + Mask: linux.STATX_BASIC_STATS, + } + + path, err := copyInPath(t, pathAddr) + if err != nil { + return err + } + + root := t.FSContext().RootDirectoryVFS2() + defer root.DecRef() + start := root + if !path.Absolute { + if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 { + return syserror.ENOENT + } + if dirfd == linux.AT_FDCWD { + start = t.FSContext().WorkingDirectoryVFS2() + defer start.DecRef() + } else { + dirfile := t.GetFileVFS2(dirfd) + if dirfile == nil { + return syserror.EBADF + } + if !path.HasComponents() { + // Use FileDescription.Stat() instead of + // VirtualFilesystem.StatAt() for fstatat(fd, ""), since the + // former may be able to use opened file state to expedite the + // Stat. + statx, err := dirfile.Stat(t, opts) + dirfile.DecRef() + if err != nil { + return err + } + var stat linux.Stat + convertStatxToUserStat(t, &statx, &stat) + return stat.CopyOut(t, statAddr) + } + start = dirfile.VirtualDentry() + start.IncRef() + defer start.DecRef() + dirfile.DecRef() + } + } + + statx, err := t.Kernel().VFS().StatAt(t, t.Credentials(), &vfs.PathOperation{ + Root: root, + Start: start, + Path: path, + FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0, + }, &opts) + if err != nil { + return err + } + var stat linux.Stat + convertStatxToUserStat(t, &statx, &stat) + return stat.CopyOut(t, statAddr) +} + +// This takes both input and output as pointer arguments to avoid copying large +// structs. +func convertStatxToUserStat(t *kernel.Task, statx *linux.Statx, stat *linux.Stat) { + // Linux just copies fields from struct kstat without regard to struct + // kstat::result_mask (fs/stat.c:cp_new_stat()), so we do too. + userns := t.UserNamespace() + *stat = linux.Stat{ + Dev: uint64(linux.MakeDeviceID(uint16(statx.DevMajor), statx.DevMinor)), + Ino: statx.Ino, + Nlink: uint64(statx.Nlink), + Mode: uint32(statx.Mode), + UID: uint32(auth.KUID(statx.UID).In(userns).OrOverflow()), + GID: uint32(auth.KGID(statx.GID).In(userns).OrOverflow()), + Rdev: uint64(linux.MakeDeviceID(uint16(statx.RdevMajor), statx.RdevMinor)), + Size: int64(statx.Size), + Blksize: int64(statx.Blksize), + Blocks: int64(statx.Blocks), + ATime: timespecFromStatxTimestamp(statx.Atime), + MTime: timespecFromStatxTimestamp(statx.Mtime), + CTime: timespecFromStatxTimestamp(statx.Ctime), + } +} + +func timespecFromStatxTimestamp(sxts linux.StatxTimestamp) linux.Timespec { + return linux.Timespec{ + Sec: sxts.Sec, + Nsec: int64(sxts.Nsec), + } +} + +// Fstat implements Linux syscall fstat(2). +func Fstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + statAddr := args[1].Pointer() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + statx, err := file.Stat(t, vfs.StatOptions{ + Mask: linux.STATX_BASIC_STATS, + }) + if err != nil { + return 0, nil, err + } + var stat linux.Stat + convertStatxToUserStat(t, &statx, &stat) + return 0, nil, stat.CopyOut(t, statAddr) +} + +// Statx implements Linux syscall statx(2). +func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirfd := args[0].Int() + pathAddr := args[1].Pointer() + flags := args[2].Int() + mask := args[3].Uint() + statxAddr := args[4].Pointer() + + if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 { + return 0, nil, syserror.EINVAL + } + + opts := vfs.StatOptions{ + Mask: mask, + Sync: uint32(flags & linux.AT_STATX_SYNC_TYPE), + } + + path, err := copyInPath(t, pathAddr) + if err != nil { + return 0, nil, err + } + + root := t.FSContext().RootDirectoryVFS2() + defer root.DecRef() + start := root + if !path.Absolute { + if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 { + return 0, nil, syserror.ENOENT + } + if dirfd == linux.AT_FDCWD { + start = t.FSContext().WorkingDirectoryVFS2() + defer start.DecRef() + } else { + dirfile := t.GetFileVFS2(dirfd) + if dirfile == nil { + return 0, nil, syserror.EBADF + } + if !path.HasComponents() { + // Use FileDescription.Stat() instead of + // VirtualFilesystem.StatAt() for statx(fd, ""), since the + // former may be able to use opened file state to expedite the + // Stat. + statx, err := dirfile.Stat(t, opts) + dirfile.DecRef() + if err != nil { + return 0, nil, err + } + userifyStatx(t, &statx) + return 0, nil, statx.CopyOut(t, statxAddr) + } + start = dirfile.VirtualDentry() + start.IncRef() + defer start.DecRef() + dirfile.DecRef() + } + } + + statx, err := t.Kernel().VFS().StatAt(t, t.Credentials(), &vfs.PathOperation{ + Root: root, + Start: start, + Path: path, + FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0, + }, &opts) + if err != nil { + return 0, nil, err + } + userifyStatx(t, &statx) + return 0, nil, statx.CopyOut(t, statxAddr) +} + +func userifyStatx(t *kernel.Task, statx *linux.Statx) { + userns := t.UserNamespace() + statx.UID = uint32(auth.KUID(statx.UID).In(userns).OrOverflow()) + statx.GID = uint32(auth.KGID(statx.GID).In(userns).OrOverflow()) +} + +// Readlink implements Linux syscall readlink(2). +func Readlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + bufAddr := args[1].Pointer() + size := args[2].SizeT() + return readlinkat(t, linux.AT_FDCWD, pathAddr, bufAddr, size) +} + +// Access implements Linux syscall access(2). +func Access(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + // FIXME(jamieliu): actually implement + return 0, nil, nil +} + +// Faccessat implements Linux syscall access(2). +func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + // FIXME(jamieliu): actually implement + return 0, nil, nil +} + +// Readlinkat implements Linux syscall mknodat(2). +func Readlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirfd := args[0].Int() + pathAddr := args[1].Pointer() + bufAddr := args[2].Pointer() + size := args[3].SizeT() + return readlinkat(t, dirfd, pathAddr, bufAddr, size) +} + +func readlinkat(t *kernel.Task, dirfd int32, pathAddr, bufAddr usermem.Addr, size uint) (uintptr, *kernel.SyscallControl, error) { + if int(size) <= 0 { + return 0, nil, syserror.EINVAL + } + + path, err := copyInPath(t, pathAddr) + if err != nil { + return 0, nil, err + } + // "Since Linux 2.6.39, pathname can be an empty string, in which case the + // call operates on the symbolic link referred to by dirfd ..." - + // readlinkat(2) + tpop, err := getTaskPathOperation(t, dirfd, path, allowEmptyPath, nofollowFinalSymlink) + if err != nil { + return 0, nil, err + } + defer tpop.Release() + + target, err := t.Kernel().VFS().ReadlinkAt(t, t.Credentials(), &tpop.pop) + if err != nil { + return 0, nil, err + } + + if len(target) > int(size) { + target = target[:size] + } + n, err := t.CopyOutBytes(bufAddr, gohacks.ImmutableBytesFromString(target)) + if n == 0 { + return 0, nil, err + } + return uintptr(n), nil, nil +} + +// Statfs implements Linux syscall statfs(2). +func Statfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + bufAddr := args[1].Pointer() + + path, err := copyInPath(t, pathAddr) + if err != nil { + return 0, nil, err + } + tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink) + if err != nil { + return 0, nil, err + } + defer tpop.Release() + + statfs, err := t.Kernel().VFS().StatFSAt(t, t.Credentials(), &tpop.pop) + if err != nil { + return 0, nil, err + } + + return 0, nil, statfs.CopyOut(t, bufAddr) +} + +// Fstatfs implements Linux syscall fstatfs(2). +func Fstatfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + bufAddr := args[1].Pointer() + + tpop, err := getTaskPathOperation(t, fd, fspath.Path{}, allowEmptyPath, nofollowFinalSymlink) + if err != nil { + return 0, nil, err + } + defer tpop.Release() + + statfs, err := t.Kernel().VFS().StatFSAt(t, t.Credentials(), &tpop.pop) + if err != nil { + return 0, nil, err + } + + return 0, nil, statfs.CopyOut(t, bufAddr) +} diff --git a/pkg/sentry/syscalls/linux/vfs2/sync.go b/pkg/sentry/syscalls/linux/vfs2/sync.go new file mode 100644 index 000000000..365250b0b --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/sync.go @@ -0,0 +1,87 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/syserror" +) + +// Sync implements Linux syscall sync(2). +func Sync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + return 0, nil, t.Kernel().VFS().SyncAllFilesystems(t) +} + +// Syncfs implements Linux syscall syncfs(2). +func Syncfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + return 0, nil, file.SyncFS(t) +} + +// Fsync implements Linux syscall fsync(2). +func Fsync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + return 0, nil, file.Sync(t) +} + +// Fdatasync implements Linux syscall fdatasync(2). +func Fdatasync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + // TODO(gvisor.dev/issue/1897): Avoid writeback of unnecessary metadata. + return Fsync(t, args) +} + +// SyncFileRange implements Linux syscall sync_file_range(2). +func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + offset := args[1].Int64() + nbytes := args[2].Int64() + flags := args[3].Uint() + + if offset < 0 { + return 0, nil, syserror.EINVAL + } + if nbytes < 0 { + return 0, nil, syserror.EINVAL + } + if flags&^(linux.SYNC_FILE_RANGE_WAIT_BEFORE|linux.SYNC_FILE_RANGE_WRITE|linux.SYNC_FILE_RANGE_WAIT_AFTER) != 0 { + return 0, nil, syserror.EINVAL + } + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // TODO(gvisor.dev/issue/1897): Avoid writeback of data ranges outside of + // [offset, offset+nbytes). + return 0, nil, file.Sync(t) +} diff --git a/pkg/sentry/syscalls/linux/vfs2/sys_read.go b/pkg/sentry/syscalls/linux/vfs2/sys_read.go deleted file mode 100644 index 7667524c7..000000000 --- a/pkg/sentry/syscalls/linux/vfs2/sys_read.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package vfs2 - -import ( - "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - // EventMaskRead contains events that can be triggered on reads. - EventMaskRead = waiter.EventIn | waiter.EventHUp | waiter.EventErr -) - -// Read implements linux syscall read(2). Note that we try to get a buffer that -// is exactly the size requested because some applications like qemu expect -// they can do large reads all at once. Bug for bug. Same for other read -// calls below. -func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - fd := args[0].Int() - addr := args[1].Pointer() - size := args[2].SizeT() - - file := t.GetFileVFS2(fd) - if file == nil { - return 0, nil, syserror.EBADF - } - defer file.DecRef() - - // Check that the size is legitimate. - si := int(size) - if si < 0 { - return 0, nil, syserror.EINVAL - } - - // Get the destination of the read. - dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{ - AddressSpaceActive: true, - }) - if err != nil { - return 0, nil, err - } - - n, err := read(t, file, dst, vfs.ReadOptions{}) - t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, linux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "read", file) -} - -func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { - n, err := file.Read(t, dst, opts) - if err != syserror.ErrWouldBlock { - return n, err - } - - // Register for notifications. - w, ch := waiter.NewChannelEntry(nil) - file.EventRegister(&w, EventMaskRead) - - total := n - for { - // Shorten dst to reflect bytes previously read. - dst = dst.DropFirst(int(n)) - - // Issue the request and break out if it completes with anything other than - // "would block". - n, err := file.Read(t, dst, opts) - total += n - if err != syserror.ErrWouldBlock { - break - } - if err := t.Block(ch); err != nil { - break - } - } - file.EventUnregister(&w) - - return total, err -} diff --git a/pkg/sentry/syscalls/linux/vfs2/xattr.go b/pkg/sentry/syscalls/linux/vfs2/xattr.go new file mode 100644 index 000000000..89e9ff4d7 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/xattr.go @@ -0,0 +1,353 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "bytes" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/gohacks" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// Listxattr implements Linux syscall listxattr(2). +func Listxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + return listxattr(t, args, followFinalSymlink) +} + +// Llistxattr implements Linux syscall llistxattr(2). +func Llistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + return listxattr(t, args, nofollowFinalSymlink) +} + +func listxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + listAddr := args[1].Pointer() + size := args[2].SizeT() + + path, err := copyInPath(t, pathAddr) + if err != nil { + return 0, nil, err + } + tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink) + if err != nil { + return 0, nil, err + } + defer tpop.Release() + + names, err := t.Kernel().VFS().ListxattrAt(t, t.Credentials(), &tpop.pop) + if err != nil { + return 0, nil, err + } + n, err := copyOutXattrNameList(t, listAddr, size, names) + if err != nil { + return 0, nil, err + } + return uintptr(n), nil, nil +} + +// Flistxattr implements Linux syscall flistxattr(2). +func Flistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + listAddr := args[1].Pointer() + size := args[2].SizeT() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + names, err := file.Listxattr(t) + if err != nil { + return 0, nil, err + } + n, err := copyOutXattrNameList(t, listAddr, size, names) + if err != nil { + return 0, nil, err + } + return uintptr(n), nil, nil +} + +// Getxattr implements Linux syscall getxattr(2). +func Getxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + return getxattr(t, args, followFinalSymlink) +} + +// Lgetxattr implements Linux syscall lgetxattr(2). +func Lgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + return getxattr(t, args, nofollowFinalSymlink) +} + +func getxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + nameAddr := args[1].Pointer() + valueAddr := args[2].Pointer() + size := args[3].SizeT() + + path, err := copyInPath(t, pathAddr) + if err != nil { + return 0, nil, err + } + tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink) + if err != nil { + return 0, nil, err + } + defer tpop.Release() + + name, err := copyInXattrName(t, nameAddr) + if err != nil { + return 0, nil, err + } + + value, err := t.Kernel().VFS().GetxattrAt(t, t.Credentials(), &tpop.pop, name) + if err != nil { + return 0, nil, err + } + n, err := copyOutXattrValue(t, valueAddr, size, value) + if err != nil { + return 0, nil, err + } + return uintptr(n), nil, nil +} + +// Fgetxattr implements Linux syscall fgetxattr(2). +func Fgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + nameAddr := args[1].Pointer() + valueAddr := args[2].Pointer() + size := args[3].SizeT() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + name, err := copyInXattrName(t, nameAddr) + if err != nil { + return 0, nil, err + } + + value, err := file.Getxattr(t, name) + if err != nil { + return 0, nil, err + } + n, err := copyOutXattrValue(t, valueAddr, size, value) + if err != nil { + return 0, nil, err + } + return uintptr(n), nil, nil +} + +// Setxattr implements Linux syscall setxattr(2). +func Setxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + return 0, nil, setxattr(t, args, followFinalSymlink) +} + +// Lsetxattr implements Linux syscall lsetxattr(2). +func Lsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + return 0, nil, setxattr(t, args, nofollowFinalSymlink) +} + +func setxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) error { + pathAddr := args[0].Pointer() + nameAddr := args[1].Pointer() + valueAddr := args[2].Pointer() + size := args[3].SizeT() + flags := args[4].Int() + + if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 { + return syserror.EINVAL + } + + path, err := copyInPath(t, pathAddr) + if err != nil { + return err + } + tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink) + if err != nil { + return err + } + defer tpop.Release() + + name, err := copyInXattrName(t, nameAddr) + if err != nil { + return err + } + value, err := copyInXattrValue(t, valueAddr, size) + if err != nil { + return err + } + + return t.Kernel().VFS().SetxattrAt(t, t.Credentials(), &tpop.pop, &vfs.SetxattrOptions{ + Name: name, + Value: value, + Flags: uint32(flags), + }) +} + +// Fsetxattr implements Linux syscall fsetxattr(2). +func Fsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + nameAddr := args[1].Pointer() + valueAddr := args[2].Pointer() + size := args[3].SizeT() + flags := args[4].Int() + + if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 { + return 0, nil, syserror.EINVAL + } + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + name, err := copyInXattrName(t, nameAddr) + if err != nil { + return 0, nil, err + } + value, err := copyInXattrValue(t, valueAddr, size) + if err != nil { + return 0, nil, err + } + + return 0, nil, file.Setxattr(t, vfs.SetxattrOptions{ + Name: name, + Value: value, + Flags: uint32(flags), + }) +} + +// Removexattr implements Linux syscall removexattr(2). +func Removexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + return 0, nil, removexattr(t, args, followFinalSymlink) +} + +// Lremovexattr implements Linux syscall lremovexattr(2). +func Lremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + return 0, nil, removexattr(t, args, nofollowFinalSymlink) +} + +func removexattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) error { + pathAddr := args[0].Pointer() + nameAddr := args[1].Pointer() + + path, err := copyInPath(t, pathAddr) + if err != nil { + return err + } + tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink) + if err != nil { + return err + } + defer tpop.Release() + + name, err := copyInXattrName(t, nameAddr) + if err != nil { + return err + } + + return t.Kernel().VFS().RemovexattrAt(t, t.Credentials(), &tpop.pop, name) +} + +// Fremovexattr implements Linux syscall fremovexattr(2). +func Fremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + nameAddr := args[1].Pointer() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + name, err := copyInXattrName(t, nameAddr) + if err != nil { + return 0, nil, err + } + + return 0, nil, file.Removexattr(t, name) +} + +func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) { + name, err := t.CopyInString(nameAddr, linux.XATTR_NAME_MAX+1) + if err != nil { + if err == syserror.ENAMETOOLONG { + return "", syserror.ERANGE + } + return "", err + } + if len(name) == 0 { + return "", syserror.ERANGE + } + return name, nil +} + +func copyOutXattrNameList(t *kernel.Task, listAddr usermem.Addr, size uint, names []string) (int, error) { + if size > linux.XATTR_LIST_MAX { + size = linux.XATTR_LIST_MAX + } + var buf bytes.Buffer + for _, name := range names { + buf.WriteString(name) + buf.WriteByte(0) + } + if size == 0 { + // Return the size that would be required to accomodate the list. + return buf.Len(), nil + } + if buf.Len() > int(size) { + if size >= linux.XATTR_LIST_MAX { + return 0, syserror.E2BIG + } + return 0, syserror.ERANGE + } + return t.CopyOutBytes(listAddr, buf.Bytes()) +} + +func copyInXattrValue(t *kernel.Task, valueAddr usermem.Addr, size uint) (string, error) { + if size > linux.XATTR_SIZE_MAX { + return "", syserror.E2BIG + } + buf := make([]byte, size) + if _, err := t.CopyInBytes(valueAddr, buf); err != nil { + return "", err + } + return gohacks.StringFromImmutableBytes(buf), nil +} + +func copyOutXattrValue(t *kernel.Task, valueAddr usermem.Addr, size uint, value string) (int, error) { + if size > linux.XATTR_SIZE_MAX { + size = linux.XATTR_SIZE_MAX + } + if size == 0 { + // Return the size that would be required to accomodate the value. + return len(value), nil + } + if len(value) > int(size) { + if size >= linux.XATTR_SIZE_MAX { + return 0, syserror.E2BIG + } + return 0, syserror.ERANGE + } + return t.CopyOutBytes(valueAddr, gohacks.ImmutableBytesFromString(value)) +} diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 0b4f18ab5..07c8383e6 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -43,6 +43,7 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/fspath", + "//pkg/gohacks", "//pkg/log", "//pkg/sentry/arch", "//pkg/sentry/fs/lock", diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go index eed41139b..3da45d744 100644 --- a/pkg/sentry/vfs/epoll.go +++ b/pkg/sentry/vfs/epoll.go @@ -202,6 +202,9 @@ func (ep *EpollInstance) AddInterest(file *FileDescription, num int32, event lin // Add epi to file.epolls so that it is removed when the last // FileDescription reference is dropped. file.epollMu.Lock() + if file.epolls == nil { + file.epolls = make(map[*epollInterest]struct{}) + } file.epolls[epi] = struct{}{} file.epollMu.Unlock() diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 9912df799..31a4e5480 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -139,6 +139,23 @@ func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth return mntns, nil } +// NewDisconnectedMount returns a Mount representing fs with the given root +// (which may be nil). The new Mount is not associated with any MountNamespace +// and is not connected to any other Mounts. References are taken on fs and +// root. +func (vfs *VirtualFilesystem) NewDisconnectedMount(fs *Filesystem, root *Dentry, opts *MountOptions) (*Mount, error) { + fs.IncRef() + if root != nil { + root.IncRef() + } + return &Mount{ + vfs: vfs, + fs: fs, + root: root, + refs: 1, + }, nil +} + // MountAt creates and mounts a Filesystem configured by the given arguments. func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *MountOptions) error { rft := vfs.getFilesystemType(fsTypeName) diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index 1fe766a44..bc7581698 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -26,6 +26,7 @@ import ( "sync/atomic" "unsafe" + "gvisor.dev/gvisor/pkg/gohacks" "gvisor.dev/gvisor/pkg/sync" ) @@ -160,7 +161,7 @@ func newMountTableSlots(cap uintptr) unsafe.Pointer { // Lookup may be called even if there are concurrent mutators of mt. func (mt *mountTable) Lookup(parent *Mount, point *Dentry) *Mount { key := mountKey{parent: unsafe.Pointer(parent), point: unsafe.Pointer(point)} - hash := memhash(noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes) + hash := memhash(gohacks.Noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes) loop: for { @@ -361,12 +362,3 @@ func memhash(p unsafe.Pointer, seed, s uintptr) uintptr //go:linkname rand32 runtime.fastrand func rand32() uint32 - -// This is copy/pasted from runtime.noescape(), and is needed because arguments -// apparently escape from all functions defined by linkname. -// -//go:nosplit -func noescape(p unsafe.Pointer) unsafe.Pointer { - x := uintptr(p) - return unsafe.Pointer(x ^ 0) -} diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go index 8a0b382f6..eb4ebb511 100644 --- a/pkg/sentry/vfs/resolving_path.go +++ b/pkg/sentry/vfs/resolving_path.go @@ -228,7 +228,7 @@ func (rp *ResolvingPath) Advance() { rp.pit = next } else { // at end of path segment, continue with next one rp.curPart-- - rp.pit = rp.parts[rp.curPart-1] + rp.pit = rp.parts[rp.curPart] } } diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 8f29031b2..bde81e1ef 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -126,17 +126,23 @@ func (vfs *VirtualFilesystem) Init() error { // Construct vfs.anonMount. anonfsDevMinor, err := vfs.GetAnonBlockDevMinor() if err != nil { - return err + // This shouldn't be possible since anonBlockDevMinorNext was + // initialized to 1 above (no device numbers have been allocated yet). + panic(fmt.Sprintf("VirtualFilesystem.Init: device number allocation for anonfs failed: %v", err)) } anonfs := anonFilesystem{ devMinor: anonfsDevMinor, } anonfs.vfsfs.Init(vfs, &anonfs) - vfs.anonMount = &Mount{ - vfs: vfs, - fs: &anonfs.vfsfs, - refs: 1, + defer anonfs.vfsfs.DecRef() + anonMount, err := vfs.NewDisconnectedMount(&anonfs.vfsfs, nil, &MountOptions{}) + if err != nil { + // We should not be passing any MountOptions that would cause + // construction of this mount to fail. + panic(fmt.Sprintf("VirtualFilesystem.Init: anonfs mount failed: %v", err)) } + vfs.anonMount = anonMount + return nil } @@ -385,15 +391,11 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential // Only a regular file can be executed. stat, err := fd.Stat(ctx, StatOptions{Mask: linux.STATX_TYPE}) if err != nil { + fd.DecRef() return nil, err } - if stat.Mask&linux.STATX_TYPE != 0 { - // This shouldn't happen, but if type can't be retrieved, file can't - // be executed. - return nil, syserror.EACCES - } - if t := linux.FileMode(stat.Mode).FileType(); t != linux.ModeRegular { - ctx.Infof("%q is not a regular file: %v", pop.Path, t) + if stat.Mask&linux.STATX_TYPE == 0 || stat.Mode&linux.S_IFMT != linux.S_IFREG { + fd.DecRef() return nil, syserror.EACCES } } diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go index 2269f6237..4b5a0fca6 100644 --- a/pkg/syserror/syserror.go +++ b/pkg/syserror/syserror.go @@ -29,6 +29,7 @@ var ( EACCES = error(syscall.EACCES) EAGAIN = error(syscall.EAGAIN) EBADF = error(syscall.EBADF) + EBADFD = error(syscall.EBADFD) EBUSY = error(syscall.EBUSY) ECHILD = error(syscall.ECHILD) ECONNREFUSED = error(syscall.ECONNREFUSED) diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index ea0a0409a..3c552988a 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -127,6 +127,10 @@ func TestCloseReader(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} @@ -175,6 +179,10 @@ func TestCloseReaderWithForwarder(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) @@ -225,30 +233,21 @@ func TestCloseRead(t *testing.T) { if terr != nil { t.Fatalf("newLoopbackStack() = %v", terr) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) + _, err := r.CreateEndpoint(&wq) if err != nil { t.Fatalf("r.CreateEndpoint() = %v", err) } - defer ep.Close() - r.Complete(false) - - c := NewTCPConn(&wq, ep) - - buf := make([]byte, 256) - n, e := c.Read(buf) - if e != nil || string(buf[:n]) != "abc123" { - t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, e) - } - - if n, e = c.Write([]byte("abc123")); e != nil { - t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e) - } + // Endpoint will be closed in deferred s.Close (above). }) s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) @@ -278,6 +277,10 @@ func TestCloseWrite(t *testing.T) { if terr != nil { t.Fatalf("newLoopbackStack() = %v", terr) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) @@ -334,6 +337,10 @@ func TestUDPForwarder(t *testing.T) { if terr != nil { t.Fatalf("newLoopbackStack() = %v", terr) } + defer func() { + s.Close() + s.Wait() + }() ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr1 := tcpip.FullAddress{NICID, ip1, 11211} @@ -391,6 +398,10 @@ func TestDeadlineChange(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} @@ -440,6 +451,10 @@ func TestPacketConnTransfer(t *testing.T) { if e != nil { t.Fatalf("newLoopbackStack() = %v", e) } + defer func() { + s.Close() + s.Wait() + }() ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr1 := tcpip.FullAddress{NICID, ip1, 11211} @@ -492,6 +507,10 @@ func TestConnectedPacketConnTransfer(t *testing.T) { if e != nil { t.Fatalf("newLoopbackStack() = %v", e) } + defer func() { + s.Close() + s.Wait() + }() ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} @@ -562,6 +581,8 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) { stop = func() { c1.Close() c2.Close() + s.Close() + s.Wait() } if err := l.Close(); err != nil { @@ -624,6 +645,10 @@ func TestTCPDialError(t *testing.T) { if e != nil { t.Fatalf("newLoopbackStack() = %v", e) } + defer func() { + s.Close() + s.Wait() + }() ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} @@ -641,6 +666,10 @@ func TestDialContextTCPCanceled(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) @@ -659,6 +688,10 @@ func TestDialContextTCPTimeout(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 150310c11..17e94c562 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -156,3 +156,9 @@ func (vv *VectorisedView) Append(vv2 VectorisedView) { vv.views = append(vv.views, vv2.views...) vv.size += vv2.size } + +// AppendView appends the given view into this vectorised view. +func (vv *VectorisedView) AppendView(v View) { + vv.views = append(vv.views, v) + vv.size += len(v) +} diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 70e6ce095..76e88e9b3 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -115,6 +115,19 @@ const ( // for the secret key used to generate an opaque interface identifier as // outlined by RFC 7217. OpaqueIIDSecretKeyMinBytes = 16 + + // ipv6MulticastAddressScopeByteIdx is the byte where the scope (scop) field + // is located within a multicast IPv6 address, as per RFC 4291 section 2.7. + ipv6MulticastAddressScopeByteIdx = 1 + + // ipv6MulticastAddressScopeMask is the mask for the scope (scop) field, + // within the byte holding the field, as per RFC 4291 section 2.7. + ipv6MulticastAddressScopeMask = 0xF + + // ipv6LinkLocalMulticastScope is the value of the scope (scop) field within + // a multicast IPv6 address that indicates the address has link-local scope, + // as per RFC 4291 section 2.7. + ipv6LinkLocalMulticastScope = 2 ) // IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the @@ -340,6 +353,12 @@ func IsV6LinkLocalAddress(addr tcpip.Address) bool { return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80 } +// IsV6LinkLocalMulticastAddress determines if the provided address is an IPv6 +// link-local multicast address. +func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool { + return IsV6MulticastAddress(addr) && addr[ipv6MulticastAddressScopeByteIdx]&ipv6MulticastAddressScopeMask == ipv6LinkLocalMulticastScope +} + // IsV6UniqueLocalAddress determines if the provided address is an IPv6 // unique-local address (within the prefix FC00::/7). func IsV6UniqueLocalAddress(addr tcpip.Address) bool { @@ -411,6 +430,9 @@ func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) { } switch { + case IsV6LinkLocalMulticastAddress(addr): + return LinkLocalScope, nil + case IsV6LinkLocalAddress(addr): return LinkLocalScope, nil diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go index c3ad503aa..426a873b1 100644 --- a/pkg/tcpip/header/ipv6_test.go +++ b/pkg/tcpip/header/ipv6_test.go @@ -27,11 +27,12 @@ import ( ) const ( - linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - linkLocalAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - globalAddr = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + linkLocalAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + globalAddr = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") ) func TestEthernetAdddressToModifiedEUI64(t *testing.T) { @@ -256,6 +257,85 @@ func TestIsV6UniqueLocalAddress(t *testing.T) { } } +func TestIsV6LinkLocalMulticastAddress(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + expected bool + }{ + { + name: "Valid Link Local Multicast", + addr: linkLocalMulticastAddr, + expected: true, + }, + { + name: "Valid Link Local Multicast with flags", + addr: "\xff\xf2\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + expected: true, + }, + { + name: "Link Local Unicast", + addr: linkLocalAddr, + expected: false, + }, + { + name: "IPv4 Multicast", + addr: "\xe0\x00\x00\x01", + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := header.IsV6LinkLocalMulticastAddress(test.addr); got != test.expected { + t.Errorf("got header.IsV6LinkLocalMulticastAddress(%s) = %t, want = %t", test.addr, got, test.expected) + } + }) + } +} + +func TestIsV6LinkLocalAddress(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + expected bool + }{ + { + name: "Valid Link Local Unicast", + addr: linkLocalAddr, + expected: true, + }, + { + name: "Link Local Multicast", + addr: linkLocalMulticastAddr, + expected: false, + }, + { + name: "Unique Local", + addr: uniqueLocalAddr1, + expected: false, + }, + { + name: "Global", + addr: globalAddr, + expected: false, + }, + { + name: "IPv4 Link Local", + addr: "\xa9\xfe\x00\x01", + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := header.IsV6LinkLocalAddress(test.addr); got != test.expected { + t.Errorf("got header.IsV6LinkLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected) + } + }) + } +} + func TestScopeForIPv6Address(t *testing.T) { tests := []struct { name string @@ -270,12 +350,18 @@ func TestScopeForIPv6Address(t *testing.T) { err: nil, }, { - name: "Link Local", + name: "Link Local Unicast", addr: linkLocalAddr, scope: header.LinkLocalScope, err: nil, }, { + name: "Link Local Multicast", + addr: linkLocalMulticastAddr, + scope: header.LinkLocalScope, + err: nil, + }, + { name: "Global", addr: globalAddr, scope: header.GlobalScope, diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD index 3974c464e..b8b93e78e 100644 --- a/pkg/tcpip/link/channel/BUILD +++ b/pkg/tcpip/link/channel/BUILD @@ -7,6 +7,7 @@ go_library( srcs = ["channel.go"], visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 78d447acd..5944ba190 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -20,6 +20,7 @@ package channel import ( "context" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -33,6 +34,118 @@ type PacketInfo struct { Route stack.Route } +// Notification is the interface for receiving notification from the packet +// queue. +type Notification interface { + // WriteNotify will be called when a write happens to the queue. + WriteNotify() +} + +// NotificationHandle is an opaque handle to the registered notification target. +// It can be used to unregister the notification when no longer interested. +// +// +stateify savable +type NotificationHandle struct { + n Notification +} + +type queue struct { + // mu protects fields below. + mu sync.RWMutex + // c is the outbound packet channel. Sending to c should hold mu. + c chan PacketInfo + numWrite int + numRead int + notify []*NotificationHandle +} + +func (q *queue) Close() { + close(q.c) +} + +func (q *queue) Read() (PacketInfo, bool) { + q.mu.Lock() + defer q.mu.Unlock() + select { + case p := <-q.c: + q.numRead++ + return p, true + default: + return PacketInfo{}, false + } +} + +func (q *queue) ReadContext(ctx context.Context) (PacketInfo, bool) { + // We have to receive from channel without holding the lock, since it can + // block indefinitely. This will cause a window that numWrite - numRead + // produces a larger number, but won't go to negative. numWrite >= numRead + // still holds. + select { + case pkt := <-q.c: + q.mu.Lock() + defer q.mu.Unlock() + q.numRead++ + return pkt, true + case <-ctx.Done(): + return PacketInfo{}, false + } +} + +func (q *queue) Write(p PacketInfo) bool { + wrote := false + + // It's important to make sure nobody can see numWrite until we increment it, + // so numWrite >= numRead holds. + q.mu.Lock() + select { + case q.c <- p: + wrote = true + q.numWrite++ + default: + } + notify := q.notify + q.mu.Unlock() + + if wrote { + // Send notification outside of lock. + for _, h := range notify { + h.n.WriteNotify() + } + } + return wrote +} + +func (q *queue) Num() int { + q.mu.RLock() + defer q.mu.RUnlock() + n := q.numWrite - q.numRead + if n < 0 { + panic("numWrite < numRead") + } + return n +} + +func (q *queue) AddNotify(notify Notification) *NotificationHandle { + q.mu.Lock() + defer q.mu.Unlock() + h := &NotificationHandle{n: notify} + q.notify = append(q.notify, h) + return h +} + +func (q *queue) RemoveNotify(handle *NotificationHandle) { + q.mu.Lock() + defer q.mu.Unlock() + // Make a copy, since we reads the array outside of lock when notifying. + notify := make([]*NotificationHandle, 0, len(q.notify)) + for _, h := range q.notify { + if h != handle { + notify = append(notify, h) + } + } + q.notify = notify +} + // Endpoint is link layer endpoint that stores outbound packets in a channel // and allows injection of inbound packets. type Endpoint struct { @@ -41,14 +154,16 @@ type Endpoint struct { linkAddr tcpip.LinkAddress LinkEPCapabilities stack.LinkEndpointCapabilities - // c is where outbound packets are queued. - c chan PacketInfo + // Outbound packet queue. + q *queue } // New creates a new channel endpoint. func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint { return &Endpoint{ - c: make(chan PacketInfo, size), + q: &queue{ + c: make(chan PacketInfo, size), + }, mtu: mtu, linkAddr: linkAddr, } @@ -57,43 +172,36 @@ func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint { // Close closes e. Further packet injections will panic. Reads continue to // succeed until all packets are read. func (e *Endpoint) Close() { - close(e.c) + e.q.Close() } -// Read does non-blocking read for one packet from the outbound packet queue. +// Read does non-blocking read one packet from the outbound packet queue. func (e *Endpoint) Read() (PacketInfo, bool) { - select { - case pkt := <-e.c: - return pkt, true - default: - return PacketInfo{}, false - } + return e.q.Read() } // ReadContext does blocking read for one packet from the outbound packet queue. // It can be cancelled by ctx, and in this case, it returns false. func (e *Endpoint) ReadContext(ctx context.Context) (PacketInfo, bool) { - select { - case pkt := <-e.c: - return pkt, true - case <-ctx.Done(): - return PacketInfo{}, false - } + return e.q.ReadContext(ctx) } // Drain removes all outbound packets from the channel and counts them. func (e *Endpoint) Drain() int { c := 0 for { - select { - case <-e.c: - c++ - default: + if _, ok := e.Read(); !ok { return c } + c++ } } +// NumQueued returns the number of packet queued for outbound. +func (e *Endpoint) NumQueued() int { + return e.q.Num() +} + // InjectInbound injects an inbound packet. func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { e.InjectLinkAddr(protocol, "", pkt) @@ -155,10 +263,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne Route: route, } - select { - case e.c <- p: - default: - } + e.q.Write(p) return nil } @@ -171,7 +276,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac route.Release() payloadView := pkts[0].Data.ToView() n := 0 -packetLoop: for _, pkt := range pkts { off := pkt.DataOffset size := pkt.DataSize @@ -185,12 +289,10 @@ packetLoop: Route: route, } - select { - case e.c <- p: - n++ - default: - break packetLoop + if !e.q.Write(p) { + break } + n++ } return n, nil @@ -204,13 +306,21 @@ func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { GSO: nil, } - select { - case e.c <- p: - default: - } + e.q.Write(p) return nil } // Wait implements stack.LinkEndpoint.Wait. func (*Endpoint) Wait() {} + +// AddNotify adds a notification target for receiving event about outgoing +// packets. +func (e *Endpoint) AddNotify(notify Notification) *NotificationHandle { + return e.q.AddNotify(notify) +} + +// RemoveNotify removes handle from the list of notification targets. +func (e *Endpoint) RemoveNotify(handle *NotificationHandle) { + e.q.RemoveNotify(handle) +} diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index e5096ea38..e0db6cf54 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -4,6 +4,22 @@ package(licenses = ["notice"]) go_library( name = "tun", - srcs = ["tun_unsafe.go"], + srcs = [ + "device.go", + "protocol.go", + "tun_unsafe.go", + ], visibility = ["//visibility:public"], + deps = [ + "//pkg/abi/linux", + "//pkg/refs", + "//pkg/sync", + "//pkg/syserror", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/stack", + "//pkg/waiter", + ], ) diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go new file mode 100644 index 000000000..6ff47a742 --- /dev/null +++ b/pkg/tcpip/link/tun/device.go @@ -0,0 +1,352 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tun + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + // drivers/net/tun.c:tun_net_init() + defaultDevMtu = 1500 + + // Queue length for outbound packet, arriving at fd side for read. Overflow + // causes packet drops. gVisor implementation-specific. + defaultDevOutQueueLen = 1024 +) + +var zeroMAC [6]byte + +// Device is an opened /dev/net/tun device. +// +// +stateify savable +type Device struct { + waiter.Queue + + mu sync.RWMutex `state:"nosave"` + endpoint *tunEndpoint + notifyHandle *channel.NotificationHandle + flags uint16 +} + +// beforeSave is invoked by stateify. +func (d *Device) beforeSave() { + d.mu.Lock() + defer d.mu.Unlock() + // TODO(b/110961832): Restore the device to stack. At this moment, the stack + // is not savable. + if d.endpoint != nil { + panic("/dev/net/tun does not support save/restore when a device is associated with it.") + } +} + +// Release implements fs.FileOperations.Release. +func (d *Device) Release() { + d.mu.Lock() + defer d.mu.Unlock() + + // Decrease refcount if there is an endpoint associated with this file. + if d.endpoint != nil { + d.endpoint.RemoveNotify(d.notifyHandle) + d.endpoint.DecRef() + d.endpoint = nil + } +} + +// SetIff services TUNSETIFF ioctl(2) request. +func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error { + d.mu.Lock() + defer d.mu.Unlock() + + if d.endpoint != nil { + return syserror.EINVAL + } + + // Input validations. + isTun := flags&linux.IFF_TUN != 0 + isTap := flags&linux.IFF_TAP != 0 + supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI) + if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 { + return syserror.EINVAL + } + + prefix := "tun" + if isTap { + prefix = "tap" + } + + endpoint, err := attachOrCreateNIC(s, name, prefix) + if err != nil { + return syserror.EINVAL + } + + d.endpoint = endpoint + d.notifyHandle = d.endpoint.AddNotify(d) + d.flags = flags + return nil +} + +func attachOrCreateNIC(s *stack.Stack, name, prefix string) (*tunEndpoint, error) { + for { + // 1. Try to attach to an existing NIC. + if name != "" { + if nic, found := s.GetNICByName(name); found { + endpoint, ok := nic.LinkEndpoint().(*tunEndpoint) + if !ok { + // Not a NIC created by tun device. + return nil, syserror.EOPNOTSUPP + } + if !endpoint.TryIncRef() { + // Race detected: NIC got deleted in between. + continue + } + return endpoint, nil + } + } + + // 2. Creating a new NIC. + id := tcpip.NICID(s.UniqueID()) + endpoint := &tunEndpoint{ + Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""), + stack: s, + nicID: id, + name: name, + } + if endpoint.name == "" { + endpoint.name = fmt.Sprintf("%s%d", prefix, id) + } + err := s.CreateNICWithOptions(endpoint.nicID, endpoint, stack.NICOptions{ + Name: endpoint.name, + }) + switch err { + case nil: + return endpoint, nil + case tcpip.ErrDuplicateNICID: + // Race detected: A NIC has been created in between. + continue + default: + return nil, syserror.EINVAL + } + } +} + +// Write inject one inbound packet to the network interface. +func (d *Device) Write(data []byte) (int64, error) { + d.mu.RLock() + endpoint := d.endpoint + d.mu.RUnlock() + if endpoint == nil { + return 0, syserror.EBADFD + } + if !endpoint.IsAttached() { + return 0, syserror.EIO + } + + dataLen := int64(len(data)) + + // Packet information. + var pktInfoHdr PacketInfoHeader + if !d.hasFlags(linux.IFF_NO_PI) { + if len(data) < PacketInfoHeaderSize { + // Ignore bad packet. + return dataLen, nil + } + pktInfoHdr = PacketInfoHeader(data[:PacketInfoHeaderSize]) + data = data[PacketInfoHeaderSize:] + } + + // Ethernet header (TAP only). + var ethHdr header.Ethernet + if d.hasFlags(linux.IFF_TAP) { + if len(data) < header.EthernetMinimumSize { + // Ignore bad packet. + return dataLen, nil + } + ethHdr = header.Ethernet(data[:header.EthernetMinimumSize]) + data = data[header.EthernetMinimumSize:] + } + + // Try to determine network protocol number, default zero. + var protocol tcpip.NetworkProtocolNumber + switch { + case pktInfoHdr != nil: + protocol = pktInfoHdr.Protocol() + case ethHdr != nil: + protocol = ethHdr.Type() + } + + // Try to determine remote link address, default zero. + var remote tcpip.LinkAddress + switch { + case ethHdr != nil: + remote = ethHdr.SourceAddress() + default: + remote = tcpip.LinkAddress(zeroMAC[:]) + } + + pkt := tcpip.PacketBuffer{ + Data: buffer.View(data).ToVectorisedView(), + } + if ethHdr != nil { + pkt.LinkHeader = buffer.View(ethHdr) + } + endpoint.InjectLinkAddr(protocol, remote, pkt) + return dataLen, nil +} + +// Read reads one outgoing packet from the network interface. +func (d *Device) Read() ([]byte, error) { + d.mu.RLock() + endpoint := d.endpoint + d.mu.RUnlock() + if endpoint == nil { + return nil, syserror.EBADFD + } + + for { + info, ok := endpoint.Read() + if !ok { + return nil, syserror.ErrWouldBlock + } + + v, ok := d.encodePkt(&info) + if !ok { + // Ignore unsupported packet. + continue + } + return v, nil + } +} + +// encodePkt encodes packet for fd side. +func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { + var vv buffer.VectorisedView + + // Packet information. + if !d.hasFlags(linux.IFF_NO_PI) { + hdr := make(PacketInfoHeader, PacketInfoHeaderSize) + hdr.Encode(&PacketInfoFields{ + Protocol: info.Proto, + }) + vv.AppendView(buffer.View(hdr)) + } + + // If the packet does not already have link layer header, and the route + // does not exist, we can't compute it. This is possibly a raw packet, tun + // device doesn't support this at the moment. + if info.Pkt.LinkHeader == nil && info.Route.RemoteLinkAddress == "" { + return nil, false + } + + // Ethernet header (TAP only). + if d.hasFlags(linux.IFF_TAP) { + // Add ethernet header if not provided. + if info.Pkt.LinkHeader == nil { + hdr := &header.EthernetFields{ + SrcAddr: info.Route.LocalLinkAddress, + DstAddr: info.Route.RemoteLinkAddress, + Type: info.Proto, + } + if hdr.SrcAddr == "" { + hdr.SrcAddr = d.endpoint.LinkAddress() + } + + eth := make(header.Ethernet, header.EthernetMinimumSize) + eth.Encode(hdr) + vv.AppendView(buffer.View(eth)) + } else { + vv.AppendView(info.Pkt.LinkHeader) + } + } + + // Append upper headers. + vv.AppendView(buffer.View(info.Pkt.Header.View()[len(info.Pkt.LinkHeader):])) + // Append data payload. + vv.Append(info.Pkt.Data) + + return vv.ToView(), true +} + +// Name returns the name of the attached network interface. Empty string if +// unattached. +func (d *Device) Name() string { + d.mu.RLock() + defer d.mu.RUnlock() + if d.endpoint != nil { + return d.endpoint.name + } + return "" +} + +// Flags returns the flags set for d. Zero value if unset. +func (d *Device) Flags() uint16 { + d.mu.RLock() + defer d.mu.RUnlock() + return d.flags +} + +func (d *Device) hasFlags(flags uint16) bool { + return d.flags&flags == flags +} + +// Readiness implements watier.Waitable.Readiness. +func (d *Device) Readiness(mask waiter.EventMask) waiter.EventMask { + if mask&waiter.EventIn != 0 { + d.mu.RLock() + endpoint := d.endpoint + d.mu.RUnlock() + if endpoint != nil && endpoint.NumQueued() == 0 { + mask &= ^waiter.EventIn + } + } + return mask & (waiter.EventIn | waiter.EventOut) +} + +// WriteNotify implements channel.Notification.WriteNotify. +func (d *Device) WriteNotify() { + d.Notify(waiter.EventIn) +} + +// tunEndpoint is the link endpoint for the NIC created by the tun device. +// +// It is ref-counted as multiple opening files can attach to the same NIC. +// The last owner is responsible for deleting the NIC. +type tunEndpoint struct { + *channel.Endpoint + + refs.AtomicRefCount + + stack *stack.Stack + nicID tcpip.NICID + name string +} + +// DecRef decrements refcount of e, removes NIC if refcount goes to 0. +func (e *tunEndpoint) DecRef() { + e.DecRefWithDestructor(func() { + e.stack.RemoveNIC(e.nicID) + }) +} diff --git a/pkg/tcpip/link/tun/protocol.go b/pkg/tcpip/link/tun/protocol.go new file mode 100644 index 000000000..89d9d91a9 --- /dev/null +++ b/pkg/tcpip/link/tun/protocol.go @@ -0,0 +1,56 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tun + +import ( + "encoding/binary" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + // PacketInfoHeaderSize is the size of the packet information header. + PacketInfoHeaderSize = 4 + + offsetFlags = 0 + offsetProtocol = 2 +) + +// PacketInfoFields contains fields sent through the wire if IFF_NO_PI flag is +// not set. +type PacketInfoFields struct { + Flags uint16 + Protocol tcpip.NetworkProtocolNumber +} + +// PacketInfoHeader is the wire representation of the packet information sent if +// IFF_NO_PI flag is not set. +type PacketInfoHeader []byte + +// Encode encodes f into h. +func (h PacketInfoHeader) Encode(f *PacketInfoFields) { + binary.BigEndian.PutUint16(h[offsetFlags:][:2], f.Flags) + binary.BigEndian.PutUint16(h[offsetProtocol:][:2], uint16(f.Protocol)) +} + +// Flags returns the flag field in h. +func (h PacketInfoHeader) Flags() uint16 { + return binary.BigEndian.Uint16(h[offsetFlags:]) +} + +// Protocol returns the protocol field in h. +func (h PacketInfoHeader) Protocol() tcpip.NetworkProtocolNumber { + return tcpip.NetworkProtocolNumber(binary.BigEndian.Uint16(h[offsetProtocol:])) +} diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 4da13c5df..e9fcc89a8 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -148,12 +148,12 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWi }, nil } -// LinkAddressProtocol implements stack.LinkAddressResolver. +// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol. func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return header.IPv4ProtocolNumber } -// LinkAddressRequest implements stack.LinkAddressResolver. +// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { r := &stack.Route{ RemoteLinkAddress: broadcastMAC, @@ -172,7 +172,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. }) } -// ResolveStaticAddress implements stack.LinkAddressResolver. +// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { if addr == header.IPv4Broadcast { return broadcastMAC, true @@ -183,16 +183,22 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo return tcpip.LinkAddress([]byte(nil)), false } -// SetOption implements NetworkProtocol. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +// SetOption implements stack.NetworkProtocol.SetOption. +func (*protocol) SetOption(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -// Option implements NetworkProtocol. -func (p *protocol) Option(option interface{}) *tcpip.Error { +// Option implements stack.NetworkProtocol.Option. +func (*protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) // NewProtocol returns an ARP network protocol. diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 6597e6781..4f1742938 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -473,6 +473,12 @@ func (p *protocol) DefaultTTL() uint8 { return uint8(atomic.LoadUint32(&p.defaultTTL)) } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + // calculateMTU calculates the network-layer payload MTU based on the link-layer // payload mtu. func calculateMTU(mtu uint32) uint32 { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 180a480fd..9aef5234b 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -265,6 +265,12 @@ func (p *protocol) DefaultTTL() uint8 { return uint8(atomic.LoadUint32(&p.defaultTTL)) } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + // calculateMTU calculates the network-layer payload MTU based on the link-layer // payload mtu. func calculateMTU(mtu uint32) uint32 { diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 045409bda..f651871ce 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -1148,22 +1148,27 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo return true } -// cleanupHostOnlyState cleans up any state that is only useful for hosts. +// cleanupState cleans up ndp's state. // -// cleanupHostOnlyState MUST be called when ndp's NIC is transitioning from a -// host to a router. This function will invalidate all discovered on-link -// prefixes, discovered routers, and auto-generated addresses as routers do not -// normally process Router Advertisements to discover default routers and -// on-link prefixes, and auto-generate addresses via SLAAC. +// If hostOnly is true, then only host-specific state will be cleaned up. +// +// cleanupState MUST be called with hostOnly set to true when ndp's NIC is +// transitioning from a host to a router. This function will invalidate all +// discovered on-link prefixes, discovered routers, and auto-generated +// addresses. +// +// If hostOnly is true, then the link-local auto-generated address will not be +// invalidated as routers are also expected to generate a link-local address. // // The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) cleanupHostOnlyState() { +func (ndp *ndpState) cleanupState(hostOnly bool) { linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet() linkLocalAddrs := 0 for addr := range ndp.autoGenAddresses { // RFC 4862 section 5 states that routers are also expected to generate a - // link-local address so we do not invalidate them. - if linkLocalSubnet.Contains(addr) { + // link-local address so we do not invalidate them if we are cleaning up + // host-only state. + if hostOnly && linkLocalSubnet.Contains(addr) { linkLocalAddrs++ continue } @@ -1230,7 +1235,7 @@ func (ndp *ndpState) startSolicitingRouters() { } payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize - hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadSize) + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + payloadSize) pkt := header.ICMPv6(hdr.Prepend(payloadSize)) pkt.SetType(header.ICMPv6RouterSolicit) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 1f6f77439..6e9306d09 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -267,6 +267,17 @@ func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration s } } +// channelLinkWithHeaderLength is a channel.Endpoint with a configurable +// header length. +type channelLinkWithHeaderLength struct { + *channel.Endpoint + headerLength uint16 +} + +func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 { + return l.headerLength +} + // Check e to make sure that the event is for addr on nic with ID 1, and the // resolved flag set to resolved with the specified err. func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) string { @@ -323,21 +334,46 @@ func TestDADDisabled(t *testing.T) { // DAD for various values of DupAddrDetectTransmits and RetransmitTimer. // Included in the subtests is a test to make sure that an invalid // RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s. +// This tests also validates the NDP NS packet that is transmitted. func TestDADResolve(t *testing.T) { const nicID = 1 tests := []struct { name string + linkHeaderLen uint16 dupAddrDetectTransmits uint8 retransTimer time.Duration expectedRetransmitTimer time.Duration }{ - {"1:1s:1s", 1, time.Second, time.Second}, - {"2:1s:1s", 2, time.Second, time.Second}, - {"1:2s:2s", 1, 2 * time.Second, 2 * time.Second}, + { + name: "1:1s:1s", + dupAddrDetectTransmits: 1, + retransTimer: time.Second, + expectedRetransmitTimer: time.Second, + }, + { + name: "2:1s:1s", + linkHeaderLen: 1, + dupAddrDetectTransmits: 2, + retransTimer: time.Second, + expectedRetransmitTimer: time.Second, + }, + { + name: "1:2s:2s", + linkHeaderLen: 2, + dupAddrDetectTransmits: 1, + retransTimer: 2 * time.Second, + expectedRetransmitTimer: 2 * time.Second, + }, // 0s is an invalid RetransmitTimer timer and will be fixed to // the default RetransmitTimer value of 1s. - {"1:0s:1s", 1, 0, time.Second}, + { + name: "1:0s:1s", + linkHeaderLen: 3, + dupAddrDetectTransmits: 1, + retransTimer: 0, + expectedRetransmitTimer: time.Second, + }, } for _, test := range tests { @@ -356,10 +392,13 @@ func TestDADResolve(t *testing.T) { opts.NDPConfigs.RetransmitTimer = test.retransTimer opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits - e := channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + e := channelLinkWithHeaderLength{ + Endpoint: channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1), + headerLength: test.linkHeaderLen, + } + e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired s := stack.New(opts) - if err := s.CreateNIC(nicID, e); err != nil { + if err := s.CreateNIC(nicID, &e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -445,6 +484,10 @@ func TestDADResolve(t *testing.T) { checker.NDPNSTargetAddress(addr1), checker.NDPNSOptions(nil), )) + + if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want) + } } }) } @@ -592,70 +635,94 @@ func TestDADFail(t *testing.T) { } } -// TestDADStop tests to make sure that the DAD process stops when an address is -// removed. func TestDADStop(t *testing.T) { const nicID = 1 - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - ndpConfigs := stack.NDPConfigurations{ - RetransmitTimer: time.Second, - DupAddrDetectTransmits: 2, - } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPDisp: &ndpDisp, - NDPConfigs: ndpConfigs, - } + tests := []struct { + name string + stopFn func(t *testing.T, s *stack.Stack) + }{ + // Tests to make sure that DAD stops when an address is removed. + { + name: "Remove address", + stopFn: func(t *testing.T, s *stack.Stack) { + if err := s.RemoveAddress(nicID, addr1); err != nil { + t.Fatalf("RemoveAddress(%d, %s): %s", nicID, addr1, err) + } + }, + }, - e := channel.New(0, 1280, linkAddr1) - s := stack.New(opts) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + // Tests to make sure that DAD stops when the NIC is disabled. + { + name: "Disable NIC", + stopFn: func(t *testing.T, s *stack.Stack) { + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("DisableNIC(%d): %s", nicID, err) + } + }, + }, } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + } + ndpConfigs := stack.NDPConfigurations{ + RetransmitTimer: time.Second, + DupAddrDetectTransmits: 2, + } + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPDisp: &ndpDisp, + NDPConfigs: ndpConfigs, + } - // Address should not be considered bound to the NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) - } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(opts) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } - // Remove the address. This should stop DAD. - if err := s.RemoveAddress(nicID, addr1); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err) - } + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err) + } - // Wait for DAD to fail (since the address was removed during DAD). - select { - case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): - // If we don't get a failure event after the expected resolution - // time + extra 1s buffer, something is wrong. - t.Fatal("timed out waiting for DAD failure") - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - } - addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) - } + // Address should not be considered bound to the NIC yet (DAD ongoing). + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + } + + test.stopFn(t, s) + + // Wait for DAD to fail (since the address was removed during DAD). + select { + case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + // If we don't get a failure event after the expected resolution + // time + extra 1s buffer, something is wrong. + t.Fatal("timed out waiting for DAD failure") + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + } + addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + } - // Should not have sent more than 1 NS message. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 { - t.Fatalf("got NeighborSolicit = %d, want <= 1", got) + // Should not have sent more than 1 NS message. + if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 { + t.Errorf("got NeighborSolicit = %d, want <= 1", got) + } + }) } } @@ -2886,17 +2953,16 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) { } } -// TestCleanupHostOnlyStateOnBecomingRouter tests that all discovered routers -// and prefixes, and non-linklocal auto-generated addresses are invalidated when -// a NIC becomes a router. -func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { +// TestCleanupNDPState tests that all discovered routers and prefixes, and +// auto-generated addresses are invalidated when a NIC becomes a router. +func TestCleanupNDPState(t *testing.T) { t.Parallel() const ( - lifetimeSeconds = 5 - maxEvents = 4 - nicID1 = 1 - nicID2 = 2 + lifetimeSeconds = 5 + maxRouterAndPrefixEvents = 4 + nicID1 = 1 + nicID2 = 2 ) prefix1, subnet1, e1Addr1 := prefixSubnetAddr(0, linkAddr1) @@ -2912,254 +2978,308 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { PrefixLen: 64, } - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, maxEvents), - rememberRouter: true, - prefixC: make(chan ndpPrefixEvent, maxEvents), - rememberPrefix: true, - autoGenAddrC: make(chan ndpAutoGenAddrEvent, maxEvents), - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - AutoGenIPv6LinkLocal: true, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - DiscoverOnLinkPrefixes: true, - AutoGenGlobalAddresses: true, + tests := []struct { + name string + cleanupFn func(t *testing.T, s *stack.Stack) + keepAutoGenLinkLocal bool + maxAutoGenAddrEvents int + }{ + // A NIC should still keep its auto-generated link-local address when + // becoming a router. + { + name: "Forwarding Enable", + cleanupFn: func(t *testing.T, s *stack.Stack) { + t.Helper() + s.SetForwarding(true) + }, + keepAutoGenLinkLocal: true, + maxAutoGenAddrEvents: 4, }, - NDPDisp: &ndpDisp, - }) - expectRouterEvent := func() (bool, ndpRouterEvent) { - select { - case e := <-ndpDisp.routerC: - return true, e - default: - } + // A NIC should cleanup all NDP state when it is disabled. + { + name: "NIC Disable", + cleanupFn: func(t *testing.T, s *stack.Stack) { + t.Helper() - return false, ndpRouterEvent{} + if err := s.DisableNIC(nicID1); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID1, err) + } + if err := s.DisableNIC(nicID2); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID2, err) + } + }, + keepAutoGenLinkLocal: false, + maxAutoGenAddrEvents: 6, + }, } - expectPrefixEvent := func() (bool, ndpPrefixEvent) { - select { - case e := <-ndpDisp.prefixC: - return true, e - default: - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, maxRouterAndPrefixEvents), + rememberRouter: true, + prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents), + rememberPrefix: true, + autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: true, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + DiscoverOnLinkPrefixes: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) - return false, ndpPrefixEvent{} - } + expectRouterEvent := func() (bool, ndpRouterEvent) { + select { + case e := <-ndpDisp.routerC: + return true, e + default: + } - expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) { - select { - case e := <-ndpDisp.autoGenAddrC: - return true, e - default: - } + return false, ndpRouterEvent{} + } - return false, ndpAutoGenAddrEvent{} - } + expectPrefixEvent := func() (bool, ndpPrefixEvent) { + select { + case e := <-ndpDisp.prefixC: + return true, e + default: + } - e1 := channel.New(0, 1280, linkAddr1) - if err := s.CreateNIC(nicID1, e1); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) - } - // We have other tests that make sure we receive the *correct* events - // on normal discovery of routers/prefixes, and auto-generated - // addresses. Here we just make sure we get an event and let other tests - // handle the correctness check. - expectAutoGenAddrEvent() + return false, ndpPrefixEvent{} + } - e2 := channel.New(0, 1280, linkAddr2) - if err := s.CreateNIC(nicID2, e2); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) - } - expectAutoGenAddrEvent() + expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) { + select { + case e := <-ndpDisp.autoGenAddrC: + return true, e + default: + } - // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and - // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from - // llAddr4) to discover multiple routers and prefixes, and auto-gen - // multiple addresses. + return false, ndpAutoGenAddrEvent{} + } - e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1) - } + e1 := channel.New(0, 1280, linkAddr1) + if err := s.CreateNIC(nicID1, e1); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) + } + // We have other tests that make sure we receive the *correct* events + // on normal discovery of routers/prefixes, and auto-generated + // addresses. Here we just make sure we get an event and let other tests + // handle the correctness check. + expectAutoGenAddrEvent() - e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1) - } + e2 := channel.New(0, 1280, linkAddr2) + if err := s.CreateNIC(nicID2, e2); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) + } + expectAutoGenAddrEvent() - e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2) - } + // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and + // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from + // llAddr4) to discover multiple routers and prefixes, and auto-gen + // multiple addresses. - e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2) - } + e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1) + } - // We should have the auto-generated addresses added. - nicinfo := s.NICInfo() - nic1Addrs := nicinfo[nicID1].ProtocolAddresses - nic2Addrs := nicinfo[nicID2].ProtocolAddresses - if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) - } - if !containsV6Addr(nic1Addrs, e1Addr1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) - } - if !containsV6Addr(nic1Addrs, e1Addr2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) - } - if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) - } - if !containsV6Addr(nic2Addrs, e2Addr1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) - } - if !containsV6Addr(nic2Addrs, e2Addr2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) - } + e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1) + } - // We can't proceed any further if we already failed the test (missing - // some discovery/auto-generated address events or addresses). - if t.Failed() { - t.FailNow() - } + e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2) + } - s.SetForwarding(true) + e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2) + } - // Collect invalidation events after becoming a router - gotRouterEvents := make(map[ndpRouterEvent]int) - for i := 0; i < maxEvents; i++ { - ok, e := expectRouterEvent() - if !ok { - t.Errorf("expected %d router events after becoming a router; got = %d", maxEvents, i) - break - } - gotRouterEvents[e]++ - } - gotPrefixEvents := make(map[ndpPrefixEvent]int) - for i := 0; i < maxEvents; i++ { - ok, e := expectPrefixEvent() - if !ok { - t.Errorf("expected %d prefix events after becoming a router; got = %d", maxEvents, i) - break - } - gotPrefixEvents[e]++ - } - gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int) - for i := 0; i < maxEvents; i++ { - ok, e := expectAutoGenAddrEvent() - if !ok { - t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", maxEvents, i) - break - } - gotAutoGenAddrEvents[e]++ - } + // We should have the auto-generated addresses added. + nicinfo := s.NICInfo() + nic1Addrs := nicinfo[nicID1].ProtocolAddresses + nic2Addrs := nicinfo[nicID2].ProtocolAddresses + if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + } + if !containsV6Addr(nic1Addrs, e1Addr1) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) + } + if !containsV6Addr(nic1Addrs, e1Addr2) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) + } + if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + } + if !containsV6Addr(nic2Addrs, e2Addr1) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) + } + if !containsV6Addr(nic2Addrs, e2Addr2) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) + } - // No need to proceed any further if we already failed the test (missing - // some invalidation events). - if t.Failed() { - t.FailNow() - } + // We can't proceed any further if we already failed the test (missing + // some discovery/auto-generated address events or addresses). + if t.Failed() { + t.FailNow() + } - expectedRouterEvents := map[ndpRouterEvent]int{ - {nicID: nicID1, addr: llAddr3, discovered: false}: 1, - {nicID: nicID1, addr: llAddr4, discovered: false}: 1, - {nicID: nicID2, addr: llAddr3, discovered: false}: 1, - {nicID: nicID2, addr: llAddr4, discovered: false}: 1, - } - if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" { - t.Errorf("router events mismatch (-want +got):\n%s", diff) - } - expectedPrefixEvents := map[ndpPrefixEvent]int{ - {nicID: nicID1, prefix: subnet1, discovered: false}: 1, - {nicID: nicID1, prefix: subnet2, discovered: false}: 1, - {nicID: nicID2, prefix: subnet1, discovered: false}: 1, - {nicID: nicID2, prefix: subnet2, discovered: false}: 1, - } - if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" { - t.Errorf("prefix events mismatch (-want +got):\n%s", diff) - } - expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{ - {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1, - {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1, - {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1, - {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1, - } - if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" { - t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff) - } + test.cleanupFn(t, s) - // Make sure the auto-generated addresses got removed. - nicinfo = s.NICInfo() - nic1Addrs = nicinfo[nicID1].ProtocolAddresses - nic2Addrs = nicinfo[nicID2].ProtocolAddresses - if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) - } - if containsV6Addr(nic1Addrs, e1Addr1) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) - } - if containsV6Addr(nic1Addrs, e1Addr2) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) - } - if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) - } - if containsV6Addr(nic2Addrs, e2Addr1) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) - } - if containsV6Addr(nic2Addrs, e2Addr2) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) - } + // Collect invalidation events after having NDP state cleaned up. + gotRouterEvents := make(map[ndpRouterEvent]int) + for i := 0; i < maxRouterAndPrefixEvents; i++ { + ok, e := expectRouterEvent() + if !ok { + t.Errorf("expected %d router events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) + break + } + gotRouterEvents[e]++ + } + gotPrefixEvents := make(map[ndpPrefixEvent]int) + for i := 0; i < maxRouterAndPrefixEvents; i++ { + ok, e := expectPrefixEvent() + if !ok { + t.Errorf("expected %d prefix events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) + break + } + gotPrefixEvents[e]++ + } + gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int) + for i := 0; i < test.maxAutoGenAddrEvents; i++ { + ok, e := expectAutoGenAddrEvent() + if !ok { + t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", test.maxAutoGenAddrEvents, i) + break + } + gotAutoGenAddrEvents[e]++ + } - // Should not get any more events (invalidation timers should have been - // cancelled when we transitioned into a router). - time.Sleep(lifetimeSeconds*time.Second + defaultTimeout) - select { - case <-ndpDisp.routerC: - t.Error("unexpected router event") - default: - } - select { - case <-ndpDisp.prefixC: - t.Error("unexpected prefix event") - default: - } - select { - case <-ndpDisp.autoGenAddrC: - t.Error("unexpected auto-generated address event") - default: + // No need to proceed any further if we already failed the test (missing + // some invalidation events). + if t.Failed() { + t.FailNow() + } + + expectedRouterEvents := map[ndpRouterEvent]int{ + {nicID: nicID1, addr: llAddr3, discovered: false}: 1, + {nicID: nicID1, addr: llAddr4, discovered: false}: 1, + {nicID: nicID2, addr: llAddr3, discovered: false}: 1, + {nicID: nicID2, addr: llAddr4, discovered: false}: 1, + } + if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" { + t.Errorf("router events mismatch (-want +got):\n%s", diff) + } + expectedPrefixEvents := map[ndpPrefixEvent]int{ + {nicID: nicID1, prefix: subnet1, discovered: false}: 1, + {nicID: nicID1, prefix: subnet2, discovered: false}: 1, + {nicID: nicID2, prefix: subnet1, discovered: false}: 1, + {nicID: nicID2, prefix: subnet2, discovered: false}: 1, + } + if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" { + t.Errorf("prefix events mismatch (-want +got):\n%s", diff) + } + expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{ + {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1, + {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1, + {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1, + {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1, + } + + if !test.keepAutoGenLinkLocal { + expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID1, addr: llAddrWithPrefix1, eventType: invalidatedAddr}] = 1 + expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID2, addr: llAddrWithPrefix2, eventType: invalidatedAddr}] = 1 + } + + if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" { + t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff) + } + + // Make sure the auto-generated addresses got removed. + nicinfo = s.NICInfo() + nic1Addrs = nicinfo[nicID1].ProtocolAddresses + nic2Addrs = nicinfo[nicID2].ProtocolAddresses + if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal { + if test.keepAutoGenLinkLocal { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + } else { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + } + } + if containsV6Addr(nic1Addrs, e1Addr1) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) + } + if containsV6Addr(nic1Addrs, e1Addr2) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) + } + if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal { + if test.keepAutoGenLinkLocal { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + } else { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + } + } + if containsV6Addr(nic2Addrs, e2Addr1) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) + } + if containsV6Addr(nic2Addrs, e2Addr2) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) + } + + // Should not get any more events (invalidation timers should have been + // cancelled when the NDP state was cleaned up). + time.Sleep(lifetimeSeconds*time.Second + defaultTimeout) + select { + case <-ndpDisp.routerC: + t.Error("unexpected router event") + default: + } + select { + case <-ndpDisp.prefixC: + t.Error("unexpected prefix event") + default: + } + select { + case <-ndpDisp.autoGenAddrC: + t.Error("unexpected auto-generated address event") + default: + } + }) } } @@ -3259,8 +3379,11 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { func TestRouterSolicitation(t *testing.T) { t.Parallel() + const nicID = 1 + tests := []struct { name string + linkHeaderLen uint16 maxRtrSolicit uint8 rtrSolicitInt time.Duration effectiveRtrSolicitInt time.Duration @@ -3277,6 +3400,7 @@ func TestRouterSolicitation(t *testing.T) { }, { name: "Two RS with delay", + linkHeaderLen: 1, maxRtrSolicit: 2, rtrSolicitInt: time.Second, effectiveRtrSolicitInt: time.Second, @@ -3285,6 +3409,7 @@ func TestRouterSolicitation(t *testing.T) { }, { name: "Single RS without delay", + linkHeaderLen: 2, maxRtrSolicit: 1, rtrSolicitInt: time.Second, effectiveRtrSolicitInt: time.Second, @@ -3293,6 +3418,7 @@ func TestRouterSolicitation(t *testing.T) { }, { name: "Two RS without delay and invalid zero interval", + linkHeaderLen: 3, maxRtrSolicit: 2, rtrSolicitInt: 0, effectiveRtrSolicitInt: 4 * time.Second, @@ -3330,8 +3456,11 @@ func TestRouterSolicitation(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() - e := channel.New(int(test.maxRtrSolicit), 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + e := channelLinkWithHeaderLength{ + Endpoint: channel.New(int(test.maxRtrSolicit), 1280, linkAddr1), + headerLength: test.linkHeaderLen, + } + e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired waitForPkt := func(timeout time.Duration) { t.Helper() ctx, _ := context.WithTimeout(context.Background(), timeout) @@ -3357,6 +3486,10 @@ func TestRouterSolicitation(t *testing.T) { checker.TTL(header.NDPHopLimit), checker.NDPRS(), ) + + if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want) + } } waitForNothing := func(timeout time.Duration) { t.Helper() @@ -3373,8 +3506,8 @@ func TestRouterSolicitation(t *testing.T) { MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, }, }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } // Make sure each RS got sent at the right @@ -3406,77 +3539,130 @@ func TestRouterSolicitation(t *testing.T) { }) } -// TestStopStartSolicitingRouters tests that when forwarding is enabled or -// disabled, router solicitations are stopped or started, respecitively. func TestStopStartSolicitingRouters(t *testing.T) { t.Parallel() + const nicID = 1 const interval = 500 * time.Millisecond const delay = time.Second const maxRtrSolicitations = 3 - e := channel.New(maxRtrSolicitations, 1280, linkAddr1) - waitForPkt := func(timeout time.Duration) { - t.Helper() - ctx, _ := context.WithTimeout(context.Background(), timeout) - p, ok := e.ReadContext(ctx) - if !ok { - t.Fatal("timed out waiting for packet") - return - } - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - checker.IPv6(t, p.Pkt.Header.View(), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS()) - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - MaxRtrSolicitations: maxRtrSolicitations, - RtrSolicitationInterval: interval, - MaxRtrSolicitationDelay: delay, + tests := []struct { + name string + startFn func(t *testing.T, s *stack.Stack) + stopFn func(t *testing.T, s *stack.Stack) + }{ + // Tests that when forwarding is enabled or disabled, router solicitations + // are stopped or started, respectively. + { + name: "Forwarding enabled and disabled", + startFn: func(t *testing.T, s *stack.Stack) { + t.Helper() + s.SetForwarding(false) + }, + stopFn: func(t *testing.T, s *stack.Stack) { + t.Helper() + s.SetForwarding(true) + }, }, - }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - // Enable forwarding which should stop router solicitations. - s.SetForwarding(true) - ctx, _ := context.WithTimeout(context.Background(), delay+defaultTimeout) - if _, ok := e.ReadContext(ctx); ok { - // A single RS may have been sent before forwarding was enabled. - ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout) - if _, ok = e.ReadContext(ctx); ok { - t.Fatal("Should not have sent more than one RS message") - } - } + // Tests that when a NIC is enabled or disabled, router solicitations + // are started or stopped, respectively. + { + name: "NIC disabled and enabled", + startFn: func(t *testing.T, s *stack.Stack) { + t.Helper() - // Enabling forwarding again should do nothing. - s.SetForwarding(true) - ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout) - if _, ok := e.ReadContext(ctx); ok { - t.Fatal("unexpectedly got a packet after becoming a router") - } + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + }, + stopFn: func(t *testing.T, s *stack.Stack) { + t.Helper() - // Disable forwarding which should start router solicitations. - s.SetForwarding(false) - waitForPkt(delay + defaultAsyncEventTimeout) - waitForPkt(interval + defaultAsyncEventTimeout) - waitForPkt(interval + defaultAsyncEventTimeout) - ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout) - if _, ok := e.ReadContext(ctx); ok { - t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + }, + }, } - // Disabling forwarding again should do nothing. - s.SetForwarding(false) - ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout) - if _, ok := e.ReadContext(ctx); ok { - t.Fatal("unexpectedly got a packet after becoming a router") + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(maxRtrSolicitations, 1280, linkAddr1) + waitForPkt := func(timeout time.Duration) { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + p, ok := e.ReadContext(ctx) + if !ok { + t.Fatal("timed out waiting for packet") + return + } + + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } + checker.IPv6(t, p.Pkt.Header.View(), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS()) + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + MaxRtrSolicitations: maxRtrSolicitations, + RtrSolicitationInterval: interval, + MaxRtrSolicitationDelay: delay, + }, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + // Stop soliciting routers. + test.stopFn(t, s) + ctx, cancel := context.WithTimeout(context.Background(), delay+defaultTimeout) + defer cancel() + if _, ok := e.ReadContext(ctx); ok { + // A single RS may have been sent before forwarding was enabled. + ctx, cancel := context.WithTimeout(context.Background(), interval+defaultTimeout) + defer cancel() + if _, ok = e.ReadContext(ctx); ok { + t.Fatal("should not have sent more than one RS message") + } + } + + // Stopping router solicitations after it has already been stopped should + // do nothing. + test.stopFn(t, s) + ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout) + defer cancel() + if _, ok := e.ReadContext(ctx); ok { + t.Fatal("unexpectedly got a packet after router solicitation has been stopepd") + } + + // Start soliciting routers. + test.startFn(t, s) + waitForPkt(delay + defaultAsyncEventTimeout) + waitForPkt(interval + defaultAsyncEventTimeout) + waitForPkt(interval + defaultAsyncEventTimeout) + ctx, cancel = context.WithTimeout(context.Background(), interval+defaultTimeout) + defer cancel() + if _, ok := e.ReadContext(ctx); ok { + t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") + } + + // Starting router solicitations after it has already completed should do + // nothing. + test.startFn(t, s) + ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout) + defer cancel() + if _, ok := e.ReadContext(ctx); ok { + t.Fatal("unexpectedly got a packet after finishing router solicitations") + } + }) } } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index ca3a7a07e..46d3a6646 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -27,6 +27,14 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) +var ipv4BroadcastAddr = tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: header.IPv4Broadcast, + PrefixLen: 8 * header.IPv4AddressSize, + }, +} + // NIC represents a "network interface card" to which the networking stack is // attached. type NIC struct { @@ -132,11 +140,77 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC nic.mu.packetEPs[netProto.Number()] = []PacketEndpoint{} } + nic.linkEP.Attach(nic) + return nic } -// enable enables the NIC. enable will attach the link to its LinkEndpoint and -// join the IPv6 All-Nodes Multicast address (ff02::1). +// enabled returns true if n is enabled. +func (n *NIC) enabled() bool { + n.mu.RLock() + enabled := n.mu.enabled + n.mu.RUnlock() + return enabled +} + +// disable disables n. +// +// It undoes the work done by enable. +func (n *NIC) disable() *tcpip.Error { + n.mu.RLock() + enabled := n.mu.enabled + n.mu.RUnlock() + if !enabled { + return nil + } + + n.mu.Lock() + defer n.mu.Unlock() + + if !n.mu.enabled { + return nil + } + + // TODO(b/147015577): Should Routes that are currently bound to n be + // invalidated? Currently, Routes will continue to work when a NIC is enabled + // again, and applications may not know that the underlying NIC was ever + // disabled. + + if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok { + n.mu.ndp.stopSolicitingRouters() + n.mu.ndp.cleanupState(false /* hostOnly */) + + // Stop DAD for all the unicast IPv6 endpoints that are in the + // permanentTentative state. + for _, r := range n.mu.endpoints { + if addr := r.ep.ID().LocalAddress; r.getKind() == permanentTentative && header.IsV6UnicastAddress(addr) { + n.mu.ndp.stopDuplicateAddressDetection(addr) + } + } + + // The NIC may have already left the multicast group. + if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { + return err + } + } + + if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { + // The address may have already been removed. + if err := n.removePermanentAddressLocked(ipv4BroadcastAddr.AddressWithPrefix.Address); err != nil && err != tcpip.ErrBadLocalAddress { + return err + } + } + + n.mu.enabled = false + return nil +} + +// enable enables n. +// +// If the stack has IPv6 enabled, enable will join the IPv6 All-Nodes Multicast +// address (ff02::1), start DAD for permanent addresses, and start soliciting +// routers if the stack is not operating as a router. If the stack is also +// configured to auto-generate a link-local address, one will be generated. func (n *NIC) enable() *tcpip.Error { n.mu.RLock() enabled := n.mu.enabled @@ -154,14 +228,9 @@ func (n *NIC) enable() *tcpip.Error { n.mu.enabled = true - n.attachLinkEndpoint() - // Create an endpoint to receive broadcast packets on this interface. if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { - if _, err := n.addAddressLocked(tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}, - }, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { + if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { return err } } @@ -183,6 +252,14 @@ func (n *NIC) enable() *tcpip.Error { return nil } + // Join the All-Nodes multicast group before starting DAD as responses to DAD + // (NDP NS) messages may be sent to the All-Nodes multicast group if the + // source address of the NDP NS is the unspecified address, as per RFC 4861 + // section 7.2.4. + if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil { + return err + } + // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent // state. // @@ -200,10 +277,6 @@ func (n *NIC) enable() *tcpip.Error { } } - if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil { - return err - } - // Do not auto-generate an IPv6 link-local address for loopback devices. if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() { // The valid and preferred lifetime is infinite for the auto-generated @@ -225,6 +298,33 @@ func (n *NIC) enable() *tcpip.Error { return nil } +// remove detaches NIC from the link endpoint, and marks existing referenced +// network endpoints expired. This guarantees no packets between this NIC and +// the network stack. +func (n *NIC) remove() *tcpip.Error { + n.mu.Lock() + defer n.mu.Unlock() + + // Detach from link endpoint, so no packet comes in. + n.linkEP.Attach(nil) + + // Remove permanent and permanentTentative addresses, so no packet goes out. + var errs []*tcpip.Error + for nid, ref := range n.mu.endpoints { + switch ref.getKind() { + case permanentTentative, permanent: + if err := n.removePermanentAddressLocked(nid.LocalAddress); err != nil { + errs = append(errs, err) + } + } + } + if len(errs) > 0 { + return errs[0] + } + + return nil +} + // becomeIPv6Router transitions n into an IPv6 router. // // When transitioning into an IPv6 router, host-only state (NDP discovered @@ -234,7 +334,7 @@ func (n *NIC) becomeIPv6Router() { n.mu.Lock() defer n.mu.Unlock() - n.mu.ndp.cleanupHostOnlyState() + n.mu.ndp.cleanupState(true /* hostOnly */) n.mu.ndp.stopSolicitingRouters() } @@ -249,12 +349,6 @@ func (n *NIC) becomeIPv6Host() { n.mu.ndp.startSolicitingRouters() } -// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it -// to start delivering packets. -func (n *NIC) attachLinkEndpoint() { - n.linkEP.Attach(n) -} - // setPromiscuousMode enables or disables promiscuous mode. func (n *NIC) setPromiscuousMode(enable bool) { n.mu.Lock() @@ -712,6 +806,7 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { case permanentExpired, temporary: continue } + addrs = append(addrs, tcpip.ProtocolAddress{ Protocol: ref.protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ @@ -1009,6 +1104,15 @@ func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { return nil } +// isInGroup returns true if n has joined the multicast group addr. +func (n *NIC) isInGroup(addr tcpip.Address) bool { + n.mu.RLock() + joins := n.mu.mcastJoins[NetworkEndpointID{addr}] + n.mu.RUnlock() + + return joins != 0 +} + func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt tcpip.PacketBuffer) { r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remotelinkAddr @@ -1225,6 +1329,11 @@ func (n *NIC) Stack() *Stack { return n.stack } +// LinkEndpoint returns the link endpoint of n. +func (n *NIC) LinkEndpoint() LinkEndpoint { + return n.linkEP +} + // isAddrTentative returns true if addr is tentative on n. // // Note that if addr is not associated with n, then this function will return @@ -1411,7 +1520,7 @@ func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { // // r's NIC must be read locked. func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool { - return r.getKind() != permanentExpired || r.nic.mu.spoofing + return r.nic.mu.enabled && (r.getKind() != permanentExpired || r.nic.mu.spoofing) } // decRef decrements the ref count and cleans up the endpoint once it reaches diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index d83adf0ec..f9fd8f18f 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -74,10 +74,11 @@ type TransportEndpoint interface { // HandleControlPacket takes ownership of pkt. HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) - // Close puts the endpoint in a closed state and frees all resources - // associated with it. This cleanup may happen asynchronously. Wait can - // be used to block on this asynchronous cleanup. - Close() + // Abort initiates an expedited endpoint teardown. It puts the endpoint + // in a closed state and frees all resources associated with it. This + // cleanup may happen asynchronously. Wait can be used to block on this + // asynchronous cleanup. + Abort() // Wait waits for any worker goroutines owned by the endpoint to stop. // @@ -160,6 +161,13 @@ type TransportProtocol interface { // Option returns an error if the option is not supported or the // provided option value is invalid. Option(option interface{}) *tcpip.Error + + // Close requests that any worker goroutines owned by the protocol + // stop. + Close() + + // Wait waits for any worker goroutines owned by the protocol to stop. + Wait() } // TransportDispatcher contains the methods used by the network stack to deliver @@ -293,6 +301,13 @@ type NetworkProtocol interface { // Option returns an error if the option is not supported or the // provided option value is invalid. Option(option interface{}) *tcpip.Error + + // Close requests that any worker goroutines owned by the protocol + // stop. + Close() + + // Wait waits for any worker goroutines owned by the protocol to stop. + Wait() } // NetworkDispatcher contains the methods used by the network stack to deliver diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 6eac16e16..ebb6c5e3b 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -881,6 +881,8 @@ type NICOptions struct { // CreateNICWithOptions creates a NIC with the provided id, LinkEndpoint, and // NICOptions. See the documentation on type NICOptions for details on how // NICs can be configured. +// +// LinkEndpoint.Attach will be called to bind ep with a NetworkDispatcher. func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) *tcpip.Error { s.mu.Lock() defer s.mu.Unlock() @@ -900,7 +902,6 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp } n := newNIC(s, id, opts.Name, ep, opts.Context) - s.nics[id] = n if !opts.Disabled { return n.enable() @@ -910,34 +911,88 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp } // CreateNIC creates a NIC with the provided id and LinkEndpoint and calls -// `LinkEndpoint.Attach` to start delivering packets to it. +// LinkEndpoint.Attach to bind ep with a NetworkDispatcher. func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { return s.CreateNICWithOptions(id, ep, NICOptions{}) } +// GetNICByName gets the NIC specified by name. +func (s *Stack) GetNICByName(name string) (*NIC, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, nic := range s.nics { + if nic.Name() == name { + return nic, true + } + } + return nil, false +} + // EnableNIC enables the given NIC so that the link-layer endpoint can start // delivering packets to it. func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() - nic := s.nics[id] - if nic == nil { + nic, ok := s.nics[id] + if !ok { return tcpip.ErrUnknownNICID } return nic.enable() } +// DisableNIC disables the given NIC. +func (s *Stack) DisableNIC(id tcpip.NICID) *tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + nic, ok := s.nics[id] + if !ok { + return tcpip.ErrUnknownNICID + } + + return nic.disable() +} + // CheckNIC checks if a NIC is usable. func (s *Stack) CheckNIC(id tcpip.NICID) bool { s.mu.RLock() + defer s.mu.RUnlock() + nic, ok := s.nics[id] - s.mu.RUnlock() - if ok { - return nic.linkEP.IsAttached() + if !ok { + return false } - return false + + return nic.enabled() +} + +// RemoveNIC removes NIC and all related routes from the network stack. +func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + nic, ok := s.nics[id] + if !ok { + return tcpip.ErrUnknownNICID + } + delete(s.nics, id) + + // Remove routes in-place. n tracks the number of routes written. + n := 0 + for i, r := range s.routeTable { + if r.NIC != id { + // Keep this route. + if i > n { + s.routeTable[n] = r + } + n++ + } + } + s.routeTable = s.routeTable[:n] + + return nic.remove() } // NICAddressRanges returns a map of NICIDs to their associated subnets. @@ -989,7 +1044,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { for id, nic := range s.nics { flags := NICStateFlags{ Up: true, // Netstack interfaces are always up. - Running: nic.linkEP.IsAttached(), + Running: nic.enabled(), Promiscuous: nic.isPromiscuousMode(), Loopback: nic.isLoopback(), } @@ -1151,7 +1206,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr) needRoute := !(isBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr)) if id != 0 && !needRoute { - if nic, ok := s.nics[id]; ok { + if nic, ok := s.nics[id]; ok && nic.enabled() { if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil } @@ -1161,7 +1216,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) { continue } - if nic, ok := s.nics[route.NIC]; ok { + if nic, ok := s.nics[route.NIC]; ok && nic.enabled() { if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { if len(remoteAddr) == 0 { // If no remote address was provided, then the route @@ -1391,7 +1446,13 @@ func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) { // Endpoints created or modified during this call may not get closed. func (s *Stack) Close() { for _, e := range s.RegisteredEndpoints() { - e.Close() + e.Abort() + } + for _, p := range s.transportProtocols { + p.proto.Close() + } + for _, p := range s.networkProtocols { + p.Close() } } @@ -1409,6 +1470,12 @@ func (s *Stack) Wait() { for _, e := range s.CleanupEndpoints() { e.Wait() } + for _, p := range s.transportProtocols { + p.proto.Wait() + } + for _, p := range s.networkProtocols { + p.Wait() + } s.mu.RLock() defer s.mu.RUnlock() @@ -1614,6 +1681,18 @@ func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NIC return tcpip.ErrUnknownNICID } +// IsInGroup returns true if the NIC with ID nicID has joined the multicast +// group multicastAddr. +func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, *tcpip.Error) { + s.mu.RLock() + defer s.mu.RUnlock() + + if nic, ok := s.nics[nicID]; ok { + return nic.isInGroup(multicastAddr), nil + } + return false, tcpip.ErrUnknownNICID +} + // IPTables returns the stack's iptables. func (s *Stack) IPTables() iptables.IPTables { s.tablesMu.RLock() diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 7ba604442..e15db40fb 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -33,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -234,10 +235,33 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { } } +// Close implements TransportProtocol.Close. +func (*fakeNetworkProtocol) Close() {} + +// Wait implements TransportProtocol.Wait. +func (*fakeNetworkProtocol) Wait() {} + func fakeNetFactory() stack.NetworkProtocol { return &fakeNetworkProtocol{} } +// linkEPWithMockedAttach is a stack.LinkEndpoint that tests can use to verify +// that LinkEndpoint.Attach was called. +type linkEPWithMockedAttach struct { + stack.LinkEndpoint + attached bool +} + +// Attach implements stack.LinkEndpoint.Attach. +func (l *linkEPWithMockedAttach) Attach(d stack.NetworkDispatcher) { + l.LinkEndpoint.Attach(d) + l.attached = true +} + +func (l *linkEPWithMockedAttach) isAttached() bool { + return l.attached +} + func TestNetworkReceive(t *testing.T) { // Create a stack with the fake network protocol, one nic, and two // addresses attached to it: 1 & 2. @@ -509,6 +533,296 @@ func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr } } +// TestAttachToLinkEndpointImmediately tests that a LinkEndpoint is attached to +// a NetworkDispatcher when the NIC is created. +func TestAttachToLinkEndpointImmediately(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + nicOpts stack.NICOptions + }{ + { + name: "Create enabled NIC", + nicOpts: stack.NICOptions{Disabled: false}, + }, + { + name: "Create disabled NIC", + nicOpts: stack.NICOptions{Disabled: true}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + e := linkEPWithMockedAttach{ + LinkEndpoint: loopback.New(), + } + + if err := s.CreateNICWithOptions(nicID, &e, test.nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, test.nicOpts, err) + } + if !e.isAttached() { + t.Fatalf("link endpoint not attached to a network disatcher") + } + }) + } +} + +func TestDisableUnknownNIC(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID { + t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID) + } +} + +func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + e := loopback.New() + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err) + } + + checkNIC := func(enabled bool) { + t.Helper() + + allNICInfo := s.NICInfo() + nicInfo, ok := allNICInfo[nicID] + if !ok { + t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) + } else if nicInfo.Flags.Running != enabled { + t.Errorf("got nicInfo.Flags.Running = %t, want = %t", nicInfo.Flags.Running, enabled) + } + + if got := s.CheckNIC(nicID); got != enabled { + t.Errorf("got s.CheckNIC(%d) = %t, want = %t", nicID, got, enabled) + } + } + + // NIC should initially report itself as disabled. + checkNIC(false) + + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + checkNIC(true) + + // If the NIC is not reporting a correct enabled status, we cannot trust the + // next check so end the test here. + if t.Failed() { + t.FailNow() + } + + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + checkNIC(false) +} + +func TestRoutesWithDisabledNIC(t *testing.T) { + const unspecifiedNIC = 0 + const nicID1 = 1 + const nicID2 = 2 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + ep1 := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + + addr1 := tcpip.Address("\x01") + if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) + } + + ep2 := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID2, ep2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + + addr2 := tcpip.Address("\x02") + if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) + } + + // Set a route table that sends all packets with odd destination + // addresses through the first NIC, and all even destination address + // through the second one. + { + subnet0, err := tcpip.NewSubnet("\x00", "\x01") + if err != nil { + t.Fatal(err) + } + subnet1, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: nicID1}, + {Destination: subnet0, Gateway: "\x00", NIC: nicID2}, + }) + } + + // Test routes to odd address. + testRoute(t, s, unspecifiedNIC, "", "\x05", addr1) + testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1) + testRoute(t, s, nicID1, addr1, "\x05", addr1) + + // Test routes to even address. + testRoute(t, s, unspecifiedNIC, "", "\x06", addr2) + testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2) + testRoute(t, s, nicID2, addr2, "\x06", addr2) + + // Disabling NIC1 should result in no routes to odd addresses. Routes to even + // addresses should continue to be available as NIC2 is still enabled. + if err := s.DisableNIC(nicID1); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID1, err) + } + nic1Dst := tcpip.Address("\x05") + testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) + testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) + testNoRoute(t, s, nicID1, addr1, nic1Dst) + nic2Dst := tcpip.Address("\x06") + testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2) + testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2) + testRoute(t, s, nicID2, addr2, nic2Dst, addr2) + + // Disabling NIC2 should result in no routes to even addresses. No route + // should be available to any address as routes to odd addresses were made + // unavailable by disabling NIC1 above. + if err := s.DisableNIC(nicID2); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID2, err) + } + testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) + testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) + testNoRoute(t, s, nicID1, addr1, nic1Dst) + testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) + testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) + testNoRoute(t, s, nicID2, addr2, nic2Dst) + + // Enabling NIC1 should make routes to odd addresses available again. Routes + // to even addresses should continue to be unavailable as NIC2 is still + // disabled. + if err := s.EnableNIC(nicID1); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID1, err) + } + testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1) + testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1) + testRoute(t, s, nicID1, addr1, nic1Dst, addr1) + testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) + testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) + testNoRoute(t, s, nicID2, addr2, nic2Dst) +} + +func TestRouteWritePacketWithDisabledNIC(t *testing.T) { + const unspecifiedNIC = 0 + const nicID1 = 1 + const nicID2 = 2 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + ep1 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + + addr1 := tcpip.Address("\x01") + if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) + } + + ep2 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID2, ep2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + + addr2 := tcpip.Address("\x02") + if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) + } + + // Set a route table that sends all packets with odd destination + // addresses through the first NIC, and all even destination address + // through the second one. + { + subnet0, err := tcpip.NewSubnet("\x00", "\x01") + if err != nil { + t.Fatal(err) + } + subnet1, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: nicID1}, + {Destination: subnet0, Gateway: "\x00", NIC: nicID2}, + }) + } + + nic1Dst := tcpip.Address("\x05") + r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err) + } + defer r1.Release() + + nic2Dst := tcpip.Address("\x06") + r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err) + } + defer r2.Release() + + // If we failed to get routes r1 or r2, we cannot proceed with the test. + if t.Failed() { + t.FailNow() + } + + buf := buffer.View([]byte{1}) + testSend(t, r1, ep1, buf) + testSend(t, r2, ep2, buf) + + // Writes with Routes that use the disabled NIC1 should fail. + if err := s.DisableNIC(nicID1); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID1, err) + } + testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) + testSend(t, r2, ep2, buf) + + // Writes with Routes that use the disabled NIC2 should fail. + if err := s.DisableNIC(nicID2); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID2, err) + } + testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) + testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + + // Writes with Routes that use the re-enabled NIC1 should succeed. + // TODO(b/147015577): Should we instead completely invalidate all Routes that + // were bound to a disabled NIC at some point? + if err := s.EnableNIC(nicID1); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID1, err) + } + testSend(t, r1, ep1, buf) + testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) +} + func TestRoutes(t *testing.T) { // Create a stack with the fake network protocol, two nics, and two // addresses per nic, the first nic has odd address, the second one has @@ -2173,13 +2487,29 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { e := channel.New(0, 1280, test.linkAddr) s := stack.New(opts) - nicOpts := stack.NICOptions{Name: test.nicName} + nicOpts := stack.NICOptions{Name: test.nicName, Disabled: true} if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) } - var expectedMainAddr tcpip.AddressWithPrefix + // A new disabled NIC should not have any address, even if auto generation + // was enabled. + allStackAddrs := s.AllAddresses() + allNICAddrs, ok := allStackAddrs[nicID] + if !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } + if l := len(allNICAddrs); l != 0 { + t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) + } + + // Enabling the NIC should attempt auto-generation of a link-local + // address. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + var expectedMainAddr tcpip.AddressWithPrefix if test.shouldGen { expectedMainAddr = tcpip.AddressWithPrefix{ Address: test.expectedAddr, @@ -2460,13 +2790,14 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { const ( - linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - nicID = 1 + linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + nicID = 1 ) // Rule 3 is not tested here, and is instead tested by NDP's AutoGenAddr test. @@ -2540,6 +2871,18 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { expectedLocalAddr: linkLocalAddr1, }, { + name: "Link Local most preferred for link local multicast (last address)", + nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, + connectAddr: linkLocalMulticastAddr, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Link Local most preferred for link local multicast (first address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: linkLocalMulticastAddr, + expectedLocalAddr: linkLocalAddr1, + }, + { name: "Unique Local most preferred (last address)", nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1}, connectAddr: uniqueLocalAddr2, @@ -2609,6 +2952,111 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { } } +func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) { + const nicID = 1 + + e := loopback.New() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + }) + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) + } + + allStackAddrs := s.AllAddresses() + allNICAddrs, ok := allStackAddrs[nicID] + if !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } + if l := len(allNICAddrs); l != 0 { + t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) + } + + // Enabling the NIC should add the IPv4 broadcast address. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + allStackAddrs = s.AllAddresses() + allNICAddrs, ok = allStackAddrs[nicID] + if !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } + if l := len(allNICAddrs); l != 1 { + t.Fatalf("got len(allNICAddrs) = %d, want = 1", l) + } + want := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: header.IPv4Broadcast, + PrefixLen: 32, + }, + } + if allNICAddrs[0] != want { + t.Fatalf("got allNICAddrs[0] = %+v, want = %+v", allNICAddrs[0], want) + } + + // Disabling the NIC should remove the IPv4 broadcast address. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + allStackAddrs = s.AllAddresses() + allNICAddrs, ok = allStackAddrs[nicID] + if !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } + if l := len(allNICAddrs); l != 0 { + t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) + } +} + +func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) { + const nicID = 1 + + e := loopback.New() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + }) + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) + } + + // Should not be in the IPv6 all-nodes multicast group yet because the NIC has + // not been enabled yet. + isInGroup, err := s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) + if err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) + } + if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress) + } + + // The all-nodes multicast group should be joined when the NIC is enabled. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) + if err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) + } + if !isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, header.IPv6AllNodesMulticastAddress) + } + + // The all-nodes multicast group should be left when the NIC is disabled. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) + if err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) + } + if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress) + } +} + // TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC // was disabled have DAD performed on them when the NIC is enabled. func TestDoDADWhenNICEnabled(t *testing.T) { diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index d686e6eb8..778c0a4d6 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -306,26 +306,6 @@ func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, p ep.mu.RUnlock() // Don't use defer for performance reasons. } -// Close implements stack.TransportEndpoint.Close. -func (ep *multiPortEndpoint) Close() { - ep.mu.RLock() - eps := append([]TransportEndpoint(nil), ep.endpointsArr...) - ep.mu.RUnlock() - for _, e := range eps { - e.Close() - } -} - -// Wait implements stack.TransportEndpoint.Wait. -func (ep *multiPortEndpoint) Wait() { - ep.mu.RLock() - eps := append([]TransportEndpoint(nil), ep.endpointsArr...) - ep.mu.RUnlock() - for _, e := range eps { - e.Wait() - } -} - // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint // list. The list might be empty already. func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 869c69a6d..5d1da2f8b 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -61,6 +61,10 @@ func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netP return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} } +func (f *fakeTransportEndpoint) Abort() { + f.Close() +} + func (f *fakeTransportEndpoint) Close() { f.route.Release() } @@ -272,7 +276,7 @@ func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.N return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil } -func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (*fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return nil, tcpip.ErrUnknownProtocol } @@ -310,6 +314,15 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { } } +// Abort implements TransportProtocol.Abort. +func (*fakeTransportProtocol) Abort() {} + +// Close implements tcpip.Endpoint.Close. +func (*fakeTransportProtocol) Close() {} + +// Wait implements TransportProtocol.Wait. +func (*fakeTransportProtocol) Wait() {} + func fakeTransFactory() stack.TransportProtocol { return &fakeTransportProtocol{} } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index ce5527391..3dc5d87d6 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -341,9 +341,15 @@ type ControlMessages struct { // networking stack. type Endpoint interface { // Close puts the endpoint in a closed state and frees all resources - // associated with it. + // associated with it. Close initiates the teardown process, the + // Endpoint may not be fully closed when Close returns. Close() + // Abort initiates an expedited endpoint teardown. As compared to + // Close, Abort prioritizes closing the Endpoint quickly over cleanly. + // Abort is best effort; implementing Abort with Close is acceptable. + Abort() + // Read reads data from the endpoint and optionally returns the sender. // // This method does not block if there is no data pending. It will also diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go index 48764b978..2f98a996f 100644 --- a/pkg/tcpip/time_unsafe.go +++ b/pkg/tcpip/time_unsafe.go @@ -25,6 +25,8 @@ import ( ) // StdClock implements Clock with the time package. +// +// +stateify savable type StdClock struct{} var _ Clock = (*StdClock)(nil) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 42afb3f5b..426da1ee6 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -96,6 +96,11 @@ func (e *endpoint) UniqueID() uint64 { return e.uniqueID } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 9ce500e80..113d92901 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -104,20 +104,26 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { +func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { return true } -// SetOption implements TransportProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +// SetOption implements stack.TransportProtocol.SetOption. +func (*protocol) SetOption(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -// Option implements TransportProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +// Option implements stack.TransportProtocol.Option. +func (*protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + // NewProtocol4 returns an ICMPv4 transport protocol. func NewProtocol4() stack.TransportProtocol { return &protocol{ProtocolNumber4} diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index fc5bc69fa..5722815e9 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -98,6 +98,11 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb return ep, nil } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + // Close implements tcpip.Endpoint.Close. func (ep *endpoint) Close() { ep.mu.Lock() diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index ee9c4c58b..2ef5fac76 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -121,6 +121,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt return e, nil } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + // Close implements tcpip.Endpoint.Close. func (e *endpoint) Close() { e.mu.Lock() diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 08afb7c17..13e383ffc 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -299,6 +299,13 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) if err := h.execute(); err != nil { ep.Close() + // Wake up any waiters. This is strictly not required normally + // as a socket that was never accepted can't really have any + // registered waiters except when stack.Wait() is called which + // waits for all registered endpoints to stop and expects an + // EventHUp. + ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + if l.listenEP != nil { l.removePendingEndpoint(ep) } @@ -607,7 +614,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.mu.Unlock() // Notify waiters that the endpoint is shutdown. - e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut) + e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr) }() s := sleep.Sleeper{} diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 5c5397823..7730e6445 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -1372,7 +1372,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.snd.updateMaxPayloadSize(mtu, count) } - if n¬ifyReset != 0 { + if n¬ifyReset != 0 || n¬ifyAbort != 0 { return tcpip.ErrConnectionAborted } @@ -1655,7 +1655,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { } case notification: n := e.fetchNotifications() - if n¬ifyClose != 0 { + if n¬ifyClose != 0 || n¬ifyAbort != 0 { return nil } if n¬ifyDrain != 0 { diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index e18012ac0..d792b07d6 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -68,17 +68,28 @@ func (q *epQueue) empty() bool { type processor struct { epQ epQueue newEndpointWaker sleep.Waker + closeWaker sleep.Waker id int + wg sync.WaitGroup } func newProcessor(id int) *processor { p := &processor{ id: id, } + p.wg.Add(1) go p.handleSegments() return p } +func (p *processor) close() { + p.closeWaker.Assert() +} + +func (p *processor) wait() { + p.wg.Wait() +} + func (p *processor) queueEndpoint(ep *endpoint) { // Queue an endpoint for processing by the processor goroutine. p.epQ.enqueue(ep) @@ -87,11 +98,17 @@ func (p *processor) queueEndpoint(ep *endpoint) { func (p *processor) handleSegments() { const newEndpointWaker = 1 + const closeWaker = 2 s := sleep.Sleeper{} s.AddWaker(&p.newEndpointWaker, newEndpointWaker) + s.AddWaker(&p.closeWaker, closeWaker) defer s.Done() for { - s.Fetch(true) + id, ok := s.Fetch(true) + if ok && id == closeWaker { + p.wg.Done() + return + } for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() { if ep.segmentQueue.empty() { continue @@ -160,6 +177,18 @@ func newDispatcher(nProcessors int) *dispatcher { } } +func (d *dispatcher) close() { + for _, p := range d.processors { + p.close() + } +} + +func (d *dispatcher) wait() { + for _, p := range d.processors { + p.wait() + } +} + func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { ep := stackEP.(*endpoint) s := newSegment(r, id, pkt) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index f2be0e651..f1ad19dac 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -121,6 +121,8 @@ const ( notifyDrain notifyReset notifyResetByPeer + // notifyAbort is a request for an expedited teardown. + notifyAbort notifyKeepaliveChanged notifyMSSChanged // notifyTickleWorker is used to tickle the protocol main loop during a @@ -785,6 +787,24 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) { } } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + // The abort notification is not processed synchronously, so no + // synchronization is needed. + // + // If the endpoint becomes connected after this check, we still close + // the endpoint. This worst case results in a slower abort. + // + // If the endpoint disconnected after the check, nothing needs to be + // done, so sending a notification which will potentially be ignored is + // fine. + if e.EndpointState().connected() { + e.notifyProtocolGoroutine(notifyAbort) + return + } + e.Close() +} + // Close puts the endpoint in a closed state and frees all resources associated // with it. It must be called only once and with no other concurrent calls to // the endpoint. @@ -829,9 +849,18 @@ func (e *endpoint) closeNoShutdown() { // Either perform the local cleanup or kick the worker to make sure it // knows it needs to cleanup. tcpip.AddDanglingEndpoint(e) - if !e.workerRunning { + switch e.EndpointState() { + // Sockets in StateSynRecv state(passive connections) are closed when + // the handshake fails or if the listening socket is closed while + // handshake was in progress. In such cases the handshake goroutine + // is already gone by the time Close is called and we need to cleanup + // here. + case StateInitial, StateBound, StateSynRecv: e.cleanupLocked() - } else { + e.setEndpointState(StateClose) + case StateError, StateClose: + // do nothing. + default: e.workerCleanup = true e.notifyProtocolGoroutine(notifyClose) } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 958c06fa7..73098d904 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -194,7 +194,7 @@ func replyWithReset(s *segment) { sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) } -// SetOption implements TransportProtocol.SetOption. +// SetOption implements stack.TransportProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { switch v := option.(type) { case SACKEnabled: @@ -269,7 +269,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { } } -// Option implements TransportProtocol.Option. +// Option implements stack.TransportProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { switch v := option.(type) { case *SACKEnabled: @@ -331,6 +331,16 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { } } +// Close implements stack.TransportProtocol.Close. +func (p *protocol) Close() { + p.dispatcher.close() +} + +// Wait implements stack.TransportProtocol.Wait. +func (p *protocol) Wait() { + p.dispatcher.wait() +} + // NewProtocol returns a TCP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{ diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index cc118c993..5b2b16afa 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -543,8 +543,9 @@ func TestCurrentConnectedIncrement(t *testing.T) { ), ) - // Wait for the TIME-WAIT state to transition to CLOSED. - time.Sleep(1 * time.Second) + // Wait for a little more than the TIME-WAIT duration for the socket to + // transition to CLOSED state. + time.Sleep(1200 * time.Millisecond) if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index eff7f3600..1c6a600b8 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -186,6 +186,11 @@ func (e *endpoint) UniqueID() uint64 { return e.uniqueID } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 259c3072a..8df089d22 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -180,16 +180,22 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans return true } -// SetOption implements TransportProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +// SetOption implements stack.TransportProtocol.SetOption. +func (*protocol) SetOption(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -// Option implements TransportProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +// Option implements stack.TransportProtocol.Option. +func (*protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + // NewProtocol returns a UDP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{} diff --git a/pkg/usermem/BUILD b/pkg/usermem/BUILD index ff8b9e91a..6c9ada9c7 100644 --- a/pkg/usermem/BUILD +++ b/pkg/usermem/BUILD @@ -25,7 +25,6 @@ go_library( "bytes_io_unsafe.go", "usermem.go", "usermem_arm64.go", - "usermem_unsafe.go", "usermem_x86.go", ], visibility = ["//:sandbox"], @@ -33,6 +32,7 @@ go_library( "//pkg/atomicbitops", "//pkg/binary", "//pkg/context", + "//pkg/gohacks", "//pkg/log", "//pkg/safemem", "//pkg/syserror", diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go index 71fd4e155..d2f4403b0 100644 --- a/pkg/usermem/usermem.go +++ b/pkg/usermem/usermem.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/gohacks" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/syserror" ) @@ -251,7 +252,7 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt } end, ok := addr.AddLength(uint64(readlen)) if !ok { - return stringFromImmutableBytes(buf[:done]), syserror.EFAULT + return gohacks.StringFromImmutableBytes(buf[:done]), syserror.EFAULT } // Shorten the read to avoid crossing page boundaries, since faulting // in a page unnecessarily is expensive. This also ensures that partial @@ -272,16 +273,16 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt // Look for the terminating zero byte, which may have occurred before // hitting err. if i := bytes.IndexByte(buf[done:done+n], byte(0)); i >= 0 { - return stringFromImmutableBytes(buf[:done+i]), nil + return gohacks.StringFromImmutableBytes(buf[:done+i]), nil } done += n if err != nil { - return stringFromImmutableBytes(buf[:done]), err + return gohacks.StringFromImmutableBytes(buf[:done]), err } addr = end } - return stringFromImmutableBytes(buf), syserror.ENAMETOOLONG + return gohacks.StringFromImmutableBytes(buf), syserror.ENAMETOOLONG } // CopyOutVec copies bytes from src to the memory mapped at ars in uio. The |