diff options
Diffstat (limited to 'pkg')
157 files changed, 3741 insertions, 1406 deletions
diff --git a/pkg/abi/linux/ip.go b/pkg/abi/linux/ip.go index 31e56ffa6..ef6d1093e 100644 --- a/pkg/abi/linux/ip.go +++ b/pkg/abi/linux/ip.go @@ -92,6 +92,16 @@ const ( IP_UNICAST_IF = 50 ) +// IP_MTU_DISCOVER values from uapi/linux/in.h +const ( + IP_PMTUDISC_DONT = 0 + IP_PMTUDISC_WANT = 1 + IP_PMTUDISC_DO = 2 + IP_PMTUDISC_PROBE = 3 + IP_PMTUDISC_INTERFACE = 4 + IP_PMTUDISC_OMIT = 5 +) + // Socket options from uapi/linux/in6.h const ( IPV6_ADDRFORM = 1 diff --git a/pkg/buffer/safemem.go b/pkg/buffer/safemem.go index 0e5b86344..b789e56e9 100644 --- a/pkg/buffer/safemem.go +++ b/pkg/buffer/safemem.go @@ -28,12 +28,11 @@ func (b *buffer) ReadBlock() safemem.Block { return safemem.BlockFromSafeSlice(b.ReadSlice()) } -// WriteFromBlocks implements safemem.Writer.WriteFromBlocks. -// -// This will advance the write index. -func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { - need := int(srcs.NumBytes()) - if need == 0 { +// WriteFromSafememReader writes up to count bytes from r to v and advances the +// write index by the number of bytes written. It calls r.ReadToBlocks() at +// most once. +func (v *View) WriteFromSafememReader(r safemem.Reader, count uint64) (uint64, error) { + if count == 0 { return 0, nil } @@ -50,32 +49,33 @@ func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { } // Does the last block have sufficient capacity alone? - if l := firstBuf.WriteSize(); l >= need { - dst = safemem.BlockSeqOf(firstBuf.WriteBlock()) + if l := uint64(firstBuf.WriteSize()); l >= count { + dst = safemem.BlockSeqOf(firstBuf.WriteBlock().TakeFirst64(count)) } else { // Append blocks until sufficient. - need -= l + count -= l blocks = append(blocks, firstBuf.WriteBlock()) - for need > 0 { + for count > 0 { emptyBuf := bufferPool.Get().(*buffer) v.data.PushBack(emptyBuf) - need -= emptyBuf.WriteSize() - blocks = append(blocks, emptyBuf.WriteBlock()) + block := emptyBuf.WriteBlock().TakeFirst64(count) + count -= uint64(block.Len()) + blocks = append(blocks, block) } dst = safemem.BlockSeqFromSlice(blocks) } - // Perform the copy. - n, err := safemem.CopySeq(dst, srcs) + // Perform I/O. + n, err := r.ReadToBlocks(dst) v.size += int64(n) // Update all indices. - for left := int(n); left > 0; firstBuf = firstBuf.Next() { - if l := firstBuf.WriteSize(); left >= l { + for left := n; left > 0; firstBuf = firstBuf.Next() { + if l := firstBuf.WriteSize(); left >= uint64(l) { firstBuf.WriteMove(l) // Whole block. - left -= l + left -= uint64(l) } else { - firstBuf.WriteMove(left) // Partial block. + firstBuf.WriteMove(int(left)) // Partial block. left = 0 } } @@ -83,14 +83,16 @@ func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { return n, err } -// ReadToBlocks implements safemem.Reader.ReadToBlocks. -// -// This will not advance the read index; the caller should follow -// this call with a call to TrimFront in order to remove the read -// data from the buffer. This is done to support pipe sematics. -func (v *View) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { - need := int(dsts.NumBytes()) - if need == 0 { +// WriteFromBlocks implements safemem.Writer.WriteFromBlocks. It advances the +// write index by the number of bytes written. +func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { + return v.WriteFromSafememReader(&safemem.BlockSeqReader{srcs}, srcs.NumBytes()) +} + +// ReadToSafememWriter reads up to count bytes from v to w. It does not advance +// the read index. It calls w.WriteFromBlocks() at most once. +func (v *View) ReadToSafememWriter(w safemem.Writer, count uint64) (uint64, error) { + if count == 0 { return 0, nil } @@ -105,25 +107,27 @@ func (v *View) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { } // Is all the data in a single block? - if l := firstBuf.ReadSize(); l >= need { - src = safemem.BlockSeqOf(firstBuf.ReadBlock()) + if l := uint64(firstBuf.ReadSize()); l >= count { + src = safemem.BlockSeqOf(firstBuf.ReadBlock().TakeFirst64(count)) } else { // Build a list of all the buffers. - need -= l + count -= l blocks = append(blocks, firstBuf.ReadBlock()) - for buf := firstBuf.Next(); buf != nil && need > 0; buf = buf.Next() { - need -= buf.ReadSize() - blocks = append(blocks, buf.ReadBlock()) + for buf := firstBuf.Next(); buf != nil && count > 0; buf = buf.Next() { + block := buf.ReadBlock().TakeFirst64(count) + count -= uint64(block.Len()) + blocks = append(blocks, block) } src = safemem.BlockSeqFromSlice(blocks) } - // Perform the copy. - n, err := safemem.CopySeq(dsts, src) - - // See above: we would normally advance the read index here, but we - // don't do that in order to support pipe semantics. We rely on a - // separate call to TrimFront() in this case. + // Perform I/O. As documented, we don't advance the read index. + return w.WriteFromBlocks(src) +} - return n, err +// ReadToBlocks implements safemem.Reader.ReadToBlocks. It does not advance the +// read index by the number of bytes read, such that it's only safe to call if +// the caller guarantees that ReadToBlocks will only be called once. +func (v *View) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { + return v.ReadToSafememWriter(&safemem.BlockSeqWriter{dsts}, dsts.NumBytes()) } diff --git a/pkg/cleanup/BUILD b/pkg/cleanup/BUILD new file mode 100644 index 000000000..5c34b9872 --- /dev/null +++ b/pkg/cleanup/BUILD @@ -0,0 +1,17 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "cleanup", + srcs = ["cleanup.go"], + visibility = ["//:sandbox"], + deps = [ + ], +) + +go_test( + name = "cleanup_test", + srcs = ["cleanup_test.go"], + library = ":cleanup", +) diff --git a/pkg/cleanup/cleanup.go b/pkg/cleanup/cleanup.go new file mode 100644 index 000000000..14a05f076 --- /dev/null +++ b/pkg/cleanup/cleanup.go @@ -0,0 +1,60 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cleanup provides utilities to clean "stuff" on defers. +package cleanup + +// Cleanup allows defers to be aborted when cleanup needs to happen +// conditionally. Usage: +// cu := cleanup.Make(func() { f.Close() }) +// defer cu.Clean() // failure before release is called will close the file. +// ... +// cu.Add(func() { f2.Close() }) // Adds another cleanup function +// ... +// cu.Release() // on success, aborts closing the file. +// return f +type Cleanup struct { + cleaners []func() +} + +// Make creates a new Cleanup object. +func Make(f func()) Cleanup { + return Cleanup{cleaners: []func(){f}} +} + +// Add adds a new function to be called on Clean(). +func (c *Cleanup) Add(f func()) { + c.cleaners = append(c.cleaners, f) +} + +// Clean calls all cleanup functions in reverse order. +func (c *Cleanup) Clean() { + clean(c.cleaners) + c.cleaners = nil +} + +// Release releases the cleanup from its duties, i.e. cleanup functions are not +// called after this point. Returns a function that calls all registered +// functions in case the caller has use for them. +func (c *Cleanup) Release() func() { + old := c.cleaners + c.cleaners = nil + return func() { clean(old) } +} + +func clean(cleaners []func()) { + for i := len(cleaners) - 1; i >= 0; i-- { + cleaners[i]() + } +} diff --git a/pkg/cleanup/cleanup_test.go b/pkg/cleanup/cleanup_test.go new file mode 100644 index 000000000..ab3d9ed95 --- /dev/null +++ b/pkg/cleanup/cleanup_test.go @@ -0,0 +1,66 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cleanup + +import "testing" + +func testCleanupHelper(clean, cleanAdd *bool, release bool) func() { + cu := Make(func() { + *clean = true + }) + cu.Add(func() { + *cleanAdd = true + }) + defer cu.Clean() + if release { + return cu.Release() + } + return nil +} + +func TestCleanup(t *testing.T) { + clean := false + cleanAdd := false + testCleanupHelper(&clean, &cleanAdd, false) + if !clean { + t.Fatalf("cleanup function was not called.") + } + if !cleanAdd { + t.Fatalf("added cleanup function was not called.") + } +} + +func TestRelease(t *testing.T) { + clean := false + cleanAdd := false + cleaner := testCleanupHelper(&clean, &cleanAdd, true) + + // Check that clean was not called after release. + if clean { + t.Fatalf("cleanup function was called.") + } + if cleanAdd { + t.Fatalf("added cleanup function was called.") + } + + // Call the cleaner function and check that both cleanup functions are called. + cleaner() + if !clean { + t.Fatalf("cleanup function was not called.") + } + if !cleanAdd { + t.Fatalf("added cleanup function was not called.") + } +} diff --git a/pkg/procid/procid_amd64.s b/pkg/procid/procid_amd64.s index 38cea9be3..7c622e5d7 100644 --- a/pkg/procid/procid_amd64.s +++ b/pkg/procid/procid_amd64.s @@ -14,7 +14,7 @@ // +build amd64 // +build go1.8 -// +build !go1.15 +// +build !go1.16 #include "textflag.h" diff --git a/pkg/procid/procid_arm64.s b/pkg/procid/procid_arm64.s index 4f4b70fef..48ebb5fd1 100644 --- a/pkg/procid/procid_arm64.s +++ b/pkg/procid/procid_arm64.s @@ -14,7 +14,7 @@ // +build arm64 // +build go1.8 -// +build !go1.15 +// +build !go1.16 #include "textflag.h" diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD index e74275d2d..0c9a62f0d 100644 --- a/pkg/sentry/control/BUILD +++ b/pkg/sentry/control/BUILD @@ -23,6 +23,7 @@ go_library( "//pkg/sentry/fdimport", "//pkg/sentry/fs", "//pkg/sentry/fs/host", + "//pkg/sentry/fs/user", "//pkg/sentry/fsbridge", "//pkg/sentry/fsimpl/host", "//pkg/sentry/kernel", diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go index 2ed17ee09..8767430b7 100644 --- a/pkg/sentry/control/proc.go +++ b/pkg/sentry/control/proc.go @@ -18,7 +18,6 @@ import ( "bytes" "encoding/json" "fmt" - "path" "sort" "strings" "text/tabwriter" @@ -28,10 +27,10 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fdimport" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/host" + "gvisor.dev/gvisor/pkg/sentry/fs/user" "gvisor.dev/gvisor/pkg/sentry/fsbridge" hostvfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/host" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -190,17 +189,12 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI // transferred to the new process. initArgs.MountNamespaceVFS2 = proc.Kernel.GlobalInit().Leader().MountNamespaceVFS2() } - - paths := fs.GetPath(initArgs.Envv) - vfsObj := proc.Kernel.VFS() - file, err := ResolveExecutablePath(ctx, vfsObj, initArgs.WorkingDirectory, initArgs.Argv[0], paths) + file, err := getExecutableFD(ctx, creds, proc.Kernel.VFS(), initArgs.MountNamespaceVFS2, initArgs.Envv, initArgs.WorkingDirectory, initArgs.Argv[0]) if err != nil { - return nil, 0, nil, nil, fmt.Errorf("error finding executable %q in PATH %v: %v", initArgs.Argv[0], paths, err) + return nil, 0, nil, nil, fmt.Errorf("error finding executable %q in environment %v: %v", initArgs.Argv[0], initArgs.Envv, err) } initArgs.File = fsbridge.NewVFSFile(file) } else { - // Get the full path to the filename from the PATH env variable. - paths := fs.GetPath(initArgs.Envv) if initArgs.MountNamespace == nil { // Set initArgs so that 'ctx' returns the namespace. initArgs.MountNamespace = proc.Kernel.GlobalInit().Leader().MountNamespace() @@ -209,9 +203,9 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI // be donated to the new process in CreateProcess. initArgs.MountNamespace.IncRef() } - f, err := initArgs.MountNamespace.ResolveExecutablePath(ctx, initArgs.WorkingDirectory, initArgs.Argv[0], paths) + f, err := user.ResolveExecutablePath(ctx, creds, initArgs.MountNamespace, initArgs.Envv, initArgs.WorkingDirectory, initArgs.Argv[0]) if err != nil { - return nil, 0, nil, nil, fmt.Errorf("error finding executable %q in PATH %v: %v", initArgs.Argv[0], paths, err) + return nil, 0, nil, nil, fmt.Errorf("error finding executable %q in PATH %v: %v", initArgs.Argv[0], initArgs.Envv, err) } initArgs.Filename = f } @@ -429,53 +423,17 @@ func ttyName(tty *kernel.TTY) string { return fmt.Sprintf("pts/%d", tty.Index) } -// ResolveExecutablePath resolves the given executable name given a set of -// paths that might contain it. -func ResolveExecutablePath(ctx context.Context, vfsObj *vfs.VirtualFilesystem, wd, name string, paths []string) (*vfs.FileDescription, error) { - root := vfs.RootFromContext(ctx) - defer root.DecRef() - creds := auth.CredentialsFromContext(ctx) - - // Absolute paths can be used directly. - if path.IsAbs(name) { - return openExecutable(ctx, vfsObj, creds, root, name) - } - - // Paths with '/' in them should be joined to the working directory, or - // to the root if working directory is not set. - if strings.IndexByte(name, '/') > 0 { - if len(wd) == 0 { - wd = "/" - } - if !path.IsAbs(wd) { - return nil, fmt.Errorf("working directory %q must be absolute", wd) - } - return openExecutable(ctx, vfsObj, creds, root, path.Join(wd, name)) +// getExecutableFD resolves the given executable name and returns a +// vfs.FileDescription for the executable file. +func getExecutableFD(ctx context.Context, creds *auth.Credentials, vfsObj *vfs.VirtualFilesystem, mns *vfs.MountNamespace, envv []string, wd, name string) (*vfs.FileDescription, error) { + path, err := user.ResolveExecutablePathVFS2(ctx, creds, mns, envv, wd, name) + if err != nil { + return nil, err } - // Otherwise, we must lookup the name in the paths, starting from the - // calling context's root directory. - for _, p := range paths { - if !path.IsAbs(p) { - // Relative paths aren't safe, no one should be using them. - log.Warningf("Skipping relative path %q in $PATH", p) - continue - } - - binPath := path.Join(p, name) - f, err := openExecutable(ctx, vfsObj, creds, root, binPath) - if err != nil { - return nil, err - } - if f == nil { - continue // Not found/no access. - } - return f, nil - } - return nil, syserror.ENOENT -} + root := vfs.RootFromContext(ctx) + defer root.DecRef() -func openExecutable(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, root vfs.VirtualDentry, path string) (*vfs.FileDescription, error) { pop := vfs.PathOperation{ Root: root, Start: root, // binPath is absolute, Start can be anything. diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go index 846252c89..2a278fbe3 100644 --- a/pkg/sentry/fs/file.go +++ b/pkg/sentry/fs/file.go @@ -310,7 +310,6 @@ func (f *File) Writev(ctx context.Context, src usermem.IOSequence) (int64, error if !f.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - unlockAppendMu := f.Dirent.Inode.lockAppendMu(f.Flags().Append) // Handle append mode. if f.Flags().Append { @@ -355,7 +354,6 @@ func (f *File) Pwritev(ctx context.Context, src usermem.IOSequence, offset int64 // offset." unlockAppendMu := f.Dirent.Inode.lockAppendMu(f.Flags().Append) defer unlockAppendMu() - if f.Flags().Append { if err := f.offsetForAppend(ctx, &offset); err != nil { return 0, err @@ -374,9 +372,10 @@ func (f *File) Pwritev(ctx context.Context, src usermem.IOSequence, offset int64 return f.FileOperations.Write(ctx, f, src, offset) } -// offsetForAppend sets the given offset to the end of the file. +// offsetForAppend atomically sets the given offset to the end of the file. // -// Precondition: the file.Dirent.Inode.appendMu mutex should be held for writing. +// Precondition: the file.Dirent.Inode.appendMu mutex should be held for +// writing. func (f *File) offsetForAppend(ctx context.Context, offset *int64) error { uattr, err := f.Dirent.Inode.UnstableAttr(ctx) if err != nil { @@ -386,7 +385,7 @@ func (f *File) offsetForAppend(ctx context.Context, offset *int64) error { } // Update the offset. - *offset = uattr.Size + atomic.StoreInt64(offset, uattr.Size) return nil } diff --git a/pkg/sentry/fs/fs.go b/pkg/sentry/fs/fs.go index bdba6efe5..d2dbff268 100644 --- a/pkg/sentry/fs/fs.go +++ b/pkg/sentry/fs/fs.go @@ -42,9 +42,10 @@ // Dirent.dirMu // Dirent.mu // DirentCache.mu -// Locks in InodeOperations implementations or overlayEntry // Inode.Watches.mu (see `Inotify` for other lock ordering) // MountSource.mu +// Inode.appendMu +// Locks in InodeOperations implementations or overlayEntry // // If multiple Dirent or MountSource locks must be taken, locks in the parent must be // taken before locks in their children. diff --git a/pkg/sentry/fs/g3doc/fuse.md b/pkg/sentry/fs/g3doc/fuse.md index 635cc009b..2ca84dd74 100644 --- a/pkg/sentry/fs/g3doc/fuse.md +++ b/pkg/sentry/fs/g3doc/fuse.md @@ -76,7 +76,8 @@ ops can be implemented in parallel. #### Minimal client that can mount a trivial FUSE filesystem. -- Implement `/dev/fuse`. +- Implement `/dev/fuse` - a character device used to establish an FD for + communication between the sentry and the server daemon. - Implement basic FUSE ops like `FUSE_INIT`, `FUSE_DESTROY`. @@ -99,7 +100,7 @@ ops can be implemented in parallel. ## FUSE Protocol The FUSE protocol is a request-response protocol. All requests are initiated by -the client. The wire-format for the protocol is raw c structs serialized to +the client. The wire-format for the protocol is raw C structs serialized to memory. All FUSE requests begin with the following request header: @@ -255,6 +256,8 @@ I/O syscalls like `read(2)`, `write(2)` and `mmap(2)`. # References -- `fuse(4)` manpage. -- Linux kernel FUSE documentation: - https://www.kernel.org/doc/html/latest/filesystems/fuse.html +- [fuse(4) Linux manual page](https://www.man7.org/linux/man-pages/man4/fuse.4.html) +- [Linux kernel FUSE documentation](https://www.kernel.org/doc/html/latest/filesystems/fuse.html) +- [The reference implementation of the Linux FUSE (Filesystem in Userspace) + interface](https://github.com/libfuse/libfuse) +- [The kernel interface of FUSE](https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/include/uapi/linux/fuse.h) diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go index b414ddaee..3f2bd0e87 100644 --- a/pkg/sentry/fs/mounts.go +++ b/pkg/sentry/fs/mounts.go @@ -17,13 +17,9 @@ package fs import ( "fmt" "math" - "path" - "strings" "syscall" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sync" @@ -625,71 +621,3 @@ func (mns *MountNamespace) SyncAll(ctx context.Context) { defer mns.mu.Unlock() mns.root.SyncAll(ctx) } - -// ResolveExecutablePath resolves the given executable name given a set of -// paths that might contain it. -func (mns *MountNamespace) ResolveExecutablePath(ctx context.Context, wd, name string, paths []string) (string, error) { - // Absolute paths can be used directly. - if path.IsAbs(name) { - return name, nil - } - - // Paths with '/' in them should be joined to the working directory, or - // to the root if working directory is not set. - if strings.IndexByte(name, '/') > 0 { - if wd == "" { - wd = "/" - } - if !path.IsAbs(wd) { - return "", fmt.Errorf("working directory %q must be absolute", wd) - } - return path.Join(wd, name), nil - } - - // Otherwise, We must lookup the name in the paths, starting from the - // calling context's root directory. - root := RootFromContext(ctx) - if root == nil { - // Caller has no root. Don't bother traversing anything. - return "", syserror.ENOENT - } - defer root.DecRef() - for _, p := range paths { - binPath := path.Join(p, name) - traversals := uint(linux.MaxSymlinkTraversals) - d, err := mns.FindInode(ctx, root, nil, binPath, &traversals) - if err == syserror.ENOENT || err == syserror.EACCES { - // Didn't find it here. - continue - } - if err != nil { - return "", err - } - defer d.DecRef() - - // Check that it is a regular file. - if !IsRegular(d.Inode.StableAttr) { - continue - } - - // Check whether we can read and execute the found file. - if err := d.Inode.CheckPermission(ctx, PermMask{Read: true, Execute: true}); err != nil { - log.Infof("Found executable at %q, but user cannot execute it: %v", binPath, err) - continue - } - return path.Join("/", p, name), nil - } - return "", syserror.ENOENT -} - -// GetPath returns the PATH as a slice of strings given the environment -// variables. -func GetPath(env []string) []string { - const prefix = "PATH=" - for _, e := range env { - if strings.HasPrefix(e, prefix) { - return strings.Split(strings.TrimPrefix(e, prefix), ":") - } - } - return nil -} diff --git a/pkg/sentry/fs/user/BUILD b/pkg/sentry/fs/user/BUILD index f37f979f1..bd5dac373 100644 --- a/pkg/sentry/fs/user/BUILD +++ b/pkg/sentry/fs/user/BUILD @@ -4,15 +4,20 @@ package(licenses = ["notice"]) go_library( name = "user", - srcs = ["user.go"], + srcs = [ + "path.go", + "user.go", + ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", "//pkg/context", "//pkg/fspath", + "//pkg/log", "//pkg/sentry/fs", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", + "//pkg/syserror", "//pkg/usermem", ], ) diff --git a/pkg/sentry/fs/user/path.go b/pkg/sentry/fs/user/path.go new file mode 100644 index 000000000..fbd4547a7 --- /dev/null +++ b/pkg/sentry/fs/user/path.go @@ -0,0 +1,169 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package user + +import ( + "fmt" + "path" + "strings" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// ResolveExecutablePath resolves the given executable name given the working +// dir and environment. +func ResolveExecutablePath(ctx context.Context, creds *auth.Credentials, mns *fs.MountNamespace, envv []string, wd, name string) (string, error) { + // Absolute paths can be used directly. + if path.IsAbs(name) { + return name, nil + } + + // Paths with '/' in them should be joined to the working directory, or + // to the root if working directory is not set. + if strings.IndexByte(name, '/') > 0 { + if wd == "" { + wd = "/" + } + if !path.IsAbs(wd) { + return "", fmt.Errorf("working directory %q must be absolute", wd) + } + return path.Join(wd, name), nil + } + + // Otherwise, We must lookup the name in the paths, starting from the + // calling context's root directory. + paths := getPath(envv) + + root := fs.RootFromContext(ctx) + if root == nil { + // Caller has no root. Don't bother traversing anything. + return "", syserror.ENOENT + } + defer root.DecRef() + for _, p := range paths { + if !path.IsAbs(p) { + // Relative paths aren't safe, no one should be using them. + log.Warningf("Skipping relative path %q in $PATH", p) + continue + } + + binPath := path.Join(p, name) + traversals := uint(linux.MaxSymlinkTraversals) + d, err := mns.FindInode(ctx, root, nil, binPath, &traversals) + if err == syserror.ENOENT || err == syserror.EACCES { + // Didn't find it here. + continue + } + if err != nil { + return "", err + } + defer d.DecRef() + + // Check that it is a regular file. + if !fs.IsRegular(d.Inode.StableAttr) { + continue + } + + // Check whether we can read and execute the found file. + if err := d.Inode.CheckPermission(ctx, fs.PermMask{Read: true, Execute: true}); err != nil { + log.Infof("Found executable at %q, but user cannot execute it: %v", binPath, err) + continue + } + return path.Join("/", p, name), nil + } + + // Couldn't find it. + return "", syserror.ENOENT +} + +// ResolveExecutablePathVFS2 resolves the given executable name given the +// working dir and environment. +func ResolveExecutablePathVFS2(ctx context.Context, creds *auth.Credentials, mns *vfs.MountNamespace, envv []string, wd, name string) (string, error) { + // Absolute paths can be used directly. + if path.IsAbs(name) { + return name, nil + } + + // Paths with '/' in them should be joined to the working directory, or + // to the root if working directory is not set. + if strings.IndexByte(name, '/') > 0 { + if wd == "" { + wd = "/" + } + if !path.IsAbs(wd) { + return "", fmt.Errorf("working directory %q must be absolute", wd) + } + return path.Join(wd, name), nil + } + + // Otherwise, We must lookup the name in the paths, starting from the + // calling context's root directory. + paths := getPath(envv) + + root := mns.Root() + defer root.DecRef() + for _, p := range paths { + if !path.IsAbs(p) { + // Relative paths aren't safe, no one should be using them. + log.Warningf("Skipping relative path %q in $PATH", p) + continue + } + + binPath := path.Join(p, name) + pop := &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(binPath), + FollowFinalSymlink: true, + } + opts := &vfs.OpenOptions{ + FileExec: true, + Flags: linux.O_RDONLY, + } + dentry, err := root.Mount().Filesystem().VirtualFilesystem().OpenAt(ctx, creds, pop, opts) + if err == syserror.ENOENT || err == syserror.EACCES { + // Didn't find it here. + continue + } + if err != nil { + return "", err + } + dentry.DecRef() + + return binPath, nil + } + + // Couldn't find it. + return "", syserror.ENOENT +} + +// getPath returns the PATH as a slice of strings given the environment +// variables. +func getPath(env []string) []string { + const prefix = "PATH=" + for _, e := range env { + if strings.HasPrefix(e, prefix) { + return strings.Split(strings.TrimPrefix(e, prefix), ":") + } + } + return nil +} diff --git a/pkg/sentry/fs/user/user.go b/pkg/sentry/fs/user/user.go index fe7f67c00..f4d525523 100644 --- a/pkg/sentry/fs/user/user.go +++ b/pkg/sentry/fs/user/user.go @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package user contains methods for resolving filesystem paths based on the +// user and their environment. package user import ( diff --git a/pkg/sentry/fsimpl/ext/dentry.go b/pkg/sentry/fsimpl/ext/dentry.go index bfbd7c3d4..6bd1a9fc6 100644 --- a/pkg/sentry/fsimpl/ext/dentry.go +++ b/pkg/sentry/fsimpl/ext/dentry.go @@ -60,3 +60,15 @@ func (d *dentry) DecRef() { // inode.decRef(). d.inode.decRef() } + +// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. +// +// TODO(gvisor.dev/issue/1479): Implement inotify. +func (d *dentry) InotifyWithParent(events uint32, cookie uint32, et vfs.EventType) {} + +// Watches implements vfs.DentryImpl.Watches. +// +// TODO(gvisor.dev/issue/1479): Implement inotify. +func (d *dentry) Watches() *vfs.Watches { + return nil +} diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD index 67e916525..f5f35a3bc 100644 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ b/pkg/sentry/fsimpl/gofer/BUILD @@ -35,6 +35,7 @@ go_library( "fstree.go", "gofer.go", "handle.go", + "host_named_pipe.go", "p9file.go", "regular_file.go", "socket.go", @@ -47,6 +48,7 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/fd", + "//pkg/fdnotifier", "//pkg/fspath", "//pkg/log", "//pkg/p9", @@ -71,6 +73,7 @@ go_library( "//pkg/unet", "//pkg/usermem", "//pkg/waiter", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 7f2181216..36e0e1856 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -760,7 +760,7 @@ afterTrailingSymlink: parent.dirMu.Unlock() return nil, syserror.EPERM } - fd, err := parent.createAndOpenChildLocked(ctx, rp, &opts) + fd, err := parent.createAndOpenChildLocked(ctx, rp, &opts, &ds) parent.dirMu.Unlock() return fd, err } @@ -873,19 +873,37 @@ func (d *dentry) openSpecialFileLocked(ctx context.Context, mnt *vfs.Mount, opts if opts.Flags&linux.O_DIRECT != 0 { return nil, syserror.EINVAL } - h, err := openHandle(ctx, d.file, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, opts.Flags&linux.O_TRUNC != 0) + // We assume that the server silently inserts O_NONBLOCK in the open flags + // for all named pipes (because all existing gofers do this). + // + // NOTE(b/133875563): This makes named pipe opens racy, because the + // mechanisms for translating nonblocking to blocking opens can only detect + // the instantaneous presence of a peer holding the other end of the pipe + // open, not whether the pipe was *previously* opened by a peer that has + // since closed its end. + isBlockingOpenOfNamedPipe := d.fileType() == linux.S_IFIFO && opts.Flags&linux.O_NONBLOCK == 0 +retry: + h, err := openHandle(ctx, d.file, ats.MayRead(), ats.MayWrite(), opts.Flags&linux.O_TRUNC != 0) if err != nil { + if isBlockingOpenOfNamedPipe && ats == vfs.MayWrite && err == syserror.ENXIO { + // An attempt to open a named pipe with O_WRONLY|O_NONBLOCK fails + // with ENXIO if opening the same named pipe with O_WRONLY would + // block because there are no readers of the pipe. + if err := sleepBetweenNamedPipeOpenChecks(ctx); err != nil { + return nil, err + } + goto retry + } return nil, err } - seekable := d.fileType() == linux.S_IFREG - fd := &specialFileFD{ - handle: h, - seekable: seekable, + if isBlockingOpenOfNamedPipe && ats == vfs.MayRead && h.fd >= 0 { + if err := blockUntilNonblockingPipeHasWriter(ctx, h.fd); err != nil { + h.close(ctx) + return nil, err + } } - if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{ - DenyPRead: !seekable, - DenyPWrite: !seekable, - }); err != nil { + fd, err := newSpecialFileFD(h, mnt, d, opts.Flags) + if err != nil { h.close(ctx) return nil, err } @@ -894,7 +912,7 @@ func (d *dentry) openSpecialFileLocked(ctx context.Context, mnt *vfs.Mount, opts // Preconditions: d.fs.renameMu must be locked. d.dirMu must be locked. // !d.isSynthetic(). -func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { +func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions, ds **[]*dentry) (*vfs.FileDescription, error) { if err := d.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil { return nil, err } @@ -947,6 +965,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving } return nil, err } + *ds = appendDentry(*ds, child) // Incorporate the fid that was opened by lcreate. useRegularFileFD := child.fileType() == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD if useRegularFileFD { @@ -959,10 +978,6 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving child.handleWritable = vfs.MayWriteFileWithOpenFlags(opts.Flags) child.handleMu.Unlock() } - // Take a reference on the new dentry to be held by the new file - // description. (This reference also means that the new dentry is not - // eligible for caching yet, so we don't need to append to a dentry slice.) - child.refs = 1 // Insert the dentry into the tree. d.cacheNewChildLocked(child, name) if d.cachedMetadataAuthoritative() { @@ -981,22 +996,16 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving } childVFSFD = &fd.vfsfd } else { - seekable := child.fileType() == linux.S_IFREG - fd := &specialFileFD{ - handle: handle{ - file: openFile, - fd: -1, - }, - seekable: seekable, + h := handle{ + file: openFile, + fd: -1, } if fdobj != nil { - fd.handle.fd = int32(fdobj.Release()) + h.fd = int32(fdobj.Release()) } - if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &child.vfsd, &vfs.FileDescriptionOptions{ - DenyPRead: !seekable, - DenyPWrite: !seekable, - }); err != nil { - fd.handle.close(ctx) + fd, err := newSpecialFileFD(h, mnt, child, opts.Flags) + if err != nil { + h.close(ctx) return nil, err } childVFSFD = &fd.vfsfd diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 6295f6b54..3f3bd56f0 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -84,12 +84,6 @@ type filesystem struct { // devMinor is the filesystem's minor device number. devMinor is immutable. devMinor uint32 - // uid and gid are the effective KUID and KGID of the filesystem's creator, - // and are used as the owner and group for files that don't specify one. - // uid and gid are immutable. - uid auth.KUID - gid auth.KGID - // renameMu serves two purposes: // // - It synchronizes path resolution with renaming initiated by this @@ -122,6 +116,8 @@ type filesystemOptions struct { fd int aname string interop InteropMode // derived from the "cache" mount option + dfltuid auth.KUID + dfltgid auth.KGID msize uint32 version string @@ -230,6 +226,15 @@ type InternalFilesystemOptions struct { OpenSocketsByConnecting bool } +// _V9FS_DEFUID and _V9FS_DEFGID (from Linux's fs/9p/v9fs.h) are the default +// UIDs and GIDs used for files that do not provide a specific owner or group +// respectively. +const ( + // uint32(-2) doesn't work in Go. + _V9FS_DEFUID = auth.KUID(4294967294) + _V9FS_DEFGID = auth.KGID(4294967294) +) + // Name implements vfs.FilesystemType.Name. func (FilesystemType) Name() string { return Name @@ -315,6 +320,31 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } } + // Parse the default UID and GID. + fsopts.dfltuid = _V9FS_DEFUID + if dfltuidstr, ok := mopts["dfltuid"]; ok { + delete(mopts, "dfltuid") + dfltuid, err := strconv.ParseUint(dfltuidstr, 10, 32) + if err != nil { + ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: dfltuid=%s", dfltuidstr) + return nil, nil, syserror.EINVAL + } + // In Linux, dfltuid is interpreted as a UID and is converted to a KUID + // in the caller's user namespace, but goferfs isn't + // application-mountable. + fsopts.dfltuid = auth.KUID(dfltuid) + } + fsopts.dfltgid = _V9FS_DEFGID + if dfltgidstr, ok := mopts["dfltgid"]; ok { + delete(mopts, "dfltgid") + dfltgid, err := strconv.ParseUint(dfltgidstr, 10, 32) + if err != nil { + ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: dfltgid=%s", dfltgidstr) + return nil, nil, syserror.EINVAL + } + fsopts.dfltgid = auth.KGID(dfltgid) + } + // Parse the 9P message size. fsopts.msize = 1024 * 1024 // 1M, tested to give good enough performance up to 64M if msizestr, ok := mopts["msize"]; ok { @@ -422,8 +452,6 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt client: client, clock: ktime.RealtimeClockFromContext(ctx), devMinor: devMinor, - uid: creds.EffectiveKUID, - gid: creds.EffectiveKGID, syncableDentries: make(map[*dentry]struct{}), specialFileFDs: make(map[*specialFileFD]struct{}), } @@ -672,8 +700,8 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma file: file, ino: qid.Path, mode: uint32(attr.Mode), - uid: uint32(fs.uid), - gid: uint32(fs.gid), + uid: uint32(fs.opts.dfltuid), + gid: uint32(fs.opts.dfltgid), blockSize: usermem.PageSize, handle: handle{ fd: -1, @@ -1011,6 +1039,18 @@ func (d *dentry) decRefLocked() { } } +// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. +// +// TODO(gvisor.dev/issue/1479): Implement inotify. +func (d *dentry) InotifyWithParent(events uint32, cookie uint32, et vfs.EventType) {} + +// Watches implements vfs.DentryImpl.Watches. +// +// TODO(gvisor.dev/issue/1479): Implement inotify. +func (d *dentry) Watches() *vfs.Watches { + return nil +} + // checkCachingLocked should be called after d's reference count becomes 0 or it // becomes disowned. // diff --git a/pkg/sentry/fsimpl/gofer/host_named_pipe.go b/pkg/sentry/fsimpl/gofer/host_named_pipe.go new file mode 100644 index 000000000..7294de7d6 --- /dev/null +++ b/pkg/sentry/fsimpl/gofer/host_named_pipe.go @@ -0,0 +1,97 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gofer + +import ( + "fmt" + "sync" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/syserror" +) + +// Global pipe used by blockUntilNonblockingPipeHasWriter since we can't create +// pipes after sentry initialization due to syscall filters. +var ( + tempPipeMu sync.Mutex + tempPipeReadFD int + tempPipeWriteFD int + tempPipeBuf [1]byte +) + +func init() { + var pipeFDs [2]int + if err := unix.Pipe(pipeFDs[:]); err != nil { + panic(fmt.Sprintf("failed to create pipe for gofer.blockUntilNonblockingPipeHasWriter: %v", err)) + } + tempPipeReadFD = pipeFDs[0] + tempPipeWriteFD = pipeFDs[1] +} + +func blockUntilNonblockingPipeHasWriter(ctx context.Context, fd int32) error { + for { + ok, err := nonblockingPipeHasWriter(fd) + if err != nil { + return err + } + if ok { + return nil + } + if err := sleepBetweenNamedPipeOpenChecks(ctx); err != nil { + return err + } + } +} + +func nonblockingPipeHasWriter(fd int32) (bool, error) { + tempPipeMu.Lock() + defer tempPipeMu.Unlock() + // Copy 1 byte from fd into the temporary pipe. + n, err := unix.Tee(int(fd), tempPipeWriteFD, 1, unix.SPLICE_F_NONBLOCK) + if err == syserror.EAGAIN { + // The pipe represented by fd is empty, but has a writer. + return true, nil + } + if err != nil { + return false, err + } + if n == 0 { + // The pipe represented by fd is empty and has no writer. + return false, nil + } + // The pipe represented by fd is non-empty, so it either has, or has + // previously had, a writer. Remove the byte copied to the temporary pipe + // before returning. + if n, err := unix.Read(tempPipeReadFD, tempPipeBuf[:]); err != nil || n != 1 { + panic(fmt.Sprintf("failed to drain pipe for gofer.blockUntilNonblockingPipeHasWriter: got (%d, %v), wanted (1, nil)", n, err)) + } + return true, nil +} + +func sleepBetweenNamedPipeOpenChecks(ctx context.Context) error { + t := time.NewTimer(100 * time.Millisecond) + defer t.Stop() + cancel := ctx.SleepStart() + select { + case <-t.C: + ctx.SleepFinish(true) + return nil + case <-cancel: + ctx.SleepFinish(false) + return syserror.ErrInterrupted + } +} diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index a464e6a94..ff6126b87 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -19,17 +19,18 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" ) -// specialFileFD implements vfs.FileDescriptionImpl for files other than -// regular files, directories, and symlinks: pipes, sockets, etc. It is also -// used for regular files when filesystemOptions.specialRegularFiles is in -// effect. specialFileFD differs from regularFileFD by using per-FD handles -// instead of shared per-dentry handles, and never buffering I/O. +// specialFileFD implements vfs.FileDescriptionImpl for pipes, sockets, device +// special files, and (when filesystemOptions.specialRegularFiles is in effect) +// regular files. specialFileFD differs from regularFileFD by using per-FD +// handles instead of shared per-dentry handles, and never buffering I/O. type specialFileFD struct { fileDescription @@ -40,13 +41,47 @@ type specialFileFD struct { // file offset is significant, i.e. a regular file. seekable is immutable. seekable bool + // mayBlock is true if this file description represents a file for which + // queue may send I/O readiness events. mayBlock is immutable. + mayBlock bool + queue waiter.Queue + // If seekable is true, off is the file offset. off is protected by mu. mu sync.Mutex off int64 } +func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, flags uint32) (*specialFileFD, error) { + ftype := d.fileType() + seekable := ftype == linux.S_IFREG + mayBlock := ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK + fd := &specialFileFD{ + handle: h, + seekable: seekable, + mayBlock: mayBlock, + } + if mayBlock && h.fd >= 0 { + if err := fdnotifier.AddFD(h.fd, &fd.queue); err != nil { + return nil, err + } + } + if err := fd.vfsfd.Init(fd, flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{ + DenyPRead: !seekable, + DenyPWrite: !seekable, + }); err != nil { + if mayBlock && h.fd >= 0 { + fdnotifier.RemoveFD(h.fd) + } + return nil, err + } + return fd, nil +} + // Release implements vfs.FileDescriptionImpl.Release. func (fd *specialFileFD) Release() { + if fd.mayBlock && fd.handle.fd >= 0 { + fdnotifier.RemoveFD(fd.handle.fd) + } fd.handle.close(context.Background()) fs := fd.vfsfd.Mount().Filesystem().Impl().(*filesystem) fs.syncMu.Lock() @@ -62,6 +97,32 @@ func (fd *specialFileFD) OnClose(ctx context.Context) error { return fd.handle.file.flush(ctx) } +// Readiness implements waiter.Waitable.Readiness. +func (fd *specialFileFD) Readiness(mask waiter.EventMask) waiter.EventMask { + if fd.mayBlock { + return fdnotifier.NonBlockingPoll(fd.handle.fd, mask) + } + return fd.fileDescription.Readiness(mask) +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (fd *specialFileFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + if fd.mayBlock { + fd.queue.EventRegister(e, mask) + return + } + fd.fileDescription.EventRegister(e, mask) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (fd *specialFileFD) EventUnregister(e *waiter.Entry) { + if fd.mayBlock { + fd.queue.EventUnregister(e) + return + } + fd.fileDescription.EventUnregister(e) +} + // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { if fd.seekable && offset < 0 { @@ -81,6 +142,9 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs } buf := make([]byte, dst.NumBytes()) n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset)) + if err == syserror.EAGAIN { + err = syserror.ErrWouldBlock + } if n == 0 { return 0, err } @@ -130,6 +194,9 @@ func (fd *specialFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off return 0, err } n, err := fd.handle.writeFromBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset)) + if err == syserror.EAGAIN { + err = syserror.ErrWouldBlock + } return int64(n), err } diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go index 2608e7e1d..1d5aa82dc 100644 --- a/pkg/sentry/fsimpl/gofer/time.go +++ b/pkg/sentry/fsimpl/gofer/time.go @@ -38,6 +38,9 @@ func statxTimestampFromDentry(ns int64) linux.StatxTimestamp { // Preconditions: fs.interop != InteropModeShared. func (d *dentry) touchAtime(mnt *vfs.Mount) { + if mnt.Flags.NoATime { + return + } if err := mnt.CheckBeginWrite(); err != nil { return } diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index a83151ad3..bbee8ccda 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -225,9 +225,21 @@ func (d *Dentry) destroy() { } } +// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. +// +// TODO(gvisor.dev/issue/1479): Implement inotify. +func (d *Dentry) InotifyWithParent(events uint32, cookie uint32, et vfs.EventType) {} + +// Watches implements vfs.DentryImpl.Watches. +// +// TODO(gvisor.dev/issue/1479): Implement inotify. +func (d *Dentry) Watches() *vfs.Watches { + return nil +} + // InsertChild inserts child into the vfs dentry cache with the given name under // this dentry. This does not update the directory inode, so calling this on -// it's own isn't sufficient to insert a child into a directory. InsertChild +// its own isn't sufficient to insert a child into a directory. InsertChild // updates the link count on d if required. // // Precondition: d must represent a directory inode. diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index 007be1572..062321cbc 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -59,6 +59,7 @@ go_library( "//pkg/sentry/pgalloc", "//pkg/sentry/platform", "//pkg/sentry/socket/unix/transport", + "//pkg/sentry/uniqueid", "//pkg/sentry/usage", "//pkg/sentry/vfs", "//pkg/sentry/vfs/lock", diff --git a/pkg/sentry/fsimpl/tmpfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go index f2399981b..70387cb9c 100644 --- a/pkg/sentry/fsimpl/tmpfs/directory.go +++ b/pkg/sentry/fsimpl/tmpfs/directory.go @@ -79,6 +79,7 @@ func (dir *directory) removeChildLocked(child *dentry) { dir.iterMu.Lock() dir.childList.Remove(child) dir.iterMu.Unlock() + child.unlinked = true } type directoryFD struct { @@ -112,6 +113,7 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba dir.iterMu.Lock() defer dir.iterMu.Unlock() + fd.dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) fd.inode().touchAtime(fd.vfsfd.Mount()) if fd.off == 0 { diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index 80fa7b29d..183eb975c 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -177,6 +177,12 @@ func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(pa if err := create(parentDir, name); err != nil { return err } + + ev := linux.IN_CREATE + if dir { + ev |= linux.IN_ISDIR + } + parentDir.inode.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent) parentDir.inode.touchCMtime() return nil } @@ -241,6 +247,7 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. return syserror.EMLINK } d.inode.incLinksLocked() + d.inode.watches.Notify("", linux.IN_ATTRIB, 0, vfs.InodeEvent) parentDir.insertChildLocked(fs.newDentry(d.inode), name) return nil }) @@ -354,6 +361,7 @@ afterTrailingSymlink: if err != nil { return nil, err } + parentDir.inode.watches.Notify(name, linux.IN_CREATE, 0, vfs.PathEvent) parentDir.inode.touchCMtime() return fd, nil } @@ -559,6 +567,8 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa newParentDir.inode.touchCMtime() } renamed.inode.touchCtime() + + vfs.InotifyRename(ctx, &renamed.inode.watches, &oldParentDir.inode.watches, &newParentDir.inode.watches, oldName, newName, renamed.inode.isDir()) return nil } @@ -603,8 +613,11 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error return err } parentDir.removeChildLocked(child) - parentDir.inode.decLinksLocked() // from child's ".." + parentDir.inode.watches.Notify(name, linux.IN_DELETE|linux.IN_ISDIR, 0, vfs.InodeEvent) + // Remove links for child, child/., and child/.. child.inode.decLinksLocked() + child.inode.decLinksLocked() + parentDir.inode.decLinksLocked() vfsObj.CommitDeleteDentry(&child.vfsd) parentDir.inode.touchCMtime() return nil @@ -618,7 +631,14 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts if err != nil { return err } - return d.inode.setStat(ctx, rp.Credentials(), &opts.Stat) + if err := d.inode.setStat(ctx, rp.Credentials(), &opts.Stat); err != nil { + return err + } + + if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { + d.InotifyWithParent(ev, 0, vfs.InodeEvent) + } + return nil } // StatAt implements vfs.FilesystemImpl.StatAt. @@ -698,6 +718,12 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil { return err } + + // Generate inotify events. Note that this must take place before the link + // count of the child is decremented, or else the watches may be dropped + // before these events are added. + vfs.InotifyRemoveChild(&child.inode.watches, &parentDir.inode.watches, name) + parentDir.removeChildLocked(child) child.inode.decLinksLocked() vfsObj.CommitDeleteDentry(&child.vfsd) @@ -754,7 +780,12 @@ func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt if err != nil { return err } - return d.inode.setxattr(rp.Credentials(), &opts) + if err := d.inode.setxattr(rp.Credentials(), &opts); err != nil { + return err + } + + d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + return nil } // RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. @@ -765,7 +796,12 @@ func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, if err != nil { return err } - return d.inode.removexattr(rp.Credentials(), name) + if err := d.inode.removexattr(rp.Credentials(), name); err != nil { + return err + } + + d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + return nil } // PrependPath implements vfs.FilesystemImpl.PrependPath. diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index 3f433d666..fee174375 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -312,7 +312,7 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off f := fd.inode().impl.(*regularFile) if end := offset + srclen; end < offset { // Overflow. - return 0, syserror.EFBIG + return 0, syserror.EINVAL } var err error diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index 1e781aecd..3777ebdf2 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -30,6 +30,7 @@ package tmpfs import ( "fmt" "math" + "strconv" "strings" "sync/atomic" @@ -124,14 +125,45 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } fs.vfsfs.Init(vfsObj, newFSType, &fs) + mopts := vfs.GenericParseMountOptions(opts.Data) + + defaultMode := linux.FileMode(0777) + if modeStr, ok := mopts["mode"]; ok { + mode, err := strconv.ParseUint(modeStr, 8, 32) + if err != nil { + return nil, nil, fmt.Errorf("Mount option \"mode='%v'\" not parsable: %v", modeStr, err) + } + defaultMode = linux.FileMode(mode) + } + + defaultOwnerCreds := creds.Fork() + if uidStr, ok := mopts["uid"]; ok { + uid, err := strconv.ParseInt(uidStr, 10, 32) + if err != nil { + return nil, nil, fmt.Errorf("Mount option \"uid='%v'\" not parsable: %v", uidStr, err) + } + if err := defaultOwnerCreds.SetUID(auth.UID(uid)); err != nil { + return nil, nil, fmt.Errorf("Error using mount option \"uid='%v'\": %v", uidStr, err) + } + } + if gidStr, ok := mopts["gid"]; ok { + gid, err := strconv.ParseInt(gidStr, 10, 32) + if err != nil { + return nil, nil, fmt.Errorf("Mount option \"gid='%v'\" not parsable: %v", gidStr, err) + } + if err := defaultOwnerCreds.SetGID(auth.GID(gid)); err != nil { + return nil, nil, fmt.Errorf("Error using mount option \"gid='%v'\": %v", gidStr, err) + } + } + var root *dentry switch rootFileType { case linux.S_IFREG: - root = fs.newDentry(fs.newRegularFile(creds, 0777)) + root = fs.newDentry(fs.newRegularFile(defaultOwnerCreds, defaultMode)) case linux.S_IFLNK: - root = fs.newDentry(fs.newSymlink(creds, tmpfsOpts.RootSymlinkTarget)) + root = fs.newDentry(fs.newSymlink(defaultOwnerCreds, tmpfsOpts.RootSymlinkTarget)) case linux.S_IFDIR: - root = &fs.newDirectory(creds, 01777).dentry + root = &fs.newDirectory(defaultOwnerCreds, defaultMode).dentry default: fs.vfsfs.DecRef() return nil, nil, fmt.Errorf("invalid tmpfs root file type: %#o", rootFileType) @@ -163,6 +195,11 @@ type dentry struct { // filesystem.mu. name string + // unlinked indicates whether this dentry has been unlinked from its parent. + // It is only set to true on an unlink operation, and never set from true to + // false. unlinked is protected by filesystem.mu. + unlinked bool + // dentryEntry (ugh) links dentries into their parent directory.childList. dentryEntry @@ -201,6 +238,26 @@ func (d *dentry) DecRef() { d.inode.decRef() } +// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. +func (d *dentry) InotifyWithParent(events uint32, cookie uint32, et vfs.EventType) { + if d.inode.isDir() { + events |= linux.IN_ISDIR + } + + // The ordering below is important, Linux always notifies the parent first. + if d.parent != nil { + // Note that d.parent or d.name may be stale if there is a concurrent + // rename operation. Inotify does not provide consistency guarantees. + d.parent.inode.watches.NotifyWithExclusions(d.name, events, cookie, et, d.unlinked) + } + d.inode.watches.Notify("", events, cookie, et) +} + +// Watches implements vfs.DentryImpl.Watches. +func (d *dentry) Watches() *vfs.Watches { + return &d.inode.watches +} + // inode represents a filesystem object. type inode struct { // fs is the owning filesystem. fs is immutable. @@ -209,11 +266,9 @@ type inode struct { // refs is a reference count. refs is accessed using atomic memory // operations. // - // A reference is held on all inodes that are reachable in the filesystem - // tree. For non-directories (which may have multiple hard links), this - // means that a reference is dropped when nlink reaches 0. For directories, - // nlink never reaches 0 due to the "." entry; instead, - // filesystem.RmdirAt() drops the reference. + // A reference is held on all inodes as long as they are reachable in the + // filesystem tree, i.e. nlink is nonzero. This reference is dropped when + // nlink reaches 0. refs int64 // xattrs implements extended attributes. @@ -238,6 +293,9 @@ type inode struct { // Advisory file locks, which lock at the inode level. locks lock.FileLocks + // Inotify watches for this inode. + watches vfs.Watches + impl interface{} // immutable } @@ -259,6 +317,7 @@ func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, i.ctime = now i.mtime = now // i.nlink initialized by caller + i.watches = vfs.Watches{} i.impl = impl } @@ -276,14 +335,17 @@ func (i *inode) incLinksLocked() { atomic.AddUint32(&i.nlink, 1) } -// decLinksLocked decrements i's link count. +// decLinksLocked decrements i's link count. If the link count reaches 0, we +// remove a reference on i as well. // // Preconditions: filesystem.mu must be locked for writing. i.nlink != 0. func (i *inode) decLinksLocked() { if i.nlink == 0 { panic("tmpfs.inode.decLinksLocked() called with no existing links") } - atomic.AddUint32(&i.nlink, ^uint32(0)) + if atomic.AddUint32(&i.nlink, ^uint32(0)) == 0 { + i.decRef() + } } func (i *inode) incRef() { @@ -306,6 +368,7 @@ func (i *inode) tryIncRef() bool { func (i *inode) decRef() { if refs := atomic.AddInt64(&i.refs, -1); refs == 0 { + i.watches.HandleDeletion() if regFile, ok := i.impl.(*regularFile); ok { // Release memory used by regFile to store data. Since regFile is // no longer usable, we don't need to grab any locks or update any @@ -531,6 +594,9 @@ func (i *inode) isDir() bool { } func (i *inode) touchAtime(mnt *vfs.Mount) { + if mnt.Flags.NoATime { + return + } if err := mnt.CheckBeginWrite(); err != nil { return } @@ -627,8 +693,12 @@ func (fd *fileDescription) filesystem() *filesystem { return fd.vfsfd.Mount().Filesystem().Impl().(*filesystem) } +func (fd *fileDescription) dentry() *dentry { + return fd.vfsfd.Dentry().Impl().(*dentry) +} + func (fd *fileDescription) inode() *inode { - return fd.vfsfd.Dentry().Impl().(*dentry).inode + return fd.dentry().inode } // Stat implements vfs.FileDescriptionImpl.Stat. @@ -641,7 +711,15 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu // SetStat implements vfs.FileDescriptionImpl.SetStat. func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { creds := auth.CredentialsFromContext(ctx) - return fd.inode().setStat(ctx, creds, &opts.Stat) + d := fd.dentry() + if err := d.inode.setStat(ctx, creds, &opts.Stat); err != nil { + return err + } + + if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { + d.InotifyWithParent(ev, 0, vfs.InodeEvent) + } + return nil } // Listxattr implements vfs.FileDescriptionImpl.Listxattr. @@ -656,12 +734,26 @@ func (fd *fileDescription) Getxattr(ctx context.Context, opts vfs.GetxattrOption // Setxattr implements vfs.FileDescriptionImpl.Setxattr. func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error { - return fd.inode().setxattr(auth.CredentialsFromContext(ctx), &opts) + d := fd.dentry() + if err := d.inode.setxattr(auth.CredentialsFromContext(ctx), &opts); err != nil { + return err + } + + // Generate inotify events. + d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + return nil } // Removexattr implements vfs.FileDescriptionImpl.Removexattr. func (fd *fileDescription) Removexattr(ctx context.Context, name string) error { - return fd.inode().removexattr(auth.CredentialsFromContext(ctx), name) + d := fd.dentry() + if err := d.inode.removexattr(auth.CredentialsFromContext(ctx), name); err != nil { + return err + } + + // Generate inotify events. + d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + return nil } // NewMemfd creates a new tmpfs regular file and file description that can back diff --git a/pkg/sentry/kernel/auth/credentials.go b/pkg/sentry/kernel/auth/credentials.go index e057d2c6d..6862f2ef5 100644 --- a/pkg/sentry/kernel/auth/credentials.go +++ b/pkg/sentry/kernel/auth/credentials.go @@ -232,3 +232,31 @@ func (c *Credentials) UseGID(gid GID) (KGID, error) { } return NoID, syserror.EPERM } + +// SetUID translates the provided uid to the root user namespace and updates c's +// uids to it. This performs no permissions or capabilities checks, the caller +// is responsible for ensuring the calling context is permitted to modify c. +func (c *Credentials) SetUID(uid UID) error { + kuid := c.UserNamespace.MapToKUID(uid) + if !kuid.Ok() { + return syserror.EINVAL + } + c.RealKUID = kuid + c.EffectiveKUID = kuid + c.SavedKUID = kuid + return nil +} + +// SetGID translates the provided gid to the root user namespace and updates c's +// gids to it. This performs no permissions or capabilities checks, the caller +// is responsible for ensuring the calling context is permitted to modify c. +func (c *Credentials) SetGID(gid GID) error { + kgid := c.UserNamespace.MapToKGID(gid) + if !kgid.Ok() { + return syserror.EINVAL + } + c.RealKGID = kgid + c.EffectiveKGID = kgid + c.SavedKGID = kgid + return nil +} diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index ed40b5303..dbfcef0fa 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -152,7 +152,13 @@ func (f *FDTable) drop(file *fs.File) { // dropVFS2 drops the table reference. func (f *FDTable) dropVFS2(file *vfs.FileDescription) { // TODO(gvisor.dev/issue/1480): Release locks. - // TODO(gvisor.dev/issue/1479): Send inotify events. + + // Generate inotify events. + ev := uint32(linux.IN_CLOSE_NOWRITE) + if file.IsWritable() { + ev = linux.IN_CLOSE_WRITE + } + file.Dentry().InotifyWithParent(ev, 0, vfs.PathEvent) // Drop the table reference. file.DecRef() diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index f29dc0472..7bfa9075a 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -8,6 +8,7 @@ go_library( "device.go", "node.go", "pipe.go", + "pipe_unsafe.go", "pipe_util.go", "reader.go", "reader_writer.go", @@ -20,6 +21,7 @@ go_library( "//pkg/amutex", "//pkg/buffer", "//pkg/context", + "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/device", "//pkg/sentry/fs", diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 62c8691f1..79645d7d2 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -207,7 +207,10 @@ func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) { p.mu.Lock() defer p.mu.Unlock() + return p.readLocked(ctx, ops) +} +func (p *Pipe) readLocked(ctx context.Context, ops readOps) (int64, error) { // Is the pipe empty? if p.view.Size() == 0 { if !p.HasWriters() { @@ -246,7 +249,10 @@ type writeOps struct { func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) { p.mu.Lock() defer p.mu.Unlock() + return p.writeLocked(ctx, ops) +} +func (p *Pipe) writeLocked(ctx context.Context, ops writeOps) (int64, error) { // Can't write to a pipe with no readers. if !p.HasReaders() { return 0, syscall.EPIPE diff --git a/pkg/sentry/kernel/pipe/pipe_unsafe.go b/pkg/sentry/kernel/pipe/pipe_unsafe.go new file mode 100644 index 000000000..dd60cba24 --- /dev/null +++ b/pkg/sentry/kernel/pipe/pipe_unsafe.go @@ -0,0 +1,35 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pipe + +import ( + "unsafe" +) + +// lockTwoPipes locks both x.mu and y.mu in an order that is guaranteed to be +// consistent for both lockTwoPipes(x, y) and lockTwoPipes(y, x), such that +// concurrent calls cannot deadlock. +// +// Preconditions: x != y. +func lockTwoPipes(x, y *Pipe) { + // Lock the two pipes in order of increasing address. + if uintptr(unsafe.Pointer(x)) < uintptr(unsafe.Pointer(y)) { + x.mu.Lock() + y.mu.Lock() + } else { + y.mu.Lock() + x.mu.Lock() + } +} diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index b54f08a30..2602bed72 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -16,7 +16,9 @@ package pipe import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" @@ -150,7 +152,9 @@ func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) * return &fd.vfsfd } -// VFSPipeFD implements vfs.FileDescriptionImpl for pipes. +// VFSPipeFD implements vfs.FileDescriptionImpl for pipes. It also implements +// non-atomic usermem.IO methods, allowing it to be passed as usermem.IO to +// other FileDescriptions for splice(2) and tee(2). type VFSPipeFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -229,3 +233,216 @@ func (fd *VFSPipeFD) PipeSize() int64 { func (fd *VFSPipeFD) SetPipeSize(size int64) (int64, error) { return fd.pipe.SetFifoSize(size) } + +// IOSequence returns a useremm.IOSequence that reads up to count bytes from, +// or writes up to count bytes to, fd. +func (fd *VFSPipeFD) IOSequence(count int64) usermem.IOSequence { + return usermem.IOSequence{ + IO: fd, + Addrs: usermem.AddrRangeSeqOf(usermem.AddrRange{0, usermem.Addr(count)}), + } +} + +// CopyIn implements usermem.IO.CopyIn. +func (fd *VFSPipeFD) CopyIn(ctx context.Context, addr usermem.Addr, dst []byte, opts usermem.IOOpts) (int, error) { + origCount := int64(len(dst)) + n, err := fd.pipe.read(ctx, readOps{ + left: func() int64 { + return int64(len(dst)) + }, + limit: func(l int64) { + dst = dst[:l] + }, + read: func(view *buffer.View) (int64, error) { + n, err := view.ReadAt(dst, 0) + view.TrimFront(int64(n)) + return int64(n), err + }, + }) + if n > 0 { + fd.pipe.Notify(waiter.EventOut) + } + if err == nil && n != origCount { + return int(n), syserror.ErrWouldBlock + } + return int(n), err +} + +// CopyOut implements usermem.IO.CopyOut. +func (fd *VFSPipeFD) CopyOut(ctx context.Context, addr usermem.Addr, src []byte, opts usermem.IOOpts) (int, error) { + origCount := int64(len(src)) + n, err := fd.pipe.write(ctx, writeOps{ + left: func() int64 { + return int64(len(src)) + }, + limit: func(l int64) { + src = src[:l] + }, + write: func(view *buffer.View) (int64, error) { + view.Append(src) + return int64(len(src)), nil + }, + }) + if n > 0 { + fd.pipe.Notify(waiter.EventIn) + } + if err == nil && n != origCount { + return int(n), syserror.ErrWouldBlock + } + return int(n), err +} + +// ZeroOut implements usermem.IO.ZeroOut. +func (fd *VFSPipeFD) ZeroOut(ctx context.Context, addr usermem.Addr, toZero int64, opts usermem.IOOpts) (int64, error) { + origCount := toZero + n, err := fd.pipe.write(ctx, writeOps{ + left: func() int64 { + return toZero + }, + limit: func(l int64) { + toZero = l + }, + write: func(view *buffer.View) (int64, error) { + view.Grow(view.Size()+toZero, true /* zero */) + return toZero, nil + }, + }) + if n > 0 { + fd.pipe.Notify(waiter.EventIn) + } + if err == nil && n != origCount { + return n, syserror.ErrWouldBlock + } + return n, err +} + +// CopyInTo implements usermem.IO.CopyInTo. +func (fd *VFSPipeFD) CopyInTo(ctx context.Context, ars usermem.AddrRangeSeq, dst safemem.Writer, opts usermem.IOOpts) (int64, error) { + count := ars.NumBytes() + if count == 0 { + return 0, nil + } + origCount := count + n, err := fd.pipe.read(ctx, readOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + read: func(view *buffer.View) (int64, error) { + n, err := view.ReadToSafememWriter(dst, uint64(count)) + view.TrimFront(int64(n)) + return int64(n), err + }, + }) + if n > 0 { + fd.pipe.Notify(waiter.EventOut) + } + if err == nil && n != origCount { + return n, syserror.ErrWouldBlock + } + return n, err +} + +// CopyOutFrom implements usermem.IO.CopyOutFrom. +func (fd *VFSPipeFD) CopyOutFrom(ctx context.Context, ars usermem.AddrRangeSeq, src safemem.Reader, opts usermem.IOOpts) (int64, error) { + count := ars.NumBytes() + if count == 0 { + return 0, nil + } + origCount := count + n, err := fd.pipe.write(ctx, writeOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + write: func(view *buffer.View) (int64, error) { + n, err := view.WriteFromSafememReader(src, uint64(count)) + return int64(n), err + }, + }) + if n > 0 { + fd.pipe.Notify(waiter.EventIn) + } + if err == nil && n != origCount { + return n, syserror.ErrWouldBlock + } + return n, err +} + +// SwapUint32 implements usermem.IO.SwapUint32. +func (fd *VFSPipeFD) SwapUint32(ctx context.Context, addr usermem.Addr, new uint32, opts usermem.IOOpts) (uint32, error) { + // How did a pipe get passed as the virtual address space to futex(2)? + panic("VFSPipeFD.SwapUint32 called unexpectedly") +} + +// CompareAndSwapUint32 implements usermem.IO.CompareAndSwapUint32. +func (fd *VFSPipeFD) CompareAndSwapUint32(ctx context.Context, addr usermem.Addr, old, new uint32, opts usermem.IOOpts) (uint32, error) { + panic("VFSPipeFD.CompareAndSwapUint32 called unexpectedly") +} + +// LoadUint32 implements usermem.IO.LoadUint32. +func (fd *VFSPipeFD) LoadUint32(ctx context.Context, addr usermem.Addr, opts usermem.IOOpts) (uint32, error) { + panic("VFSPipeFD.LoadUint32 called unexpectedly") +} + +// Splice reads up to count bytes from src and writes them to dst. It returns +// the number of bytes moved. +// +// Preconditions: count > 0. +func Splice(ctx context.Context, dst, src *VFSPipeFD, count int64) (int64, error) { + return spliceOrTee(ctx, dst, src, count, true /* removeFromSrc */) +} + +// Tee reads up to count bytes from src and writes them to dst, without +// removing the read bytes from src. It returns the number of bytes copied. +// +// Preconditions: count > 0. +func Tee(ctx context.Context, dst, src *VFSPipeFD, count int64) (int64, error) { + return spliceOrTee(ctx, dst, src, count, false /* removeFromSrc */) +} + +// Preconditions: count > 0. +func spliceOrTee(ctx context.Context, dst, src *VFSPipeFD, count int64, removeFromSrc bool) (int64, error) { + if dst.pipe == src.pipe { + return 0, syserror.EINVAL + } + + lockTwoPipes(dst.pipe, src.pipe) + defer dst.pipe.mu.Unlock() + defer src.pipe.mu.Unlock() + + n, err := dst.pipe.writeLocked(ctx, writeOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + write: func(dstView *buffer.View) (int64, error) { + return src.pipe.readLocked(ctx, readOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + read: func(srcView *buffer.View) (int64, error) { + n, err := srcView.ReadToSafememWriter(dstView, uint64(count)) + if n > 0 && removeFromSrc { + srcView.TrimFront(int64(n)) + } + return int64(n), err + }, + }) + }, + }) + if n > 0 { + dst.pipe.Notify(waiter.EventIn) + src.pipe.Notify(waiter.EventOut) + } + return n, err +} diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go index 00c425cca..9b69f3cbe 100644 --- a/pkg/sentry/kernel/task_exec.go +++ b/pkg/sentry/kernel/task_exec.go @@ -198,6 +198,10 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { t.tg.oldRSeqCritical.Store(&OldRSeqCriticalRegion{}) t.tg.pidns.owner.mu.Unlock() + oldFDTable := t.fdTable + t.fdTable = t.fdTable.Fork() + oldFDTable.DecRef() + // Remove FDs with the CloseOnExec flag set. t.fdTable.RemoveIf(func(_ *fs.File, _ *vfs.FileDescription, flags FDFlags) bool { return flags.CloseOnExec diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index 1eeb9f317..a9836ba71 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -33,6 +33,7 @@ go_template_instance( out = "usage_set.go", consts = { "minDegree": "10", + "trackGaps": "1", }, imports = { "platform": "gvisor.dev/gvisor/pkg/sentry/platform", @@ -48,6 +49,26 @@ go_template_instance( }, ) +go_template_instance( + name = "reclaim_set", + out = "reclaim_set.go", + consts = { + "minDegree": "10", + }, + imports = { + "platform": "gvisor.dev/gvisor/pkg/sentry/platform", + }, + package = "pgalloc", + prefix = "reclaim", + template = "//pkg/segment:generic_set", + types = { + "Key": "uint64", + "Range": "platform.FileRange", + "Value": "reclaimSetValue", + "Functions": "reclaimSetFunctions", + }, +) + go_library( name = "pgalloc", srcs = [ @@ -56,6 +77,7 @@ go_library( "evictable_range_set.go", "pgalloc.go", "pgalloc_unsafe.go", + "reclaim_set.go", "save_restore.go", "usage_set.go", ], diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index 2b11ea4ae..c8d9facc2 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -108,12 +108,6 @@ type MemoryFile struct { usageSwapped uint64 usageLast time.Time - // minUnallocatedPage is the minimum page that may be unallocated. - // i.e., there are no unallocated pages below minUnallocatedPage. - // - // minUnallocatedPage is protected by mu. - minUnallocatedPage uint64 - // fileSize is the size of the backing memory file in bytes. fileSize is // always a power-of-two multiple of chunkSize. // @@ -146,11 +140,9 @@ type MemoryFile struct { // is protected by mu. reclaimable bool - // minReclaimablePage is the minimum page that may be reclaimable. - // i.e., all reclaimable pages are >= minReclaimablePage. - // - // minReclaimablePage is protected by mu. - minReclaimablePage uint64 + // relcaim is the collection of regions for reclaim. relcaim is protected + // by mu. + reclaim reclaimSet // reclaimCond is signaled (with mu locked) when reclaimable or destroyed // transitions from false to true. @@ -273,12 +265,10 @@ type evictableMemoryUserInfo struct { } const ( - chunkShift = 24 - chunkSize = 1 << chunkShift // 16 MB + chunkShift = 30 + chunkSize = 1 << chunkShift // 1 GB chunkMask = chunkSize - 1 - initialSize = chunkSize - // maxPage is the highest 64-bit page. maxPage = math.MaxUint64 &^ (usermem.PageSize - 1) ) @@ -302,19 +292,12 @@ func NewMemoryFile(file *os.File, opts MemoryFileOpts) (*MemoryFile, error) { if err := file.Truncate(0); err != nil { return nil, err } - if err := file.Truncate(initialSize); err != nil { - return nil, err - } f := &MemoryFile{ - opts: opts, - fileSize: initialSize, - file: file, - // No pages are reclaimable. DecRef will always be able to - // decrease minReclaimablePage from this point. - minReclaimablePage: maxPage, - evictable: make(map[EvictableMemoryUser]*evictableMemoryUserInfo), + opts: opts, + file: file, + evictable: make(map[EvictableMemoryUser]*evictableMemoryUserInfo), } - f.mappings.Store(make([]uintptr, initialSize/chunkSize)) + f.mappings.Store(make([]uintptr, 0)) f.reclaimCond.L = &f.mu if f.opts.DelayedEviction == DelayedEvictionEnabled && f.opts.UseHostMemcgPressure { @@ -404,39 +387,28 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi alignment = usermem.HugePageSize } - start, minUnallocatedPage := findUnallocatedRange(&f.usage, f.minUnallocatedPage, length, alignment) - end := start + length - // File offsets are int64s. Since length must be strictly positive, end - // cannot legitimately be 0. - if end < start || int64(end) <= 0 { + // Find a range in the underlying file. + fr, ok := findAvailableRange(&f.usage, f.fileSize, length, alignment) + if !ok { return platform.FileRange{}, syserror.ENOMEM } - // Expand the file if needed. Double the file size on each expansion; - // uncommitted pages have effectively no cost. - fileSize := f.fileSize - for int64(end) > fileSize { - if fileSize >= 2*fileSize { - // fileSize overflow. - return platform.FileRange{}, syserror.ENOMEM - } - fileSize *= 2 - } - if fileSize > f.fileSize { - if err := f.file.Truncate(fileSize); err != nil { + // Expand the file if needed. Note that findAvailableRange will + // appropriately double the fileSize when required. + if int64(fr.End) > f.fileSize { + if err := f.file.Truncate(int64(fr.End)); err != nil { return platform.FileRange{}, err } - f.fileSize = fileSize + f.fileSize = int64(fr.End) f.mappingsMu.Lock() oldMappings := f.mappings.Load().([]uintptr) - newMappings := make([]uintptr, fileSize>>chunkShift) + newMappings := make([]uintptr, f.fileSize>>chunkShift) copy(newMappings, oldMappings) f.mappings.Store(newMappings) f.mappingsMu.Unlock() } // Mark selected pages as in use. - fr := platform.FileRange{start, end} if f.opts.ManualZeroing { if err := f.forEachMappingSlice(fr, func(bs []byte) { for i := range bs { @@ -453,49 +425,71 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi panic(fmt.Sprintf("allocating %v: failed to insert into usage set:\n%v", fr, &f.usage)) } - if minUnallocatedPage < start { - f.minUnallocatedPage = minUnallocatedPage - } else { - // start was the first unallocated page. The next must be - // somewhere beyond end. - f.minUnallocatedPage = end - } - return fr, nil } -// findUnallocatedRange returns the first unallocated page in usage of the -// specified length and alignment beginning at page start and the first single -// unallocated page. -func findUnallocatedRange(usage *usageSet, start, length, alignment uint64) (uint64, uint64) { - // Only searched until the first page is found. - firstPage := start - foundFirstPage := false - alignMask := alignment - 1 - for seg := usage.LowerBoundSegment(start); seg.Ok(); seg = seg.NextSegment() { - r := seg.Range() +// findAvailableRange returns an available range in the usageSet. +// +// Note that scanning for available slots takes place from end first backwards, +// then forwards. This heuristic has important consequence for how sequential +// mappings can be merged in the host VMAs, given that addresses for both +// application and sentry mappings are allocated top-down (from higher to +// lower addresses). The file is also grown expoentially in order to create +// space for mappings to be allocated downwards. +// +// Precondition: alignment must be a power of 2. +func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint64) (platform.FileRange, bool) { + alignmentMask := alignment - 1 + for gap := usage.UpperBoundGap(uint64(fileSize)); gap.Ok(); gap = gap.PrevLargeEnoughGap(length) { + // Start searching only at end of file. + end := gap.End() + if end > uint64(fileSize) { + end = uint64(fileSize) + } - if !foundFirstPage && r.Start > firstPage { - foundFirstPage = true + // Start at the top and align downwards. + start := end - length + if start > end { + break // Underflow. } + start &^= alignmentMask - if start >= r.End { - // start was rounded up to an alignment boundary from the end - // of a previous segment and is now beyond r.End. + // Is the gap still sufficient? + if start < gap.Start() { continue } - // This segment represents allocated or reclaimable pages; only the - // range from start to the segment's beginning is allocatable, and the - // next allocatable range begins after the segment. - if r.Start > start && r.Start-start >= length { - break + + // Allocate in the given gap. + return platform.FileRange{start, start + length}, true + } + + // Check that it's possible to fit this allocation at the end of a file of any size. + min := usage.LastGap().Start() + min = (min + alignmentMask) &^ alignmentMask + if min+length < min { + // Overflow. + return platform.FileRange{}, false + } + + // Determine the minimum file size required to fit this allocation at its end. + for { + if fileSize >= 2*fileSize { + // Is this because it's initially empty? + if fileSize == 0 { + fileSize += chunkSize + } else { + // fileSize overflow. + return platform.FileRange{}, false + } + } else { + // Double the current fileSize. + fileSize *= 2 } - start = (r.End + alignMask) &^ alignMask - if !foundFirstPage { - firstPage = r.End + start := (uint64(fileSize) - length) &^ alignmentMask + if start >= min { + return platform.FileRange{start, start + length}, true } } - return start, firstPage } // AllocateAndFill allocates memory of the given kind and fills it by calling @@ -616,6 +610,7 @@ func (f *MemoryFile) DecRef(fr platform.FileRange) { } val.refs-- if val.refs == 0 { + f.reclaim.Add(seg.Range(), reclaimSetValue{}) freed = true // Reclassify memory as System, until it's freed by the reclaim // goroutine. @@ -628,10 +623,6 @@ func (f *MemoryFile) DecRef(fr platform.FileRange) { f.usage.MergeAdjacent(fr) if freed { - if fr.Start < f.minReclaimablePage { - // We've freed at least one lower page. - f.minReclaimablePage = fr.Start - } f.reclaimable = true f.reclaimCond.Signal() } @@ -1030,6 +1021,7 @@ func (f *MemoryFile) String() string { // for allocation. func (f *MemoryFile) runReclaim() { for { + // N.B. We must call f.markReclaimed on the returned FrameRange. fr, ok := f.findReclaimable() if !ok { break @@ -1085,6 +1077,10 @@ func (f *MemoryFile) runReclaim() { } } +// findReclaimable finds memory that has been marked for reclaim. +// +// Note that there returned range will be removed from tracking. It +// must be reclaimed (removed from f.usage) at this point. func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) { f.mu.Lock() defer f.mu.Unlock() @@ -1103,18 +1099,15 @@ func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) { } f.reclaimCond.Wait() } - // Allocate returns the first usable range in offset order and is - // currently a linear scan, so reclaiming from the beginning of the - // file minimizes the expected latency of Allocate. - for seg := f.usage.LowerBoundSegment(f.minReclaimablePage); seg.Ok(); seg = seg.NextSegment() { - if seg.ValuePtr().refs == 0 { - f.minReclaimablePage = seg.End() - return seg.Range(), true - } + // Allocate works from the back of the file inwards, so reclaim + // preserves this order to minimize the cost of the search. + if seg := f.reclaim.LastSegment(); seg.Ok() { + fr := seg.Range() + f.reclaim.Remove(seg) + return fr, true } - // No pages are reclaimable. + // Nothing is reclaimable. f.reclaimable = false - f.minReclaimablePage = maxPage } } @@ -1122,8 +1115,8 @@ func (f *MemoryFile) markReclaimed(fr platform.FileRange) { f.mu.Lock() defer f.mu.Unlock() seg := f.usage.FindSegment(fr.Start) - // All of fr should be mapped to a single uncommitted reclaimable segment - // accounted to System. + // All of fr should be mapped to a single uncommitted reclaimable + // segment accounted to System. if !seg.Ok() { panic(fmt.Sprintf("reclaimed pages %v include unreferenced pages:\n%v", fr, &f.usage)) } @@ -1137,14 +1130,10 @@ func (f *MemoryFile) markReclaimed(fr platform.FileRange) { }); got != want { panic(fmt.Sprintf("reclaimed pages %v in segment %v has incorrect state %v, wanted %v:\n%v", fr, seg.Range(), got, want, &f.usage)) } - // Deallocate reclaimed pages. Even though all of seg is reclaimable, the - // caller of markReclaimed may not have decommitted it, so we can only mark - // fr as reclaimed. + // Deallocate reclaimed pages. Even though all of seg is reclaimable, + // the caller of markReclaimed may not have decommitted it, so we can + // only mark fr as reclaimed. f.usage.Remove(f.usage.Isolate(seg, fr)) - if fr.Start < f.minUnallocatedPage { - // We've deallocated at least one lower page. - f.minUnallocatedPage = fr.Start - } } // StartEvictions requests that f evict all evictable allocations. It does not @@ -1255,3 +1244,27 @@ func (evictableRangeSetFunctions) Merge(_ EvictableRange, _ evictableRangeSetVal func (evictableRangeSetFunctions) Split(_ EvictableRange, _ evictableRangeSetValue, _ uint64) (evictableRangeSetValue, evictableRangeSetValue) { return evictableRangeSetValue{}, evictableRangeSetValue{} } + +// reclaimSetValue is the value type of reclaimSet. +type reclaimSetValue struct{} + +type reclaimSetFunctions struct{} + +func (reclaimSetFunctions) MinKey() uint64 { + return 0 +} + +func (reclaimSetFunctions) MaxKey() uint64 { + return math.MaxUint64 +} + +func (reclaimSetFunctions) ClearValue(val *reclaimSetValue) { +} + +func (reclaimSetFunctions) Merge(_ platform.FileRange, _ reclaimSetValue, _ platform.FileRange, _ reclaimSetValue) (reclaimSetValue, bool) { + return reclaimSetValue{}, true +} + +func (reclaimSetFunctions) Split(_ platform.FileRange, _ reclaimSetValue, _ uint64) (reclaimSetValue, reclaimSetValue) { + return reclaimSetValue{}, reclaimSetValue{} +} diff --git a/pkg/sentry/pgalloc/pgalloc_test.go b/pkg/sentry/pgalloc/pgalloc_test.go index 293f22c6b..b5b68eb52 100644 --- a/pkg/sentry/pgalloc/pgalloc_test.go +++ b/pkg/sentry/pgalloc/pgalloc_test.go @@ -23,39 +23,49 @@ import ( const ( page = usermem.PageSize hugepage = usermem.HugePageSize + topPage = (1 << 63) - page ) func TestFindUnallocatedRange(t *testing.T) { for _, test := range []struct { - desc string - usage *usageSegmentDataSlices - start uint64 - length uint64 - alignment uint64 - unallocated uint64 - minUnallocated uint64 + desc string + usage *usageSegmentDataSlices + fileSize int64 + length uint64 + alignment uint64 + start uint64 + expectFail bool }{ { - desc: "Initial allocation succeeds", - usage: &usageSegmentDataSlices{}, - start: 0, - length: page, - alignment: page, - unallocated: 0, - minUnallocated: 0, + desc: "Initial allocation succeeds", + usage: &usageSegmentDataSlices{}, + length: page, + alignment: page, + start: chunkSize - page, // Grows by chunkSize, allocate down. }, { - desc: "Allocation begins at start of file", + desc: "Allocation finds empty space at start of file", usage: &usageSegmentDataSlices{ Start: []uint64{page}, End: []uint64{2 * page}, Values: []usageInfo{{refs: 1}}, }, - start: 0, - length: page, - alignment: page, - unallocated: 0, - minUnallocated: 0, + fileSize: 2 * page, + length: page, + alignment: page, + start: 0, + }, + { + desc: "Allocation finds empty space at end of file", + usage: &usageSegmentDataSlices{ + Start: []uint64{0}, + End: []uint64{page}, + Values: []usageInfo{{refs: 1}}, + }, + fileSize: 2 * page, + length: page, + alignment: page, + start: page, }, { desc: "In-use frames are not allocatable", @@ -64,11 +74,10 @@ func TestFindUnallocatedRange(t *testing.T) { End: []uint64{page, 2 * page}, Values: []usageInfo{{refs: 1}, {refs: 2}}, }, - start: 0, - length: page, - alignment: page, - unallocated: 2 * page, - minUnallocated: 2 * page, + fileSize: 2 * page, + length: page, + alignment: page, + start: 3 * page, // Double fileSize, allocate top-down. }, { desc: "Reclaimable frames are not allocatable", @@ -77,11 +86,10 @@ func TestFindUnallocatedRange(t *testing.T) { End: []uint64{page, 2 * page, 3 * page}, Values: []usageInfo{{refs: 1}, {refs: 0}, {refs: 1}}, }, - start: 0, - length: page, - alignment: page, - unallocated: 3 * page, - minUnallocated: 3 * page, + fileSize: 3 * page, + length: page, + alignment: page, + start: 5 * page, // Double fileSize, grow down. }, { desc: "Gaps between in-use frames are allocatable", @@ -90,11 +98,10 @@ func TestFindUnallocatedRange(t *testing.T) { End: []uint64{page, 3 * page}, Values: []usageInfo{{refs: 1}, {refs: 1}}, }, - start: 0, - length: page, - alignment: page, - unallocated: page, - minUnallocated: page, + fileSize: 3 * page, + length: page, + alignment: page, + start: page, }, { desc: "Inadequately-sized gaps are rejected", @@ -103,14 +110,13 @@ func TestFindUnallocatedRange(t *testing.T) { End: []uint64{page, 3 * page}, Values: []usageInfo{{refs: 1}, {refs: 1}}, }, - start: 0, - length: 2 * page, - alignment: page, - unallocated: 3 * page, - minUnallocated: page, + fileSize: 3 * page, + length: 2 * page, + alignment: page, + start: 4 * page, // Double fileSize, grow down. }, { - desc: "Hugepage alignment is honored", + desc: "Alignment is honored at end of file", usage: &usageSegmentDataSlices{ Start: []uint64{0, hugepage + page}, // Hugepage-sized gap here that shouldn't be allocated from @@ -118,37 +124,95 @@ func TestFindUnallocatedRange(t *testing.T) { End: []uint64{page, hugepage + 2*page}, Values: []usageInfo{{refs: 1}, {refs: 1}}, }, - start: 0, - length: hugepage, - alignment: hugepage, - unallocated: 2 * hugepage, - minUnallocated: page, + fileSize: hugepage + 2*page, + length: hugepage, + alignment: hugepage, + start: 3 * hugepage, // Double fileSize until alignment is satisfied, grow down. + }, + { + desc: "Alignment is honored before end of file", + usage: &usageSegmentDataSlices{ + Start: []uint64{0, 2*hugepage + page}, + // Page will need to be shifted down from top. + End: []uint64{page, 2*hugepage + 2*page}, + Values: []usageInfo{{refs: 1}, {refs: 1}}, + }, + fileSize: 2*hugepage + 2*page, + length: hugepage, + alignment: hugepage, + start: hugepage, }, { - desc: "Pages before start ignored", + desc: "Allocations are compact if possible", usage: &usageSegmentDataSlices{ Start: []uint64{page, 3 * page}, End: []uint64{2 * page, 4 * page}, Values: []usageInfo{{refs: 1}, {refs: 2}}, }, - start: page, - length: page, - alignment: page, - unallocated: 2 * page, - minUnallocated: 2 * page, + fileSize: 4 * page, + length: page, + alignment: page, + start: 2 * page, + }, + { + desc: "Top-down allocation within one gap", + usage: &usageSegmentDataSlices{ + Start: []uint64{page, 4 * page, 7 * page}, + End: []uint64{2 * page, 5 * page, 8 * page}, + Values: []usageInfo{{refs: 1}, {refs: 2}, {refs: 1}}, + }, + fileSize: 8 * page, + length: page, + alignment: page, + start: 6 * page, + }, + { + desc: "Top-down allocation between multiple gaps", + usage: &usageSegmentDataSlices{ + Start: []uint64{page, 3 * page, 5 * page}, + End: []uint64{2 * page, 4 * page, 6 * page}, + Values: []usageInfo{{refs: 1}, {refs: 2}, {refs: 1}}, + }, + fileSize: 6 * page, + length: page, + alignment: page, + start: 4 * page, }, { - desc: "start may be in the middle of segment", + desc: "Top-down allocation with large top gap", usage: &usageSegmentDataSlices{ - Start: []uint64{0, 3 * page}, + Start: []uint64{page, 3 * page}, End: []uint64{2 * page, 4 * page}, Values: []usageInfo{{refs: 1}, {refs: 2}}, }, - start: page, - length: page, - alignment: page, - unallocated: 2 * page, - minUnallocated: 2 * page, + fileSize: 8 * page, + length: page, + alignment: page, + start: 7 * page, + }, + { + desc: "Gaps found with possible overflow", + usage: &usageSegmentDataSlices{ + Start: []uint64{page, topPage - page}, + End: []uint64{2 * page, topPage}, + Values: []usageInfo{{refs: 1}, {refs: 1}}, + }, + fileSize: topPage, + length: page, + alignment: page, + start: topPage - 2*page, + }, + { + desc: "Overflow detected", + usage: &usageSegmentDataSlices{ + Start: []uint64{page}, + End: []uint64{topPage}, + Values: []usageInfo{{refs: 1}}, + }, + fileSize: topPage, + length: 2 * page, + alignment: page, + expectFail: true, }, } { t.Run(test.desc, func(t *testing.T) { @@ -156,12 +220,18 @@ func TestFindUnallocatedRange(t *testing.T) { if err := usage.ImportSortedSlices(test.usage); err != nil { t.Fatalf("Failed to initialize usage from %v: %v", test.usage, err) } - unallocated, minUnallocated := findUnallocatedRange(&usage, test.start, test.length, test.alignment) - if unallocated != test.unallocated { - t.Errorf("findUnallocatedRange(%v, %x, %x, %x): got unallocated %x, wanted %x", test.usage, test.start, test.length, test.alignment, unallocated, test.unallocated) + fr, ok := findAvailableRange(&usage, test.fileSize, test.length, test.alignment) + if !test.expectFail && !ok { + t.Fatalf("findAvailableRange(%v, %x, %x, %x): got %x, false wanted %x, true", test.usage, test.fileSize, test.length, test.alignment, fr.Start, test.start) + } + if test.expectFail && ok { + t.Fatalf("findAvailableRange(%v, %x, %x, %x): got %x, true wanted %x, false", test.usage, test.fileSize, test.length, test.alignment, fr.Start, test.start) + } + if ok && fr.Start != test.start { + t.Errorf("findAvailableRange(%v, %x, %x, %x): got start=%x, wanted %x", test.usage, test.fileSize, test.length, test.alignment, fr.Start, test.start) } - if minUnallocated != test.minUnallocated { - t.Errorf("findUnallocatedRange(%v, %x, %x, %x): got minUnallocated %x, wanted %x", test.usage, test.start, test.length, test.alignment, minUnallocated, test.minUnallocated) + if ok && fr.End != test.start+test.length { + t.Errorf("findAvailableRange(%v, %x, %x, %x): got end=%x, wanted %x", test.usage, test.fileSize, test.length, test.alignment, fr.End, test.start+test.length) } }) } diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 159f7eafd..4792454c4 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -6,8 +6,8 @@ go_library( name = "kvm", srcs = [ "address_space.go", - "allocator.go", "bluepill.go", + "bluepill_allocator.go", "bluepill_amd64.go", "bluepill_amd64.s", "bluepill_amd64_unsafe.go", diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go index be213bfe8..faf1d5e1c 100644 --- a/pkg/sentry/platform/kvm/address_space.go +++ b/pkg/sentry/platform/kvm/address_space.go @@ -26,16 +26,15 @@ import ( // dirtySet tracks vCPUs for invalidation. type dirtySet struct { - vCPUs []uint64 + vCPUMasks []uint64 } // forEach iterates over all CPUs in the dirty set. +// +//go:nosplit func (ds *dirtySet) forEach(m *machine, fn func(c *vCPU)) { - m.mu.RLock() - defer m.mu.RUnlock() - - for index := range ds.vCPUs { - mask := atomic.SwapUint64(&ds.vCPUs[index], 0) + for index := range ds.vCPUMasks { + mask := atomic.SwapUint64(&ds.vCPUMasks[index], 0) if mask != 0 { for bit := 0; bit < 64; bit++ { if mask&(1<<uint64(bit)) == 0 { @@ -54,7 +53,7 @@ func (ds *dirtySet) mark(c *vCPU) bool { index := uint64(c.id) / 64 bit := uint64(1) << uint(c.id%64) - oldValue := atomic.LoadUint64(&ds.vCPUs[index]) + oldValue := atomic.LoadUint64(&ds.vCPUMasks[index]) if oldValue&bit != 0 { return false // Not clean. } @@ -62,7 +61,7 @@ func (ds *dirtySet) mark(c *vCPU) bool { // Set the bit unilaterally, and ensure that a flush takes place. Note // that it's possible for races to occur here, but since the flush is // taking place long after these lines there's no race in practice. - atomicbitops.OrUint64(&ds.vCPUs[index], bit) + atomicbitops.OrUint64(&ds.vCPUMasks[index], bit) return true // Previously clean. } @@ -113,7 +112,12 @@ type hostMapEntry struct { length uintptr } -func (as *addressSpace) mapHost(addr usermem.Addr, m hostMapEntry, at usermem.AccessType) (inv bool) { +// mapLocked maps the given host entry. +// +// +checkescape:hard,stack +// +//go:nosplit +func (as *addressSpace) mapLocked(addr usermem.Addr, m hostMapEntry, at usermem.AccessType) (inv bool) { for m.length > 0 { physical, length, ok := translateToPhysical(m.addr) if !ok { @@ -133,18 +137,10 @@ func (as *addressSpace) mapHost(addr usermem.Addr, m hostMapEntry, at usermem.Ac // important; if the pagetable mappings were installed before // ensuring the physical pages were available, then some other // thread could theoretically access them. - // - // Due to the way KVM's shadow paging implementation works, - // modifications to the page tables while in host mode may not - // be trapped, leading to the shadow pages being out of sync. - // Therefore, we need to ensure that we are in guest mode for - // page table modifications. See the call to bluepill, below. - as.machine.retryInGuest(func() { - inv = as.pageTables.Map(addr, length, pagetables.MapOpts{ - AccessType: at, - User: true, - }, physical) || inv - }) + inv = as.pageTables.Map(addr, length, pagetables.MapOpts{ + AccessType: at, + User: true, + }, physical) || inv m.addr += length m.length -= length addr += usermem.Addr(length) @@ -176,6 +172,10 @@ func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform. return err } + // See block in mapLocked. + as.pageTables.Allocator.(*allocator).cpu = as.machine.Get() + defer as.machine.Put(as.pageTables.Allocator.(*allocator).cpu) + // Map the mappings in the sentry's address space (guest physical memory) // into the application's address space (guest virtual memory). inv := false @@ -190,7 +190,12 @@ func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform. _ = s[i] // Touch to commit. } } - prev := as.mapHost(addr, hostMapEntry{ + + // See bluepill_allocator.go. + bluepill(as.pageTables.Allocator.(*allocator).cpu) + + // Perform the mapping. + prev := as.mapLocked(addr, hostMapEntry{ addr: b.Addr(), length: uintptr(b.Len()), }, at) @@ -204,17 +209,27 @@ func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform. return nil } +// unmapLocked is an escape-checked wrapped around Unmap. +// +// +checkescape:hard,stack +// +//go:nosplit +func (as *addressSpace) unmapLocked(addr usermem.Addr, length uint64) bool { + return as.pageTables.Unmap(addr, uintptr(length)) +} + // Unmap unmaps the given range by calling pagetables.PageTables.Unmap. func (as *addressSpace) Unmap(addr usermem.Addr, length uint64) { as.mu.Lock() defer as.mu.Unlock() - // See above re: retryInGuest. - var prev bool - as.machine.retryInGuest(func() { - prev = as.pageTables.Unmap(addr, uintptr(length)) || prev - }) - if prev { + // See above & bluepill_allocator.go. + as.pageTables.Allocator.(*allocator).cpu = as.machine.Get() + defer as.machine.Put(as.pageTables.Allocator.(*allocator).cpu) + bluepill(as.pageTables.Allocator.(*allocator).cpu) + + if prev := as.unmapLocked(addr, length); prev { + // Invalidate all active vCPUs. as.invalidate() // Recycle any freed intermediate pages. @@ -227,7 +242,7 @@ func (as *addressSpace) Release() { as.Unmap(0, ^uint64(0)) // Free all pages from the allocator. - as.pageTables.Allocator.(allocator).base.Drain() + as.pageTables.Allocator.(*allocator).base.Drain() // Drop all cached machine references. as.machine.dropPageTables(as.pageTables) diff --git a/pkg/sentry/platform/kvm/allocator.go b/pkg/sentry/platform/kvm/bluepill_allocator.go index 3f35414bb..9485e1301 100644 --- a/pkg/sentry/platform/kvm/allocator.go +++ b/pkg/sentry/platform/kvm/bluepill_allocator.go @@ -21,56 +21,80 @@ import ( ) type allocator struct { - base *pagetables.RuntimeAllocator + base pagetables.RuntimeAllocator + + // cpu must be set prior to any pagetable operation. + // + // Due to the way KVM's shadow paging implementation works, + // modifications to the page tables while in host mode may not be + // trapped, leading to the shadow pages being out of sync. Therefore, + // we need to ensure that we are in guest mode for page table + // modifications. See the call to bluepill, below. + cpu *vCPU } // newAllocator is used to define the allocator. -func newAllocator() allocator { - return allocator{ - base: pagetables.NewRuntimeAllocator(), - } +func newAllocator() *allocator { + a := new(allocator) + a.base.Init() + return a } // NewPTEs implements pagetables.Allocator.NewPTEs. // +// +checkescape:all +// //go:nosplit -func (a allocator) NewPTEs() *pagetables.PTEs { - return a.base.NewPTEs() +func (a *allocator) NewPTEs() *pagetables.PTEs { + ptes := a.base.NewPTEs() // escapes: bluepill below. + if a.cpu != nil { + bluepill(a.cpu) + } + return ptes } // PhysicalFor returns the physical address for a set of PTEs. // +// +checkescape:all +// //go:nosplit -func (a allocator) PhysicalFor(ptes *pagetables.PTEs) uintptr { +func (a *allocator) PhysicalFor(ptes *pagetables.PTEs) uintptr { virtual := a.base.PhysicalFor(ptes) physical, _, ok := translateToPhysical(virtual) if !ok { - panic(fmt.Sprintf("PhysicalFor failed for %p", ptes)) + panic(fmt.Sprintf("PhysicalFor failed for %p", ptes)) // escapes: panic. } return physical } // LookupPTEs implements pagetables.Allocator.LookupPTEs. // +// +checkescape:all +// //go:nosplit -func (a allocator) LookupPTEs(physical uintptr) *pagetables.PTEs { +func (a *allocator) LookupPTEs(physical uintptr) *pagetables.PTEs { virtualStart, physicalStart, _, ok := calculateBluepillFault(physical, physicalRegions) if !ok { - panic(fmt.Sprintf("LookupPTEs failed for 0x%x", physical)) + panic(fmt.Sprintf("LookupPTEs failed for 0x%x", physical)) // escapes: panic. } return a.base.LookupPTEs(virtualStart + (physical - physicalStart)) } // FreePTEs implements pagetables.Allocator.FreePTEs. // +// +checkescape:all +// //go:nosplit -func (a allocator) FreePTEs(ptes *pagetables.PTEs) { - a.base.FreePTEs(ptes) +func (a *allocator) FreePTEs(ptes *pagetables.PTEs) { + a.base.FreePTEs(ptes) // escapes: bluepill below. + if a.cpu != nil { + bluepill(a.cpu) + } } // Recycle implements pagetables.Allocator.Recycle. // //go:nosplit -func (a allocator) Recycle() { +func (a *allocator) Recycle() { a.base.Recycle() } diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.go b/pkg/sentry/platform/kvm/bluepill_amd64.go index 133c2203d..ddc1554d5 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64.go @@ -63,6 +63,8 @@ func bluepillArchEnter(context *arch.SignalContext64) *vCPU { // KernelSyscall handles kernel syscalls. // +// +checkescape:all +// //go:nosplit func (c *vCPU) KernelSyscall() { regs := c.Registers() @@ -72,13 +74,15 @@ func (c *vCPU) KernelSyscall() { // We only trigger a bluepill entry in the bluepill function, and can // therefore be guaranteed that there is no floating point state to be // loaded on resuming from halt. We only worry about saving on exit. - ring0.SaveFloatingPoint((*byte)(c.floatingPointState)) + ring0.SaveFloatingPoint((*byte)(c.floatingPointState)) // escapes: no. ring0.Halt() - ring0.WriteFS(uintptr(regs.Fs_base)) // Reload host segment. + ring0.WriteFS(uintptr(regs.Fs_base)) // escapes: no, reload host segment. } // KernelException handles kernel exceptions. // +// +checkescape:all +// //go:nosplit func (c *vCPU) KernelException(vector ring0.Vector) { regs := c.Registers() @@ -89,9 +93,9 @@ func (c *vCPU) KernelException(vector ring0.Vector) { regs.Rip = 0 } // See above. - ring0.SaveFloatingPoint((*byte)(c.floatingPointState)) + ring0.SaveFloatingPoint((*byte)(c.floatingPointState)) // escapes: no. ring0.Halt() - ring0.WriteFS(uintptr(regs.Fs_base)) // Reload host segment. + ring0.WriteFS(uintptr(regs.Fs_base)) // escapes: no; reload host segment. } // bluepillArchExit is called during bluepillEnter. diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go index c215d443c..83643c602 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64.go @@ -66,6 +66,8 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) { // KernelSyscall handles kernel syscalls. // +// +checkescape:all +// //go:nosplit func (c *vCPU) KernelSyscall() { regs := c.Registers() @@ -88,6 +90,8 @@ func (c *vCPU) KernelSyscall() { // KernelException handles kernel exceptions. // +// +checkescape:all +// //go:nosplit func (c *vCPU) KernelException(vector ring0.Vector) { regs := c.Registers() diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go index 9add7c944..c025aa0bb 100644 --- a/pkg/sentry/platform/kvm/bluepill_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.15 +// +build !go1.16 // Check go:linkname function signatures when updating Go version. @@ -64,6 +64,8 @@ func bluepillArchContext(context unsafe.Pointer) *arch.SignalContext64 { // signal stack. It should only execute raw system calls and functions that are // explicitly marked go:nosplit. // +// +checkescape:all +// //go:nosplit func bluepillHandler(context unsafe.Pointer) { // Sanitize the registers; interrupts must always be disabled. @@ -82,7 +84,8 @@ func bluepillHandler(context unsafe.Pointer) { } for { - switch _, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(c.fd), _KVM_RUN, 0); errno { + _, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(c.fd), _KVM_RUN, 0) // escapes: no. + switch errno { case 0: // Expected case. case syscall.EINTR: // First, we process whatever pending signal @@ -90,7 +93,7 @@ func bluepillHandler(context unsafe.Pointer) { // currently, all signals are masked and the signal // must have been delivered directly to this thread. timeout := syscall.Timespec{} - sig, _, errno := syscall.RawSyscall6( + sig, _, errno := syscall.RawSyscall6( // escapes: no. syscall.SYS_RT_SIGTIMEDWAIT, uintptr(unsafe.Pointer(&bounceSignalMask)), 0, // siginfo. @@ -125,7 +128,7 @@ func bluepillHandler(context unsafe.Pointer) { // MMIO exit we receive EFAULT from the run ioctl. We // always inject an NMI here since we may be in kernel // mode and have interrupts disabled. - if _, _, errno := syscall.RawSyscall( + if _, _, errno := syscall.RawSyscall( // escapes: no. syscall.SYS_IOCTL, uintptr(c.fd), _KVM_NMI, 0); errno != 0 { diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index f1afc74dc..6c54712d1 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -52,16 +52,19 @@ type machine struct { // available is notified when vCPUs are available. available sync.Cond - // vCPUs are the machine vCPUs. + // vCPUsByTID are the machine vCPUs. // // These are populated dynamically. - vCPUs map[uint64]*vCPU + vCPUsByTID map[uint64]*vCPU // vCPUsByID are the machine vCPUs, can be indexed by the vCPU's ID. - vCPUsByID map[int]*vCPU + vCPUsByID []*vCPU // maxVCPUs is the maximum number of vCPUs supported by the machine. maxVCPUs int + + // nextID is the next vCPU ID. + nextID uint32 } const ( @@ -137,9 +140,8 @@ type dieState struct { // // Precondition: mu must be held. func (m *machine) newVCPU() *vCPU { - id := len(m.vCPUs) - // Create the vCPU. + id := int(atomic.AddUint32(&m.nextID, 1) - 1) fd, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CREATE_VCPU, uintptr(id)) if errno != 0 { panic(fmt.Sprintf("error creating new vCPU: %v", errno)) @@ -176,11 +178,7 @@ func (m *machine) newVCPU() *vCPU { // newMachine returns a new VM context. func newMachine(vm int) (*machine, error) { // Create the machine. - m := &machine{ - fd: vm, - vCPUs: make(map[uint64]*vCPU), - vCPUsByID: make(map[int]*vCPU), - } + m := &machine{fd: vm} m.available.L = &m.mu m.kernel.Init(ring0.KernelOpts{ PageTables: pagetables.New(newAllocator()), @@ -194,6 +192,10 @@ func newMachine(vm int) (*machine, error) { } log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs) + // Create the vCPUs map/slices. + m.vCPUsByTID = make(map[uint64]*vCPU) + m.vCPUsByID = make([]*vCPU, m.maxVCPUs) + // Apply the physical mappings. Note that these mappings may point to // guest physical addresses that are not actually available. These // physical pages are mapped on demand, see kernel_unsafe.go. @@ -274,6 +276,8 @@ func newMachine(vm int) (*machine, error) { // not available. This attempts to be efficient for calls in the hot path. // // This panics on error. +// +//go:nosplit func (m *machine) mapPhysical(physical, length uintptr, phyRegions []physicalRegion, flags uint32) { for end := physical + length; physical < end; { _, physicalStart, length, ok := calculateBluepillFault(physical, phyRegions) @@ -304,7 +308,11 @@ func (m *machine) Destroy() { runtime.SetFinalizer(m, nil) // Destroy vCPUs. - for _, c := range m.vCPUs { + for _, c := range m.vCPUsByID { + if c == nil { + continue + } + // Ensure the vCPU is not still running in guest mode. This is // possible iff teardown has been done by other threads, and // somehow a single thread has not executed any system calls. @@ -337,7 +345,7 @@ func (m *machine) Get() *vCPU { tid := procid.Current() // Check for an exact match. - if c := m.vCPUs[tid]; c != nil { + if c := m.vCPUsByTID[tid]; c != nil { c.lock() m.mu.RUnlock() return c @@ -356,7 +364,7 @@ func (m *machine) Get() *vCPU { tid = procid.Current() // Recheck for an exact match. - if c := m.vCPUs[tid]; c != nil { + if c := m.vCPUsByTID[tid]; c != nil { c.lock() m.mu.Unlock() return c @@ -364,10 +372,10 @@ func (m *machine) Get() *vCPU { for { // Scan for an available vCPU. - for origTID, c := range m.vCPUs { + for origTID, c := range m.vCPUsByTID { if atomic.CompareAndSwapUint32(&c.state, vCPUReady, vCPUUser) { - delete(m.vCPUs, origTID) - m.vCPUs[tid] = c + delete(m.vCPUsByTID, origTID) + m.vCPUsByTID[tid] = c m.mu.Unlock() c.loadSegments(tid) return c @@ -375,17 +383,17 @@ func (m *machine) Get() *vCPU { } // Create a new vCPU (maybe). - if len(m.vCPUs) < m.maxVCPUs { + if int(m.nextID) < m.maxVCPUs { c := m.newVCPU() c.lock() - m.vCPUs[tid] = c + m.vCPUsByTID[tid] = c m.mu.Unlock() c.loadSegments(tid) return c } // Scan for something not in user mode. - for origTID, c := range m.vCPUs { + for origTID, c := range m.vCPUsByTID { if !atomic.CompareAndSwapUint32(&c.state, vCPUGuest, vCPUGuest|vCPUWaiter) { continue } @@ -403,8 +411,8 @@ func (m *machine) Get() *vCPU { } // Steal the vCPU. - delete(m.vCPUs, origTID) - m.vCPUs[tid] = c + delete(m.vCPUsByTID, origTID) + m.vCPUsByTID[tid] = c m.mu.Unlock() c.loadSegments(tid) return c @@ -431,7 +439,7 @@ func (m *machine) Put(c *vCPU) { // newDirtySet returns a new dirty set. func (m *machine) newDirtySet() *dirtySet { return &dirtySet{ - vCPUs: make([]uint64, (m.maxVCPUs+63)/64, (m.maxVCPUs+63)/64), + vCPUMasks: make([]uint64, (m.maxVCPUs+63)/64, (m.maxVCPUs+63)/64), } } diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index 923ce3909..acc823ba6 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -51,9 +51,10 @@ func (m *machine) initArchState() error { recover() debug.SetPanicOnFault(old) }() - m.retryInGuest(func() { - ring0.SetCPUIDFaulting(true) - }) + c := m.Get() + defer m.Put(c) + bluepill(c) + ring0.SetCPUIDFaulting(true) return nil } @@ -89,8 +90,8 @@ func (m *machine) dropPageTables(pt *pagetables.PageTables) { defer m.mu.Unlock() // Clear from all PCIDs. - for _, c := range m.vCPUs { - if c.PCIDs != nil { + for _, c := range m.vCPUsByID { + if c != nil && c.PCIDs != nil { c.PCIDs.Drop(pt) } } @@ -335,29 +336,6 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) } } -// retryInGuest runs the given function in guest mode. -// -// If the function does not complete in guest mode (due to execution of a -// system call due to a GC stall, for example), then it will be retried. The -// given function must be idempotent as a result of the retry mechanism. -func (m *machine) retryInGuest(fn func()) { - c := m.Get() - defer m.Put(c) - for { - c.ClearErrorCode() // See below. - bluepill(c) // Force guest mode. - fn() // Execute the given function. - _, user := c.ErrorCode() - if user { - // If user is set, then we haven't bailed back to host - // mode via a kernel exception or system call. We - // consider the full function to have executed in guest - // mode and we can return. - break - } - } -} - // On x86 platform, the flags for "setMemoryRegion" can always be set as 0. // There is no need to return read-only physicalRegions. func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go index 7156c245f..290f035dd 100644 --- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go @@ -154,7 +154,7 @@ func (c *vCPU) setUserRegisters(uregs *userRegs) error { // //go:nosplit func (c *vCPU) getUserRegisters(uregs *userRegs) syscall.Errno { - if _, _, errno := syscall.RawSyscall( + if _, _, errno := syscall.RawSyscall( // escapes: no. syscall.SYS_IOCTL, uintptr(c.fd), _KVM_GET_REGS, diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index e42505542..750751aa3 100644 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -60,6 +60,12 @@ func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { if !vr.accessType.Write && vr.accessType.Read { rdonlyRegions = append(rdonlyRegions, vr.region) } + + // TODO(gvisor.dev/issue/2686): PROT_NONE should be specially treated. + // Workaround: treated as rdonly temporarily. + if !vr.accessType.Write && !vr.accessType.Read && !vr.accessType.Execute { + rdonlyRegions = append(rdonlyRegions, vr.region) + } }) for _, r := range rdonlyRegions { diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go index f04be2ab5..9f86f6a7a 100644 --- a/pkg/sentry/platform/kvm/machine_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.15 +// +build !go1.16 // Check go:linkname function signatures when updating Go version. @@ -115,7 +115,7 @@ func (a *atomicAddressSpace) get() *addressSpace { // //go:nosplit func (c *vCPU) notify() { - _, _, errno := syscall.RawSyscall6( + _, _, errno := syscall.RawSyscall6( // escapes: no. syscall.SYS_FUTEX, uintptr(unsafe.Pointer(&c.state)), linux.FUTEX_WAKE|linux.FUTEX_PRIVATE_FLAG, diff --git a/pkg/sentry/platform/ptrace/subprocess_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_unsafe.go index 2ae6b9f9d..0bee995e4 100644 --- a/pkg/sentry/platform/ptrace/subprocess_unsafe.go +++ b/pkg/sentry/platform/ptrace/subprocess_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.15 +// +build !go1.16 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/sentry/platform/ring0/kernel.go b/pkg/sentry/platform/ring0/kernel.go index 900c0bba7..021693791 100644 --- a/pkg/sentry/platform/ring0/kernel.go +++ b/pkg/sentry/platform/ring0/kernel.go @@ -31,23 +31,39 @@ type defaultHooks struct{} // KernelSyscall implements Hooks.KernelSyscall. // +// +checkescape:all +// //go:nosplit -func (defaultHooks) KernelSyscall() { Halt() } +func (defaultHooks) KernelSyscall() { + Halt() +} // KernelException implements Hooks.KernelException. // +// +checkescape:all +// //go:nosplit -func (defaultHooks) KernelException(Vector) { Halt() } +func (defaultHooks) KernelException(Vector) { + Halt() +} // kernelSyscall is a trampoline. // +// +checkescape:hard,stack +// //go:nosplit -func kernelSyscall(c *CPU) { c.hooks.KernelSyscall() } +func kernelSyscall(c *CPU) { + c.hooks.KernelSyscall() +} // kernelException is a trampoline. // +// +checkescape:hard,stack +// //go:nosplit -func kernelException(c *CPU, vector Vector) { c.hooks.KernelException(vector) } +func kernelException(c *CPU, vector Vector) { + c.hooks.KernelException(vector) +} // Init initializes a new CPU. // diff --git a/pkg/sentry/platform/ring0/kernel_amd64.go b/pkg/sentry/platform/ring0/kernel_amd64.go index 0feff8778..d37981dbf 100644 --- a/pkg/sentry/platform/ring0/kernel_amd64.go +++ b/pkg/sentry/platform/ring0/kernel_amd64.go @@ -178,6 +178,8 @@ func IsCanonical(addr uint64) bool { // // Precondition: the Rip, Rsp, Fs and Gs registers must be canonical. // +// +checkescape:all +// //go:nosplit func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { userCR3 := switchOpts.PageTables.CR3(!switchOpts.Flush, switchOpts.UserPCID) @@ -192,9 +194,9 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { // Perform the switch. swapgs() // GS will be swapped on return. - WriteFS(uintptr(regs.Fs_base)) // Set application FS. - WriteGS(uintptr(regs.Gs_base)) // Set application GS. - LoadFloatingPoint(switchOpts.FloatingPointState) // Copy in floating point. + WriteFS(uintptr(regs.Fs_base)) // escapes: no. Set application FS. + WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS. + LoadFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy in floating point. jumpToKernel() // Switch to upper half. writeCR3(uintptr(userCR3)) // Change to user address space. if switchOpts.FullRestore { @@ -204,8 +206,8 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { } writeCR3(uintptr(kernelCR3)) // Return to kernel address space. jumpToUser() // Return to lower half. - SaveFloatingPoint(switchOpts.FloatingPointState) // Copy out floating point. - WriteFS(uintptr(c.registers.Fs_base)) // Restore kernel FS. + SaveFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy out floating point. + WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS. return } diff --git a/pkg/sentry/platform/ring0/pagetables/allocator.go b/pkg/sentry/platform/ring0/pagetables/allocator.go index 23fd5c352..8d75b7599 100644 --- a/pkg/sentry/platform/ring0/pagetables/allocator.go +++ b/pkg/sentry/platform/ring0/pagetables/allocator.go @@ -53,9 +53,14 @@ type RuntimeAllocator struct { // NewRuntimeAllocator returns an allocator that uses runtime allocation. func NewRuntimeAllocator() *RuntimeAllocator { - return &RuntimeAllocator{ - used: make(map[*PTEs]struct{}), - } + r := new(RuntimeAllocator) + r.Init() + return r +} + +// Init initializes a RuntimeAllocator. +func (r *RuntimeAllocator) Init() { + r.used = make(map[*PTEs]struct{}) } // Recycle returns freed pages to the pool. diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables.go b/pkg/sentry/platform/ring0/pagetables/pagetables.go index 87e88e97d..7f18ac296 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables.go @@ -86,6 +86,8 @@ func (*mapVisitor) requiresSplit() bool { return true } // // Precondition: addr & length must be page-aligned, their sum must not overflow. // +// +checkescape:hard,stack +// //go:nosplit func (p *PageTables) Map(addr usermem.Addr, length uintptr, opts MapOpts, physical uintptr) bool { if !opts.AccessType.Any() { @@ -128,6 +130,8 @@ func (v *unmapVisitor) visit(start uintptr, pte *PTE, align uintptr) { // // Precondition: addr & length must be page-aligned. // +// +checkescape:hard,stack +// //go:nosplit func (p *PageTables) Unmap(addr usermem.Addr, length uintptr) bool { w := unmapWalker{ @@ -162,6 +166,8 @@ func (v *emptyVisitor) visit(start uintptr, pte *PTE, align uintptr) { // // Precondition: addr & length must be page-aligned. // +// +checkescape:hard,stack +// //go:nosplit func (p *PageTables) IsEmpty(addr usermem.Addr, length uintptr) bool { w := emptyWalker{ @@ -197,6 +203,8 @@ func (*lookupVisitor) requiresSplit() bool { return false } // Lookup returns the physical address for the given virtual address. // +// +checkescape:hard,stack +// //go:nosplit func (p *PageTables) Lookup(addr usermem.Addr) (physical uintptr, opts MapOpts) { mask := uintptr(usermem.PageSize - 1) diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 789bb94c8..66015e2bc 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -64,6 +64,8 @@ const enableLogging = false var emptyFilter = stack.IPHeaderFilter{ Dst: "\x00\x00\x00\x00", DstMask: "\x00\x00\x00\x00", + Src: "\x00\x00\x00\x00", + SrcMask: "\x00\x00\x00\x00", } // nflog logs messages related to the writing and reading of iptables. @@ -142,31 +144,27 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen } func findTable(stk *stack.Stack, tablename linux.TableName) (stack.Table, error) { - ipt := stk.IPTables() - table, ok := ipt.Tables[tablename.String()] + table, ok := stk.IPTables().GetTable(tablename.String()) if !ok { return stack.Table{}, fmt.Errorf("couldn't find table %q", tablename) } return table, nil } -// FillDefaultIPTables sets stack's IPTables to the default tables and -// populates them with metadata. -func FillDefaultIPTables(stk *stack.Stack) { - ipt := stack.DefaultTables() - - // In order to fill in the metadata, we have to translate ipt from its - // netstack format to Linux's giant-binary-blob format. - for name, table := range ipt.Tables { - _, metadata, err := convertNetstackToBinary(name, table) - if err != nil { - panic(fmt.Errorf("Unable to set default IP tables: %v", err)) +// FillIPTablesMetadata populates stack's IPTables with metadata. +func FillIPTablesMetadata(stk *stack.Stack) { + stk.IPTables().ModifyTables(func(tables map[string]stack.Table) { + // In order to fill in the metadata, we have to translate ipt from its + // netstack format to Linux's giant-binary-blob format. + for name, table := range tables { + _, metadata, err := convertNetstackToBinary(name, table) + if err != nil { + panic(fmt.Errorf("Unable to set default IP tables: %v", err)) + } + table.SetMetadata(metadata) + tables[name] = table } - table.SetMetadata(metadata) - ipt.Tables[name] = table - } - - stk.SetIPTables(ipt) + }) } // convertNetstackToBinary converts the iptables as stored in netstack to the @@ -214,11 +212,16 @@ func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelI } copy(entry.IPTEntry.IP.Dst[:], rule.Filter.Dst) copy(entry.IPTEntry.IP.DstMask[:], rule.Filter.DstMask) + copy(entry.IPTEntry.IP.Src[:], rule.Filter.Src) + copy(entry.IPTEntry.IP.SrcMask[:], rule.Filter.SrcMask) copy(entry.IPTEntry.IP.OutputInterface[:], rule.Filter.OutputInterface) copy(entry.IPTEntry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) if rule.Filter.DstInvert { entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_DSTIP } + if rule.Filter.SrcInvert { + entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_SRCIP + } if rule.Filter.OutputInterfaceInvert { entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT } @@ -566,15 +569,13 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // - There are no chains without an unconditional final rule. // - There are no chains without an unconditional underflow rule. - ipt := stk.IPTables() table.SetMetadata(metadata{ HookEntry: replace.HookEntry, Underflow: replace.Underflow, NumEntries: replace.NumEntries, Size: replace.Size, }) - ipt.Tables[replace.Name.String()] = table - stk.SetIPTables(ipt) + stk.IPTables().ReplaceTable(replace.Name.String(), table) return nil } @@ -737,6 +738,9 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { if len(iptip.Dst) != header.IPv4AddressSize || len(iptip.DstMask) != header.IPv4AddressSize { return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask)) } + if len(iptip.Src) != header.IPv4AddressSize || len(iptip.SrcMask) != header.IPv4AddressSize { + return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask)) + } n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0) if n == -1 { @@ -755,6 +759,9 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { Dst: tcpip.Address(iptip.Dst[:]), DstMask: tcpip.Address(iptip.DstMask[:]), DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0, + Src: tcpip.Address(iptip.Src[:]), + SrcMask: tcpip.Address(iptip.SrcMask[:]), + SrcInvert: iptip.InverseFlags&linux.IPT_INV_SRCIP != 0, OutputInterface: ifname, OutputInterfaceMask: ifnameMask, OutputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_OUT != 0, @@ -765,15 +772,13 @@ func containsUnsupportedFields(iptip linux.IPTIP) bool { // The following features are supported: // - Protocol // - Dst and DstMask + // - Src and SrcMask // - The inverse destination IP check flag // - OutputInterface, OutputInterfaceMask and its inverse. - var emptyInetAddr = linux.InetAddr{} var emptyInterface = [linux.IFNAMSIZ]byte{} // Disable any supported inverse flags. - inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_VIA_OUT) - return iptip.Src != emptyInetAddr || - iptip.SrcMask != emptyInetAddr || - iptip.InputInterface != emptyInterface || + inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_SRCIP) | uint8(linux.IPT_INV_VIA_OUT) + return iptip.InputInterface != emptyInterface || iptip.InputInterfaceMask != emptyInterface || iptip.Flags != 0 || iptip.InverseFlags&^inverseMask != 0 diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go index 3863293c7..1b4e0ad79 100644 --- a/pkg/sentry/socket/netfilter/owner_matcher.go +++ b/pkg/sentry/socket/netfilter/owner_matcher.go @@ -111,7 +111,7 @@ func (*OwnerMatcher) Name() string { } // Match implements Matcher.Match. -func (om *OwnerMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) { +func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { // Support only for OUTPUT chain. // TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also. if hook != stack.Output { diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 57a1e1c12..ebabdf334 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -96,7 +96,7 @@ func (*TCPMatcher) Name() string { } // Match implements Matcher.Match. -func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) { +func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { netHeader := header.IPv4(pkt.NetworkHeader) if netHeader.TransportProtocol() != header.TCPProtocolNumber { diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index cfa9e621d..98b9943f8 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -93,7 +93,7 @@ func (*UDPMatcher) Name() string { } // Match implements Matcher.Match. -func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) { +func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { netHeader := header.IPv4(pkt.NetworkHeader) // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index 333e0042e..8f0f5466e 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -50,5 +50,6 @@ go_library( "//pkg/tcpip/transport/udp", "//pkg/usermem", "//pkg/waiter", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 60df51dae..e1e0c5931 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -33,6 +33,7 @@ import ( "syscall" "time" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/binary" @@ -719,6 +720,14 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool defer s.EventUnregister(&e) if err := s.Endpoint.Connect(addr); err != tcpip.ErrConnectStarted && err != tcpip.ErrAlreadyConnecting { + if (s.family == unix.AF_INET || s.family == unix.AF_INET6) && s.skType == linux.SOCK_STREAM { + // TCP unlike UDP returns EADDRNOTAVAIL when it can't + // find an available local ephemeral port. + if err == tcpip.ErrNoPortAvailable { + return syserr.ErrAddressNotAvailable + } + } + return syserr.TranslateNetstackError(err) } diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index f5fa18136..9b44c2b89 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -362,14 +362,13 @@ func (s *Stack) RouteTable() []inet.Route { } // IPTables returns the stack's iptables. -func (s *Stack) IPTables() (stack.IPTables, error) { +func (s *Stack) IPTables() (*stack.IPTables, error) { return s.Stack.IPTables(), nil } -// FillDefaultIPTables sets the stack's iptables to the default tables, which -// allow and do not modify all traffic. -func (s *Stack) FillDefaultIPTables() { - netfilter.FillDefaultIPTables(s.Stack) +// FillIPTablesMetadata populates stack's IPTables with metadata. +func (s *Stack) FillIPTablesMetadata() { + netfilter.FillIPTablesMetadata(s.Stack) } // Resume implements inet.Stack.Resume. diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index ce5b94ee7..09c6d3b27 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -252,7 +252,7 @@ func (e *connectionedEndpoint) Close() { // BidirectionalConnect implements BoundEndpoint.BidirectionalConnect. func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error { if ce.Type() != e.stype { - return syserr.ErrConnectionRefused + return syserr.ErrWrongProtocolForSocket } // Check if ce is e to avoid a deadlock. diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 5b29e9d7f..c4c9db81b 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -417,7 +417,18 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool defer ep.Release() // Connect the server endpoint. - return s.ep.Connect(t, ep) + err = s.ep.Connect(t, ep) + + if err == syserr.ErrWrongProtocolForSocket { + // Linux for abstract sockets returns ErrConnectionRefused + // instead of ErrWrongProtocolForSocket. + path, _ := extractPath(sockaddr) + if len(path) > 0 && path[0] == 0 { + err = syserr.ErrConnectionRefused + } + } + + return err } // Write implements fs.FileOperations.Write. diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index 39f2b79ec..77c78889d 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -80,6 +80,12 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB } } + if total > 0 { + // On Linux, inotify behavior is not very consistent with splice(2). We try + // our best to emulate Linux for very basic calls to splice, where for some + // reason, events are generated for output files, but not input files. + outFile.Dirent.InotifyEvent(linux.IN_MODIFY, 0) + } return total, err } diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go index 2de5e3422..c24946160 100644 --- a/pkg/sentry/syscalls/linux/sys_xattr.go +++ b/pkg/sentry/syscalls/linux/sys_xattr.go @@ -207,7 +207,11 @@ func setXattr(t *kernel.Task, d *fs.Dirent, nameAddr, valueAddr usermem.Addr, si return syserror.EOPNOTSUPP } - return d.Inode.SetXattr(t, d, name, value, flags) + if err := d.Inode.SetXattr(t, d, name, value, flags); err != nil { + return err + } + d.InotifyEvent(linux.IN_ATTRIB, 0) + return nil } func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) { @@ -418,7 +422,11 @@ func removeXattr(t *kernel.Task, d *fs.Dirent, nameAddr usermem.Addr) error { return syserror.EOPNOTSUPP } - return d.Inode.RemoveXattr(t, d, name) + if err := d.Inode.RemoveXattr(t, d, name); err != nil { + return err + } + d.InotifyEvent(linux.IN_ATTRIB, 0) + return nil } // LINT.ThenChange(vfs2/xattr.go) diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD index f882ef840..c0d005247 100644 --- a/pkg/sentry/syscalls/linux/vfs2/BUILD +++ b/pkg/sentry/syscalls/linux/vfs2/BUILD @@ -12,9 +12,11 @@ go_library( "filesystem.go", "fscontext.go", "getdents.go", + "inotify.go", "ioctl.go", "memfd.go", "mmap.go", + "mount.go", "path.go", "pipe.go", "poll.go", @@ -22,6 +24,7 @@ go_library( "setstat.go", "signal.go", "socket.go", + "splice.go", "stat.go", "stat_amd64.go", "stat_arm64.go", diff --git a/pkg/sentry/syscalls/linux/vfs2/inotify.go b/pkg/sentry/syscalls/linux/vfs2/inotify.go new file mode 100644 index 000000000..7d50b6a16 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/inotify.go @@ -0,0 +1,134 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +const allFlags = linux.IN_NONBLOCK | linux.IN_CLOEXEC + +// InotifyInit1 implements the inotify_init1() syscalls. +func InotifyInit1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + flags := args[0].Int() + if flags&^allFlags != 0 { + return 0, nil, syserror.EINVAL + } + + ino, err := vfs.NewInotifyFD(t, t.Kernel().VFS(), uint32(flags)) + if err != nil { + return 0, nil, err + } + defer ino.DecRef() + + fd, err := t.NewFDFromVFS2(0, ino, kernel.FDFlags{ + CloseOnExec: flags&linux.IN_CLOEXEC != 0, + }) + + if err != nil { + return 0, nil, err + } + + return uintptr(fd), nil, nil +} + +// InotifyInit implements the inotify_init() syscalls. +func InotifyInit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + args[0].Value = 0 + return InotifyInit1(t, args) +} + +// fdToInotify resolves an fd to an inotify object. If successful, the file will +// have an extra ref and the caller is responsible for releasing the ref. +func fdToInotify(t *kernel.Task, fd int32) (*vfs.Inotify, *vfs.FileDescription, error) { + f := t.GetFileVFS2(fd) + if f == nil { + // Invalid fd. + return nil, nil, syserror.EBADF + } + + ino, ok := f.Impl().(*vfs.Inotify) + if !ok { + // Not an inotify fd. + f.DecRef() + return nil, nil, syserror.EINVAL + } + + return ino, f, nil +} + +// InotifyAddWatch implements the inotify_add_watch() syscall. +func InotifyAddWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + mask := args[2].Uint() + + // "EINVAL: The given event mask contains no valid events." + // -- inotify_add_watch(2) + if validBits := mask & linux.ALL_INOTIFY_BITS; validBits == 0 { + return 0, nil, syserror.EINVAL + } + + // "IN_DONT_FOLLOW: Don't dereference pathname if it is a symbolic link." + // -- inotify(7) + follow := followFinalSymlink + if mask&linux.IN_DONT_FOLLOW == 0 { + follow = nofollowFinalSymlink + } + + ino, f, err := fdToInotify(t, fd) + if err != nil { + return 0, nil, err + } + defer f.DecRef() + + path, err := copyInPath(t, addr) + if err != nil { + return 0, nil, err + } + if mask&linux.IN_ONLYDIR != 0 { + path.Dir = true + } + tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, follow) + if err != nil { + return 0, nil, err + } + defer tpop.Release() + d, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{}) + if err != nil { + return 0, nil, err + } + defer d.DecRef() + + fd = ino.AddWatch(d.Dentry(), mask) + return uintptr(fd), nil, err +} + +// InotifyRmWatch implements the inotify_rm_watch() syscall. +func InotifyRmWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + wd := args[1].Int() + + ino, f, err := fdToInotify(t, fd) + if err != nil { + return 0, nil, err + } + defer f.DecRef() + return 0, nil, ino.RmWatch(wd) +} diff --git a/pkg/sentry/syscalls/linux/vfs2/mount.go b/pkg/sentry/syscalls/linux/vfs2/mount.go new file mode 100644 index 000000000..adeaa39cc --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/mount.go @@ -0,0 +1,145 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// Mount implements Linux syscall mount(2). +func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + sourceAddr := args[0].Pointer() + targetAddr := args[1].Pointer() + typeAddr := args[2].Pointer() + flags := args[3].Uint64() + dataAddr := args[4].Pointer() + + // For null-terminated strings related to mount(2), Linux copies in at most + // a page worth of data. See fs/namespace.c:copy_mount_string(). + fsType, err := t.CopyInString(typeAddr, usermem.PageSize) + if err != nil { + return 0, nil, err + } + source, err := t.CopyInString(sourceAddr, usermem.PageSize) + if err != nil { + return 0, nil, err + } + + targetPath, err := copyInPath(t, targetAddr) + if err != nil { + return 0, nil, err + } + + data := "" + if dataAddr != 0 { + // In Linux, a full page is always copied in regardless of null + // character placement, and the address is passed to each file system. + // Most file systems always treat this data as a string, though, and so + // do all of the ones we implement. + data, err = t.CopyInString(dataAddr, usermem.PageSize) + if err != nil { + return 0, nil, err + } + } + + // Ignore magic value that was required before Linux 2.4. + if flags&linux.MS_MGC_MSK == linux.MS_MGC_VAL { + flags = flags &^ linux.MS_MGC_MSK + } + + // Must have CAP_SYS_ADMIN in the current mount namespace's associated user + // namespace. + creds := t.Credentials() + if !creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, t.MountNamespaceVFS2().Owner) { + return 0, nil, syserror.EPERM + } + + const unsupportedOps = linux.MS_REMOUNT | linux.MS_BIND | + linux.MS_SHARED | linux.MS_PRIVATE | linux.MS_SLAVE | + linux.MS_UNBINDABLE | linux.MS_MOVE + + // Silently allow MS_NOSUID, since we don't implement set-id bits + // anyway. + const unsupportedFlags = linux.MS_NODEV | + linux.MS_NODIRATIME | linux.MS_STRICTATIME + + // Linux just allows passing any flags to mount(2) - it won't fail when + // unknown or unsupported flags are passed. Since we don't implement + // everything, we fail explicitly on flags that are unimplemented. + if flags&(unsupportedOps|unsupportedFlags) != 0 { + return 0, nil, syserror.EINVAL + } + + var opts vfs.MountOptions + if flags&linux.MS_NOATIME == linux.MS_NOATIME { + opts.Flags.NoATime = true + } + if flags&linux.MS_NOEXEC == linux.MS_NOEXEC { + opts.Flags.NoExec = true + } + if flags&linux.MS_RDONLY == linux.MS_RDONLY { + opts.ReadOnly = true + } + opts.GetFilesystemOptions.Data = data + + target, err := getTaskPathOperation(t, linux.AT_FDCWD, targetPath, disallowEmptyPath, nofollowFinalSymlink) + if err != nil { + return 0, nil, err + } + defer target.Release() + + return 0, nil, t.Kernel().VFS().MountAt(t, creds, source, &target.pop, fsType, &opts) +} + +// Umount2 implements Linux syscall umount2(2). +func Umount2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + addr := args[0].Pointer() + flags := args[1].Int() + + // Must have CAP_SYS_ADMIN in the mount namespace's associated user + // namespace. + // + // Currently, this is always the init task's user namespace. + creds := t.Credentials() + if !creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, t.MountNamespaceVFS2().Owner) { + return 0, nil, syserror.EPERM + } + + const unsupported = linux.MNT_FORCE | linux.MNT_EXPIRE + if flags&unsupported != 0 { + return 0, nil, syserror.EINVAL + } + + path, err := copyInPath(t, addr) + if err != nil { + return 0, nil, err + } + tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, nofollowFinalSymlink) + if err != nil { + return 0, nil, err + } + defer tpop.Release() + + opts := vfs.UmountOptions{ + Flags: uint32(flags), + } + + return 0, nil, t.Kernel().VFS().UmountAt(t, creds, &tpop.pop, &opts) +} diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go index 3a7ef24f5..7f9debd4a 100644 --- a/pkg/sentry/syscalls/linux/vfs2/read_write.go +++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go @@ -93,11 +93,17 @@ func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { n, err := file.Read(t, dst, opts) if err != syserror.ErrWouldBlock { + if n > 0 { + file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + } return n, err } allowBlock, deadline, hasDeadline := blockPolicy(t, file) if !allowBlock { + if n > 0 { + file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + } return n, err } @@ -128,6 +134,9 @@ func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opt } file.EventUnregister(&w) + if total > 0 { + file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + } return total, err } @@ -248,11 +257,17 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { n, err := file.PRead(t, dst, offset, opts) if err != syserror.ErrWouldBlock { + if n > 0 { + file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + } return n, err } allowBlock, deadline, hasDeadline := blockPolicy(t, file) if !allowBlock { + if n > 0 { + file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + } return n, err } @@ -283,6 +298,9 @@ func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, of } file.EventUnregister(&w) + if total > 0 { + file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + } return total, err } @@ -345,11 +363,17 @@ func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { n, err := file.Write(t, src, opts) if err != syserror.ErrWouldBlock { + if n > 0 { + file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + } return n, err } allowBlock, deadline, hasDeadline := blockPolicy(t, file) if !allowBlock { + if n > 0 { + file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + } return n, err } @@ -380,6 +404,9 @@ func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, op } file.EventUnregister(&w) + if total > 0 { + file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + } return total, err } @@ -500,11 +527,17 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { n, err := file.PWrite(t, src, offset, opts) if err != syserror.ErrWouldBlock { + if n > 0 { + file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + } return n, err } allowBlock, deadline, hasDeadline := blockPolicy(t, file) if !allowBlock { + if n > 0 { + file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + } return n, err } @@ -535,6 +568,9 @@ func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, o } file.EventUnregister(&w) + if total > 0 { + file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + } return total, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go new file mode 100644 index 000000000..945a364a7 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/splice.go @@ -0,0 +1,291 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/waiter" +) + +// Splice implements Linux syscall splice(2). +func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + inFD := args[0].Int() + inOffsetPtr := args[1].Pointer() + outFD := args[2].Int() + outOffsetPtr := args[3].Pointer() + count := int64(args[4].SizeT()) + flags := args[5].Int() + + if count == 0 { + return 0, nil, nil + } + if count > int64(kernel.MAX_RW_COUNT) { + count = int64(kernel.MAX_RW_COUNT) + } + + // Check for invalid flags. + if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 { + return 0, nil, syserror.EINVAL + } + + // Get file descriptions. + inFile := t.GetFileVFS2(inFD) + if inFile == nil { + return 0, nil, syserror.EBADF + } + defer inFile.DecRef() + outFile := t.GetFileVFS2(outFD) + if outFile == nil { + return 0, nil, syserror.EBADF + } + defer outFile.DecRef() + + // Check that both files support the required directionality. + if !inFile.IsReadable() || !outFile.IsWritable() { + return 0, nil, syserror.EBADF + } + + // The operation is non-blocking if anything is non-blocking. + // + // N.B. This is a rather simplistic heuristic that avoids some + // poor edge case behavior since the exact semantics here are + // underspecified and vary between versions of Linux itself. + nonBlock := ((inFile.StatusFlags()|outFile.StatusFlags())&linux.O_NONBLOCK != 0) || (flags&linux.SPLICE_F_NONBLOCK != 0) + + // At least one file description must represent a pipe. + inPipeFD, inIsPipe := inFile.Impl().(*pipe.VFSPipeFD) + outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD) + if !inIsPipe && !outIsPipe { + return 0, nil, syserror.EINVAL + } + + // Copy in offsets. + inOffset := int64(-1) + if inOffsetPtr != 0 { + if inIsPipe { + return 0, nil, syserror.ESPIPE + } + if inFile.Options().DenyPRead { + return 0, nil, syserror.EINVAL + } + if _, err := t.CopyIn(inOffsetPtr, &inOffset); err != nil { + return 0, nil, err + } + if inOffset < 0 { + return 0, nil, syserror.EINVAL + } + } + outOffset := int64(-1) + if outOffsetPtr != 0 { + if outIsPipe { + return 0, nil, syserror.ESPIPE + } + if outFile.Options().DenyPWrite { + return 0, nil, syserror.EINVAL + } + if _, err := t.CopyIn(outOffsetPtr, &outOffset); err != nil { + return 0, nil, err + } + if outOffset < 0 { + return 0, nil, syserror.EINVAL + } + } + + // Move data. + var ( + n int64 + err error + inCh chan struct{} + outCh chan struct{} + ) + for { + // If both input and output are pipes, delegate to the pipe + // implementation. Otherwise, exactly one end is a pipe, which we + // ensure is consistently ordered after the non-pipe FD's locks by + // passing the pipe FD as usermem.IO to the non-pipe end. + switch { + case inIsPipe && outIsPipe: + n, err = pipe.Splice(t, outPipeFD, inPipeFD, count) + case inIsPipe: + if outOffset != -1 { + n, err = outFile.PWrite(t, inPipeFD.IOSequence(count), outOffset, vfs.WriteOptions{}) + outOffset += n + } else { + n, err = outFile.Write(t, inPipeFD.IOSequence(count), vfs.WriteOptions{}) + } + case outIsPipe: + if inOffset != -1 { + n, err = inFile.PRead(t, outPipeFD.IOSequence(count), inOffset, vfs.ReadOptions{}) + inOffset += n + } else { + n, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{}) + } + } + if n != 0 || err != syserror.ErrWouldBlock || nonBlock { + break + } + + // Note that the blocking behavior here is a bit different than the + // normal pattern. Because we need to have both data to read and data + // to write simultaneously, we actually explicitly block on both of + // these cases in turn before returning to the splice operation. + if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 { + if inCh == nil { + inCh = make(chan struct{}, 1) + inW, _ := waiter.NewChannelEntry(inCh) + inFile.EventRegister(&inW, eventMaskRead) + defer inFile.EventUnregister(&inW) + continue // Need to refresh readiness. + } + if err = t.Block(inCh); err != nil { + break + } + } + if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { + if outCh == nil { + outCh = make(chan struct{}, 1) + outW, _ := waiter.NewChannelEntry(outCh) + outFile.EventRegister(&outW, eventMaskWrite) + defer outFile.EventUnregister(&outW) + continue // Need to refresh readiness. + } + if err = t.Block(outCh); err != nil { + break + } + } + } + + // Copy updated offsets out. + if inOffsetPtr != 0 { + if _, err := t.CopyOut(inOffsetPtr, &inOffset); err != nil { + return 0, nil, err + } + } + if outOffsetPtr != 0 { + if _, err := t.CopyOut(outOffsetPtr, &outOffset); err != nil { + return 0, nil, err + } + } + + if n == 0 { + return 0, nil, err + } + + // On Linux, inotify behavior is not very consistent with splice(2). We try + // our best to emulate Linux for very basic calls to splice, where for some + // reason, events are generated for output files, but not input files. + outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + return uintptr(n), nil, nil +} + +// Tee implements Linux syscall tee(2). +func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + inFD := args[0].Int() + outFD := args[1].Int() + count := int64(args[2].SizeT()) + flags := args[3].Int() + + if count == 0 { + return 0, nil, nil + } + if count > int64(kernel.MAX_RW_COUNT) { + count = int64(kernel.MAX_RW_COUNT) + } + + // Check for invalid flags. + if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 { + return 0, nil, syserror.EINVAL + } + + // Get file descriptions. + inFile := t.GetFileVFS2(inFD) + if inFile == nil { + return 0, nil, syserror.EBADF + } + defer inFile.DecRef() + outFile := t.GetFileVFS2(outFD) + if outFile == nil { + return 0, nil, syserror.EBADF + } + defer outFile.DecRef() + + // Check that both files support the required directionality. + if !inFile.IsReadable() || !outFile.IsWritable() { + return 0, nil, syserror.EBADF + } + + // The operation is non-blocking if anything is non-blocking. + // + // N.B. This is a rather simplistic heuristic that avoids some + // poor edge case behavior since the exact semantics here are + // underspecified and vary between versions of Linux itself. + nonBlock := ((inFile.StatusFlags()|outFile.StatusFlags())&linux.O_NONBLOCK != 0) || (flags&linux.SPLICE_F_NONBLOCK != 0) + + // Both file descriptions must represent pipes. + inPipeFD, inIsPipe := inFile.Impl().(*pipe.VFSPipeFD) + outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD) + if !inIsPipe || !outIsPipe { + return 0, nil, syserror.EINVAL + } + + // Copy data. + var ( + inCh chan struct{} + outCh chan struct{} + ) + for { + n, err := pipe.Tee(t, outPipeFD, inPipeFD, count) + if n != 0 { + return uintptr(n), nil, nil + } + if err != syserror.ErrWouldBlock || nonBlock { + return 0, nil, err + } + + // Note that the blocking behavior here is a bit different than the + // normal pattern. Because we need to have both data to read and data + // to write simultaneously, we actually explicitly block on both of + // these cases in turn before returning to the tee operation. + if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 { + if inCh == nil { + inCh = make(chan struct{}, 1) + inW, _ := waiter.NewChannelEntry(inCh) + inFile.EventRegister(&inW, eventMaskRead) + defer inFile.EventUnregister(&inW) + continue // Need to refresh readiness. + } + if err := t.Block(inCh); err != nil { + return 0, nil, err + } + } + if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { + if outCh == nil { + outCh = make(chan struct{}, 1) + outW, _ := waiter.NewChannelEntry(outCh) + outFile.EventRegister(&outW, eventMaskWrite) + defer outFile.EventUnregister(&outW) + continue // Need to refresh readiness. + } + if err := t.Block(outCh); err != nil { + return 0, nil, err + } + } + } +} diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go index a332d01bd..7b6e7571a 100644 --- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go @@ -90,8 +90,8 @@ func Override() { s.Table[138] = syscalls.Supported("fstatfs", Fstatfs) s.Table[161] = syscalls.Supported("chroot", Chroot) s.Table[162] = syscalls.Supported("sync", Sync) - delete(s.Table, 165) // mount - delete(s.Table, 166) // umount2 + s.Table[165] = syscalls.Supported("mount", Mount) + s.Table[166] = syscalls.Supported("umount2", Umount2) delete(s.Table, 187) // readahead s.Table[188] = syscalls.Supported("setxattr", Setxattr) s.Table[189] = syscalls.Supported("lsetxattr", Lsetxattr) @@ -116,9 +116,9 @@ func Override() { s.Table[232] = syscalls.Supported("epoll_wait", EpollWait) s.Table[233] = syscalls.Supported("epoll_ctl", EpollCtl) s.Table[235] = syscalls.Supported("utimes", Utimes) - delete(s.Table, 253) // inotify_init - delete(s.Table, 254) // inotify_add_watch - delete(s.Table, 255) // inotify_rm_watch + s.Table[253] = syscalls.PartiallySupported("inotify_init", InotifyInit, "inotify events are only available inside the sandbox.", nil) + s.Table[254] = syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil) + s.Table[255] = syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil) s.Table[257] = syscalls.Supported("openat", Openat) s.Table[258] = syscalls.Supported("mkdirat", Mkdirat) s.Table[259] = syscalls.Supported("mknodat", Mknodat) @@ -134,8 +134,8 @@ func Override() { s.Table[269] = syscalls.Supported("faccessat", Faccessat) s.Table[270] = syscalls.Supported("pselect", Pselect) s.Table[271] = syscalls.Supported("ppoll", Ppoll) - delete(s.Table, 275) // splice - delete(s.Table, 276) // tee + s.Table[275] = syscalls.Supported("splice", Splice) + s.Table[276] = syscalls.Supported("tee", Tee) s.Table[277] = syscalls.Supported("sync_file_range", SyncFileRange) s.Table[280] = syscalls.Supported("utimensat", Utimensat) s.Table[281] = syscalls.Supported("epoll_pwait", EpollPwait) @@ -151,7 +151,7 @@ func Override() { s.Table[291] = syscalls.Supported("epoll_create1", EpollCreate1) s.Table[292] = syscalls.Supported("dup3", Dup3) s.Table[293] = syscalls.Supported("pipe2", Pipe2) - delete(s.Table, 294) // inotify_init1 + s.Table[294] = syscalls.PartiallySupported("inotify_init1", InotifyInit1, "inotify events are only available inside the sandbox.", nil) s.Table[295] = syscalls.Supported("preadv", Preadv) s.Table[296] = syscalls.Supported("pwritev", Pwritev) s.Table[299] = syscalls.Supported("recvmmsg", RecvMMsg) diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 94d69c1cc..774cc66cc 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -15,6 +15,18 @@ go_template_instance( }, ) +go_template_instance( + name = "event_list", + out = "event_list.go", + package = "vfs", + prefix = "event", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*Event", + "Linker": "*Event", + }, +) + go_library( name = "vfs", srcs = [ @@ -25,11 +37,13 @@ go_library( "device.go", "epoll.go", "epoll_interest_list.go", + "event_list.go", "file_description.go", "file_description_impl_util.go", "filesystem.go", "filesystem_impl_util.go", "filesystem_type.go", + "inotify.go", "mount.go", "mount_unsafe.go", "options.go", @@ -57,6 +71,7 @@ go_library( "//pkg/sentry/limits", "//pkg/sentry/memmap", "//pkg/sentry/socket/unix/transport", + "//pkg/sentry/uniqueid", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", diff --git a/pkg/sentry/vfs/README.md b/pkg/sentry/vfs/README.md index 9aa133bcb..66f3105bd 100644 --- a/pkg/sentry/vfs/README.md +++ b/pkg/sentry/vfs/README.md @@ -39,8 +39,8 @@ Mount references are held by: - Mount: Each referenced Mount holds a reference on its parent, which is the mount containing its mount point. -- VirtualFilesystem: A reference is held on each Mount that has not been - umounted. +- VirtualFilesystem: A reference is held on each Mount that has been connected + to a mount point, but not yet umounted. MountNamespace and FileDescription references are held by users of VFS. The expectation is that each `kernel.Task` holds a reference on its corresponding diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go index caf770fd5..b7c6b60b8 100644 --- a/pkg/sentry/vfs/anonfs.go +++ b/pkg/sentry/vfs/anonfs.go @@ -297,3 +297,15 @@ func (d *anonDentry) TryIncRef() bool { func (d *anonDentry) DecRef() { // no-op } + +// InotifyWithParent implements DentryImpl.InotifyWithParent. +// +// TODO(gvisor.dev/issue/1479): Implement inotify. +func (d *anonDentry) InotifyWithParent(events uint32, cookie uint32, et EventType) {} + +// Watches implements DentryImpl.Watches. +// +// TODO(gvisor.dev/issue/1479): Implement inotify. +func (d *anonDentry) Watches() *Watches { + return nil +} diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go index 8624dbd5d..24af13eb1 100644 --- a/pkg/sentry/vfs/dentry.go +++ b/pkg/sentry/vfs/dentry.go @@ -103,6 +103,22 @@ type DentryImpl interface { // DecRef decrements the Dentry's reference count. DecRef() + + // InotifyWithParent notifies all watches on the targets represented by this + // dentry and its parent. The parent's watches are notified first, followed + // by this dentry's. + // + // InotifyWithParent automatically adds the IN_ISDIR flag for dentries + // representing directories. + // + // Note that the events may not actually propagate up to the user, depending + // on the event masks. + InotifyWithParent(events uint32, cookie uint32, et EventType) + + // Watches returns the set of inotify watches for the file corresponding to + // the Dentry. Dentries that are hard links to the same underlying file + // share the same watches. + Watches() *Watches } // IncRef increments d's reference count. @@ -133,6 +149,17 @@ func (d *Dentry) isMounted() bool { return atomic.LoadUint32(&d.mounts) != 0 } +// InotifyWithParent notifies all watches on the inodes for this dentry and +// its parent of events. +func (d *Dentry) InotifyWithParent(events uint32, cookie uint32, et EventType) { + d.impl.InotifyWithParent(events, cookie, et) +} + +// Watches returns the set of inotify watches associated with d. +func (d *Dentry) Watches() *Watches { + return d.impl.Watches() +} + // The following functions are exported so that filesystem implementations can // use them. The vfs package, and users of VFS, should not call these // functions. diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index cfabd936c..bb294563d 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -210,6 +210,11 @@ func (fd *FileDescription) VirtualDentry() VirtualDentry { return fd.vd } +// Options returns the options passed to fd.Init(). +func (fd *FileDescription) Options() FileDescriptionOptions { + return fd.opts +} + // StatusFlags returns file description status flags, as for fcntl(F_GETFL). func (fd *FileDescription) StatusFlags() uint32 { return atomic.LoadUint32(&fd.statusFlags) diff --git a/pkg/sentry/vfs/genericfstree/genericfstree.go b/pkg/sentry/vfs/genericfstree/genericfstree.go index 286510195..8882fa84a 100644 --- a/pkg/sentry/vfs/genericfstree/genericfstree.go +++ b/pkg/sentry/vfs/genericfstree/genericfstree.go @@ -43,7 +43,7 @@ type Dentry struct { // IsAncestorDentry returns true if d is an ancestor of d2; that is, d is // either d2's parent or an ancestor of d2's parent. func IsAncestorDentry(d, d2 *Dentry) bool { - for { + for d2 != nil { // Stop at root, where d2.parent == nil. if d2.parent == d { return true } @@ -52,6 +52,7 @@ func IsAncestorDentry(d, d2 *Dentry) bool { } d2 = d2.parent } + return false } // ParentOrSelf returns d.parent. If d.parent is nil, ParentOrSelf returns d. diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go new file mode 100644 index 000000000..05a3051a4 --- /dev/null +++ b/pkg/sentry/vfs/inotify.go @@ -0,0 +1,697 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs + +import ( + "bytes" + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/uniqueid" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +// inotifyEventBaseSize is the base size of linux's struct inotify_event. This +// must be a power 2 for rounding below. +const inotifyEventBaseSize = 16 + +// EventType defines different kinds of inotfiy events. +// +// The way events are labelled appears somewhat arbitrary, but they must match +// Linux so that IN_EXCL_UNLINK behaves as it does in Linux. +type EventType uint8 + +// PathEvent and InodeEvent correspond to FSNOTIFY_EVENT_PATH and +// FSNOTIFY_EVENT_INODE in Linux. +const ( + PathEvent EventType = iota + InodeEvent EventType = iota +) + +// Inotify represents an inotify instance created by inotify_init(2) or +// inotify_init1(2). Inotify implements FileDescriptionImpl. +// +// Lock ordering: +// Inotify.mu -> Watches.mu -> Inotify.evMu +// +// +stateify savable +type Inotify struct { + vfsfd FileDescription + FileDescriptionDefaultImpl + DentryMetadataFileDescriptionImpl + + // Unique identifier for this inotify instance. We don't just reuse the + // inotify fd because fds can be duped. These should not be exposed to the + // user, since we may aggressively reuse an id on S/R. + id uint64 + + // queue is used to notify interested parties when the inotify instance + // becomes readable or writable. + queue waiter.Queue `state:"nosave"` + + // evMu *only* protects the events list. We need a separate lock while + // queuing events: using mu may violate lock ordering, since at that point + // the calling goroutine may already hold Watches.mu. + evMu sync.Mutex `state:"nosave"` + + // A list of pending events for this inotify instance. Protected by evMu. + events eventList + + // A scratch buffer, used to serialize inotify events. Allocate this + // ahead of time for the sake of performance. Protected by evMu. + scratch []byte + + // mu protects the fields below. + mu sync.Mutex `state:"nosave"` + + // nextWatchMinusOne is used to allocate watch descriptors on this Inotify + // instance. Note that Linux starts numbering watch descriptors from 1. + nextWatchMinusOne int32 + + // Map from watch descriptors to watch objects. + watches map[int32]*Watch +} + +var _ FileDescriptionImpl = (*Inotify)(nil) + +// NewInotifyFD constructs a new Inotify instance. +func NewInotifyFD(ctx context.Context, vfsObj *VirtualFilesystem, flags uint32) (*FileDescription, error) { + // O_CLOEXEC affects file descriptors, so it must be handled outside of vfs. + flags &^= linux.O_CLOEXEC + if flags&^linux.O_NONBLOCK != 0 { + return nil, syserror.EINVAL + } + + id := uniqueid.GlobalFromContext(ctx) + vd := vfsObj.NewAnonVirtualDentry(fmt.Sprintf("[inotifyfd:%d]", id)) + defer vd.DecRef() + fd := &Inotify{ + id: id, + scratch: make([]byte, inotifyEventBaseSize), + watches: make(map[int32]*Watch), + } + if err := fd.vfsfd.Init(fd, flags, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{ + UseDentryMetadata: true, + DenyPRead: true, + DenyPWrite: true, + }); err != nil { + return nil, err + } + return &fd.vfsfd, nil +} + +// Release implements FileDescriptionImpl.Release. Release removes all +// watches and frees all resources for an inotify instance. +func (i *Inotify) Release() { + // We need to hold i.mu to avoid a race with concurrent calls to + // Inotify.handleDeletion from Watches. There's no risk of Watches + // accessing this Inotify after the destructor ends, because we remove all + // references to it below. + i.mu.Lock() + defer i.mu.Unlock() + for _, w := range i.watches { + // Remove references to the watch from the watches set on the target. We + // don't need to worry about the references from i.watches, since this + // file description is about to be destroyed. + w.set.Remove(i.id) + } +} + +// EventRegister implements waiter.Waitable. +func (i *Inotify) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + i.queue.EventRegister(e, mask) +} + +// EventUnregister implements waiter.Waitable. +func (i *Inotify) EventUnregister(e *waiter.Entry) { + i.queue.EventUnregister(e) +} + +// Readiness implements waiter.Waitable.Readiness. +// +// Readiness indicates whether there are pending events for an inotify instance. +func (i *Inotify) Readiness(mask waiter.EventMask) waiter.EventMask { + ready := waiter.EventMask(0) + + i.evMu.Lock() + defer i.evMu.Unlock() + + if !i.events.Empty() { + ready |= waiter.EventIn + } + + return mask & ready +} + +// PRead implements FileDescriptionImpl. +func (*Inotify) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// PWrite implements FileDescriptionImpl. +func (*Inotify) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Write implements FileDescriptionImpl.Write. +func (*Inotify) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) { + return 0, syserror.EBADF +} + +// Read implements FileDescriptionImpl.Read. +func (i *Inotify) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) { + if dst.NumBytes() < inotifyEventBaseSize { + return 0, syserror.EINVAL + } + + i.evMu.Lock() + defer i.evMu.Unlock() + + if i.events.Empty() { + // Nothing to read yet, tell caller to block. + return 0, syserror.ErrWouldBlock + } + + var writeLen int64 + for it := i.events.Front(); it != nil; { + // Advance `it` before the element is removed from the list, or else + // it.Next() will always be nil. + event := it + it = it.Next() + + // Does the buffer have enough remaining space to hold the event we're + // about to write out? + if dst.NumBytes() < int64(event.sizeOf()) { + if writeLen > 0 { + // Buffer wasn't big enough for all pending events, but we did + // write some events out. + return writeLen, nil + } + return 0, syserror.EINVAL + } + + // Linux always dequeues an available event as long as there's enough + // buffer space to copy it out, even if the copy below fails. Emulate + // this behaviour. + i.events.Remove(event) + + // Buffer has enough space, copy event to the read buffer. + n, err := event.CopyTo(ctx, i.scratch, dst) + if err != nil { + return 0, err + } + + writeLen += n + dst = dst.DropFirst64(n) + } + return writeLen, nil +} + +// Ioctl implements fs.FileOperations.Ioctl. +func (i *Inotify) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + switch args[1].Int() { + case linux.FIONREAD: + i.evMu.Lock() + defer i.evMu.Unlock() + var n uint32 + for e := i.events.Front(); e != nil; e = e.Next() { + n += uint32(e.sizeOf()) + } + var buf [4]byte + usermem.ByteOrder.PutUint32(buf[:], n) + _, err := uio.CopyOut(ctx, args[2].Pointer(), buf[:], usermem.IOOpts{}) + return 0, err + + default: + return 0, syserror.ENOTTY + } +} + +func (i *Inotify) queueEvent(ev *Event) { + i.evMu.Lock() + + // Check if we should coalesce the event we're about to queue with the last + // one currently in the queue. Events are coalesced if they are identical. + if last := i.events.Back(); last != nil { + if ev.equals(last) { + // "Coalesce" the two events by simply not queuing the new one. We + // don't need to raise a waiter.EventIn notification because no new + // data is available for reading. + i.evMu.Unlock() + return + } + } + + i.events.PushBack(ev) + + // Release mutex before notifying waiters because we don't control what they + // can do. + i.evMu.Unlock() + + i.queue.Notify(waiter.EventIn) +} + +// newWatchLocked creates and adds a new watch to target. +// +// Precondition: i.mu must be locked. +func (i *Inotify) newWatchLocked(target *Dentry, mask uint32) *Watch { + targetWatches := target.Watches() + w := &Watch{ + owner: i, + wd: i.nextWatchIDLocked(), + set: targetWatches, + mask: mask, + } + + // Hold the watch in this inotify instance as well as the watch set on the + // target. + i.watches[w.wd] = w + targetWatches.Add(w) + return w +} + +// newWatchIDLocked allocates and returns a new watch descriptor. +// +// Precondition: i.mu must be locked. +func (i *Inotify) nextWatchIDLocked() int32 { + i.nextWatchMinusOne++ + return i.nextWatchMinusOne +} + +// handleDeletion handles the deletion of the target of watch w. It removes w +// from i.watches and a watch removal event is generated. +func (i *Inotify) handleDeletion(w *Watch) { + i.mu.Lock() + _, found := i.watches[w.wd] + delete(i.watches, w.wd) + i.mu.Unlock() + + if found { + i.queueEvent(newEvent(w.wd, "", linux.IN_IGNORED, 0)) + } +} + +// AddWatch constructs a new inotify watch and adds it to the target. It +// returns the watch descriptor returned by inotify_add_watch(2). +func (i *Inotify) AddWatch(target *Dentry, mask uint32) int32 { + // Note: Locking this inotify instance protects the result returned by + // Lookup() below. With the lock held, we know for sure the lookup result + // won't become stale because it's impossible for *this* instance to + // add/remove watches on target. + i.mu.Lock() + defer i.mu.Unlock() + + // Does the target already have a watch from this inotify instance? + if existing := target.Watches().Lookup(i.id); existing != nil { + newmask := mask + if mask&linux.IN_MASK_ADD != 0 { + // "Add (OR) events to watch mask for this pathname if it already + // exists (instead of replacing mask)." -- inotify(7) + newmask |= atomic.LoadUint32(&existing.mask) + } + atomic.StoreUint32(&existing.mask, newmask) + return existing.wd + } + + // No existing watch, create a new watch. + w := i.newWatchLocked(target, mask) + return w.wd +} + +// RmWatch looks up an inotify watch for the given 'wd' and configures the +// target to stop sending events to this inotify instance. +func (i *Inotify) RmWatch(wd int32) error { + i.mu.Lock() + + // Find the watch we were asked to removed. + w, ok := i.watches[wd] + if !ok { + i.mu.Unlock() + return syserror.EINVAL + } + + // Remove the watch from this instance. + delete(i.watches, wd) + + // Remove the watch from the watch target. + w.set.Remove(w.OwnerID()) + i.mu.Unlock() + + // Generate the event for the removal. + i.queueEvent(newEvent(wd, "", linux.IN_IGNORED, 0)) + + return nil +} + +// Watches is the collection of all inotify watches on a single file. +// +// +stateify savable +type Watches struct { + // mu protects the fields below. + mu sync.RWMutex `state:"nosave"` + + // ws is the map of active watches in this collection, keyed by the inotify + // instance id of the owner. + ws map[uint64]*Watch +} + +// Lookup returns the watch owned by an inotify instance with the given id. +// Returns nil if no such watch exists. +// +// Precondition: the inotify instance with the given id must be locked to +// prevent the returned watch from being concurrently modified or replaced in +// Inotify.watches. +func (w *Watches) Lookup(id uint64) *Watch { + w.mu.Lock() + defer w.mu.Unlock() + return w.ws[id] +} + +// Add adds watch into this set of watches. +// +// Precondition: the inotify instance with the given id must be locked. +func (w *Watches) Add(watch *Watch) { + w.mu.Lock() + defer w.mu.Unlock() + + owner := watch.OwnerID() + // Sanity check, we should never have two watches for one owner on the + // same target. + if _, exists := w.ws[owner]; exists { + panic(fmt.Sprintf("Watch collision with ID %+v", owner)) + } + if w.ws == nil { + w.ws = make(map[uint64]*Watch) + } + w.ws[owner] = watch +} + +// Remove removes a watch with the given id from this set of watches and +// releases it. The caller is responsible for generating any watch removal +// event, as appropriate. The provided id must match an existing watch in this +// collection. +// +// Precondition: the inotify instance with the given id must be locked. +func (w *Watches) Remove(id uint64) { + w.mu.Lock() + defer w.mu.Unlock() + + if w.ws == nil { + // This watch set is being destroyed. The thread executing the + // destructor is already in the process of deleting all our watches. We + // got here with no references on the target because we raced with the + // destructor notifying all the watch owners of destruction. See the + // comment in Watches.HandleDeletion for why this race exists. + return + } + + if _, ok := w.ws[id]; !ok { + // While there's technically no problem with silently ignoring a missing + // watch, this is almost certainly a bug. + panic(fmt.Sprintf("Attempt to remove a watch, but no watch found with provided id %+v.", id)) + } + delete(w.ws, id) +} + +// Notify queues a new event with all watches in this set. +func (w *Watches) Notify(name string, events, cookie uint32, et EventType) { + w.NotifyWithExclusions(name, events, cookie, et, false) +} + +// NotifyWithExclusions queues a new event with watches in this set. Watches +// with IN_EXCL_UNLINK are skipped if the event is coming from a child that +// has been unlinked. +func (w *Watches) NotifyWithExclusions(name string, events, cookie uint32, et EventType, unlinked bool) { + // N.B. We don't defer the unlocks because Notify is in the hot path of + // all IO operations, and the defer costs too much for small IO + // operations. + w.mu.RLock() + for _, watch := range w.ws { + if unlinked && watch.ExcludeUnlinkedChildren() && et == PathEvent { + continue + } + watch.Notify(name, events, cookie) + } + w.mu.RUnlock() +} + +// HandleDeletion is called when the watch target is destroyed to emit +// the appropriate events. +func (w *Watches) HandleDeletion() { + w.Notify("", linux.IN_DELETE_SELF, 0, InodeEvent) + + // TODO(gvisor.dev/issue/1479): This doesn't work because maps are not copied + // by value. Ideally, we wouldn't have this circular locking so we can just + // notify of IN_DELETE_SELF in the same loop below. + // + // We can't hold w.mu while calling watch.handleDeletion to preserve lock + // ordering w.r.t to the owner inotify instances. Instead, atomically move + // the watches map into a local variable so we can iterate over it safely. + // + // Because of this however, it is possible for the watches' owners to reach + // this inode while the inode has no refs. This is still safe because the + // owners can only reach the inode until this function finishes calling + // watch.handleDeletion below and the inode is guaranteed to exist in the + // meantime. But we still have to be very careful not to rely on inode state + // that may have been already destroyed. + var ws map[uint64]*Watch + w.mu.Lock() + ws = w.ws + w.ws = nil + w.mu.Unlock() + + for _, watch := range ws { + // TODO(gvisor.dev/issue/1479): consider refactoring this. + watch.handleDeletion() + } +} + +// Watch represent a particular inotify watch created by inotify_add_watch. +// +// +stateify savable +type Watch struct { + // Inotify instance which owns this watch. + owner *Inotify + + // Descriptor for this watch. This is unique across an inotify instance. + wd int32 + + // set is the watch set containing this watch. It belongs to the target file + // of this watch. + set *Watches + + // Events being monitored via this watch. Must be accessed with atomic + // memory operations. + mask uint32 +} + +// OwnerID returns the id of the inotify instance that owns this watch. +func (w *Watch) OwnerID() uint64 { + return w.owner.id +} + +// ExcludeUnlinkedChildren indicates whether the watched object should continue +// to be notified of events of its children after they have been unlinked, e.g. +// for an open file descriptor. +// +// TODO(gvisor.dev/issue/1479): Implement IN_EXCL_UNLINK. +// We can do this by keeping track of the set of unlinked children in Watches +// to skip notification. +func (w *Watch) ExcludeUnlinkedChildren() bool { + return atomic.LoadUint32(&w.mask)&linux.IN_EXCL_UNLINK != 0 +} + +// Notify queues a new event on this watch. +func (w *Watch) Notify(name string, events uint32, cookie uint32) { + mask := atomic.LoadUint32(&w.mask) + if mask&events == 0 { + // We weren't watching for this event. + return + } + + // Event mask should include bits matched from the watch plus all control + // event bits. + unmaskableBits := ^uint32(0) &^ linux.IN_ALL_EVENTS + effectiveMask := unmaskableBits | mask + matchedEvents := effectiveMask & events + w.owner.queueEvent(newEvent(w.wd, name, matchedEvents, cookie)) +} + +// handleDeletion handles the deletion of w's target. +func (w *Watch) handleDeletion() { + w.owner.handleDeletion(w) +} + +// Event represents a struct inotify_event from linux. +// +// +stateify savable +type Event struct { + eventEntry + + wd int32 + mask uint32 + cookie uint32 + + // len is computed based on the name field is set automatically by + // Event.setName. It should be 0 when no name is set; otherwise it is the + // length of the name slice. + len uint32 + + // The name field has special padding requirements and should only be set by + // calling Event.setName. + name []byte +} + +func newEvent(wd int32, name string, events, cookie uint32) *Event { + e := &Event{ + wd: wd, + mask: events, + cookie: cookie, + } + if name != "" { + e.setName(name) + } + return e +} + +// paddedBytes converts a go string to a null-terminated c-string, padded with +// null bytes to a total size of 'l'. 'l' must be large enough for all the bytes +// in the 's' plus at least one null byte. +func paddedBytes(s string, l uint32) []byte { + if l < uint32(len(s)+1) { + panic("Converting string to byte array results in truncation, this can lead to buffer-overflow due to the missing null-byte!") + } + b := make([]byte, l) + copy(b, s) + + // b was zero-value initialized during make(), so the rest of the slice is + // already filled with null bytes. + + return b +} + +// setName sets the optional name for this event. +func (e *Event) setName(name string) { + // We need to pad the name such that the entire event length ends up a + // multiple of inotifyEventBaseSize. + unpaddedLen := len(name) + 1 + // Round up to nearest multiple of inotifyEventBaseSize. + e.len = uint32((unpaddedLen + inotifyEventBaseSize - 1) & ^(inotifyEventBaseSize - 1)) + // Make sure we haven't overflowed and wrapped around when rounding. + if unpaddedLen > int(e.len) { + panic("Overflow when rounding inotify event size, the 'name' field was too big.") + } + e.name = paddedBytes(name, e.len) +} + +func (e *Event) sizeOf() int { + s := inotifyEventBaseSize + int(e.len) + if s < inotifyEventBaseSize { + panic("overflow") + } + return s +} + +// CopyTo serializes this event to dst. buf is used as a scratch buffer to +// construct the output. We use a buffer allocated ahead of time for +// performance. buf must be at least inotifyEventBaseSize bytes. +func (e *Event) CopyTo(ctx context.Context, buf []byte, dst usermem.IOSequence) (int64, error) { + usermem.ByteOrder.PutUint32(buf[0:], uint32(e.wd)) + usermem.ByteOrder.PutUint32(buf[4:], e.mask) + usermem.ByteOrder.PutUint32(buf[8:], e.cookie) + usermem.ByteOrder.PutUint32(buf[12:], e.len) + + writeLen := 0 + + n, err := dst.CopyOut(ctx, buf) + if err != nil { + return 0, err + } + writeLen += n + dst = dst.DropFirst(n) + + if e.len > 0 { + n, err = dst.CopyOut(ctx, e.name) + if err != nil { + return 0, err + } + writeLen += n + } + + // Santiy check. + if writeLen != e.sizeOf() { + panic(fmt.Sprintf("Serialized unexpected amount of data for an event, expected %d, wrote %d.", e.sizeOf(), writeLen)) + } + + return int64(writeLen), nil +} + +func (e *Event) equals(other *Event) bool { + return e.wd == other.wd && + e.mask == other.mask && + e.cookie == other.cookie && + e.len == other.len && + bytes.Equal(e.name, other.name) +} + +// InotifyEventFromStatMask generates the appropriate events for an operation +// that set the stats specified in mask. +func InotifyEventFromStatMask(mask uint32) uint32 { + var ev uint32 + if mask&(linux.STATX_UID|linux.STATX_GID|linux.STATX_MODE) != 0 { + ev |= linux.IN_ATTRIB + } + if mask&linux.STATX_SIZE != 0 { + ev |= linux.IN_MODIFY + } + + if (mask & (linux.STATX_ATIME | linux.STATX_MTIME)) == (linux.STATX_ATIME | linux.STATX_MTIME) { + // Both times indicates a utime(s) call. + ev |= linux.IN_ATTRIB + } else if mask&linux.STATX_ATIME != 0 { + ev |= linux.IN_ACCESS + } else if mask&linux.STATX_MTIME != 0 { + mask |= linux.IN_MODIFY + } + return ev +} + +// InotifyRemoveChild sends the appriopriate notifications to the watch sets of +// the child being removed and its parent. +func InotifyRemoveChild(self, parent *Watches, name string) { + self.Notify("", linux.IN_ATTRIB, 0, InodeEvent) + parent.Notify(name, linux.IN_DELETE, 0, InodeEvent) + // TODO(gvisor.dev/issue/1479): implement IN_EXCL_UNLINK. +} + +// InotifyRename sends the appriopriate notifications to the watch sets of the +// file being renamed and its old/new parents. +func InotifyRename(ctx context.Context, renamed, oldParent, newParent *Watches, oldName, newName string, isDir bool) { + var dirEv uint32 + if isDir { + dirEv = linux.IN_ISDIR + } + cookie := uniqueid.InotifyCookie(ctx) + oldParent.Notify(oldName, dirEv|linux.IN_MOVED_FROM, cookie, InodeEvent) + newParent.Notify(newName, dirEv|linux.IN_MOVED_TO, cookie, InodeEvent) + // Somewhat surprisingly, self move events do not have a cookie. + renamed.Notify("", linux.IN_MOVE_SELF, 0, InodeEvent) +} diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 02850b65c..32f901bd8 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -28,9 +28,6 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) -// lastMountID is used to allocate mount ids. Must be accessed atomically. -var lastMountID uint64 - // A Mount is a replacement of a Dentry (Mount.key.point) from one Filesystem // (Mount.key.parent.fs) with a Dentry (Mount.root) from another Filesystem // (Mount.fs), which applies to path resolution in the context of a particular @@ -58,6 +55,10 @@ type Mount struct { // ID is the immutable mount ID. ID uint64 + // Flags contains settings as specified for mount(2), e.g. MS_NOEXEC, except + // for MS_RDONLY which is tracked in "writers". Immutable. + Flags MountFlags + // key is protected by VirtualFilesystem.mountMu and // VirtualFilesystem.mounts.seq, and may be nil. References are held on // key.parent and key.point if they are not nil. @@ -84,10 +85,6 @@ type Mount struct { // umounted is true. umounted is protected by VirtualFilesystem.mountMu. umounted bool - // flags contains settings as specified for mount(2), e.g. MS_NOEXEC, except - // for MS_RDONLY which is tracked in "writers". - flags MountFlags - // The lower 63 bits of writers is the number of calls to // Mount.CheckBeginWrite() that have not yet been paired with a call to // Mount.EndWrite(). The MSB of writers is set if MS_RDONLY is in effect. @@ -97,11 +94,11 @@ type Mount struct { func newMount(vfs *VirtualFilesystem, fs *Filesystem, root *Dentry, mntns *MountNamespace, opts *MountOptions) *Mount { mnt := &Mount{ - ID: atomic.AddUint64(&lastMountID, 1), + ID: atomic.AddUint64(&vfs.lastMountID, 1), + Flags: opts.Flags, vfs: vfs, fs: fs, root: root, - flags: opts.Flags, ns: mntns, refs: 1, } @@ -111,8 +108,17 @@ func newMount(vfs *VirtualFilesystem, fs *Filesystem, root *Dentry, mntns *Mount return mnt } -// A MountNamespace is a collection of Mounts. -// +// Options returns a copy of the MountOptions currently applicable to mnt. +func (mnt *Mount) Options() MountOptions { + mnt.vfs.mountMu.Lock() + defer mnt.vfs.mountMu.Unlock() + return MountOptions{ + Flags: mnt.Flags, + ReadOnly: mnt.readOnly(), + } +} + +// A MountNamespace is a collection of Mounts.// // MountNamespaces are reference-counted. Unless otherwise specified, all // MountNamespace methods require that a reference is held. // @@ -120,6 +126,9 @@ func newMount(vfs *VirtualFilesystem, fs *Filesystem, root *Dentry, mntns *Mount // // +stateify savable type MountNamespace struct { + // Owner is the usernamespace that owns this mount namespace. + Owner *auth.UserNamespace + // root is the MountNamespace's root mount. root is immutable. root *Mount @@ -148,7 +157,7 @@ type MountNamespace struct { func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *GetFilesystemOptions) (*MountNamespace, error) { rft := vfs.getFilesystemType(fsTypeName) if rft == nil { - ctx.Warningf("Unknown filesystem: %s", fsTypeName) + ctx.Warningf("Unknown filesystem type: %s", fsTypeName) return nil, syserror.ENODEV } fs, root, err := rft.fsType.GetFilesystem(ctx, vfs, creds, source, *opts) @@ -156,6 +165,7 @@ func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth return nil, err } mntns := &MountNamespace{ + Owner: creds.UserNamespace, refs: 1, mountpoints: make(map[*Dentry]uint32), } @@ -175,26 +185,34 @@ func (vfs *VirtualFilesystem) NewDisconnectedMount(fs *Filesystem, root *Dentry, return newMount(vfs, fs, root, nil /* mntns */, opts), nil } -// MountAt creates and mounts a Filesystem configured by the given arguments. -func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *MountOptions) error { +// MountDisconnected creates a Filesystem configured by the given arguments, +// then returns a Mount representing it. The new Mount is not associated with +// any MountNamespace and is not connected to any other Mounts. +func (vfs *VirtualFilesystem) MountDisconnected(ctx context.Context, creds *auth.Credentials, source string, fsTypeName string, opts *MountOptions) (*Mount, error) { rft := vfs.getFilesystemType(fsTypeName) if rft == nil { - return syserror.ENODEV + return nil, syserror.ENODEV } if !opts.InternalMount && !rft.opts.AllowUserMount { - return syserror.ENODEV + return nil, syserror.ENODEV } fs, root, err := rft.fsType.GetFilesystem(ctx, vfs, creds, source, opts.GetFilesystemOptions) if err != nil { - return err + return nil, err } + defer root.DecRef() + defer fs.DecRef() + return vfs.NewDisconnectedMount(fs, root, opts) +} +// ConnectMountAt connects mnt at the path represented by target. +// +// Preconditions: mnt must be disconnected. +func (vfs *VirtualFilesystem) ConnectMountAt(ctx context.Context, creds *auth.Credentials, mnt *Mount, target *PathOperation) error { // We can't hold vfs.mountMu while calling FilesystemImpl methods due to // lock ordering. vd, err := vfs.GetDentryAt(ctx, creds, target, &GetDentryOptions{}) if err != nil { - root.DecRef() - fs.DecRef() return err } vfs.mountMu.Lock() @@ -204,8 +222,6 @@ func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentia vd.dentry.mu.Unlock() vfs.mountMu.Unlock() vd.DecRef() - root.DecRef() - fs.DecRef() return syserror.ENOENT } // vd might have been mounted over between vfs.GetDentryAt() and @@ -238,7 +254,6 @@ func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentia // point and the mount root are directories, or neither are, and returns // ENOTDIR if this is not the case. mntns := vd.mount.ns - mnt := newMount(vfs, fs, root, mntns, opts) vfs.mounts.seq.BeginWrite() vfs.connectLocked(mnt, vd, mntns) vfs.mounts.seq.EndWrite() @@ -247,6 +262,19 @@ func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentia return nil } +// MountAt creates and mounts a Filesystem configured by the given arguments. +func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *MountOptions) error { + mnt, err := vfs.MountDisconnected(ctx, creds, source, fsTypeName, opts) + if err != nil { + return err + } + defer mnt.DecRef() + if err := vfs.ConnectMountAt(ctx, creds, mnt, target); err != nil { + return err + } + return nil +} + // UmountAt removes the Mount at the given path. func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *UmountOptions) error { if opts.Flags&^(linux.MNT_FORCE|linux.MNT_DETACH) != 0 { @@ -254,6 +282,9 @@ func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credenti } // MNT_FORCE is currently unimplemented except for the permission check. + // Force unmounting specifically requires CAP_SYS_ADMIN in the root user + // namespace, and not in the owner user namespace for the target mount. See + // fs/namespace.c:SYSCALL_DEFINE2(umount, ...) if opts.Flags&linux.MNT_FORCE != 0 && creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, creds.UserNamespace.Root()) { return syserror.EPERM } @@ -369,14 +400,22 @@ func (vfs *VirtualFilesystem) umountRecursiveLocked(mnt *Mount, opts *umountRecu // references held by vd. // // Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a -// writer critical section. d.mu must be locked. mnt.parent() == nil. +// writer critical section. d.mu must be locked. mnt.parent() == nil, i.e. mnt +// must not already be connected. func (vfs *VirtualFilesystem) connectLocked(mnt *Mount, vd VirtualDentry, mntns *MountNamespace) { + if checkInvariants { + if mnt.parent() != nil { + panic("VFS.connectLocked called on connected mount") + } + } + mnt.IncRef() // dropped by callers of umountRecursiveLocked mnt.storeKey(vd) if vd.mount.children == nil { vd.mount.children = make(map[*Mount]struct{}) } vd.mount.children[mnt] = struct{}{} atomic.AddUint32(&vd.dentry.mounts, 1) + mnt.ns = mntns mntns.mountpoints[vd.dentry]++ vfs.mounts.insertSeqed(mnt) vfsmpmounts, ok := vfs.mountpoints[vd.dentry] @@ -394,6 +433,11 @@ func (vfs *VirtualFilesystem) connectLocked(mnt *Mount, vd VirtualDentry, mntns // writer critical section. mnt.parent() != nil. func (vfs *VirtualFilesystem) disconnectLocked(mnt *Mount) VirtualDentry { vd := mnt.loadKey() + if checkInvariants { + if vd.mount != nil { + panic("VFS.disconnectLocked called on disconnected mount") + } + } mnt.storeKey(VirtualDentry{}) delete(vd.mount.children, mnt) atomic.AddUint32(&vd.dentry.mounts, math.MaxUint32) // -1 @@ -715,7 +759,10 @@ func (vfs *VirtualFilesystem) GenerateProcMounts(ctx context.Context, taskRootDi if mnt.readOnly() { opts = "ro" } - if mnt.flags.NoExec { + if mnt.Flags.NoATime { + opts = ",noatime" + } + if mnt.Flags.NoExec { opts += ",noexec" } @@ -800,11 +847,12 @@ func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRoo if mnt.readOnly() { opts = "ro" } - if mnt.flags.NoExec { + if mnt.Flags.NoATime { + opts = ",noatime" + } + if mnt.Flags.NoExec { opts += ",noexec" } - // TODO(gvisor.dev/issue/1193): Add "noatime" if MS_NOATIME is - // set. fmt.Fprintf(buf, "%s ", opts) // (7) Optional fields: zero or more fields of the form "tag[:value]". diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index bc7581698..70f850ca4 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.15 +// +build !go1.16 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go index 53d364c5c..f223aeda8 100644 --- a/pkg/sentry/vfs/options.go +++ b/pkg/sentry/vfs/options.go @@ -75,6 +75,10 @@ type MknodOptions struct { type MountFlags struct { // NoExec is equivalent to MS_NOEXEC. NoExec bool + + // NoATime is equivalent to MS_NOATIME and indicates that the + // filesystem should not update access time in-place. + NoATime bool } // MountOptions contains options to VirtualFilesystem.MountAt(). diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 8d7f8f8af..9acca8bc7 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -82,6 +82,10 @@ type VirtualFilesystem struct { // mountpoints is analogous to Linux's mountpoint_hashtable. mountpoints map[*Dentry]map[*Mount]struct{} + // lastMountID is the last allocated mount ID. lastMountID is accessed + // using atomic memory operations. + lastMountID uint64 + // anonMount is a Mount, not included in mounts or mountpoints, // representing an anonFilesystem. anonMount is used to back // VirtualDentries returned by VirtualFilesystem.NewAnonVirtualDentry(). @@ -401,7 +405,7 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential vfs.putResolvingPath(rp) if opts.FileExec { - if fd.Mount().flags.NoExec { + if fd.Mount().Flags.NoExec { fd.DecRef() return nil, syserror.EACCES } @@ -418,6 +422,7 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential } } + fd.Dentry().InotifyWithParent(linux.IN_OPEN, 0, PathEvent) return fd, nil } if !rp.handleError(err) { diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go index 101497ed6..e2894f9f5 100644 --- a/pkg/sentry/watchdog/watchdog.go +++ b/pkg/sentry/watchdog/watchdog.go @@ -77,7 +77,10 @@ var DefaultOpts = Opts{ // trigger it. const descheduleThreshold = 1 * time.Second -var stuckTasks = metric.MustCreateNewUint64Metric("/watchdog/stuck_tasks_detected", true /* sync */, "Cumulative count of stuck tasks detected") +var ( + stuckStartup = metric.MustCreateNewUint64Metric("/watchdog/stuck_startup_detected", true /* sync */, "Incremented once on startup watchdog timeout") + stuckTasks = metric.MustCreateNewUint64Metric("/watchdog/stuck_tasks_detected", true /* sync */, "Cumulative count of stuck tasks detected") +) // Amount of time to wait before dumping the stack to the log again when the same task(s) remains stuck. var stackDumpSameTaskPeriod = time.Minute @@ -220,6 +223,9 @@ func (w *Watchdog) waitForStart() { // We are fine. return } + + stuckStartup.Increment() + var buf bytes.Buffer buf.WriteString(fmt.Sprintf("Watchdog.Start() not called within %s", w.StartupTimeout)) w.doAction(w.StartupTimeoutAction, false, &buf) @@ -328,8 +334,8 @@ func (w *Watchdog) reportStuckWatchdog() { } // doAction will take the given action. If the action is LogWarning, the stack -// is not always dumpped to the log to prevent log flooding. "forceStack" -// guarantees that the stack will be dumped regarless. +// is not always dumped to the log to prevent log flooding. "forceStack" +// guarantees that the stack will be dumped regardless. func (w *Watchdog) doAction(action Action, forceStack bool, msg *bytes.Buffer) { switch action { case LogWarning: diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go index 65bfcf778..f68c12620 100644 --- a/pkg/sleep/sleep_unsafe.go +++ b/pkg/sleep/sleep_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.11 -// +build !go1.15 +// +build !go1.16 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/sync/memmove_unsafe.go b/pkg/sync/memmove_unsafe.go index ad4a3a37e..1d7780695 100644 --- a/pkg/sync/memmove_unsafe.go +++ b/pkg/sync/memmove_unsafe.go @@ -4,7 +4,7 @@ // license that can be found in the LICENSE file. // +build go1.12 -// +build !go1.15 +// +build !go1.16 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/sync/mutex_unsafe.go b/pkg/sync/mutex_unsafe.go index 3dd15578b..dc034d561 100644 --- a/pkg/sync/mutex_unsafe.go +++ b/pkg/sync/mutex_unsafe.go @@ -4,7 +4,7 @@ // license that can be found in the LICENSE file. // +build go1.13 -// +build !go1.15 +// +build !go1.16 // When updating the build constraint (above), check that syncMutex matches the // standard library sync.Mutex definition. diff --git a/pkg/sync/rwmutex_unsafe.go b/pkg/sync/rwmutex_unsafe.go index ea6cdc447..995c0346e 100644 --- a/pkg/sync/rwmutex_unsafe.go +++ b/pkg/sync/rwmutex_unsafe.go @@ -4,7 +4,7 @@ // license that can be found in the LICENSE file. // +build go1.13 -// +build !go1.15 +// +build !go1.16 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/syncevent/waiter_unsafe.go b/pkg/syncevent/waiter_unsafe.go index 112e0e604..ad271e1a0 100644 --- a/pkg/syncevent/waiter_unsafe.go +++ b/pkg/syncevent/waiter_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.11 -// +build !go1.15 +// +build !go1.16 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD index e57d45f2a..a984f1712 100644 --- a/pkg/tcpip/adapters/gonet/BUILD +++ b/pkg/tcpip/adapters/gonet/BUILD @@ -22,7 +22,6 @@ go_test( size = "small", srcs = ["gonet_test.go"], library = ":gonet", - tags = ["flaky"], deps = [ "//pkg/tcpip", "//pkg/tcpip/header", diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index 6e0db2741..d82ed5205 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -335,6 +335,11 @@ func (c *TCPConn) Read(b []byte) (int, error) { deadline := c.readCancel() numRead := 0 + defer func() { + if numRead != 0 { + c.ep.ModerateRecvBuf(numRead) + } + }() for numRead != len(b) { if len(c.read) == 0 { var err error diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 9bf67686d..20b183da0 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -181,13 +181,13 @@ func (e *Endpoint) NumQueued() int { } // InjectInbound injects an inbound packet. -func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { e.InjectLinkAddr(protocol, "", pkt) } // InjectLinkAddr injects an inbound packet with a remote link address. -func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt stack.PacketBuffer) { - e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt) +func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *stack.PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt) } // Attach saves the stack network-layer dispatcher for use later when packets @@ -229,13 +229,13 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket stores outbound packets into the channel. -func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { // Clone r then release its resource so we only get the relevant fields from // stack.Route without holding a reference to a NIC's endpoint. route := r.Clone() route.Release() p := PacketInfo{ - Pkt: &pkt, + Pkt: pkt, Proto: protocol, GSO: gso, Route: route, diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index affa1bbdf..f34082e1a 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -387,7 +387,7 @@ const ( // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { if e.hdrSize > 0 { // Add ethernet header if needed. eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) @@ -641,8 +641,8 @@ func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) { } // InjectInbound injects an inbound packet. -func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, pkt) +func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, pkt) } // NewInjectable creates a new fd-based InjectableEndpoint. diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 3bfb15a8e..eaee7e5d7 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -45,7 +45,7 @@ const ( type packetInfo struct { raddr tcpip.LinkAddress proto tcpip.NetworkProtocolNumber - contents stack.PacketBuffer + contents *stack.PacketBuffer } type context struct { @@ -103,7 +103,7 @@ func (c *context) cleanup() { } } -func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { c.ch <- packetInfo{remote, protocol, pkt} } @@ -179,7 +179,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u L3HdrLen: header.IPv4MaximumHeaderSize, } } - if err := c.ep.WritePacket(r, gso, proto, stack.PacketBuffer{ + if err := c.ep.WritePacket(r, gso, proto, &stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), Hash: hash, @@ -295,7 +295,7 @@ func TestPreserveSrcAddress(t *testing.T) { // WritePacket panics given a prependable with anything less than // the minimum size of the ethernet header. hdr := buffer.NewPrependable(header.EthernetMinimumSize) - if err := c.ep.WritePacket(r, nil /* gso */, proto, stack.PacketBuffer{ + if err := c.ep.WritePacket(r, nil /* gso */, proto, &stack.PacketBuffer{ Header: hdr, Data: buffer.VectorisedView{}, }); err != nil { @@ -358,7 +358,7 @@ func TestDeliverPacket(t *testing.T) { want := packetInfo{ raddr: raddr, proto: proto, - contents: stack.PacketBuffer{ + contents: &stack.PacketBuffer{ Data: buffer.View(b).ToVectorisedView(), LinkHeader: buffer.View(hdr), }, diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index fe2bf3b0b..2dfd29aa9 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -191,7 +191,7 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { } pkt = pkt[d.e.hdrSize:] - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, stack.PacketBuffer{ + d.e.dispatcher.DeliverNetworkPacket(remote, local, p, &stack.PacketBuffer{ Data: buffer.View(pkt).ToVectorisedView(), LinkHeader: buffer.View(eth), }) diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index cb4cbea69..f04738cfb 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -139,13 +139,13 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(n, BufConfig) - pkt := stack.PacketBuffer{ + pkt := &stack.PacketBuffer{ Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), LinkHeader: buffer.View(eth), } pkt.Data.TrimFront(d.e.hdrSize) - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, pkt) + d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt) // Prepare e.views for another packet: release used views. for i := 0; i < used; i++ { @@ -169,7 +169,7 @@ type recvMMsgDispatcher struct { // iovecs is an array of array of iovec records where each iovec base // pointer and length are initialzed to the corresponding view above, - // except when GSO is neabled then the first iovec in each array of + // except when GSO is enabled then the first iovec in each array of // iovecs points to a buffer for the vnet header which is stripped // before the views are passed up the stack for further processing. iovecs [][]syscall.Iovec @@ -296,12 +296,12 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(k, int(n), BufConfig) - pkt := stack.PacketBuffer{ + pkt := &stack.PacketBuffer{ Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), LinkHeader: buffer.View(eth), } pkt.Data.TrimFront(d.e.hdrSize) - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, pkt) + d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt) // Prepare e.views for another packet: release used views. for i := 0; i < used; i++ { diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 073c84ef9..568c6874f 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -76,7 +76,7 @@ func (*endpoint) Wait() {} // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. -func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) views[0] = pkt.Header.View() views = append(views, pkt.Data.Views()...) @@ -84,7 +84,7 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Netw // Because we're immediately turning around and writing the packet back // to the rx path, we intentionally don't preserve the remote and local // link addresses from the stack.Route we're passed. - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, stack.PacketBuffer{ + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, &stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) @@ -106,7 +106,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { } linkHeader := header.Ethernet(hdr) vv.TrimFront(len(linkHeader)) - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{ + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), &stack.PacketBuffer{ Data: vv, LinkHeader: buffer.View(linkHeader), }) diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index a5478ce17..c69d6b7e9 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -80,8 +80,8 @@ func (m *InjectableEndpoint) IsAttached() bool { } // InjectInbound implements stack.InjectableLinkEndpoint. -func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { - m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, pkt) +func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + m.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, pkt) } // WritePackets writes outbound packets to the appropriate @@ -98,7 +98,7 @@ func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts s // WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint // based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a // route registered in this endpoint. -func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { if endpoint, ok := m.routes[r.RemoteAddress]; ok { return endpoint.WritePacket(r, gso, protocol, pkt) } diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 87c734c1f..0744f66d6 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -50,7 +50,7 @@ func TestInjectableEndpointDispatch(t *testing.T) { hdr.Prepend(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, stack.PacketBuffer{ + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), }) @@ -70,7 +70,7 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { hdr := buffer.NewPrependable(1) hdr.Prepend(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, stack.PacketBuffer{ + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buffer.NewView(0).ToVectorisedView(), }) diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index 54432194d..b5dfb7850 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -102,8 +102,8 @@ func (q *queueDispatcher) dispatchLoop() { } // DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket. -func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { - e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) +func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) } // Attach implements stack.LinkEndpoint.Attach. @@ -146,7 +146,7 @@ func (e *endpoint) GSOMaxSize() uint32 { } // WritePacket implements stack.LinkEndpoint.WritePacket. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { // WritePacket caller's do not set the following fields in PacketBuffer // so we populate them here. newRoute := r.Clone() @@ -154,7 +154,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne pkt.GSOOptions = gso pkt.NetworkProtocolNumber = protocol d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] - if !d.q.enqueue(&pkt) { + if !d.q.enqueue(pkt) { return tcpip.ErrNoBufferSpace } d.newPacketWaker.Assert() diff --git a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go index 0b5a6cf49..99313ee25 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go +++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go @@ -14,7 +14,7 @@ // +build linux,amd64 linux,arm64 // +build go1.12 -// +build !go1.15 +// +build !go1.16 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 0796d717e..0374a2441 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -185,7 +185,7 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { // Add the ethernet header here. eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) pkt.LinkHeader = buffer.View(eth) @@ -275,7 +275,7 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { // Send packet up the stack. eth := header.Ethernet(b[:header.EthernetMinimumSize]) - d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), stack.PacketBuffer{ + d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), &stack.PacketBuffer{ Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), LinkHeader: buffer.View(eth), }) diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 33f640b85..28a2e88ba 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -131,7 +131,7 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress return c } -func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { c.mu.Lock() c.packets = append(c.packets, packetInfo{ addr: remoteLinkAddr, @@ -273,7 +273,7 @@ func TestSimpleSend(t *testing.T) { randomFill(buf) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -345,7 +345,7 @@ func TestPreserveSrcAddressInSend(t *testing.T) { hdr := buffer.NewPrependable(header.EthernetMinimumSize) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{ Header: hdr, }); err != nil { t.Fatalf("WritePacket failed: %v", err) @@ -401,7 +401,7 @@ func TestFillTxQueue(t *testing.T) { for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -419,7 +419,7 @@ func TestFillTxQueue(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != want { @@ -447,7 +447,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Send two packets so that the id slice has at least two slots. for i := 2; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -470,7 +470,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { ids := make(map[uint64]struct{}) for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -488,7 +488,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != want { @@ -514,7 +514,7 @@ func TestFillTxMemory(t *testing.T) { ids := make(map[uint64]struct{}) for i := queueDataSize / bufferSize; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -533,7 +533,7 @@ func TestFillTxMemory(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }) @@ -561,7 +561,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // until there is only one buffer left. for i := queueDataSize/bufferSize - 1; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -577,7 +577,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) uu := buffer.NewView(bufferSize).ToVectorisedView() - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: uu, }); err != want { @@ -588,7 +588,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // Attempt to write the one-buffer packet again. It must succeed. { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index da1c520ae..ae3186314 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -120,9 +120,9 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) ( // DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is // called by the link-layer endpoint being wrapped when a packet arrives, and // logs the packet before forwarding to the actual dispatcher. -func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { - e.dumpPacket("recv", nil, protocol, &pkt) - e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) +func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.dumpPacket("recv", nil, protocol, pkt) + e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) } // Attach implements the stack.LinkEndpoint interface. It saves the dispatcher @@ -208,8 +208,8 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw // WritePacket implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { - e.dumpPacket("send", gso, protocol, &pkt) +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + e.dumpPacket("send", gso, protocol, pkt) return e.lower.WritePacket(r, gso, protocol, pkt) } diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 617446ea2..6bc9033d0 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -213,7 +213,7 @@ func (d *Device) Write(data []byte) (int64, error) { remote = tcpip.LinkAddress(zeroMAC[:]) } - pkt := stack.PacketBuffer{ + pkt := &stack.PacketBuffer{ Data: buffer.View(data).ToVectorisedView(), } if ethHdr != nil { diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 2b3741276..949b3f2b2 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -50,12 +50,12 @@ func New(lower stack.LinkEndpoint) *Endpoint { // It is called by the link-layer endpoint being wrapped when a packet arrives, // and only forwards to the actual dispatcher if Wait or WaitDispatch haven't // been called. -func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { if !e.dispatchGate.Enter() { return } - e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) + e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) e.dispatchGate.Leave() } @@ -99,7 +99,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket implements stack.LinkEndpoint.WritePacket. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { if !e.writeGate.Enter() { return nil } diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 54eb5322b..63bf40562 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -35,7 +35,7 @@ type countedEndpoint struct { dispatcher stack.NetworkDispatcher } -func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { e.dispatchCount++ } @@ -65,7 +65,7 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { e.writeCount++ return nil } @@ -89,21 +89,21 @@ func TestWaitWrite(t *testing.T) { wep := New(ep) // Write and check that it goes through. - wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) if want := 1; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on dispatches, then try to write. It must go through. wep.WaitDispatch() - wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on writes, then try to write. It must not go through. wep.WaitWrite() - wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } @@ -120,21 +120,21 @@ func TestWaitDispatch(t *testing.T) { } // Dispatch and check that it goes through. - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) if want := 1; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on writes, then try to dispatch. It must go through. wep.WaitWrite() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on dispatches, then try to dispatch. It must not go through. wep.WaitDispatch() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 9d0797af7..ea1acba83 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -80,7 +80,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { func (e *endpoint) Close() {} -func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -94,11 +94,11 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList return 0, tcpip.ErrNotSupported } -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } -func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { v, ok := pkt.Data.PullUp(header.ARPSize) if !ok { return @@ -122,7 +122,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) copy(packet.HardwareAddressTarget(), h.HardwareAddressSender()) copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) - e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ + e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ Header: hdr, }) fallthrough // also fill the cache from requests @@ -177,7 +177,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. copy(h.ProtocolAddressSender(), localAddr) copy(h.ProtocolAddressTarget(), addr) - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ Header: hdr, }) } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 1646d9cde..66e67429c 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -103,7 +103,7 @@ func TestDirectRequest(t *testing.T) { inject := func(addr tcpip.Address) { copy(h.ProtocolAddressTarget(), addr) - c.linkEP.InjectInbound(arp.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{ Data: v.ToVectorisedView(), }) } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 4c20301c6..d9b62f2db 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -96,7 +96,7 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // DeliverTransportPacket is called by network endpoints after parsing incoming // packets. This is used by the test object to verify that the results of the // parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt stack.PacketBuffer) { +func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) { t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress) t.dataCalls++ } @@ -104,7 +104,7 @@ func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.Trans // DeliverTransportControlPacket is called by network endpoints after parsing // incoming control (ICMP) packets. This is used by the test object to verify // that the results of the parsing are expected. -func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { t.checkValues(trans, pkt.Data, remote, local) if typ != t.typ { t.t.Errorf("typ = %v, want %v", typ, t.typ) @@ -150,7 +150,7 @@ func (*testObject) Wait() {} // WritePacket is called by network endpoints after producing a packet and // writing it to the link endpoint. This is used by the test object to verify // that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { var prot tcpip.TransportProtocolNumber var srcAddr tcpip.Address var dstAddr tcpip.Address @@ -246,7 +246,11 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{ + Protocol: 123, + TTL: 123, + TOS: stack.DefaultTOS, + }, &stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }); err != nil { @@ -289,7 +293,7 @@ func TestIPv4Receive(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, stack.PacketBuffer{ + ep.HandlePacket(&r, &stack.PacketBuffer{ Data: view.ToVectorisedView(), }) if o.dataCalls != 1 { @@ -379,7 +383,7 @@ func TestIPv4ReceiveControl(t *testing.T) { o.extra = c.expectedExtra vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, stack.PacketBuffer{ + ep.HandlePacket(&r, &stack.PacketBuffer{ Data: vv, }) if want := c.expectedCount; o.controlCalls != want { @@ -444,7 +448,7 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send first segment. - ep.HandlePacket(&r, stack.PacketBuffer{ + ep.HandlePacket(&r, &stack.PacketBuffer{ Data: frag1.ToVectorisedView(), }) if o.dataCalls != 0 { @@ -452,7 +456,7 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send second segment. - ep.HandlePacket(&r, stack.PacketBuffer{ + ep.HandlePacket(&r, &stack.PacketBuffer{ Data: frag2.ToVectorisedView(), }) if o.dataCalls != 1 { @@ -487,7 +491,11 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{ + Protocol: 123, + TTL: 123, + TOS: stack.DefaultTOS, + }, &stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }); err != nil { @@ -530,7 +538,7 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, stack.PacketBuffer{ + ep.HandlePacket(&r, &stack.PacketBuffer{ Data: view.ToVectorisedView(), }) if o.dataCalls != 1 { @@ -644,7 +652,7 @@ func TestIPv6ReceiveControl(t *testing.T) { // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) - ep.HandlePacket(&r, stack.PacketBuffer{ + ep.HandlePacket(&r, &stack.PacketBuffer{ Data: view[:len(view)-c.trunc].ToVectorisedView(), }) if want := c.expectedCount; o.controlCalls != want { diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 4cbefe5ab..d1c3ae835 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -24,7 +24,7 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) if !ok { return @@ -56,7 +56,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { +func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) @@ -88,7 +88,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, stack.PacketBuffer{ + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, &stack.PacketBuffer{ Data: pkt.Data.Clone(nil), NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...), }) @@ -102,7 +102,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { pkt.SetChecksum(0) pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) sent := stats.ICMP.V4PacketsSent - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: header.ICMPv4ProtocolNumber, + TTL: r.DefaultTTL(), + TOS: stack.DefaultTOS, + }, &stack.PacketBuffer{ Header: hdr, Data: vv, TransportHeader: buffer.View(pkt), diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 64046cbbf..959f7e007 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -129,7 +129,7 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { // packet's stated length matches the length of the header+payload. mtu // includes the IP header and options. This does not support the DontFragment // IP flag. -func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt *stack.PacketBuffer) *tcpip.Error { // This packet is too big, it needs to be fragmented. ip := header.IPv4(pkt.Header.View()) flags := ip.Flags() @@ -169,7 +169,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, if i > 0 { newPayload := pkt.Data.Clone(nil) newPayload.CapLength(innerMTU) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ Header: pkt.Header, Data: newPayload, NetworkHeader: buffer.View(h), @@ -188,7 +188,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, newPayload := pkt.Data.Clone(nil) newPayloadLength := outerMTU - pkt.Header.UsedLength() newPayload.CapLength(newPayloadLength) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ Header: pkt.Header, Data: newPayload, NetworkHeader: buffer.View(h), @@ -202,7 +202,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, startOfHdr := pkt.Header startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU) emptyVV := buffer.NewVectorisedView(0, []buffer.View{}) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ Header: startOfHdr, Data: emptyVV, NetworkHeader: buffer.View(h), @@ -245,7 +245,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) @@ -253,43 +253,29 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // iptables filtering. All packets that reach here are locally // generated. ipt := e.stack.IPTables() - if ok := ipt.Check(stack.Output, &pkt, gso, r, "", nicName); !ok { + if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. return nil } + // If the packet is manipulated as per NAT Ouput rules, handle packet + // based on destination address and do not send the packet to link layer. + // TODO(gvisor.dev/issue/170): We should do this for every packet, rather than + // only NATted packets, but removing this check short circuits broadcasts + // before they are sent out to other hosts. if pkt.NatDone { - // If the packet is manipulated as per NAT Ouput rules, handle packet - // based on destination address and do not send the packet to link layer. netHeader := header.IPv4(pkt.NetworkHeader) ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) if err == nil { - src := netHeader.SourceAddress() - dst := netHeader.DestinationAddress() - route := r.ReverseRoute(src, dst) - - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) - packet := stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)} - ep.HandlePacket(&route, packet) + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + handleLoopback(&route, pkt, ep) return nil } } if r.Loop&stack.PacketLoop != 0 { - // The inbound path expects the network header to still be in - // the PacketBuffer's Data field. - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - - e.HandlePacket(&loopedR, stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), - }) - + handleLoopback(&loopedR, pkt, e) loopedR.Release() } if r.Loop&stack.PacketOut == 0 { @@ -305,6 +291,17 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw return nil } +func handleLoopback(route *stack.Route, pkt *stack.PacketBuffer, ep stack.NetworkEndpoint) { + // The inbound path expects the network header to still be in + // the PacketBuffer's Data field. + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) + ep.HandlePacket(route, &stack.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + }) +} + // WritePackets implements stack.NetworkEndpoint.WritePackets. func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { @@ -347,18 +344,12 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe src := netHeader.SourceAddress() dst := netHeader.DestinationAddress() route := r.ReverseRoute(src, dst) - - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) - packet := stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)} - ep.HandlePacket(&route, packet) + handleLoopback(&route, pkt, ep) n++ continue } } - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, *pkt); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) return n, err } @@ -370,7 +361,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) @@ -426,7 +417,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) if !ok { r.Stats().IP.MalformedPacketsReceived.Increment() @@ -447,7 +438,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. ipt := e.stack.IPTables() - if ok := ipt.Check(stack.Input, &pkt, nil, nil, "", ""); !ok { + if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. return } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 36035c820..c208ebd99 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -114,7 +114,7 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer. // comparePayloads compared the contents of all the packets against the contents // of the source packet. -func compareFragments(t *testing.T, packets []stack.PacketBuffer, sourcePacketInfo stack.PacketBuffer, mtu uint32) { +func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) { t.Helper() // Make a complete array of the sourcePacketInfo packet. source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) @@ -174,7 +174,7 @@ func compareFragments(t *testing.T, packets []stack.PacketBuffer, sourcePacketIn type errorChannel struct { *channel.Endpoint - Ch chan stack.PacketBuffer + Ch chan *stack.PacketBuffer packetCollectorErrors []*tcpip.Error } @@ -184,7 +184,7 @@ type errorChannel struct { func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { return &errorChannel{ Endpoint: channel.New(size, mtu, linkAddr), - Ch: make(chan stack.PacketBuffer, size), + Ch: make(chan *stack.PacketBuffer, size), packetCollectorErrors: packetCollectorErrors, } } @@ -203,7 +203,7 @@ func (e *errorChannel) Drain() int { } // WritePacket stores outbound packets into the channel. -func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { select { case e.Ch <- pkt: default: @@ -282,13 +282,17 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) - source := stack.PacketBuffer{ + source := &stack.PacketBuffer{ Header: hdr, // Save the source payload because WritePacket will modify it. Data: payload.Clone(nil), } c := buildContext(t, nil, ft.mtu) - err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: 42, + TOS: stack.DefaultTOS, + }, &stack.PacketBuffer{ Header: hdr, Data: payload, }) @@ -296,7 +300,7 @@ func TestFragmentation(t *testing.T) { t.Errorf("err got %v, want %v", err, nil) } - var results []stack.PacketBuffer + var results []*stack.PacketBuffer L: for { select { @@ -338,7 +342,11 @@ func TestFragmentationErrors(t *testing.T) { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) c := buildContext(t, ft.packetCollectorErrors, ft.mtu) - err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: 42, + TOS: stack.DefaultTOS, + }, &stack.PacketBuffer{ Header: hdr, Data: payload, }) @@ -460,7 +468,7 @@ func TestInvalidFragments(t *testing.T) { s.CreateNIC(nicID, sniffer.New(ep)) for _, pkt := range tc.packets { - ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.PacketBuffer{ + ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, &stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), }) } @@ -698,7 +706,7 @@ func TestReceiveFragments(t *testing.T) { vv := hdr.View().ToVectorisedView() vv.AppendView(frag.payload) - e.InjectInbound(header.IPv4ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{ Data: vv, }) } diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index bdf3a0d25..b62fb1de6 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -27,7 +27,7 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) if !ok { return @@ -70,7 +70,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) { +func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt *stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived @@ -288,7 +288,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // // The IP Hop Limit field has a value of 255, i.e., the packet // could not possibly have been forwarded by a router. - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, }); err != nil { sent.Dropped.Increment() @@ -390,7 +390,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: pkt.Data, }); err != nil { @@ -532,7 +532,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. }) // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ Header: hdr, }) } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index d412ff688..a720f626f 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -57,7 +57,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } -func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, stack.PacketBuffer) *tcpip.Error { +func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { return nil } @@ -67,7 +67,7 @@ type stubDispatcher struct { stack.TransportDispatcher } -func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, stack.PacketBuffer) { +func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) { } type stubLinkAddressCache struct { @@ -189,7 +189,7 @@ func TestICMPCounts(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, stack.PacketBuffer{ + ep.HandlePacket(&r, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) } @@ -328,7 +328,7 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()} size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size() vv := buffer.NewVectorisedView(size, views) - args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), stack.PacketBuffer{ + args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), &stack.PacketBuffer{ Data: vv, }) } @@ -563,7 +563,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) } @@ -740,7 +740,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) } @@ -918,7 +918,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), }) } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index daf1fcbc6..0d94ad122 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -116,7 +116,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) @@ -128,7 +128,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, stack.PacketBuffer{ + e.HandlePacket(&loopedR, &stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) @@ -163,14 +163,14 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet // supported by IPv6. -func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { +func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { // TODO(b/146666412): Support IPv6 header-included packets. return tcpip.ErrNotSupported } // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize) if !ok { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 841a0cb7a..213ff64f2 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -65,7 +65,7 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -123,7 +123,7 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -637,7 +637,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { DstAddr: addr2, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -1238,7 +1238,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { vv := hdr.View().ToVectorisedView() vv.Append(f.data) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: vv, }) } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 12b70f7e9..3c141b91b 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -136,7 +136,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -380,7 +380,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.PacketBuffer{ + e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -497,7 +497,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -568,7 +568,7 @@ func TestNDPValidation(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(r, stack.PacketBuffer{ + ep.HandlePacket(r, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) } @@ -884,7 +884,7 @@ func TestRouterAdvertValidation(t *testing.T) { t.Fatalf("got rxRA = %d, want = 0", got) } - e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 7d1ede1f2..d4053be08 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -186,7 +186,7 @@ func parseHeaders(pkt *PacketBuffer) { } // packetToTuple converts packet to a tuple in original direction. -func packetToTuple(pkt PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) { +func packetToTuple(pkt *PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) { var tuple connTrackTuple netHeader := header.IPv4(pkt.NetworkHeader) @@ -265,7 +265,7 @@ func (ct *ConnTrackTable) connTrackForPacket(pkt *PacketBuffer, hook Hook, creat } var dir ctDirection - tuple, err := packetToTuple(*pkt, hook) + tuple, err := packetToTuple(pkt, hook) if err != nil { return nil, dir } diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/forwarder.go index 6b64cd37f..3eff141e6 100644 --- a/pkg/tcpip/stack/forwarder.go +++ b/pkg/tcpip/stack/forwarder.go @@ -32,7 +32,7 @@ type pendingPacket struct { nic *NIC route *Route proto tcpip.NetworkProtocolNumber - pkt PacketBuffer + pkt *PacketBuffer } type forwardQueue struct { @@ -50,7 +50,7 @@ func newForwardQueue() *forwardQueue { return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)} } -func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { +func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { shouldWait := false f.Lock() diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 8084d50bc..63537aaad 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -68,7 +68,7 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { return &f.id } -func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) { +func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { // Consume the network header. b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) if !ok { @@ -96,7 +96,7 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu return f.proto.Number() } -func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error { +func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { // Add the protocol's header to the packet and send it to the link // endpoint. b := pkt.Header.Prepend(fwdTestNetHeaderLen) @@ -112,7 +112,7 @@ func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuf panic("not implemented") } -func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error { +func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -190,7 +190,7 @@ func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumb type fwdTestPacketInfo struct { RemoteLinkAddress tcpip.LinkAddress LocalLinkAddress tcpip.LinkAddress - Pkt PacketBuffer + Pkt *PacketBuffer } type fwdTestLinkEndpoint struct { @@ -203,13 +203,13 @@ type fwdTestLinkEndpoint struct { } // InjectInbound injects an inbound packet. -func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { +func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { e.InjectLinkAddr(protocol, "", pkt) } // InjectLinkAddr injects an inbound packet with a remote link address. -func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt PacketBuffer) { - e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt) +func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt) } // Attach saves the stack network-layer dispatcher for use later when packets @@ -251,7 +251,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error { +func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { p := fwdTestPacketInfo{ RemoteLinkAddress: r.RemoteLinkAddress, LocalLinkAddress: r.LocalLinkAddress, @@ -270,7 +270,7 @@ func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.Netw func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.WritePacket(r, gso, protocol, *pkt) + e.WritePacket(r, gso, protocol, pkt) n++ } @@ -280,7 +280,7 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := fwdTestPacketInfo{ - Pkt: PacketBuffer{Data: vv}, + Pkt: &PacketBuffer{Data: vv}, } select { @@ -362,7 +362,7 @@ func TestForwardingWithStaticResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -399,7 +399,7 @@ func TestForwardingWithFakeResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -430,7 +430,7 @@ func TestForwardingWithNoResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -460,7 +460,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // not be forwarded. buf := buffer.NewView(30) buf[0] = 4 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -468,7 +468,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // forwarded to NIC 2. buf = buffer.NewView(30) buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -510,7 +510,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { for i := 0; i < 2; i++ { buf := buffer.NewView(30) buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -557,7 +557,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { buf[0] = 3 // Set the packet sequence number. binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -610,7 +610,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // maxPendingResolutions + 7). buf := buffer.NewView(30) buf[0] = byte(3 + i) - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 443423b3c..4e9b404c8 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -16,7 +16,6 @@ package stack import ( "fmt" - "strings" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -44,11 +43,11 @@ const HookUnset = -1 // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. -func DefaultTables() IPTables { +func DefaultTables() *IPTables { // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for // iotas. - return IPTables{ - Tables: map[string]Table{ + return &IPTables{ + tables: map[string]Table{ TablenameNat: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, @@ -107,7 +106,7 @@ func DefaultTables() IPTables { UserChains: map[string]int{}, }, }, - Priorities: map[Hook][]string{ + priorities: map[Hook][]string{ Input: []string{TablenameNat, TablenameFilter}, Prerouting: []string{TablenameMangle, TablenameNat}, Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, @@ -159,6 +158,36 @@ func EmptyNatTable() Table { } } +// GetTable returns table by name. +func (it *IPTables) GetTable(name string) (Table, bool) { + it.mu.RLock() + defer it.mu.RUnlock() + t, ok := it.tables[name] + return t, ok +} + +// ReplaceTable replaces or inserts table by name. +func (it *IPTables) ReplaceTable(name string, table Table) { + it.mu.Lock() + defer it.mu.Unlock() + it.tables[name] = table +} + +// ModifyTables acquires write-lock and calls fn with internal name-to-table +// map. This function can be used to update multiple tables atomically. +func (it *IPTables) ModifyTables(fn func(map[string]Table)) { + it.mu.Lock() + defer it.mu.Unlock() + fn(it.tables) +} + +// GetPriorities returns slice of priorities associated with hook. +func (it *IPTables) GetPriorities(hook Hook) []string { + it.mu.RLock() + defer it.mu.RUnlock() + return it.priorities[hook] +} + // A chainVerdict is what a table decides should be done with a packet. type chainVerdict int @@ -185,8 +214,8 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr it.connections.HandlePacket(pkt, hook, gso, r) // Go through each table containing the hook. - for _, tablename := range it.Priorities[hook] { - table := it.Tables[tablename] + for _, tablename := range it.GetPriorities(hook) { + table, _ := it.GetTable(tablename) ruleIdx := table.BuiltinChains[hook] switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { // If the table returns Accept, move on to the next table. @@ -314,7 +343,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx } // Check whether the packet matches the IP header filter. - if !filterMatch(rule.Filter, header.IPv4(pkt.NetworkHeader), hook, nicName) { + if !rule.Filter.match(header.IPv4(pkt.NetworkHeader), hook, nicName) { // Continue on to the next rule. return RuleJump, ruleIdx + 1 } @@ -322,7 +351,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // Go through each rule matcher. If they all match, run // the rule target. for _, matcher := range rule.Matchers { - matches, hotdrop := matcher.Match(hook, *pkt, "") + matches, hotdrop := matcher.Match(hook, pkt, "") if hotdrop { return RuleDrop, 0 } @@ -335,47 +364,3 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // All the matchers matched, so run the target. return rule.Target.Action(pkt, &it.connections, hook, gso, r, address) } - -func filterMatch(filter IPHeaderFilter, hdr header.IPv4, hook Hook, nicName string) bool { - // TODO(gvisor.dev/issue/170): Support other fields of the filter. - // Check the transport protocol. - if filter.Protocol != 0 && filter.Protocol != hdr.TransportProtocol() { - return false - } - - // Check the destination IP. - dest := hdr.DestinationAddress() - matches := true - for i := range filter.Dst { - if dest[i]&filter.DstMask[i] != filter.Dst[i] { - matches = false - break - } - } - if matches == filter.DstInvert { - return false - } - - // Check the output interface. - // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING - // hooks after supported. - if hook == Output { - n := len(filter.OutputInterface) - if n == 0 { - return true - } - - // If the interface name ends with '+', any interface which begins - // with the name should be matched. - ifName := filter.OutputInterface - matches = true - if strings.HasSuffix(ifName, "+") { - matches = strings.HasPrefix(nicName, ifName[:n-1]) - } else { - matches = nicName == ifName - } - return filter.OutputInterfaceInvert != matches - } - - return true -} diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index fe06007ae..4a6a5c6f1 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -15,7 +15,11 @@ package stack import ( + "strings" + "sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" ) // A Hook specifies one of the hooks built into the network stack. @@ -75,13 +79,17 @@ const ( // IPTables holds all the tables for a netstack. type IPTables struct { - // Tables maps table names to tables. User tables have arbitrary names. - Tables map[string]Table + // mu protects tables and priorities. + mu sync.RWMutex + + // tables maps table names to tables. User tables have arbitrary names. mu + // needs to be locked for accessing. + tables map[string]Table - // Priorities maps each hook to a list of table names. The order of the + // priorities maps each hook to a list of table names. The order of the // list is the order in which each table should be visited for that - // hook. - Priorities map[Hook][]string + // hook. mu needs to be locked for accessing. + priorities map[Hook][]string connections ConnTrackTable } @@ -159,6 +167,16 @@ type IPHeaderFilter struct { // comparison. DstInvert bool + // Src matches the source IP address. + Src tcpip.Address + + // SrcMask masks bits of the source IP address when comparing with Src. + SrcMask tcpip.Address + + // SrcInvert inverts the meaning of the source IP check, i.e. when true the + // filter will match packets that fail the source comparison. + SrcInvert bool + // OutputInterface matches the name of the outgoing interface for the // packet. OutputInterface string @@ -173,6 +191,55 @@ type IPHeaderFilter struct { OutputInterfaceInvert bool } +// match returns whether hdr matches the filter. +func (fl IPHeaderFilter) match(hdr header.IPv4, hook Hook, nicName string) bool { + // TODO(gvisor.dev/issue/170): Support other fields of the filter. + // Check the transport protocol. + if fl.Protocol != 0 && fl.Protocol != hdr.TransportProtocol() { + return false + } + + // Check the source and destination IPs. + if !filterAddress(hdr.DestinationAddress(), fl.DstMask, fl.Dst, fl.DstInvert) || !filterAddress(hdr.SourceAddress(), fl.SrcMask, fl.Src, fl.SrcInvert) { + return false + } + + // Check the output interface. + // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING + // hooks after supported. + if hook == Output { + n := len(fl.OutputInterface) + if n == 0 { + return true + } + + // If the interface name ends with '+', any interface which begins + // with the name should be matched. + ifName := fl.OutputInterface + matches := true + if strings.HasSuffix(ifName, "+") { + matches = strings.HasPrefix(nicName, ifName[:n-1]) + } else { + matches = nicName == ifName + } + return fl.OutputInterfaceInvert != matches + } + + return true +} + +// filterAddress returns whether addr matches the filter. +func filterAddress(addr, mask, filterAddr tcpip.Address, invert bool) bool { + matches := true + for i := range filterAddr { + if addr[i]&mask[i] != filterAddr[i] { + matches = false + break + } + } + return matches != invert +} + // A Matcher is the interface for matching packets. type Matcher interface { // Name returns the name of the Matcher. @@ -183,7 +250,7 @@ type Matcher interface { // used for suspicious packets. // // Precondition: packet.NetworkHeader is set. - Match(hook Hook, packet PacketBuffer, interfaceName string) (matches bool, hotdrop bool) + Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool) } // A Target is the interface for taking an action for a packet. diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 526c7d6ff..ae7a8f740 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -750,7 +750,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error { Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS, - }, PacketBuffer{Header: hdr}, + }, &PacketBuffer{Header: hdr}, ); err != nil { sent.Dropped.Increment() return err @@ -1881,7 +1881,7 @@ func (ndp *ndpState) startSolicitingRouters() { Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS, - }, PacketBuffer{Header: hdr}, + }, &PacketBuffer{Header: hdr}, ); err != nil { sent.Dropped.Increment() log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index b3d174cdd..58f1ebf60 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -613,7 +613,7 @@ func TestDADFail(t *testing.T) { // Receive a packet to simulate multiple nodes owning or // attempting to own the same address. hdr := test.makeBuf(addr1) - e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -935,7 +935,7 @@ func TestSetNDPConfigurations(t *testing.T) { // raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options // and DHCPv6 configurations specified. -func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) stack.PacketBuffer { +func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length()) hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) pkt := header.ICMPv6(hdr.Prepend(icmpSize)) @@ -970,14 +970,14 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo DstAddr: header.IPv6AllNodesMulticastAddress, }) - return stack.PacketBuffer{Data: hdr.View().ToVectorisedView()} + return &stack.PacketBuffer{Data: hdr.View().ToVectorisedView()} } // raBufWithOpts returns a valid NDP Router Advertisement with options. // // Note, raBufWithOpts does not populate any of the RA fields other than the // Router Lifetime. -func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) stack.PacketBuffer { +func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer) } @@ -986,7 +986,7 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ // // Note, raBufWithDHCPv6 does not populate any of the RA fields other than the // DHCPv6 related ones. -func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) stack.PacketBuffer { +func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) *stack.PacketBuffer { return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{}) } @@ -994,7 +994,7 @@ func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bo // // Note, raBuf does not populate any of the RA fields other than the // Router Lifetime. -func raBuf(ip tcpip.Address, rl uint16) stack.PacketBuffer { +func raBuf(ip tcpip.Address, rl uint16) *stack.PacketBuffer { return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{}) } @@ -1003,7 +1003,7 @@ func raBuf(ip tcpip.Address, rl uint16) stack.PacketBuffer { // // Note, raBufWithPI does not populate any of the RA fields other than the // Router Lifetime. -func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) stack.PacketBuffer { +func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) *stack.PacketBuffer { flags := uint8(0) if onLink { // The OnLink flag is the 7th bit in the flags byte. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 54103fdb3..6664aea06 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1153,7 +1153,7 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { return joins != 0 } -func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt PacketBuffer) { +func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt *PacketBuffer) { r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remotelinkAddr @@ -1167,7 +1167,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, // Note that the ownership of the slice backing vv is retained by the caller. // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. -func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { +func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { n.mu.RLock() enabled := n.mu.enabled // If the NIC is not yet enabled, don't receive any packets. @@ -1229,18 +1229,19 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet. - if protocol == header.IPv4ProtocolNumber { + // Loopback traffic skips the prerouting chain. + if protocol == header.IPv4ProtocolNumber && !n.isLoopback() { // iptables filtering. ipt := n.stack.IPTables() address := n.primaryAddress(protocol) - if ok := ipt.Check(Prerouting, &pkt, nil, nil, address.Address, ""); !ok { + if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok { // iptables is telling us to drop the packet. return } } if ref := n.getRef(protocol, dst); ref != nil { - handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, pkt) + handlePacket(protocol, dst, src, n.linkEP.LinkAddress(), remote, ref, pkt) return } @@ -1298,24 +1299,27 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } } -func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { +func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 { pkt.Header = buffer.NewPrependable(linkHeaderLen) } + // WritePacket takes ownership of pkt, calculate numBytes first. + numBytes := pkt.Header.UsedLength() + pkt.Data.Size() + if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return } n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size())) + n.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) } // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) { +func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -1362,7 +1366,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // DeliverTransportControlPacket delivers control packets to the appropriate // transport protocol endpoint. -func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer) { +func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) { state, ok := n.stack.transportProtocols[trans] if !ok { return diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index d672fc157..fea46158c 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -44,7 +44,7 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { t.FailNow() } - nic.DeliverNetworkPacket(nil, "", "", 0, PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) + nic.DeliverNetworkPacket("", "", 0, &PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { t.Errorf("got DisabledRx.Packets = %d, want = 1", got) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 926df4d7b..1b5da6017 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -24,6 +24,8 @@ import ( // multiple endpoints. Clone() should be called in such cases so that // modifications to the Data field do not affect other copies. type PacketBuffer struct { + _ noCopy + // PacketBufferEntry is used to build an intrusive list of // PacketBuffers. PacketBufferEntry @@ -82,7 +84,32 @@ type PacketBuffer struct { // VectorisedView but does not deep copy the underlying bytes. // // Clone also does not deep copy any of its other fields. -func (pk PacketBuffer) Clone() PacketBuffer { - pk.Data = pk.Data.Clone(nil) - return pk +// +// FIXME(b/153685824): Data gets copied but not other header references. +func (pk *PacketBuffer) Clone() *PacketBuffer { + return &PacketBuffer{ + PacketBufferEntry: pk.PacketBufferEntry, + Data: pk.Data.Clone(nil), + Header: pk.Header, + LinkHeader: pk.LinkHeader, + NetworkHeader: pk.NetworkHeader, + TransportHeader: pk.TransportHeader, + Hash: pk.Hash, + Owner: pk.Owner, + EgressRoute: pk.EgressRoute, + GSOOptions: pk.GSOOptions, + NetworkProtocolNumber: pk.NetworkProtocolNumber, + NatDone: pk.NatDone, + } } + +// noCopy may be embedded into structs which must not be copied +// after the first use. +// +// See https://golang.org/issues/8005#issuecomment-190753527 +// for details. +type noCopy struct{} + +// Lock is a no-op used by -copylocks checker from `go vet`. +func (*noCopy) Lock() {} +func (*noCopy) Unlock() {} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index b331427c6..94f177841 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -67,12 +67,12 @@ type TransportEndpoint interface { // this transport endpoint. It sets pkt.TransportHeader. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) + HandlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) // HandleControlPacket is called by the stack when new control (e.g. // ICMP) packets arrive to this transport endpoint. // HandleControlPacket takes ownership of pkt. - HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) + HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) // Abort initiates an expedited endpoint teardown. It puts the endpoint // in a closed state and frees all resources associated with it. This @@ -100,7 +100,7 @@ type RawTransportEndpoint interface { // layer up. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt PacketBuffer) + HandlePacket(r *Route, pkt *PacketBuffer) } // PacketEndpoint is the interface that needs to be implemented by packet @@ -118,7 +118,7 @@ type PacketEndpoint interface { // should construct its own ethernet header for applications. // // HandlePacket takes ownership of pkt. - HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt PacketBuffer) + HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // TransportProtocol is the interface that needs to be implemented by transport @@ -150,7 +150,7 @@ type TransportProtocol interface { // stats purposes only). // // HandleUnknownDestinationPacket takes ownership of pkt. - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt PacketBuffer) bool + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -180,7 +180,7 @@ type TransportDispatcher interface { // pkt.NetworkHeader must be set before calling DeliverTransportPacket. // // DeliverTransportPacket takes ownership of pkt. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) + DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. @@ -189,7 +189,7 @@ type TransportDispatcher interface { // DeliverTransportControlPacket. // // DeliverTransportControlPacket takes ownership of pkt. - DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer) + DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) } // PacketLooping specifies where an outbound packet should be sent. @@ -240,17 +240,18 @@ type NetworkEndpoint interface { MaxHeaderLength() uint16 // WritePacket writes a packet to the given destination address and - // protocol. It sets pkt.NetworkHeader. pkt.TransportHeader must have - // already been set. - WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error + // protocol. It takes ownership of pkt. pkt.TransportHeader must have already + // been set. + WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error // WritePackets writes packets to the given destination address and - // protocol. pkts must not be zero length. + // protocol. pkts must not be zero length. It takes ownership of pkts and + // underlying packets. WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network - // header to the given destination address. - WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error + // header to the given destination address. It takes ownership of pkt. + WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error // ID returns the network protocol endpoint ID. ID() *NetworkEndpointID @@ -265,7 +266,7 @@ type NetworkEndpoint interface { // this network endpoint. It sets pkt.NetworkHeader. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt PacketBuffer) + HandlePacket(r *Route, pkt *PacketBuffer) // Close is called when the endpoint is reomved from a stack. Close() @@ -326,7 +327,7 @@ type NetworkDispatcher interface { // packets sent via loopback), and won't have the field set. // // DeliverNetworkPacket takes ownership of pkt. - DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) + DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // LinkEndpointCapabilities is the type associated with the capabilities @@ -382,17 +383,17 @@ type LinkEndpoint interface { LinkAddress() tcpip.LinkAddress // WritePacket writes a packet with the given protocol through the - // given route. It sets pkt.LinkHeader if a link layer header exists. - // pkt.NetworkHeader and pkt.TransportHeader must have already been - // set. + // given route. It takes ownership of pkt. pkt.NetworkHeader and + // pkt.TransportHeader must have already been set. // // To participate in transparent bridging, a LinkEndpoint implementation // should call eth.Encode with header.EthernetFields.SrcAddr set to // r.LocalLinkAddress if it is provided. - WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error + WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error // WritePackets writes packets with the given protocol through the - // given route. pkts must not be zero length. + // given route. pkts must not be zero length. It takes ownership of pkts and + // underlying packets. // // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, it may @@ -400,7 +401,7 @@ type LinkEndpoint interface { WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) // WriteRawPacket writes a packet directly to the link. The packet - // should already have an ethernet header. + // should already have an ethernet header. It takes ownership of vv. WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error // Attach attaches the data link layer endpoint to the network-layer @@ -430,7 +431,7 @@ type InjectableLinkEndpoint interface { LinkEndpoint // InjectInbound injects an inbound packet. - InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) + InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) // InjectOutbound writes a fully formed outbound packet directly to the // link. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 150297ab9..f5b6ca0b9 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -153,17 +153,20 @@ func (r *Route) IsResolutionRequired() bool { } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error { +func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } + // WritePacket takes ownership of pkt, calculate numBytes first. + numBytes := pkt.Header.UsedLength() + pkt.Data.Size() + err := r.ref.ep.WritePacket(r, gso, params, pkt) if err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size())) + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) } return err } @@ -175,9 +178,12 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead return 0, tcpip.ErrInvalidEndpointState } + // WritePackets takes ownership of pkt, calculate length first. + numPkts := pkts.Len() + n, err := r.ref.ep.WritePackets(r, gso, pkts, params) if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(numPkts - n)) } r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n)) @@ -193,17 +199,20 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (r *Route) WriteHeaderIncludedPacket(pkt PacketBuffer) *tcpip.Error { +func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } + // WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first. + numBytes := pkt.Data.Size() + if err := r.ref.ep.WriteHeaderIncludedPacket(r, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err } r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkt.Data.Size())) + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) return nil } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 0ab4c3e19..294ce8775 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -52,7 +52,7 @@ const ( type transportProtocolState struct { proto TransportProtocol - defaultHandler func(r *Route, id TransportEndpointID, pkt PacketBuffer) bool + defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -424,12 +424,8 @@ type Stack struct { // handleLocal allows non-loopback interfaces to loop packets. handleLocal bool - // tablesMu protects iptables. - tablesMu sync.RWMutex - - // tables are the iptables packet filtering and manipulation rules. The are - // protected by tablesMu.` - tables IPTables + // tables are the iptables packet filtering and manipulation rules. + tables *IPTables // resumableEndpoints is a list of endpoints that need to be resumed if the // stack is being restored. @@ -676,6 +672,7 @@ func New(opts Options) *Stack { clock: clock, stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, + tables: DefaultTables(), icmpRateLimiter: NewICMPRateLimiter(), seed: generateRandUint32(), ndpConfigs: opts.NDPConfigs, @@ -778,7 +775,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, PacketBuffer) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *PacketBuffer) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h @@ -1741,18 +1738,8 @@ func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, } // IPTables returns the stack's iptables. -func (s *Stack) IPTables() IPTables { - s.tablesMu.RLock() - t := s.tables - s.tablesMu.RUnlock() - return t -} - -// SetIPTables sets the stack's iptables. -func (s *Stack) SetIPTables(ipt IPTables) { - s.tablesMu.Lock() - s.tables = ipt - s.tablesMu.Unlock() +func (s *Stack) IPTables() *IPTables { + return s.tables } // ICMPLimit returns the maximum number of ICMP messages that can be sent diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 1a2cf007c..f6ddc3ced 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -90,7 +90,7 @@ func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { return &f.id } -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { +func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // Increment the received packet count in the protocol descriptor. f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ @@ -132,7 +132,7 @@ func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumbe return f.proto.Number() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ @@ -147,7 +147,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) views[0] = pkt.Header.View() views = append(views, pkt.Data.Views()...) - f.HandlePacket(r, stack.PacketBuffer{ + f.HandlePacket(r, &stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) } @@ -163,7 +163,7 @@ func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -293,7 +293,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet with wrong address is not delivered. buf[0] = 3 - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 0 { @@ -305,7 +305,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[0] = 1 - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -317,7 +317,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[0] = 2 - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -328,7 +328,7 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - ep.InjectInbound(fakeNetNumber-1, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber-1, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -340,7 +340,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -362,7 +362,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro func send(r stack.Route, payload buffer.View) *tcpip.Error { hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }) @@ -420,7 +420,7 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if got := fakeNet.PacketCount(localAddrByte); got != want { @@ -2263,7 +2263,7 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { @@ -2345,7 +2345,7 @@ func TestNICForwarding(t *testing.T) { // Send a packet to dstAddr. buf := buffer.NewView(30) buf[0] = dstAddr[0] - ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 9a33ed375..e09866405 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -152,7 +152,7 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) { +func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) { epsByNIC.mu.RLock() mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()] @@ -183,7 +183,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) { +func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) { epsByNIC.mu.RLock() defer epsByNIC.mu.RUnlock() @@ -251,7 +251,7 @@ type transportDemuxer struct { // the dispatcher to delivery packets to the QueuePacket method instead of // calling HandlePacket directly on the endpoint. type queuedTransportProtocol interface { - QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt PacketBuffer) + QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer) } func newTransportDemuxer(stack *Stack) *transportDemuxer { @@ -379,7 +379,7 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 return mpep.endpoints[idx] } -func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt PacketBuffer) { +func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt *PacketBuffer) { ep.mu.RLock() queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] // HandlePacket takes ownership of pkt, so each endpoint needs @@ -470,7 +470,7 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN // deliverPacket attempts to find one or more matching transport endpoints, and // then, if matches are found, delivers the packet to them. Returns true if // the packet no longer needs to be handled. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -520,7 +520,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // deliverRawPacket attempts to deliver the given packet and returns whether it // was delivered successfully. -func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) bool { +func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -544,7 +544,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr // deliverControlPacket attempts to deliver the given control packet. Returns // true if it found an endpoint, false otherwise. -func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 2474a7db3..67d778137 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -127,7 +127,7 @@ func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NI u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), NetworkHeader: buffer.View(ip), TransportHeader: buffer.View(u), @@ -165,7 +165,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{ + c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), NetworkHeader: buffer.View(ip), TransportHeader: buffer.View(u), diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index a611e44ab..cb350ead3 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -88,7 +88,7 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions if err != nil { return 0, nil, err } - if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: buffer.View(v).ToVectorisedView(), }); err != nil { @@ -215,7 +215,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ stack.PacketBuffer) { +func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ *stack.PacketBuffer) { // Increment the number of received packets. f.proto.packetCount++ if f.acceptQueue != nil { @@ -232,7 +232,7 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE } } -func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, stack.PacketBuffer) { +func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) { // Increment the number of received control packets. f.proto.controlCount++ } @@ -289,7 +289,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool { return true } @@ -369,7 +369,7 @@ func TestTransportReceive(t *testing.T) { // Make sure packet with wrong protocol is not delivered. buf[0] = 1 buf[2] = 0 - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 0 { @@ -380,7 +380,7 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 3 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 0 { @@ -391,7 +391,7 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 2 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 1 { @@ -446,7 +446,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 0 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = 0 - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 0 { @@ -457,7 +457,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 3 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 0 { @@ -468,7 +468,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 2 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 1 { @@ -623,7 +623,7 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - ep2.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep2.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: req.ToVectorisedView(), }) diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go index 2f98a996f..7f172f978 100644 --- a/pkg/tcpip/time_unsafe.go +++ b/pkg/tcpip/time_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.9 -// +build !go1.15 +// +build !go1.16 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index b1d820372..57e0a069b 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -140,11 +140,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { @@ -450,7 +445,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: data.ToVectorisedView(), TransportHeader: buffer.View(icmpv4), @@ -481,7 +476,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: dataVV, TransportHeader: buffer.View(icmpv6), @@ -511,6 +506,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { nicID := addr.NIC localPort := uint16(0) switch e.state { + case stateInitial: case stateBound, stateConnected: localPort = e.ID.LocalPort if e.BindNICID == 0 { @@ -743,7 +739,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: @@ -805,7 +801,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { } // State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 3c47692b2..2ec6749c7 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -104,7 +104,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool { +func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool { return true } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 23158173d..baf08eda6 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -132,11 +132,6 @@ func (ep *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (ep *endpoint) ModerateRecvBuf(copied int) {} -// IPTables implements tcpip.Endpoint.IPTables. -func (ep *endpoint) IPTables() (stack.IPTables, error) { - return ep.stack.IPTables(), nil -} - // Read implements tcpip.Endpoint.Read. func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { ep.rcvMu.Lock() @@ -298,7 +293,7 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // HandlePacket implements stack.PacketEndpoint.HandlePacket. -func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { ep.rcvMu.Lock() // Drop the packet if our buffer is currently full. diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index eee754a5a..21c34fac2 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -166,11 +166,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read implements tcpip.Endpoint.Read. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { if !e.associated { @@ -348,7 +343,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, switch e.NetProto { case header.IPv4ProtocolNumber: if !e.associated { - if err := route.WriteHeaderIncludedPacket(stack.PacketBuffer{ + if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{ Data: buffer.View(payloadBytes).ToVectorisedView(), }); err != nil { return 0, nil, err @@ -357,7 +352,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) - if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: buffer.View(payloadBytes).ToVectorisedView(), Owner: e.owner, @@ -584,7 +579,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. -func (e *endpoint) HandlePacket(route *stack.Route, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { e.rcvMu.Lock() // Drop the packet if our buffer is currently full. diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index f38eb6833..e26f01fae 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -86,10 +86,6 @@ go_test( "tcp_test.go", "tcp_timestamp_test.go", ], - # FIXME(b/68809571) - tags = [ - "flaky", - ], deps = [ ":tcp", "//pkg/sync", diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index a7e088d4e..7da93dcc4 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -833,13 +833,13 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac return sendTCPBatch(r, tf, data, gso, owner) } - pkt := stack.PacketBuffer{ + pkt := &stack.PacketBuffer{ Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), Data: data, Hash: tf.txHash, Owner: owner, } - buildTCPHdr(r, tf, &pkt, gso) + buildTCPHdr(r, tf, pkt, gso) if tf.ttl == 0 { tf.ttl = r.DefaultTTL() @@ -1347,6 +1347,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.setEndpointState(StateError) e.HardError = err + e.workerCleanup = true // Lock released below. epilogue() return err diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 6062ca916..047704c80 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -186,7 +186,7 @@ func (d *dispatcher) wait() { } } -func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) { +func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { ep := stackEP.(*endpoint) s := newSegment(r, id, pkt) if !s.parse() { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index b5ba972f1..19f7bf449 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -63,7 +63,8 @@ const ( StateClosing ) -// connected is the set of states where an endpoint is connected to a peer. +// connected returns true when s is one of the states representing an +// endpoint connected to a peer. func (s EndpointState) connected() bool { switch s { case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: @@ -73,6 +74,40 @@ func (s EndpointState) connected() bool { } } +// connecting returns true when s is one of the states representing a +// connection in progress, but not yet fully established. +func (s EndpointState) connecting() bool { + switch s { + case StateConnecting, StateSynSent, StateSynRecv: + return true + default: + return false + } +} + +// handshake returns true when s is one of the states representing an endpoint +// in the middle of a TCP handshake. +func (s EndpointState) handshake() bool { + switch s { + case StateSynSent, StateSynRecv: + return true + default: + return false + } +} + +// closed returns true when s is one of the states an endpoint transitions to +// when closed or when it encounters an error. This is distinct from a newly +// initialized endpoint that was never connected. +func (s EndpointState) closed() bool { + switch s { + case StateClose, StateError: + return true + default: + return false + } +} + // String implements fmt.Stringer.String. func (s EndpointState) String() string { switch s { @@ -1172,11 +1207,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.LockUser() @@ -2462,7 +2492,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { }, nil } -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // TCP HandlePacket is not required anymore as inbound packets first // land at the Dispatcher which then can either delivery using the // worker go routine or directly do the invoke the tcp processing inline @@ -2481,7 +2511,7 @@ func (e *endpoint) enqueueSegment(s *segment) bool { } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { switch typ { case stack.ControlPacketTooBig: e.sndBufMu.Lock() diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index fc43c11e2..cbb779666 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -49,11 +49,10 @@ func (e *endpoint) beforeSave() { e.mu.Lock() defer e.mu.Unlock() - switch e.EndpointState() { - case StateInitial, StateBound: - // TODO(b/138137272): this enumeration duplicates - // EndpointState.connected. remove it. - case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: + epState := e.EndpointState() + switch { + case epState == StateInitial || epState == StateBound: + case epState.connected() || epState.handshake(): if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 { if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 { panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)}) @@ -69,15 +68,16 @@ func (e *endpoint) beforeSave() { break } fallthrough - case StateListen, StateConnecting: + case epState == StateListen || epState == StateConnecting: e.drainSegmentLocked() - if e.EndpointState() != StateClose && e.EndpointState() != StateError { + // Refresh epState, since drainSegmentLocked may have changed it. + epState = e.EndpointState() + if !epState.closed() { if !e.workerRunning { panic("endpoint has no worker running in listen, connecting, or connected state") } - break } - case StateError, StateClose: + case epState.closed(): for e.workerRunning { e.mu.Unlock() time.Sleep(100 * time.Millisecond) @@ -148,23 +148,23 @@ var connectingLoading sync.WaitGroup // Bound endpoint loading happens last. // loadState is invoked by stateify. -func (e *endpoint) loadState(state EndpointState) { +func (e *endpoint) loadState(epState EndpointState) { // This is to ensure that the loading wait groups include all applicable // endpoints before any asynchronous calls to the Wait() methods. // For restore purposes we treat TimeWait like a connected endpoint. - if state.connected() || state == StateTimeWait { + if epState.connected() || epState == StateTimeWait { connectedLoading.Add(1) } - switch state { - case StateListen: + switch { + case epState == StateListen: listenLoading.Add(1) - case StateConnecting, StateSynSent, StateSynRecv: + case epState.connecting(): connectingLoading.Add(1) } // Directly update the state here rather than using e.setEndpointState // as the endpoint is still being loaded and the stack reference is not // yet initialized. - atomic.StoreUint32((*uint32)(&e.state), uint32(state)) + atomic.StoreUint32((*uint32)(&e.state), uint32(epState)) } // afterLoad is invoked by stateify. @@ -183,8 +183,8 @@ func (e *endpoint) afterLoad() { func (e *endpoint) Resume(s *stack.Stack) { e.stack = s e.segmentQueue.setLimit(MaxUnprocessedSegments) - state := e.origEndpointState - switch state { + epState := e.origEndpointState + switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss SendBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { @@ -208,8 +208,8 @@ func (e *endpoint) Resume(s *stack.Stack) { } } - switch state { - case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: + switch { + case epState.connected(): bind() if len(e.connectingAddress) == 0 { e.connectingAddress = e.ID.RemoteAddress @@ -232,13 +232,13 @@ func (e *endpoint) Resume(s *stack.Stack) { closed := e.closed e.mu.Unlock() e.notifyProtocolGoroutine(notifyTickleWorker) - if state == StateFinWait2 && closed { + if epState == StateFinWait2 && closed { // If the endpoint has been closed then make sure we notify so // that the FIN_WAIT2 timer is started after a restore. e.notifyProtocolGoroutine(notifyClose) } connectedLoading.Done() - case StateListen: + case epState == StateListen: tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -255,7 +255,7 @@ func (e *endpoint) Resume(s *stack.Stack) { listenLoading.Done() tcpip.AsyncLoading.Done() }() - case StateConnecting, StateSynSent, StateSynRecv: + case epState.connecting(): tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -267,7 +267,7 @@ func (e *endpoint) Resume(s *stack.Stack) { connectingLoading.Done() tcpip.AsyncLoading.Done() }() - case StateBound: + case epState == StateBound: tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -276,7 +276,7 @@ func (e *endpoint) Resume(s *stack.Stack) { bind() tcpip.AsyncLoading.Done() }() - case StateClose: + case epState == StateClose: if e.isPortReserved { tcpip.AsyncLoading.Add(1) go func() { @@ -291,12 +291,11 @@ func (e *endpoint) Resume(s *stack.Stack) { e.state = StateClose e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) - case StateError: + case epState == StateError: e.state = StateError e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } - } // saveLastError is invoked by stateify. diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 704d01c64..070b634b4 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -61,7 +61,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { s := newSegment(r, id, pkt) defer s.decRef() diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 2a2a7ddeb..c827d0277 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -206,7 +206,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // to a specific processing queue. Each queue is serviced by its own processor // goroutine which is responsible for dequeuing and doing full TCP dispatch of // the packet. -func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) { +func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { p.dispatcher.queuePacket(r, ep, id, pkt) } @@ -217,7 +217,7 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { s := newSegment(r, id, pkt) defer s.decRef() diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 074edded6..0c099e2fd 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -60,7 +60,7 @@ type segment struct { xmitCount uint32 } -func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) *segment { +func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { s := &segment{ refCnt: 1, id: id, diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 06dc9b7d7..acacb42e4 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -618,6 +618,20 @@ func (s *sender) splitSeg(seg *segment, size int) { nSeg.data.TrimFront(size) nSeg.sequenceNumber.UpdateForward(seqnum.Size(size)) s.writeList.InsertAfter(seg, nSeg) + + // The segment being split does not carry PUSH flag because it is + // followed by the newly split segment. + // RFC1122 section 4.2.2.2: MUST set the PSH bit in the last buffered + // segment (i.e., when there is no more queued data to be sent). + // Linux removes PSH flag only when the segment is being split over MSS + // and retains it when we are splitting the segment over lack of sender + // window space. + // ref: net/ipv4/tcp_output.c::tcp_write_xmit(), tcp_mss_split_point() + // ref: net/ipv4/tcp_output.c::tcp_write_wakeup(), tcp_snd_wnd_test() + if seg.data.Size() > s.maxPayloadSize { + seg.flags ^= header.TCPFlagPsh + } + seg.data.CapLength(size) } @@ -739,7 +753,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se if !s.isAssignedSequenceNumber(seg) { // Merge segments if allowed. if seg.data.Size() != 0 { - available := int(seg.sequenceNumber.Size(end)) + available := int(s.sndNxt.Size(end)) if available > limit { available = limit } @@ -782,8 +796,11 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // sent all at once. return false } - if atomic.LoadUint32(&s.ep.cork) != 0 { - // Hold back the segment until full. + // With TCP_CORK, hold back until minimum of the available + // send space and MSS. + // TODO(gvisor.dev/issue/2833): Drain the held segments after a + // timeout. + if seg.data.Size() < s.maxPayloadSize && atomic.LoadUint32(&s.ep.cork) != 0 { return false } } @@ -816,6 +833,25 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se panic("Netstack queues FIN segments without data.") } + segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) + // If the entire segment cannot be accomodated in the receiver + // advertized window, skip splitting and sending of the segment. + // ref: net/ipv4/tcp_output.c::tcp_snd_wnd_test() + // + // Linux checks this for all segment transmits not triggered + // by a probe timer. On this condition, it defers the segment + // split and transmit to a short probe timer. + // ref: include/net/tcp.h::tcp_check_probe_timer() + // ref: net/ipv4/tcp_output.c::tcp_write_wakeup() + // + // Instead of defining a new transmit timer, we attempt to split the + // segment right here if there are no pending segments. + // If there are pending segments, segment transmits are deferred + // to the retransmit timer handler. + if s.sndUna != s.sndNxt && !segEnd.LessThan(end) { + return false + } + if !seg.sequenceNumber.LessThan(end) { return false } @@ -824,9 +860,17 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se if available == 0 { return false } + + // The segment size limit is computed as a function of sender congestion + // window and MSS. When sender congestion window is > 1, this limit can + // be larger than MSS. Ensure that the currently available send space + // is not greater than minimum of this limit and MSS. if available > limit { available = limit } + if available > s.maxPayloadSize { + available = s.maxPayloadSize + } if seg.data.Size() > available { s.splitSeg(seg, available) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 7b1d72cf4..9721f6caf 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -316,7 +316,7 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt copy(icmp[header.ICMPv4PayloadOffset:], p2) // Inject packet. - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -372,7 +372,7 @@ func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcp // SendSegment sends a TCP segment that has already been built and written to a // buffer.VectorisedView. func (c *Context) SendSegment(s buffer.VectorisedView) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: s, }) } @@ -380,7 +380,7 @@ func (c *Context) SendSegment(s buffer.VectorisedView) { // SendPacket builds and sends a TCP segment(with the provided payload & TCP // headers) in an IPv4 packet via the link layer endpoint. func (c *Context) SendPacket(payload []byte, h *Headers) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: c.BuildSegment(payload, h), }) } @@ -389,7 +389,7 @@ func (c *Context) SendPacket(payload []byte, h *Headers) { // & TCPheaders) in an IPv4 packet via the link layer endpoint using the // provided source and destination IPv4 addresses. func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: c.BuildSegmentWithAddrs(payload, h, src, dst), }) } @@ -564,7 +564,7 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp t.SetChecksum(^t.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 647b2067a..663af8fec 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -247,11 +247,6 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { @@ -921,7 +916,11 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u if useDefaultTTL { ttl = r.DefaultTTL() } - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, stack.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: ProtocolNumber, + TTL: ttl, + TOS: tos, + }, &stack.PacketBuffer{ Header: hdr, Data: data, TransportHeader: buffer.View(udp), @@ -1269,7 +1268,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Get the header then trim it from the view. hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() { @@ -1336,7 +1335,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { if typ == stack.ControlPortUnreachable { e.mu.RLock() defer e.mu.RUnlock() diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index a674ceb68..7abfa0ed2 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -43,7 +43,7 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder { // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { f.handler(&ForwarderRequest{ stack: f.stack, route: r, @@ -61,7 +61,7 @@ type ForwarderRequest struct { stack *stack.Stack route *stack.Route id stack.TransportEndpointID - pkt stack.PacketBuffer + pkt *stack.PacketBuffer } // ID returns the 4-tuple (src address, src port, dst address, dst port) that diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 52af6de22..e320c5758 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -66,7 +66,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { +func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { // Get the header then trim it from the view. h, ok := pkt.Data.PullUp(header.UDPMinimumSize) if !ok { @@ -140,7 +140,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv4DstUnreachable) pkt.SetCode(header.ICMPv4PortUnreachable) pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: payload, }) @@ -177,7 +177,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv6DstUnreachable) pkt.SetCode(header.ICMPv6PortUnreachable) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: payload, }) diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 8acaa607a..e8ade882b 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -440,7 +440,7 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), NetworkHeader: buffer.View(ip), TransportHeader: buffer.View(u), @@ -487,7 +487,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool // Inject packet. - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), NetworkHeader: buffer.View(ip), TransportHeader: buffer.View(u), diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD deleted file mode 100644 index 2dcba84ae..000000000 --- a/pkg/tmutex/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "tmutex", - srcs = ["tmutex.go"], - visibility = ["//:sandbox"], -) - -go_test( - name = "tmutex_test", - size = "medium", - srcs = ["tmutex_test.go"], - library = ":tmutex", - deps = ["//pkg/sync"], -) diff --git a/pkg/tmutex/tmutex.go b/pkg/tmutex/tmutex.go deleted file mode 100644 index c4685020d..000000000 --- a/pkg/tmutex/tmutex.go +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package tmutex provides the implementation of a mutex that implements an -// efficient TryLock function in addition to Lock and Unlock. -package tmutex - -import ( - "sync/atomic" -) - -// Mutex is a mutual exclusion primitive that implements TryLock in addition -// to Lock and Unlock. -type Mutex struct { - v int32 - ch chan struct{} -} - -// Init initializes the mutex. -func (m *Mutex) Init() { - m.v = 1 - m.ch = make(chan struct{}, 1) -} - -// Lock acquires the mutex. If it is currently held by another goroutine, Lock -// will wait until it has a chance to acquire it. -func (m *Mutex) Lock() { - // Uncontended case. - if atomic.AddInt32(&m.v, -1) == 0 { - return - } - - for { - // Try to acquire the mutex again, at the same time making sure - // that m.v is negative, which indicates to the owner of the - // lock that it is contended, which will force it to try to wake - // someone up when it releases the mutex. - if v := atomic.LoadInt32(&m.v); v >= 0 && atomic.SwapInt32(&m.v, -1) == 1 { - return - } - - // Wait for the mutex to be released before trying again. - <-m.ch - } -} - -// TryLock attempts to acquire the mutex without blocking. If the mutex is -// currently held by another goroutine, it fails to acquire it and returns -// false. -func (m *Mutex) TryLock() bool { - v := atomic.LoadInt32(&m.v) - if v <= 0 { - return false - } - return atomic.CompareAndSwapInt32(&m.v, 1, 0) -} - -// Unlock releases the mutex. -func (m *Mutex) Unlock() { - if atomic.SwapInt32(&m.v, 1) == 0 { - // There were no pending waiters. - return - } - - // Wake some waiter up. - select { - case m.ch <- struct{}{}: - default: - } -} diff --git a/pkg/tmutex/tmutex_test.go b/pkg/tmutex/tmutex_test.go deleted file mode 100644 index 05540696a..000000000 --- a/pkg/tmutex/tmutex_test.go +++ /dev/null @@ -1,258 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tmutex - -import ( - "fmt" - "runtime" - "sync/atomic" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sync" -) - -func TestBasicLock(t *testing.T) { - var m Mutex - m.Init() - - m.Lock() - - // Try blocking lock the mutex from a different goroutine. This must - // not block because the mutex is held. - ch := make(chan struct{}, 1) - go func() { - m.Lock() - ch <- struct{}{} - m.Unlock() - ch <- struct{}{} - }() - - select { - case <-ch: - t.Fatalf("Lock succeeded on locked mutex") - case <-time.After(100 * time.Millisecond): - } - - // Unlock the mutex and make sure that the goroutine waiting on Lock() - // unblocks and succeeds. - m.Unlock() - - select { - case <-ch: - case <-time.After(100 * time.Millisecond): - t.Fatalf("Lock failed to acquire unlocked mutex") - } - - // Make sure we can lock and unlock again. - m.Lock() - m.Unlock() -} - -func TestTryLock(t *testing.T) { - var m Mutex - m.Init() - - // Try to lock. It should succeed. - if !m.TryLock() { - t.Fatalf("TryLock failed on unlocked mutex") - } - - // Try to lock again, it should now fail. - if m.TryLock() { - t.Fatalf("TryLock succeeded on locked mutex") - } - - // Try blocking lock the mutex from a different goroutine. This must - // not block because the mutex is held. - ch := make(chan struct{}, 1) - go func() { - m.Lock() - ch <- struct{}{} - m.Unlock() - }() - - select { - case <-ch: - t.Fatalf("Lock succeeded on locked mutex") - case <-time.After(100 * time.Millisecond): - } - - // Unlock the mutex and make sure that the goroutine waiting on Lock() - // unblocks and succeeds. - m.Unlock() - - select { - case <-ch: - case <-time.After(100 * time.Millisecond): - t.Fatalf("Lock failed to acquire unlocked mutex") - } -} - -func TestMutualExclusion(t *testing.T) { - var m Mutex - m.Init() - - // Test mutual exclusion by running "gr" goroutines concurrently, and - // have each one increment a counter "iters" times within the critical - // section established by the mutex. - // - // If at the end the counter is not gr * iters, then we know that - // goroutines ran concurrently within the critical section. - // - // If one of the goroutines doesn't complete, it's likely a bug that - // causes to it to wait forever. - const gr = 1000 - const iters = 100000 - v := 0 - var wg sync.WaitGroup - for i := 0; i < gr; i++ { - wg.Add(1) - go func() { - for j := 0; j < iters; j++ { - m.Lock() - v++ - m.Unlock() - } - wg.Done() - }() - } - - wg.Wait() - - if v != gr*iters { - t.Fatalf("Bad count: got %v, want %v", v, gr*iters) - } -} - -func TestMutualExclusionWithTryLock(t *testing.T) { - var m Mutex - m.Init() - - // Similar to the previous, with the addition of some goroutines that - // only increment the count if TryLock succeeds. - const gr = 1000 - const iters = 100000 - total := int64(gr * iters) - var tryTotal int64 - v := int64(0) - var wg sync.WaitGroup - for i := 0; i < gr; i++ { - wg.Add(2) - go func() { - for j := 0; j < iters; j++ { - m.Lock() - v++ - m.Unlock() - } - wg.Done() - }() - go func() { - local := int64(0) - for j := 0; j < iters; j++ { - if m.TryLock() { - v++ - m.Unlock() - local++ - } - } - atomic.AddInt64(&tryTotal, local) - wg.Done() - }() - } - - wg.Wait() - - t.Logf("tryTotal = %d", tryTotal) - total += tryTotal - - if v != total { - t.Fatalf("Bad count: got %v, want %v", v, total) - } -} - -// BenchmarkTmutex is equivalent to TestMutualExclusion, with the following -// differences: -// -// - The number of goroutines is variable, with the maximum value depending on -// GOMAXPROCS. -// -// - The number of iterations per benchmark is controlled by the benchmarking -// framework. -// -// - Care is taken to ensure that all goroutines participating in the benchmark -// have been created before the benchmark begins. -func BenchmarkTmutex(b *testing.B) { - for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 { - b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { - var m Mutex - m.Init() - - var ready sync.WaitGroup - begin := make(chan struct{}) - var end sync.WaitGroup - for i := 0; i < n; i++ { - ready.Add(1) - end.Add(1) - go func() { - ready.Done() - <-begin - for j := 0; j < b.N; j++ { - m.Lock() - m.Unlock() - } - end.Done() - }() - } - - ready.Wait() - b.ResetTimer() - close(begin) - end.Wait() - }) - } -} - -// BenchmarkSyncMutex is equivalent to BenchmarkTmutex, but uses sync.Mutex as -// a comparison point. -func BenchmarkSyncMutex(b *testing.B) { - for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 { - b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { - var m sync.Mutex - - var ready sync.WaitGroup - begin := make(chan struct{}) - var end sync.WaitGroup - for i := 0; i < n; i++ { - ready.Add(1) - end.Add(1) - go func() { - ready.Done() - <-begin - for j := 0; j < b.N; j++ { - m.Lock() - m.Unlock() - } - end.Done() - }() - } - - ready.Wait() - b.ResetTimer() - close(begin) - end.Wait() - }) - } -} |