diff options
49 files changed, 1651 insertions, 350 deletions
diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go index 36832ec86..4b4f9bd52 100644 --- a/pkg/merkletree/merkletree.go +++ b/pkg/merkletree/merkletree.go @@ -225,9 +225,9 @@ func Generate(data io.ReadSeeker, dataSize int64, treeReader io.ReadSeeker, tree // Verify will modify the cursor for data, but always restores it to its // original position upon exit. The cursor for tree is modified and not // restored. -func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset int64, readSize int64, expectedRoot []byte, dataAndTreeInSameFile bool) error { +func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset int64, readSize int64, expectedRoot []byte, dataAndTreeInSameFile bool) (int64, error) { if readSize <= 0 { - return fmt.Errorf("Unexpected read size: %d", readSize) + return 0, fmt.Errorf("Unexpected read size: %d", readSize) } layout := InitLayout(int64(dataSize), dataAndTreeInSameFile) @@ -240,29 +240,30 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in // finishes. origOffset, err := data.Seek(0, io.SeekCurrent) if err != nil { - return fmt.Errorf("Find current data offset failed: %v", err) + return 0, fmt.Errorf("Find current data offset failed: %v", err) } defer data.Seek(origOffset, io.SeekStart) // Move to the first block that contains target data. if _, err := data.Seek(firstDataBlock*layout.blockSize, io.SeekStart); err != nil { - return fmt.Errorf("Seek to datablock start failed: %v", err) + return 0, fmt.Errorf("Seek to datablock start failed: %v", err) } buf := make([]byte, layout.blockSize) var readErr error - bytesRead := 0 + total := int64(0) for i := firstDataBlock; i <= lastDataBlock; i++ { // Read a block that includes all or part of target range in // input data. - bytesRead, readErr = data.Read(buf) + bytesRead, err := data.Read(buf) + readErr = err // If at the end of input data and all previous blocks are // verified, return the verified input data and EOF. if readErr == io.EOF && bytesRead == 0 { break } if readErr != nil && readErr != io.EOF { - return fmt.Errorf("Read from data failed: %v", err) + return 0, fmt.Errorf("Read from data failed: %v", err) } // If this is the end of file, zero the remaining bytes in buf, // otherwise they are still from the previous block. @@ -274,7 +275,7 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in } } if err := verifyBlock(tree, layout, buf, i, expectedRoot); err != nil { - return err + return 0, err } // startOff is the beginning of the read range within the // current data block. Note that for all blocks other than the @@ -298,10 +299,14 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in if endOff > int64(bytesRead) { endOff = int64(bytesRead) } - w.Write(buf[startOff:endOff]) + n, err := w.Write(buf[startOff:endOff]) + if err != nil { + return total, err + } + total += int64(n) } - return readErr + return total, readErr } // verifyBlock verifies a block against tree. index is the number of block in diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go index ad50ba5f6..daaca759a 100644 --- a/pkg/merkletree/merkletree_test.go +++ b/pkg/merkletree/merkletree_test.go @@ -67,17 +67,17 @@ func TestLayout(t *testing.T) { t.Run(fmt.Sprintf("%d", tc.dataSize), func(t *testing.T) { l := InitLayout(tc.dataSize, tc.dataAndTreeInSameFile) if l.blockSize != int64(usermem.PageSize) { - t.Errorf("got blockSize %d, want %d", l.blockSize, usermem.PageSize) + t.Errorf("Got blockSize %d, want %d", l.blockSize, usermem.PageSize) } if l.digestSize != sha256DigestSize { - t.Errorf("got digestSize %d, want %d", l.digestSize, sha256DigestSize) + t.Errorf("Got digestSize %d, want %d", l.digestSize, sha256DigestSize) } if l.numLevels() != len(tc.expectedLevelOffset) { - t.Errorf("got levels %d, want %d", l.numLevels(), len(tc.expectedLevelOffset)) + t.Errorf("Got levels %d, want %d", l.numLevels(), len(tc.expectedLevelOffset)) } for i := 0; i < l.numLevels() && i < len(tc.expectedLevelOffset); i++ { if l.levelOffset[i] != tc.expectedLevelOffset[i] { - t.Errorf("got levelStart[%d] %d, want %d", i, l.levelOffset[i], tc.expectedLevelOffset[i]) + t.Errorf("Got levelStart[%d] %d, want %d", i, l.levelOffset[i], tc.expectedLevelOffset[i]) } } }) @@ -169,11 +169,11 @@ func TestGenerate(t *testing.T) { }, int64(len(tc.data)), &tree, &tree, dataAndTreeInSameFile) } if err != nil { - t.Fatalf("got err: %v, want nil", err) + t.Fatalf("Got err: %v, want nil", err) } if !bytes.Equal(root, tc.expectedRoot) { - t.Errorf("got root: %v, want %v", root, tc.expectedRoot) + t.Errorf("Got root: %v, want %v", root, tc.expectedRoot) } } }) @@ -334,14 +334,21 @@ func TestVerify(t *testing.T) { var buf bytes.Buffer data[tc.modifyByte] ^= 1 if tc.shouldSucceed { - if err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile); err != nil && err != io.EOF { + n, err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile) + if err != nil && err != io.EOF { t.Errorf("Verification failed when expected to succeed: %v", err) } - if int64(buf.Len()) != tc.verifySize || !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) { - t.Errorf("Incorrect output from Verify") + if n != tc.verifySize { + t.Errorf("Got Verify output size %d, want %d", n, tc.verifySize) + } + if int64(buf.Len()) != tc.verifySize { + t.Errorf("Got Verify output buf size %d, want %d,", buf.Len(), tc.verifySize) + } + if !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) { + t.Errorf("Incorrect output buf from Verify") } } else { - if err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile); err == nil { + if _, err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile); err == nil { t.Errorf("Verification succeeded when expected to fail") } } @@ -382,14 +389,21 @@ func TestVerifyRandom(t *testing.T) { var buf bytes.Buffer // Checks that the random portion of data from the original data is // verified successfully. - if err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile); err != nil && err != io.EOF { + n, err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile) + if err != nil && err != io.EOF { t.Errorf("Verification failed for correct data: %v", err) } if size > dataSize-start { size = dataSize - start } - if int64(buf.Len()) != size || !bytes.Equal(data[start:start+size], buf.Bytes()) { - t.Errorf("Incorrect output from Verify") + if n != size { + t.Errorf("Got Verify output size %d, want %d", n, size) + } + if int64(buf.Len()) != size { + t.Errorf("Got Verify output buf size %d, want %d", buf.Len(), size) + } + if !bytes.Equal(data[start:start+size], buf.Bytes()) { + t.Errorf("Incorrect output buf from Verify") } buf.Reset() @@ -397,7 +411,7 @@ func TestVerifyRandom(t *testing.T) { randBytePos := rand.Int63n(size) data[start+randBytePos] ^= 1 - if err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile); err == nil { + if _, err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile); err == nil { t.Errorf("Verification succeeded for modified data") } } diff --git a/pkg/sentry/fs/g3doc/fuse.md b/pkg/sentry/fs/g3doc/fuse.md index b102c9b34..eccb1fb2f 100644 --- a/pkg/sentry/fs/g3doc/fuse.md +++ b/pkg/sentry/fs/g3doc/fuse.md @@ -119,8 +119,8 @@ ops can be implemented in parallel. `connection.initializedChan` -- a channel that the requests issued before connection initialization - blocks on. +- a channel that the requests issued before connection initialization blocks + on. `fd.queue` @@ -128,9 +128,8 @@ ops can be implemented in parallel. `fd.completions` -- a map of the requests that have been prepared but - not yet received a response, including the ones on the - `fd.queue`. +- a map of the requests that have been prepared but not yet received a + response, including the ones on the `fd.queue`. `fd.waitQueue` diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 2cf0a38c9..f86a6a0b2 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -19,6 +19,7 @@ import ( "fmt" "io" "strconv" + "strings" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -192,10 +193,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // contains the expected xattrs. If the file or the xattr does not // exist, it indicates unexpected modifications to the file system. if err == syserror.ENOENT || err == syserror.ENODATA { - if noCrashOnVerificationFailure { - return nil, err - } - panic(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err)) + return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err)) } if err != nil { return nil, err @@ -204,10 +202,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // unexpected modifications to the file system. offset, err := strconv.Atoi(off) if err != nil { - if noCrashOnVerificationFailure { - return nil, syserror.EINVAL - } - panic(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err)) + return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err)) } // Open parent Merkle tree file to read and verify child's root hash. @@ -221,10 +216,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // The parent Merkle tree file should have been created. If it's // missing, it indicates an unexpected modification to the file system. if err == syserror.ENOENT { - if noCrashOnVerificationFailure { - return nil, err - } - panic(fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err)) + return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err)) } if err != nil { return nil, err @@ -242,10 +234,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // contains the expected xattrs. If the file or the xattr does not // exist, it indicates unexpected modifications to the file system. if err == syserror.ENOENT || err == syserror.ENODATA { - if noCrashOnVerificationFailure { - return nil, err - } - panic(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err)) + return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err)) } if err != nil { return nil, err @@ -255,10 +244,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // unexpected modifications to the file system. parentSize, err := strconv.Atoi(dataSize) if err != nil { - if noCrashOnVerificationFailure { - return nil, syserror.EINVAL - } - panic(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) + return nil, alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) } fdReader := vfs.FileReadWriteSeeker{ @@ -270,11 +256,8 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // contain the root hash of the children in the parent Merkle tree when // Verify returns with success. var buf bytes.Buffer - if err := merkletree.Verify(&buf, &fdReader, &fdReader, int64(parentSize), int64(offset), int64(merkletree.DigestSize()), parent.rootHash, true /* dataAndTreeInSameFile */); err != nil && err != io.EOF { - if noCrashOnVerificationFailure { - return nil, syserror.EIO - } - panic(fmt.Sprintf("Verification for %s failed: %v", childPath, err)) + if _, err := merkletree.Verify(&buf, &fdReader, &fdReader, int64(parentSize), int64(offset), int64(merkletree.DigestSize()), parent.rootHash, true /* dataAndTreeInSameFile */); err != nil && err != io.EOF { + return nil, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification for %s failed: %v", childPath, err)) } // Cache child root hash when it's verified the first time. @@ -370,10 +353,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, // corresponding Merkle tree is found. This indicates an // unexpected modification to the file system that // removed/renamed the child. - if noCrashOnVerificationFailure { - return nil, childErr - } - panic(fmt.Sprintf("Target file %s is expected but missing", parentPath+"/"+name)) + return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Target file %s is expected but missing", parentPath+"/"+name)) } else if childErr == nil && childMerkleErr == syserror.ENOENT { // If in allowRuntimeEnable mode, and the Merkle tree file is // not created yet, we create an empty Merkle tree file, so that @@ -409,10 +389,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, // If runtime enable is not allowed. This indicates an // unexpected modification to the file system that // removed/renamed the Merkle tree file. - if noCrashOnVerificationFailure { - return nil, childMerkleErr - } - panic(fmt.Sprintf("Expected Merkle file for target %s but none found", parentPath+"/"+name)) + return nil, alertIntegrityViolation(childMerkleErr, fmt.Sprintf("Expected Merkle file for target %s but none found", parentPath+"/"+name)) } } else if childErr == syserror.ENOENT && childMerkleErr == syserror.ENOENT { // Both the child and the corresponding Merkle tree are missing. @@ -421,7 +398,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, // TODO(b/167752508): Investigate possible ways to differentiate // cases that both files are deleted from cases that they never // exist in the file system. - panic(fmt.Sprintf("Failed to find file %s", parentPath+"/"+name)) + return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Failed to find file %s", parentPath+"/"+name)) } mask := uint32(linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID) @@ -580,8 +557,181 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v // OpenAt implements vfs.FilesystemImpl.OpenAt. func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - //TODO(b/159261227): Implement OpenAt. - return nil, nil + // Verity fs is read-only. + if opts.Flags&(linux.O_WRONLY|linux.O_CREAT) != 0 { + return nil, syserror.EROFS + } + + var ds *[]*dentry + fs.renameMu.RLock() + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + + start := rp.Start().Impl().(*dentry) + if rp.Done() { + return start.openLocked(ctx, rp, &opts) + } + +afterTrailingSymlink: + parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds) + if err != nil { + return nil, err + } + + // Check for search permission in the parent directory. + if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { + return nil, err + } + + // Open existing child or follow symlink. + parent.dirMu.Lock() + child, err := fs.stepLocked(ctx, rp, parent, false /*mayFollowSymlinks*/, &ds) + parent.dirMu.Unlock() + if err != nil { + return nil, err + } + if child.isSymlink() && rp.ShouldFollowSymlink() { + target, err := child.readlink(ctx) + if err != nil { + return nil, err + } + if err := rp.HandleSymlink(target); err != nil { + return nil, err + } + start = parent + goto afterTrailingSymlink + } + return child.openLocked(ctx, rp, &opts) +} + +// Preconditions: fs.renameMu must be locked. +func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { + // Users should not open the Merkle tree files. Those are for verity fs + // use only. + if strings.Contains(d.name, merklePrefix) { + return nil, syserror.EPERM + } + ats := vfs.AccessTypesForOpenFlags(opts) + if err := d.checkPermissions(rp.Credentials(), ats); err != nil { + return nil, err + } + + // Verity fs is read-only. + if ats&vfs.MayWrite != 0 { + return nil, syserror.EROFS + } + + // Get the path to the target file. This is only used to provide path + // information in failure case. + path, err := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.lowerVD) + if err != nil { + return nil, err + } + + // Open the file in the underlying file system. + lowerFD, err := rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: d.lowerVD, + Start: d.lowerVD, + }, opts) + + // The file should exist, as we succeeded in finding its dentry. If it's + // missing, it indicates an unexpected modification to the file system. + if err != nil { + if err == syserror.ENOENT { + return nil, alertIntegrityViolation(err, fmt.Sprintf("File %s expected but not found", path)) + } + return nil, err + } + + // lowerFD needs to be cleaned up if any error occurs. IncRef will be + // called if a verity FD is successfully created. + defer lowerFD.DecRef(ctx) + + // Open the Merkle tree file corresponding to the current file/directory + // to be used later for verifying Read/Walk. + merkleReader, err := rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: d.lowerMerkleVD, + Start: d.lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }) + + // The Merkle tree file should exist, as we succeeded in finding its + // dentry. If it's missing, it indicates an unexpected modification to + // the file system. + if err != nil { + if err == syserror.ENOENT { + return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path)) + } + return nil, err + } + + // merkleReader needs to be cleaned up if any error occurs. IncRef will + // be called if a verity FD is successfully created. + defer merkleReader.DecRef(ctx) + + lowerFlags := lowerFD.StatusFlags() + lowerFDOpts := lowerFD.Options() + var merkleWriter *vfs.FileDescription + var parentMerkleWriter *vfs.FileDescription + + // Only open the Merkle tree files for write if in allowRuntimeEnable + // mode. + if d.fs.allowRuntimeEnable { + merkleWriter, err = rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: d.lowerMerkleVD, + Start: d.lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_WRONLY | linux.O_APPEND, + }) + if err != nil { + if err == syserror.ENOENT { + return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path)) + } + return nil, err + } + // merkleWriter is cleaned up if any error occurs. IncRef will + // be called if a verity FD is created successfully. + defer merkleWriter.DecRef(ctx) + + parentMerkleWriter, err = rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: d.parent.lowerMerkleVD, + Start: d.parent.lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_WRONLY | linux.O_APPEND, + }) + if err != nil { + if err == syserror.ENOENT { + parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD) + return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", parentPath)) + } + return nil, err + } + // parentMerkleWriter is cleaned up if any error occurs. IncRef + // will be called if a verity FD is created successfully. + defer parentMerkleWriter.DecRef(ctx) + } + + fd := &fileDescription{ + d: d, + lowerFD: lowerFD, + merkleReader: merkleReader, + merkleWriter: merkleWriter, + parentMerkleWriter: parentMerkleWriter, + isDir: d.isDir(), + } + + if err := fd.vfsfd.Init(fd, lowerFlags, rp.Mount(), &d.vfsd, &lowerFDOpts); err != nil { + return nil, err + } + lowerFD.IncRef() + merkleReader.IncRef() + if merkleWriter != nil { + merkleWriter.IncRef() + } + if parentMerkleWriter != nil { + parentMerkleWriter.IncRef() + } + return &fd.vfsfd, err } // ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt. diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index 249cc1341..3e0bcd02b 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -22,6 +22,7 @@ package verity import ( + "fmt" "strconv" "sync/atomic" @@ -29,7 +30,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/marshal/primitive" - "gvisor.dev/gvisor/pkg/merkletree" "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" @@ -135,6 +135,16 @@ func (FilesystemType) Name() string { return Name } +// alertIntegrityViolation alerts a violation of integrity, which usually means +// unexpected modification to the file system is detected. In +// noCrashOnVerificationFailure mode, it returns an error, otherwise it panic. +func alertIntegrityViolation(err error, msg string) error { + if noCrashOnVerificationFailure { + return err + } + panic(msg) +} + // GetFilesystem implements vfs.FilesystemType.GetFilesystem. func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { iopts, ok := opts.InternalData.(InternalFilesystemOptions) @@ -204,15 +214,12 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return nil, nil, err } } else if err != nil { - // Failed to get dentry for the root Merkle file. This indicates - // an attack that removed/renamed the root Merkle file, or it's - // never generated. - if noCrashOnVerificationFailure { - fs.vfsfs.DecRef(ctx) - d.DecRef(ctx) - return nil, nil, err - } - panic("Failed to find root Merkle file") + // Failed to get dentry for the root Merkle file. This + // indicates an unexpected modification that removed/renamed + // the root Merkle file, or it's never generated. + fs.vfsfs.DecRef(ctx) + d.DecRef(ctx) + return nil, nil, alertIntegrityViolation(err, "Failed to find root Merkle file") } d.lowerMerkleVD = lowerMerkleVD @@ -550,7 +557,7 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO, arg defer verityMu.Unlock() if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || fd.parentMerkleWriter == nil { - panic("Unexpected verity fd: missing expected underlying fds") + return 0, alertIntegrityViolation(syserror.EIO, "Unexpected verity fd: missing expected underlying fds") } rootHash, dataSize, err := fd.generateMerkle(ctx) @@ -618,6 +625,54 @@ func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch. } } +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + // No need to verify if the file is not enabled yet in + // allowRuntimeEnable mode. + if fd.d.fs.allowRuntimeEnable && len(fd.d.rootHash) == 0 { + return fd.lowerFD.PRead(ctx, dst, offset, opts) + } + + // dataSize is the size of the whole file. + dataSize, err := fd.merkleReader.GetXattr(ctx, &vfs.GetXattrOptions{ + Name: merkleSizeXattr, + Size: sizeOfInt32, + }) + + // The Merkle tree file for the child should have been created and + // contains the expected xattrs. If the xattr does not exist, it + // indicates unexpected modifications to the file system. + if err == syserror.ENODATA { + return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) + } + if err != nil { + return 0, err + } + + // The dataSize xattr should be an integer. If it's not, it indicates + // unexpected modifications to the file system. + size, err := strconv.Atoi(dataSize) + if err != nil { + return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) + } + + dataReader := vfs.FileReadWriteSeeker{ + FD: fd.lowerFD, + Ctx: ctx, + } + + merkleReader := vfs.FileReadWriteSeeker{ + FD: fd.merkleReader, + Ctx: ctx, + } + + n, err := merkletree.Verify(dst.Writer(ctx), &dataReader, &merkleReader, int64(size), offset, dst.NumBytes(), fd.d.rootHash, false /* dataAndTreeInSameFile */) + if err != nil { + return 0, alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Verification failed: %v", err)) + } + return n, err +} + // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 905712076..537419657 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -79,7 +79,7 @@ func (c *vCPU) initArchState() error { } // tcr_el1 - data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS + data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS | _TCR_A1 reg.id = _KVM_ARM64_REGS_TCR_EL1 if err := c.setOneRegister(®); err != nil { return err @@ -103,7 +103,7 @@ func (c *vCPU) initArchState() error { c.SetTtbr0Kvm(uintptr(data)) // ttbr1_el1 - data = c.machine.kernel.PageTables.TTBR1_EL1(false, 0) + data = c.machine.kernel.PageTables.TTBR1_EL1(false, 1) reg.id = _KVM_ARM64_REGS_TTBR1_EL1 if err := c.setOneRegister(®); err != nil { diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go index 0e2ab716c..508236e46 100644 --- a/pkg/sentry/platform/ring0/defs_arm64.go +++ b/pkg/sentry/platform/ring0/defs_arm64.go @@ -77,6 +77,9 @@ type CPUArchState struct { // lazyVFP is the value of cpacr_el1. lazyVFP uintptr + + // appASID is the asid value of guest application. + appASID uintptr } // ErrorCode returns the last error code. diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s index 1e477cc49..5f63cbd45 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -302,17 +302,23 @@ // SWITCH_TO_APP_PAGETABLE sets a new pagetable for a container application. #define SWITCH_TO_APP_PAGETABLE(from) \ - MOVD CPU_TTBR0_APP(from), RSV_REG; \ - WORD $0xd5182012; \ // MSR R18, TTBR0_EL1 + MRS TTBR1_EL1, R0; \ + MOVD CPU_APP_ASID(from), R1; \ + BFI $48, R1, $16, R0; \ + MSR R0, TTBR1_EL1; \ // set the ASID in TTBR1_EL1 (since TCR.A1 is set) ISB $15; \ - DSB $15; + MOVD CPU_TTBR0_APP(from), RSV_REG; \ + MSR RSV_REG, TTBR0_EL1; // SWITCH_TO_KVM_PAGETABLE sets the kvm pagetable. #define SWITCH_TO_KVM_PAGETABLE(from) \ - MOVD CPU_TTBR0_KVM(from), RSV_REG; \ - WORD $0xd5182012; \ // MSR R18, TTBR0_EL1 + MRS TTBR1_EL1, R0; \ + MOVD $1, R1; \ + BFI $48, R1, $16, R0; \ + MSR R0, TTBR1_EL1; \ ISB $15; \ - DSB $15; + MOVD CPU_TTBR0_KVM(from), RSV_REG; \ + MSR RSV_REG, TTBR0_EL1; #define VFP_ENABLE \ MOVD $FPEN_ENABLE, R0; \ @@ -328,23 +334,20 @@ #define KERNEL_ENTRY_FROM_EL0 \ SUB $16, RSP, RSP; \ // step1, save r18, r9 into kernel temporary stack. STP (RSV_REG, RSV_REG_APP), 16*0(RSP); \ - WORD $0xd538d092; \ //MRS TPIDR_EL1, R18, step2, switch user pagetable. - SWITCH_TO_KVM_PAGETABLE(RSV_REG); \ - WORD $0xd538d092; \ //MRS TPIDR_EL1, R18 - MOVD CPU_APP_ADDR(RSV_REG), RSV_REG_APP; \ // step3, load app context pointer. - REGISTERS_SAVE(RSV_REG_APP, 0); \ // step4, save app context. + WORD $0xd538d092; \ // MRS TPIDR_EL1, R18 + MOVD CPU_APP_ADDR(RSV_REG), RSV_REG_APP; \ // step2, load app context pointer. + REGISTERS_SAVE(RSV_REG_APP, 0); \ // step3, save app context. MOVD RSV_REG_APP, R20; \ LDP 16*0(RSP), (RSV_REG, RSV_REG_APP); \ ADD $16, RSP, RSP; \ MOVD RSV_REG, PTRACE_R18(R20); \ MOVD RSV_REG_APP, PTRACE_R9(R20); \ - MOVD R20, RSV_REG_APP; \ WORD $0xd5384003; \ // MRS SPSR_EL1, R3 - MOVD R3, PTRACE_PSTATE(RSV_REG_APP); \ + MOVD R3, PTRACE_PSTATE(R20); \ MRS ELR_EL1, R3; \ - MOVD R3, PTRACE_PC(RSV_REG_APP); \ + MOVD R3, PTRACE_PC(R20); \ WORD $0xd5384103; \ // MRS SP_EL0, R3 - MOVD R3, PTRACE_SP(RSV_REG_APP); + MOVD R3, PTRACE_SP(R20); // KERNEL_ENTRY_FROM_EL1 is the entry code of the vcpu from el1 to el1. #define KERNEL_ENTRY_FROM_EL1 \ @@ -359,6 +362,13 @@ MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); \ LOAD_KERNEL_STACK(RSV_REG); // Load the temporary stack. +// storeAppASID writes the application's asid value. +TEXT ·storeAppASID(SB),NOSPLIT,$0-8 + MOVD asid+0(FP), R1 + MRS TPIDR_EL1, RSV_REG + MOVD R1, CPU_APP_ASID(RSV_REG) + RET + // Halt halts execution. TEXT ·Halt(SB),NOSPLIT,$0 // Clear bluepill. @@ -416,7 +426,7 @@ TEXT ·Current(SB),NOSPLIT,$0-8 MOVD R8, ret+0(FP) RET -#define STACK_FRAME_SIZE 16 +#define STACK_FRAME_SIZE 32 // kernelExitToEl0 is the entrypoint for application in guest_el0. // Prepare the vcpu environment for container application. @@ -460,15 +470,16 @@ TEXT ·kernelExitToEl0(SB),NOSPLIT,$0 SUB $STACK_FRAME_SIZE, RSP, RSP STP (RSV_REG, RSV_REG_APP), 16*0(RSP) + STP (R0, R1), 16*1(RSP) WORD $0xd538d092 //MRS TPIDR_EL1, R18 SWITCH_TO_APP_PAGETABLE(RSV_REG) + LDP 16*1(RSP), (R0, R1) LDP 16*0(RSP), (RSV_REG, RSV_REG_APP) ADD $STACK_FRAME_SIZE, RSP, RSP - ISB $15 ERET() // kernelExitToEl1 is the entrypoint for sentry in guest_el1. @@ -484,6 +495,9 @@ TEXT ·kernelExitToEl1(SB),NOSPLIT,$0 MOVD CPU_REGISTERS+PTRACE_SP(RSV_REG), R1 MOVD R1, RSP + SWITCH_TO_KVM_PAGETABLE(RSV_REG) + MRS TPIDR_EL1, RSV_REG + REGISTERS_LOAD(RSV_REG, CPU_REGISTERS) MOVD CPU_REGISTERS+PTRACE_R9(RSV_REG), RSV_REG_APP diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go index d0afa1aaa..14774c5db 100644 --- a/pkg/sentry/platform/ring0/kernel_arm64.go +++ b/pkg/sentry/platform/ring0/kernel_arm64.go @@ -53,6 +53,11 @@ func IsCanonical(addr uint64) bool { //go:nosplit func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { + storeAppASID(uintptr(switchOpts.UserASID)) + if switchOpts.Flush { + FlushTlbAll() + } + regs := switchOpts.Registers regs.Pstate &= ^uint64(PsrFlagsClear) diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go index 00e52c8af..2f1abcb0f 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.go +++ b/pkg/sentry/platform/ring0/lib_arm64.go @@ -16,6 +16,15 @@ package ring0 +// storeAppASID writes the application's asid value. +func storeAppASID(asid uintptr) + +// LocalFlushTlbAll same as FlushTlbAll, but only applies to the calling CPU. +func LocalFlushTlbAll() + +// FlushTlbAll flush all tlb. +func FlushTlbAll() + // CPACREL1 returns the value of the CPACR_EL1 register. func CPACREL1() (value uintptr) diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s index 86bfbe46f..8aabf7d0e 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.s +++ b/pkg/sentry/platform/ring0/lib_arm64.s @@ -15,6 +15,20 @@ #include "funcdata.h" #include "textflag.h" +TEXT ·LocalFlushTlbAll(SB),NOSPLIT,$0 + DSB $6 // dsb(nshst) + WORD $0xd508871f // __tlbi(vmalle1) + DSB $7 // dsb(nsh) + ISB $15 + RET + +TEXT ·FlushTlbAll(SB),NOSPLIT,$0 + DSB $10 // dsb(ishst) + WORD $0xd508831f // __tlbi(vmalle1is) + DSB $11 // dsb(ish) + ISB $15 + RET + TEXT ·GetTLS(SB),NOSPLIT,$0-8 MRS TPIDR_EL0, R1 MOVD R1, ret+0(FP) diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go index f3de962f0..1d86b4bcf 100644 --- a/pkg/sentry/platform/ring0/offsets_arm64.go +++ b/pkg/sentry/platform/ring0/offsets_arm64.go @@ -41,6 +41,7 @@ func Emit(w io.Writer) { fmt.Fprintf(w, "#define CPU_VECTOR_CODE 0x%02x\n", reflect.ValueOf(&c.vecCode).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "#define CPU_APP_ADDR 0x%02x\n", reflect.ValueOf(&c.appAddr).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "#define CPU_LAZY_VFP 0x%02x\n", reflect.ValueOf(&c.lazyVFP).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_APP_ASID 0x%02x\n", reflect.ValueOf(&c.appASID).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "\n// Bits.\n") fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet) diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 2e568bc3d..b462924af 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -482,8 +482,35 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error { } // Release implements fs.FileOperations.Release. -func (s *socketOpsCommon) Release(context.Context) { +func (s *socketOpsCommon) Release(ctx context.Context) { + e, ch := waiter.NewChannelEntry(nil) + s.EventRegister(&e, waiter.EventHUp|waiter.EventErr) + defer s.EventUnregister(&e) + s.Endpoint.Close() + + // SO_LINGER option is valid only for TCP. For other socket types + // return after endpoint close. + if family, skType, _ := s.Type(); skType != linux.SOCK_STREAM || (family != linux.AF_INET && family != linux.AF_INET6) { + return + } + + var v tcpip.LingerOption + if err := s.Endpoint.GetSockOpt(&v); err != nil { + return + } + + // The case for zero timeout is handled in tcp endpoint close function. + // Close is blocked until either: + // 1. The endpoint state is not in any of the states: FIN-WAIT1, + // CLOSING and LAST_ACK. + // 2. Timeout is reached. + if v.Enabled && v.Timeout != 0 { + t := kernel.TaskFromContext(ctx) + start := t.Kernel().MonotonicClock().Now() + deadline := start.Add(v.Timeout) + t.BlockWithDeadline(ch, true, deadline) + } } // Read implements fs.FileOperations.Read. @@ -1155,7 +1182,16 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - linger := linux.Linger{} + var v tcpip.LingerOption + var linger linux.Linger + if err := ep.GetSockOpt(&v); err != nil { + return &linger, nil + } + + if v.Enabled { + linger.OnOff = 1 + } + linger.Linger = int32(v.Timeout.Seconds()) return &linger, nil case linux.SO_SNDTIMEO: @@ -1884,7 +1920,10 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam socket.SetSockOptEmitUnimplementedEvent(t, name) } - return nil + return syserr.TranslateNetstackError( + ep.SetSockOpt(&tcpip.LingerOption{ + Enabled: v.OnOff != 0, + Timeout: time.Second * time.Duration(v.Linger)})) case linux.SO_DETACH_FILTER: // optval is ignored. diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index cbbdd000f..08504560c 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -945,8 +945,14 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - log.Warningf("Unsupported socket option: %T", opt) - return tcpip.ErrUnknownProtocolOption + switch opt.(type) { + case *tcpip.LingerOption: + return nil + + default: + log.Warningf("Unsupported socket option: %T", opt) + return tcpip.ErrUnknownProtocolOption + } } // LastError implements Endpoint.LastError. diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index 14b527bc2..6c410c5a6 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -18,3 +18,14 @@ go_library( "@org_golang_x_sys//unix:go_default_library", ], ) + +go_test( + name = "rawfile_test", + srcs = [ + "errors_test.go", + ], + library = "rawfile", + deps = [ + "//pkg/tcpip", + ], +) diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go index a0a873c84..604868fd8 100644 --- a/pkg/tcpip/link/rawfile/errors.go +++ b/pkg/tcpip/link/rawfile/errors.go @@ -31,10 +31,12 @@ var translations [maxErrno]*tcpip.Error // *tcpip.Error. // // Valid, but unrecognized errnos will be translated to -// tcpip.ErrInvalidEndpointState (EINVAL). Panics on invalid errnos. +// tcpip.ErrInvalidEndpointState (EINVAL). func TranslateErrno(e syscall.Errno) *tcpip.Error { - if err := translations[e]; err != nil { - return err + if e > 0 && e < syscall.Errno(len(translations)) { + if err := translations[e]; err != nil { + return err + } } return tcpip.ErrInvalidEndpointState } diff --git a/pkg/tcpip/link/rawfile/errors_test.go b/pkg/tcpip/link/rawfile/errors_test.go new file mode 100644 index 000000000..e4cdc66bd --- /dev/null +++ b/pkg/tcpip/link/rawfile/errors_test.go @@ -0,0 +1,53 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package rawfile + +import ( + "syscall" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +func TestTranslateErrno(t *testing.T) { + for _, test := range []struct { + errno syscall.Errno + translated *tcpip.Error + }{ + { + errno: syscall.Errno(0), + translated: tcpip.ErrInvalidEndpointState, + }, + { + errno: syscall.Errno(maxErrno), + translated: tcpip.ErrInvalidEndpointState, + }, + { + errno: syscall.Errno(514), + translated: tcpip.ErrInvalidEndpointState, + }, + { + errno: syscall.EEXIST, + translated: tcpip.ErrDuplicateAddress, + }, + } { + got := TranslateErrno(test.errno) + if got != test.translated { + t.Errorf("TranslateErrno(%q) = %q, want %q", test.errno, got, test.translated) + } + } +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 863ef6bee..1f1a1426b 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -665,33 +665,15 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t } } - // Check if address is a broadcast address for the endpoint's network. - // - // Only IPv4 has a notion of broadcast addresses. if protocol == header.IPv4ProtocolNumber { - if ref := n.getRefForBroadcastRLocked(address); ref != nil { + if ref := n.getIPv4RefForBroadcastOrLoopbackRLocked(address); ref != nil { n.mu.RUnlock() return ref } } - - // A usable reference was not found, create a temporary one if requested by - // the caller or if the IPv4 address is found in the NIC's subnets and the NIC - // is a loopback interface. - createTempEP := spoofingOrPromiscuous - if !createTempEP && n.isLoopback() && protocol == header.IPv4ProtocolNumber { - for _, r := range n.mu.endpoints { - addr := r.addrWithPrefix() - subnet := addr.Subnet() - if subnet.Contains(address) { - createTempEP = true - break - } - } - } n.mu.RUnlock() - if !createTempEP { + if !spoofingOrPromiscuous { return nil } @@ -704,20 +686,21 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t return ref } -// getRefForBroadcastLocked returns an endpoint where address is the IPv4 -// broadcast address for the endpoint's network. +// getRefForBroadcastOrLoopbackRLocked returns an endpoint whose address is the +// broadcast address for the endpoint's network or an address in the endpoint's +// subnet if the NIC is a loopback interface. This matches linux behaviour. // -// n.mu MUST be read locked. -func (n *NIC) getRefForBroadcastRLocked(address tcpip.Address) *referencedNetworkEndpoint { +// n.mu MUST be read or write locked. +func (n *NIC) getIPv4RefForBroadcastOrLoopbackRLocked(address tcpip.Address) *referencedNetworkEndpoint { for _, ref := range n.mu.endpoints { - // Only IPv4 has a notion of broadcast addresses. + // Only IPv4 has a notion of broadcast addresses or considers the loopback + // interface bound to an address's whole subnet (on linux). if ref.protocol != header.IPv4ProtocolNumber { continue } - addr := ref.addrWithPrefix() - subnet := addr.Subnet() - if subnet.IsBroadcast(address) && ref.tryIncRef() { + subnet := ref.addrWithPrefix().Subnet() + if (subnet.IsBroadcast(address) || (n.isLoopback() && subnet.Contains(address))) && ref.isValidForOutgoingRLocked() && ref.tryIncRef() { return ref } } @@ -745,11 +728,8 @@ func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, add n.removeEndpointLocked(ref) } - // Check if address is a broadcast address for an endpoint's network. - // - // Only IPv4 has a notion of broadcast addresses. if protocol == header.IPv4ProtocolNumber { - if ref := n.getRefForBroadcastRLocked(address); ref != nil { + if ref := n.getIPv4RefForBroadcastOrLoopbackRLocked(address); ref != nil { return ref } } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index b902c6ca9..0774b5382 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -165,7 +165,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p // If this is a broadcast or multicast datagram, deliver the datagram to all // endpoints bound to the right device. - if isMulticastOrBroadcast(id.LocalAddress) { + if isInboundMulticastOrBroadcast(r) { mpep.handlePacketAll(r, id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return @@ -526,7 +526,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // If the packet is a UDP broadcast or multicast, then find all matching // transport endpoints. - if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { + if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) { eps.mu.RLock() destEPs := eps.findAllEndpointsLocked(id) eps.mu.RUnlock() @@ -546,7 +546,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // If the packet is a TCP packet with a non-unicast source or destination // address, then do nothing further and instruct the caller to do the same. - if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) { + if protocol == header.TCPProtocolNumber && (!isInboundUnicast(r) || !isOutboundUnicast(r)) { // TCP can only be used to communicate between a single source and a // single destination; the addresses must be unicast. r.Stats().TCP.InvalidSegmentsReceived.Increment() @@ -677,10 +677,14 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN eps.mu.Unlock() } -func isMulticastOrBroadcast(addr tcpip.Address) bool { - return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) +func isInboundMulticastOrBroadcast(r *Route) bool { + return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress) } -func isUnicast(addr tcpip.Address) bool { - return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr) +func isInboundUnicast(r *Route) bool { + return r.LocalAddress != header.IPv4Any && r.LocalAddress != header.IPv6Any && !isInboundMulticastOrBroadcast(r) +} + +func isOutboundUnicast(r *Route) bool { + return r.RemoteAddress != header.IPv4Any && r.RemoteAddress != header.IPv6Any && !r.IsOutboundBroadcast() && !header.IsV4MulticastAddress(r.RemoteAddress) && !header.IsV6MulticastAddress(r.RemoteAddress) } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 5e34e27ba..b2ddb24ec 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -1194,6 +1194,19 @@ const ( TCPTimeWaitReuseLoopbackOnly ) +// LingerOption is used by SetSockOpt/GetSockOpt to set/get the +// duration for which a socket lingers before returning from Close. +// +// +stateify savable +type LingerOption struct { + Enabled bool + Timeout time.Duration +} + +func (*LingerOption) isGettableSocketOption() {} + +func (*LingerOption) isSettableSocketOption() {} + // IPPacketInfo is the message structure for IP_PKTINFO. // // +stateify savable diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 1b18023c5..fecbe7ba7 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -187,3 +187,64 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { }) } } + +// TestLoopbackSubnetLifetimeBoundToAddr tests that the lifetime of an address +// in a loopback interface's associated subnet is bound to the permanently bound +// address. +func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { + const nicID = 1 + + protoAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ipv4Addr, + } + addrBytes := []byte(ipv4Addr.Address) + addrBytes[len(addrBytes)-1]++ + otherAddr := tcpip.Address(addrBytes) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + }) + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) + } + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + }) + + r, err := s.FindRoute(nicID, otherAddr, remoteIPv4Addr, ipv4.ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, otherAddr, remoteIPv4Addr, ipv4.ProtocolNumber, err) + } + defer r.Release() + + params := stack.NetworkHeaderParams{ + Protocol: 111, + TTL: 64, + TOS: stack.DefaultTOS, + } + data := buffer.View([]byte{1, 2, 3, 4}) + if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: data.ToVectorisedView(), + })); err != nil { + t.Fatalf("r.WritePacket(nil, %#v, _): %s", params, err) + } + + // Removing the address should make the endpoint invalid. + if err := s.RemoveAddress(nicID, protoAddr.AddressWithPrefix.Address); err != nil { + t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, protoAddr.AddressWithPrefix.Address, err) + } + if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: data.ToVectorisedView(), + })); err != tcpip.ErrInvalidEndpointState { + t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, tcpip.ErrInvalidEndpointState) + } +} diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 52c27e045..659acbc7a 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -436,3 +437,122 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { }) } } + +// TestReuseAddrAndBroadcast makes sure broadcast packets are received by all +// interested endpoints. +func TestReuseAddrAndBroadcast(t *testing.T) { + const ( + nicID = 1 + localPort = 9000 + loopbackBroadcast = tcpip.Address("\x7f\xff\xff\xff") + ) + + data := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + + tests := []struct { + name string + broadcastAddr tcpip.Address + }{ + { + name: "Subnet directed broadcast", + broadcastAddr: loopbackBroadcast, + }, + { + name: "IPv4 broadcast", + broadcastAddr: header.IPv4Broadcast, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + protoAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x7f\x00\x00\x01", + PrefixLen: 8, + }, + } + if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, protoAddr, err) + } + + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + // We use the empty subnet instead of just the loopback subnet so we + // also have a route to the IPv4 Broadcast address. + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + }) + + // We create endpoints that bind to both the wildcard address and the + // broadcast address to make sure both of these types of "broadcast + // interested" endpoints receive broadcast packets. + wq := waiter.Queue{} + var eps []tcpip.Endpoint + for _, bindWildcard := range []bool{false, true} { + // Create multiple endpoints for each type of "broadcast interested" + // endpoint so we can test that all endpoints receive the broadcast + // packet. + for i := 0; i < 2; i++ { + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("(eps[%d]) NewEndpoint(%d, %d, _): %s", len(eps), udp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + defer ep.Close() + + if err := ep.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("eps[%d].SetSockOptBool(tcpip.ReuseAddressOption, true): %s", len(eps), err) + } + + if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil { + t.Fatalf("eps[%d].SetSockOptBool(tcpip.BroadcastOption, true): %s", len(eps), err) + } + + bindAddr := tcpip.FullAddress{Port: localPort} + if bindWildcard { + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err) + } + } else { + bindAddr.Addr = test.broadcastAddr + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err) + } + } + + eps = append(eps, ep) + } + } + + for i, wep := range eps { + writeOpts := tcpip.WriteOptions{ + To: &tcpip.FullAddress{ + Addr: test.broadcastAddr, + Port: localPort, + }, + } + if n, _, err := wep.Write(data, writeOpts); err != nil { + t.Fatalf("eps[%d].Write(_, _): %s", i, err) + } else if want := int64(len(data)); n != want { + t.Fatalf("got eps[%d].Write(_, _) = (%d, nil, nil), want = (%d, nil, nil)", i, n, want) + } + + for j, rep := range eps { + if gotPayload, _, err := rep.Read(nil); err != nil { + t.Errorf("(eps[%d] write) eps[%d].Read(nil): %s", i, j, err) + } else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { + t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff) + } + } + } + }) + } +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 6d5046a3d..faea7f2bb 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -654,6 +654,9 @@ type endpoint struct { // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner + + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -1007,6 +1010,26 @@ func (e *endpoint) Close() { return } + if e.linger.Enabled && e.linger.Timeout == 0 { + s := e.EndpointState() + isResetState := s == StateEstablished || s == StateCloseWait || s == StateFinWait1 || s == StateFinWait2 || s == StateSynRecv + if isResetState { + // Close the endpoint without doing full shutdown and + // send a RST. + e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.closeNoShutdownLocked() + + // Wake up worker to close the endpoint. + switch s { + case StateSynRecv: + e.notifyProtocolGoroutine(notifyClose) + default: + e.notifyProtocolGoroutine(notifyTickleWorker) + } + return + } + } + // Issue a shutdown so that the peer knows we won't send any more data // if we're connected, or stop accepting if we're listening. e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead) @@ -1807,6 +1830,11 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { case *tcpip.SocketDetachFilterOption: return nil + case *tcpip.LingerOption: + e.LockUser() + e.linger = *v + e.UnlockUser() + default: return nil } @@ -2031,6 +2059,11 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { Port: port, } + case *tcpip.LingerOption: + e.LockUser() + *o = e.linger + e.UnlockUser() + default: return tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 0d13e1efd..b1e5f1b24 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -5214,6 +5214,8 @@ func TestListenBacklogFull(t *testing.T) { func TestListenNoAcceptNonUnicastV4(t *testing.T) { multicastAddr := tcpip.Address("\xe0\x00\x01\x02") otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03") + subnet := context.StackAddrWithPrefix.Subnet() + subnetBroadcastAddr := subnet.Broadcast() tests := []struct { name string @@ -5221,53 +5223,59 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) { dstAddr tcpip.Address }{ { - "SourceUnspecified", - header.IPv4Any, - context.StackAddr, + name: "SourceUnspecified", + srcAddr: header.IPv4Any, + dstAddr: context.StackAddr, }, { - "SourceBroadcast", - header.IPv4Broadcast, - context.StackAddr, + name: "SourceBroadcast", + srcAddr: header.IPv4Broadcast, + dstAddr: context.StackAddr, }, { - "SourceOurMulticast", - multicastAddr, - context.StackAddr, + name: "SourceOurMulticast", + srcAddr: multicastAddr, + dstAddr: context.StackAddr, }, { - "SourceOtherMulticast", - otherMulticastAddr, - context.StackAddr, + name: "SourceOtherMulticast", + srcAddr: otherMulticastAddr, + dstAddr: context.StackAddr, }, { - "DestUnspecified", - context.TestAddr, - header.IPv4Any, + name: "DestUnspecified", + srcAddr: context.TestAddr, + dstAddr: header.IPv4Any, }, { - "DestBroadcast", - context.TestAddr, - header.IPv4Broadcast, + name: "DestBroadcast", + srcAddr: context.TestAddr, + dstAddr: header.IPv4Broadcast, }, { - "DestOurMulticast", - context.TestAddr, - multicastAddr, + name: "DestOurMulticast", + srcAddr: context.TestAddr, + dstAddr: multicastAddr, }, { - "DestOtherMulticast", - context.TestAddr, - otherMulticastAddr, + name: "DestOtherMulticast", + srcAddr: context.TestAddr, + dstAddr: otherMulticastAddr, + }, + { + name: "SrcSubnetBroadcast", + srcAddr: subnetBroadcastAddr, + dstAddr: context.StackAddr, + }, + { + name: "DestSubnetBroadcast", + srcAddr: context.TestAddr, + dstAddr: subnetBroadcastAddr, }, } for _, test := range tests { - test := test // capture range variable - t.Run(test.name, func(t *testing.T) { - t.Parallel() - c := context.New(t, defaultMTU) defer c.Cleanup() @@ -5367,11 +5375,7 @@ func TestListenNoAcceptNonUnicastV6(t *testing.T) { } for _, test := range tests { - test := test // capture range variable - t.Run(test.name, func(t *testing.T) { - t.Parallel() - c := context.New(t, defaultMTU) defer c.Cleanup() diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index baf7df197..85e8c1c75 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -53,11 +53,11 @@ const ( TestPort = 4096 // StackV6Addr is the IPv6 address assigned to the stack. - StackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + StackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" // TestV6Addr is the source address for packets sent to the stack via // the link layer endpoint. - TestV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + TestV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" // StackV4MappedAddr is StackAddr as a mapped v6 address. StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr @@ -73,6 +73,18 @@ const ( testInitialSequenceNumber = 789 ) +// StackAddrWithPrefix is StackAddr with its associated prefix length. +var StackAddrWithPrefix = tcpip.AddressWithPrefix{ + Address: StackAddr, + PrefixLen: 24, +} + +// StackV6AddrWithPrefix is StackV6Addr with its associated prefix length. +var StackV6AddrWithPrefix = tcpip.AddressWithPrefix{ + Address: StackV6Addr, + PrefixLen: header.IIDOffsetInIPv6Address * 8, +} + // Headers is used to represent the TCP header fields when building a // new packet. type Headers struct { @@ -184,12 +196,20 @@ func New(t *testing.T, mtu uint32) *Context { t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil { - t.Fatalf("AddAddress failed: %v", err) + v4ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: StackAddrWithPrefix, + } + if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err) } - if err := s.AddAddress(1, ipv6.ProtocolNumber, StackV6Addr); err != nil { - t.Fatalf("AddAddress failed: %v", err) + v6ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: StackV6AddrWithPrefix, + } + if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err) } s.SetRouteTable([]tcpip.Route{ diff --git a/test/fuse/README.md b/test/fuse/README.md index 7a1839714..65add57e2 100644 --- a/test/fuse/README.md +++ b/test/fuse/README.md @@ -14,8 +14,8 @@ server. It creates a `socketpair(2)` to send and receive control commands and data between the client and the server. Because the FUSE server runs in the background thread, gTest cannot catch its assertion failure immediately. Thus, `TearDown()` function sends command to the FUSE server to check if all gTest -assertion in the server are successful and all requests and preset responses -are consumed. +assertion in the server are successful and all requests and preset responses are +consumed. ## Communication Diagram @@ -29,80 +29,80 @@ however, it is still helpful to know when the client waits for the server to complete a command and when the server awaits the next instruction. ``` - | Client (Testing Thread) | Server (FUSE Server Thread) - | | - | >TEST_F() | - | >SetUp() | - | =MountFuse() | - | >SetUpFuseServer() | - | [create communication socket]| - | =fork() | =fork() - | [wait server complete] | - | | =ServerConsumeFuseInit() - | | =ServerCompleteWith() - | <SetUpFuseServer() | - | <SetUp() | - | [testing main] | - | | >ServerFuseLoop() - | | [poll on socket and fd] - | >SetServerResponse() | - | [write data to socket] | - | [wait server complete] | - | | [socket event occurs] - | | >ServerHandleCommand() - | | >ServerReceiveResponse() - | | [read data from socket] - | | [save data to memory] - | | <ServerReceiveResponse() - | | =ServerCompleteWith() - | <SetServerResponse() | - | | <ServerHandleCommand() - | >[Do fs operation] | - | [wait for fs response] | - | | [fd event occurs] - | | >ServerProcessFuseRequest() - | | =[read fs request] - | | =[save fs request to memory] - | | =[write fs response] - | <[Do fs operation] | - | | <ServerProcessFuseRequest() - | | - | =[Test fs operation result] | - | | - | >GetServerActualRequest() | - | [write data to socket] | - | [wait data from server] | - | | [socket event occurs] - | | >ServerHandleCommand() - | | >ServerSendReceivedRequest() - | | [write data to socket] - | [read data from socket] | - | [wait server complete] | - | | <ServerSendReceivedRequest() - | | =ServerCompleteWith() - | <GetServerActualRequest() | - | | <ServerHandleCommand() - | | - | =[Test actual request] | - | | - | >TearDown() | - | ... | - | >GetServerNumUnsentResponses() | - | [write data to socket] | - | [wait server complete] | - | | [socket event arrive] - | | >ServerHandleCommand() - | | >ServerSendData() - | | [write data to socket] - | | <ServerSendData() - | | =ServerCompleteWith() - | [read data from socket] | - | [test if all succeeded] | - | <GetServerNumUnsentResponses() | - | | <ServerHandleCommand() - | =UnmountFuse() | - | <TearDown() | - | <TEST_F() | +| Client (Testing Thread) | Server (FUSE Server Thread) +| | +| >TEST_F() | +| >SetUp() | +| =MountFuse() | +| >SetUpFuseServer() | +| [create communication socket]| +| =fork() | =fork() +| [wait server complete] | +| | =ServerConsumeFuseInit() +| | =ServerCompleteWith() +| <SetUpFuseServer() | +| <SetUp() | +| [testing main] | +| | >ServerFuseLoop() +| | [poll on socket and fd] +| >SetServerResponse() | +| [write data to socket] | +| [wait server complete] | +| | [socket event occurs] +| | >ServerHandleCommand() +| | >ServerReceiveResponse() +| | [read data from socket] +| | [save data to memory] +| | <ServerReceiveResponse() +| | =ServerCompleteWith() +| <SetServerResponse() | +| | <ServerHandleCommand() +| >[Do fs operation] | +| [wait for fs response] | +| | [fd event occurs] +| | >ServerProcessFuseRequest() +| | =[read fs request] +| | =[save fs request to memory] +| | =[write fs response] +| <[Do fs operation] | +| | <ServerProcessFuseRequest() +| | +| =[Test fs operation result] | +| | +| >GetServerActualRequest() | +| [write data to socket] | +| [wait data from server] | +| | [socket event occurs] +| | >ServerHandleCommand() +| | >ServerSendReceivedRequest() +| | [write data to socket] +| [read data from socket] | +| [wait server complete] | +| | <ServerSendReceivedRequest() +| | =ServerCompleteWith() +| <GetServerActualRequest() | +| | <ServerHandleCommand() +| | +| =[Test actual request] | +| | +| >TearDown() | +| ... | +| >GetServerNumUnsentResponses() | +| [write data to socket] | +| [wait server complete] | +| | [socket event arrive] +| | >ServerHandleCommand() +| | >ServerSendData() +| | [write data to socket] +| | <ServerSendData() +| | =ServerCompleteWith() +| [read data from socket] | +| [test if all succeeded] | +| <GetServerNumUnsentResponses() | +| | <ServerHandleCommand() +| =UnmountFuse() | +| <TearDown() | +| <TEST_F() | ``` ## Running the tests @@ -124,17 +124,18 @@ $ bazel test --test_tag_filters=fuse //test/fuse/... ## Writing a new FUSE test -1. Add test targets in `BUILD` and `linux/BUILD`. -2. Inherit your test from `FuseTest` base class. It allows you to: - - Fork a fake FUSE server in background during each test setup. - - Create a pair of sockets for communication and provide utility functions. - - Stop FUSE server and check if error occurs in it after test completes. -3. Build the expected opcode-response pairs of your FUSE operation. -4. Call `SetServerResponse()` to preset the next expected opcode and response. -5. Do real filesystem operations (FUSE is mounted at `mount_point_`). -6. Check FUSE response and/or errors. -7. Retrieve FUSE request by `GetServerActualRequest()`. -8. Check if the request is as expected. +1. Add test targets in `BUILD` and `linux/BUILD`. +2. Inherit your test from `FuseTest` base class. It allows you to: + - Fork a fake FUSE server in background during each test setup. + - Create a pair of sockets for communication and provide utility + functions. + - Stop FUSE server and check if error occurs in it after test completes. +3. Build the expected opcode-response pairs of your FUSE operation. +4. Call `SetServerResponse()` to preset the next expected opcode and response. +5. Do real filesystem operations (FUSE is mounted at `mount_point_`). +6. Check FUSE response and/or errors. +7. Retrieve FUSE request by `GetServerActualRequest()`. +8. Check if the request is as expected. A few customized matchers used in syscalls test are encouraged to test the outcome of filesystem operations. Such as: @@ -158,11 +159,12 @@ FUSE server in response to a sequence of FUSE requests. The lifecycle of a command contains following steps: -1. The testing thread sends a `FuseTestCmd` via socket and waits for completion. -2. The FUSE server receives the command and does corresponding action. -3. (Optional) The testing thread reads data from socket. -4. The FUSE server sends a success indicator via socket after processing. -5. The testing thread gets the success signal and continues testing. +1. The testing thread sends a `FuseTestCmd` via socket and waits for + completion. +2. The FUSE server receives the command and does corresponding action. +3. (Optional) The testing thread reads data from socket. +4. The FUSE server sends a success indicator via socket after processing. +5. The testing thread gets the success signal and continues testing. The success indicator, i.e. `WaitServerComplete()`, is crucial at the end of each `FuseTestCmd` sent from the testing thread. Because we don't want to begin @@ -172,16 +174,15 @@ supported now. To add a new `FuseTestCmd`, one must comply with following format: -1. Add a new `FuseTestCmd` enum class item defined in `linux/fuse_base.h` -2. Add a `SetServerXXX()` or `GetServerXXX()` public function in `FuseTest`. - This is how the testing thread will call to send control message. Define how - many bytes you want to send along with the command and what you will expect - to receive. Finally it should block and wait for a success indicator from - the FUSE server. -3. Add a handler logic in the switch condition of `ServerHandleCommand()`. Use - `ServerSendData()` or declare a new private function such as - `ServerReceiveXXX()` or `ServerSendXXX()`. It is mandatory to set it private - since only the FUSE server (forked from `FuseTest` base class) can call it. - This is the server part of the specific `FuseTestCmd` and the format of the - data should be consistent with what the client expects in the previous step. - +1. Add a new `FuseTestCmd` enum class item defined in `linux/fuse_base.h` +2. Add a `SetServerXXX()` or `GetServerXXX()` public function in `FuseTest`. + This is how the testing thread will call to send control message. Define how + many bytes you want to send along with the command and what you will expect + to receive. Finally it should block and wait for a success indicator from + the FUSE server. +3. Add a handler logic in the switch condition of `ServerHandleCommand()`. Use + `ServerSendData()` or declare a new private function such as + `ServerReceiveXXX()` or `ServerSendXXX()`. It is mandatory to set it private + since only the FUSE server (forked from `FuseTest` base class) can call it. + This is the server part of the specific `FuseTestCmd` and the format of the + data should be consistent with what the client expects in the previous step. diff --git a/test/fuse/linux/fuse_base.cc b/test/fuse/linux/fuse_base.cc index a033db117..5b45804e1 100644 --- a/test/fuse/linux/fuse_base.cc +++ b/test/fuse/linux/fuse_base.cc @@ -24,8 +24,8 @@ #include <sys/uio.h> #include <unistd.h> -#include "absl/strings/str_format.h" #include "gtest/gtest.h" +#include "absl/strings/str_format.h" #include "test/util/fuse_util.h" #include "test/util/posix_error.h" #include "test/util/temp_path.h" diff --git a/test/fuse/linux/fuse_fd_util.cc b/test/fuse/linux/fuse_fd_util.cc index 4a2505b00..30d1157bb 100644 --- a/test/fuse/linux/fuse_fd_util.cc +++ b/test/fuse/linux/fuse_fd_util.cc @@ -59,4 +59,3 @@ Cleanup FuseFdTest::CloseFD(FileDescriptor &fd) { } // namespace testing } // namespace gvisor - diff --git a/test/fuse/linux/readdir_test.cc b/test/fuse/linux/readdir_test.cc index ab61eb676..2afb4b062 100644 --- a/test/fuse/linux/readdir_test.cc +++ b/test/fuse/linux/readdir_test.cc @@ -127,8 +127,9 @@ TEST_F(ReaddirTest, SingleEntry) { char *readdir_payload = readdir_payload_vec.data(); // Use fake ino for other directories. - fill_fuse_dirent(readdir_payload, dot.c_str(), ino_dir-2); - fill_fuse_dirent(readdir_payload + dot_file_dirent_size, dot_dot.c_str(), ino_dir-1); + fill_fuse_dirent(readdir_payload, dot.c_str(), ino_dir - 2); + fill_fuse_dirent(readdir_payload + dot_file_dirent_size, dot_dot.c_str(), + ino_dir - 1); fill_fuse_dirent( readdir_payload + dot_file_dirent_size + dot_dot_file_dirent_size, test_file.c_str(), ino_dir); @@ -148,8 +149,9 @@ TEST_F(ReaddirTest, SingleEntry) { std::vector<char> buf(4090, 0); int nread, off = 0, i = 0; - EXPECT_THAT(nread = syscall(__NR_getdents64, fd.get(), buf.data(), buf.size()), - SyscallSucceeds()); + EXPECT_THAT( + nread = syscall(__NR_getdents64, fd.get(), buf.data(), buf.size()), + SyscallSucceeds()); for (; off < nread;) { struct dirent64 *ent = (struct dirent64 *)(buf.data() + off); off += ent->d_reclen; @@ -166,8 +168,9 @@ TEST_F(ReaddirTest, SingleEntry) { } } - EXPECT_THAT(nread = syscall(__NR_getdents64, fd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(0)); + EXPECT_THAT( + nread = syscall(__NR_getdents64, fd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(0)); SkipServerActualRequest(); // READDIR. SkipServerActualRequest(); // READDIR with no data. diff --git a/test/packetimpact/runner/dut.go b/test/packetimpact/runner/dut.go index d4c486f9c..96a0fb6c8 100644 --- a/test/packetimpact/runner/dut.go +++ b/test/packetimpact/runner/dut.go @@ -172,7 +172,7 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co } device := mkDevice(dut) - remoteIPv6, remoteMAC, dutDeviceID, testNetDev := device.Prepare(ctx, t, runOpts, ctrlNet, testNet, containerAddr) + remoteIPv6, remoteMAC, dutDeviceID, dutTestNetDev := device.Prepare(ctx, t, runOpts, ctrlNet, testNet, containerAddr) // Create the Docker container for the testbench. testbench := dockerutil.MakeNativeContainer(ctx, logger("testbench")) @@ -181,12 +181,16 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co containerTestbenchBinary := filepath.Join("/packetimpact", tbb) testbench.CopyFiles(&runOpts, "/packetimpact", filepath.Join("test/packetimpact/tests", tbb)) + // snifferNetDev is a network device on the test orchestrator that we will + // run sniffer (tcpdump or tshark) on and inject traffic to, not to be + // confused with the device on the DUT. + const snifferNetDev = "eth2" // Run tcpdump in the test bench unbuffered, without DNS resolution, just on // the interface with the test packets. snifferArgs := []string{ "tcpdump", "-S", "-vvv", "-U", "-n", - "-i", testNetDev, + "-i", snifferNetDev, "-w", testOutputDir + "/dump.pcap", } snifferRegex := "tcpdump: listening.*\n" @@ -194,7 +198,7 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co // Run tshark in the test bench unbuffered, without DNS resolution, just on // the interface with the test packets. snifferArgs = []string{ - "tshark", "-V", "-l", "-n", "-i", testNetDev, + "tshark", "-V", "-l", "-n", "-i", snifferNetDev, "-o", "tcp.check_checksum:TRUE", "-o", "udp.check_checksum:TRUE", } @@ -228,7 +232,7 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co // this, we can install the following iptables rules. The raw socket that // packetimpact tests use will still be able to see everything. for _, bin := range []string{"iptables", "ip6tables"} { - if logs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, bin, "-A", "INPUT", "-i", testNetDev, "-p", "tcp", "-j", "DROP"); err != nil { + if logs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, bin, "-A", "INPUT", "-i", snifferNetDev, "-p", "tcp", "-j", "DROP"); err != nil { t.Fatalf("unable to Exec %s on container %s: %s, logs from testbench:\n%s", bin, testbench.Name, err, logs) } } @@ -251,7 +255,8 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co "--remote_ipv6", remoteIPv6.String(), "--remote_mac", remoteMAC.String(), "--remote_interface_id", fmt.Sprintf("%d", dutDeviceID), - "--device", testNetDev, + "--local_device", snifferNetDev, + "--remote_device", dutTestNetDev, fmt.Sprintf("--native=%t", native), ) testbenchLogs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, testArgs...) diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go index ff269d949..6165ab293 100644 --- a/test/packetimpact/testbench/dut.go +++ b/test/packetimpact/testbench/dut.go @@ -16,11 +16,13 @@ package testbench import ( "context" + "encoding/binary" "flag" "net" "strconv" "syscall" "testing" + "time" pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto" @@ -701,6 +703,20 @@ func (dut *DUT) RecvWithErrno(ctx context.Context, t *testing.T, sockfd, len, fl return resp.GetRet(), resp.GetBuf(), syscall.Errno(resp.GetErrno_()) } +// SetSockLingerOption sets SO_LINGER socket option on the DUT. +func (dut *DUT) SetSockLingerOption(t *testing.T, sockfd int32, timeout time.Duration, enable bool) { + var linger unix.Linger + if enable { + linger.Onoff = 1 + } + linger.Linger = int32(timeout / time.Second) + + buf := make([]byte, 8) + binary.LittleEndian.PutUint32(buf, uint32(linger.Onoff)) + binary.LittleEndian.PutUint32(buf[4:], uint32(linger.Linger)) + dut.SetSockOpt(t, sockfd, unix.SOL_SOCKET, unix.SO_LINGER, buf) +} + // Shutdown calls shutdown on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is needed, use // ShutdownWithErrno. diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go index 57e822725..193bb2dc8 100644 --- a/test/packetimpact/testbench/rawsockets.go +++ b/test/packetimpact/testbench/rawsockets.go @@ -139,7 +139,7 @@ type Injector struct { func NewInjector(t *testing.T) (Injector, error) { t.Helper() - ifInfo, err := net.InterfaceByName(Device) + ifInfo, err := net.InterfaceByName(LocalDevice) if err != nil { return Injector{}, err } diff --git a/test/packetimpact/testbench/testbench.go b/test/packetimpact/testbench/testbench.go index e3629e1f3..0073a1361 100644 --- a/test/packetimpact/testbench/testbench.go +++ b/test/packetimpact/testbench/testbench.go @@ -29,8 +29,11 @@ import ( var ( // Native indicates that the test is being run natively. Native = false - // Device is the local device on the test network. - Device = "" + // LocalDevice is the device that testbench uses to inject traffic. + LocalDevice = "" + // RemoteDevice is the device name on the DUT, individual tests can + // use the name to construct tests. + RemoteDevice = "" // LocalIPv4 is the local IPv4 address on the test network. LocalIPv4 = "" @@ -80,7 +83,8 @@ func RegisterFlags(fs *flag.FlagSet) { fs.StringVar(&RemoteIPv4, "remote_ipv4", RemoteIPv4, "remote IPv4 address for test packets") fs.StringVar(&RemoteIPv6, "remote_ipv6", RemoteIPv6, "remote IPv6 address for test packets") fs.StringVar(&RemoteMAC, "remote_mac", RemoteMAC, "remote mac address for test packets") - fs.StringVar(&Device, "device", Device, "local device for test packets") + fs.StringVar(&LocalDevice, "local_device", LocalDevice, "local device to inject traffic") + fs.StringVar(&RemoteDevice, "remote_device", RemoteDevice, "remote device on the DUT") fs.BoolVar(&Native, "native", Native, "whether the test is running natively") fs.Uint64Var(&RemoteInterfaceID, "remote_interface_id", RemoteInterfaceID, "remote interface ID for test packets") } diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index 6dda05102..f850dfcd8 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -320,3 +320,13 @@ packetimpact_go_test( "@org_golang_x_sys//unix:go_default_library", ], ) + +packetimpact_go_test( + name = "tcp_linger", + srcs = ["tcp_linger_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/test/packetimpact/tests/tcp_linger_test.go b/test/packetimpact/tests/tcp_linger_test.go new file mode 100644 index 000000000..913e49e06 --- /dev/null +++ b/test/packetimpact/tests/tcp_linger_test.go @@ -0,0 +1,253 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_linger_test + +import ( + "context" + "flag" + "syscall" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func createSocket(t *testing.T, dut testbench.DUT) (int32, int32, testbench.TCPIPv4) { + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + conn.Connect(t) + acceptFD, _ := dut.Accept(t, listenFD) + return acceptFD, listenFD, conn +} + +func closeAll(t *testing.T, dut testbench.DUT, listenFD int32, conn testbench.TCPIPv4) { + conn.Close(t) + dut.Close(t, listenFD) + dut.TearDown() +} + +// lingerDuration is the timeout value used with SO_LINGER socket option. +const lingerDuration = 3 * time.Second + +// TestTCPLingerZeroTimeout tests when SO_LINGER is set with zero timeout. DUT +// should send RST-ACK when socket is closed. +func TestTCPLingerZeroTimeout(t *testing.T) { + // Create a socket, listen, TCP connect, and accept. + dut := testbench.NewDUT(t) + acceptFD, listenFD, conn := createSocket(t, dut) + defer closeAll(t, dut, listenFD, conn) + + dut.SetSockLingerOption(t, acceptFD, 0, true) + dut.Close(t, acceptFD) + + // If the linger timeout is set to zero, the DUT should send a RST. + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected RST-ACK packet within a second but got none: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) +} + +// TestTCPLingerOff tests when SO_LINGER is not set. DUT should send FIN-ACK +// when socket is closed. +func TestTCPLingerOff(t *testing.T) { + // Create a socket, listen, TCP connect, and accept. + dut := testbench.NewDUT(t) + acceptFD, listenFD, conn := createSocket(t, dut) + defer closeAll(t, dut, listenFD, conn) + + dut.Close(t, acceptFD) + + // If SO_LINGER is not set, DUT should send a FIN-ACK. + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected FIN-ACK packet within a second but got none: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) +} + +// TestTCPLingerNonZeroTimeout tests when SO_LINGER is set with non-zero timeout. +// DUT should close the socket after timeout. +func TestTCPLingerNonZeroTimeout(t *testing.T) { + for _, tt := range []struct { + description string + lingerOn bool + }{ + {"WithNonZeroLinger", true}, + {"WithoutLinger", false}, + } { + t.Run(tt.description, func(t *testing.T) { + // Create a socket, listen, TCP connect, and accept. + dut := testbench.NewDUT(t) + acceptFD, listenFD, conn := createSocket(t, dut) + defer closeAll(t, dut, listenFD, conn) + + dut.SetSockLingerOption(t, acceptFD, lingerDuration, tt.lingerOn) + + // Increase timeout as Close will take longer time to + // return when SO_LINGER is set with non-zero timeout. + timeout := lingerDuration + 1*time.Second + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + start := time.Now() + dut.CloseWithErrno(ctx, t, acceptFD) + end := time.Now() + diff := end.Sub(start) + + if tt.lingerOn && diff < lingerDuration { + t.Errorf("expected close to return after %v seconds, but returned sooner", lingerDuration) + } else if !tt.lingerOn && diff > 1*time.Second { + t.Errorf("expected close to return within a second, but returned later") + } + + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected FIN-ACK packet within a second but got none: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + }) + } +} + +// TestTCPLingerSendNonZeroTimeout tests when SO_LINGER is set with non-zero +// timeout and send a packet. DUT should close the socket after timeout. +func TestTCPLingerSendNonZeroTimeout(t *testing.T) { + for _, tt := range []struct { + description string + lingerOn bool + }{ + {"WithSendNonZeroLinger", true}, + {"WithoutLinger", false}, + } { + t.Run(tt.description, func(t *testing.T) { + // Create a socket, listen, TCP connect, and accept. + dut := testbench.NewDUT(t) + acceptFD, listenFD, conn := createSocket(t, dut) + defer closeAll(t, dut, listenFD, conn) + + dut.SetSockLingerOption(t, acceptFD, lingerDuration, tt.lingerOn) + + // Send data. + sampleData := []byte("Sample Data") + dut.Send(t, acceptFD, sampleData, 0) + + // Increase timeout as Close will take longer time to + // return when SO_LINGER is set with non-zero timeout. + timeout := lingerDuration + 1*time.Second + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + start := time.Now() + dut.CloseWithErrno(ctx, t, acceptFD) + end := time.Now() + diff := end.Sub(start) + + if tt.lingerOn && diff < lingerDuration { + t.Errorf("expected close to return after %v seconds, but returned sooner", lingerDuration) + } else if !tt.lingerOn && diff > 1*time.Second { + t.Errorf("expected close to return within a second, but returned later") + } + + samplePayload := &testbench.Payload{Bytes: sampleData} + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) + } + + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected FIN-ACK packet within a second but got none: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + }) + } +} + +// TestTCPLingerShutdownZeroTimeout tests SO_LINGER with shutdown() and zero +// timeout. DUT should send RST-ACK when socket is closed. +func TestTCPLingerShutdownZeroTimeout(t *testing.T) { + // Create a socket, listen, TCP connect, and accept. + dut := testbench.NewDUT(t) + acceptFD, listenFD, conn := createSocket(t, dut) + defer closeAll(t, dut, listenFD, conn) + + dut.SetSockLingerOption(t, acceptFD, 0, true) + dut.Shutdown(t, acceptFD, syscall.SHUT_RDWR) + dut.Close(t, acceptFD) + + // Shutdown will send FIN-ACK with read/write option. + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected FIN-ACK packet within a second but got none: %s", err) + } + + // If the linger timeout is set to zero, the DUT should send a RST. + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected RST-ACK packet within a second but got none: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) +} + +// TestTCPLingerShutdownSendNonZeroTimeout tests SO_LINGER with shutdown() and +// non-zero timeout. DUT should close the socket after timeout. +func TestTCPLingerShutdownSendNonZeroTimeout(t *testing.T) { + for _, tt := range []struct { + description string + lingerOn bool + }{ + {"shutdownRDWR", true}, + {"shutdownRDWR", false}, + } { + t.Run(tt.description, func(t *testing.T) { + // Create a socket, listen, TCP connect, and accept. + dut := testbench.NewDUT(t) + acceptFD, listenFD, conn := createSocket(t, dut) + defer closeAll(t, dut, listenFD, conn) + + dut.SetSockLingerOption(t, acceptFD, lingerDuration, tt.lingerOn) + + // Send data. + sampleData := []byte("Sample Data") + dut.Send(t, acceptFD, sampleData, 0) + + dut.Shutdown(t, acceptFD, syscall.SHUT_RDWR) + + // Increase timeout as Close will take longer time to + // return when SO_LINGER is set with non-zero timeout. + timeout := lingerDuration + 1*time.Second + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + start := time.Now() + dut.CloseWithErrno(ctx, t, acceptFD) + end := time.Now() + diff := end.Sub(start) + + if tt.lingerOn && diff < lingerDuration { + t.Errorf("expected close to return after %v seconds, but returned sooner", lingerDuration) + } else if !tt.lingerOn && diff > 1*time.Second { + t.Errorf("expected close to return within a second, but returned later") + } + + samplePayload := &testbench.Payload{Bytes: sampleData} + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) + } + + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected FIN-ACK packet within a second but got none: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + }) + } +} diff --git a/test/runtimes/BUILD b/test/runtimes/BUILD index 066338ee3..22b526f59 100644 --- a/test/runtimes/BUILD +++ b/test/runtimes/BUILD @@ -5,7 +5,7 @@ package(licenses = ["notice"]) runtime_test( name = "go1.12", - exclude_file = "exclude_go1.12.csv", + exclude_file = "exclude/go1.12.csv", lang = "go", shard_count = 8, ) @@ -13,28 +13,28 @@ runtime_test( runtime_test( name = "java11", batch = 100, - exclude_file = "exclude_java11.csv", + exclude_file = "exclude/java11.csv", lang = "java", shard_count = 16, ) runtime_test( name = "nodejs12.4.0", - exclude_file = "exclude_nodejs12.4.0.csv", + exclude_file = "exclude/nodejs12.4.0.csv", lang = "nodejs", shard_count = 8, ) runtime_test( name = "php7.3.6", - exclude_file = "exclude_php7.3.6.csv", + exclude_file = "exclude/php7.3.6.csv", lang = "php", shard_count = 8, ) runtime_test( name = "python3.7.3", - exclude_file = "exclude_python3.7.3.csv", + exclude_file = "exclude/python3.7.3.csv", lang = "python", shard_count = 8, ) diff --git a/test/runtimes/README.md b/test/runtimes/README.md new file mode 100644 index 000000000..9dda1a728 --- /dev/null +++ b/test/runtimes/README.md @@ -0,0 +1,62 @@ +# gVisor Runtime Tests + +App Engine uses gvisor to sandbox application containers. The runtime tests aim +to test `runsc` compatibility with these +[standard runtimes](https://cloud.google.com/appengine/docs/standard/runtimes). +The test itself runs the language-defined tests inside the sandboxed standard +runtime container. + +Note: [Ruby runtime](https://cloud.google.com/appengine/docs/standard/ruby) is +currently in beta mode and so we do not run tests for it yet. + +### Testing Locally + +To run runtime tests individually from a given runtime, use the following table. + +Language | Version | Download Image | Run Test(s) +-------- | ------- | ------------------------------------------- | ----------- +Go | 1.12 | `make -C images load-runtimes_go1.12` | If the test name ends with `.go`, it is an on-disk test: <br> `docker run --runtime=runsc -it gvisor.dev/images/runtimes/go1.12 ( cd /usr/local/go/test ; go run run.go -v -- <TEST_NAME>... )` <br> Otherwise it is a tool test: <br> `docker run --runtime=runsc -it gvisor.dev/images/runtimes/go1.12 go tool dist test -v -no-rebuild ^TEST1$\|^TEST2$...` +Java | 11 | `make -C images load-runtimes_java11` | `docker run --runtime=runsc -it gvisor.dev/images/runtimes/java11 jtreg -agentvm -dir:/root/test/jdk -noreport -timeoutFactor:20 -verbose:summary <TEST_NAME>...` +NodeJS | 12.4.0 | `make -C images load-runtimes_nodejs12.4.0` | `docker run --runtime=runsc -it gvisor.dev/images/runtimes/nodejs12.4.0 python tools/test.py --timeout=180 <TEST_NAME>...` +Php | 7.3.6 | `make -C images load-runtimes_php7.3.6` | `docker run --runtime=runsc -it gvisor.dev/images/runtimes/php7.3.6 make test "TESTS=<TEST_NAME>..."` +Python | 3.7.3 | `make -C images load-runtimes_python3.7.3` | `docker run --runtime=runsc -it gvisor.dev/images/runtimes/python3.7.3 ./python -m test <TEST_NAME>...` + +To run an entire runtime test locally, use the following table. + +Note: java runtime test take 1+ hours with 16 cores. + +Language | Version | Running the test suite +-------- | ------- | ---------------------------------------- +Go | 1.12 | `make go1.12-runtime-tests{_vfs2}` +Java | 11 | `make java11-runtime-tests{_vfs2}` +NodeJS | 12.4.0 | `make nodejs12.4.0-runtime-tests{_vfs2}` +Php | 7.3.6 | `make php7.3.6-runtime-tests{_vfs2}` +Python | 3.7.3 | `make python3.7.3-runtime-tests{_vfs2}` + +#### Clean Up + +Sometimes when runtime tests fail or when the testing container itself crashes +unexpectedly, the containers are not removed or sometimes do not even exit. This +can cause some docker commands like `docker system prune` to hang forever. + +Here are some helpful commands (should be executed in order): + +```bash +docker ps -a # Lists all docker processes; useful when investigating hanging containers. +docker kill $(docker ps -a -q) # Kills all running containers. +docker rm $(docker ps -a -q) # Removes all exited containers. +docker system prune # Remove unused data. +``` + +### Testing Infrastructure + +There are 3 components to this tests infrastructure: + +- [`runner`](runner) - This is the test entrypoint. This is the binary is + invoked by `bazel test`. The runner spawns the target runtime container + using `runsc` and then copies over the `proctor` binary into the container. +- [`proctor`](proctor) - This binary acts as our agent inside the container + which communicates with the runner and actually executes tests. +- [`exclude`](exclude) - Holds a CSV file for each language runtime containing + the full path of tests that should be excluded from running along with a + reason for exclusion. diff --git a/test/runtimes/exclude_go1.12.csv b/test/runtimes/exclude/go1.12.csv index 81e02cf64..81e02cf64 100644 --- a/test/runtimes/exclude_go1.12.csv +++ b/test/runtimes/exclude/go1.12.csv diff --git a/test/runtimes/exclude_java11.csv b/test/runtimes/exclude/java11.csv index 997a29cad..997a29cad 100644 --- a/test/runtimes/exclude_java11.csv +++ b/test/runtimes/exclude/java11.csv diff --git a/test/runtimes/exclude_nodejs12.4.0.csv b/test/runtimes/exclude/nodejs12.4.0.csv index 1740dbb76..1740dbb76 100644 --- a/test/runtimes/exclude_nodejs12.4.0.csv +++ b/test/runtimes/exclude/nodejs12.4.0.csv diff --git a/test/runtimes/exclude_php7.3.6.csv b/test/runtimes/exclude/php7.3.6.csv index 815a137c5..815a137c5 100644 --- a/test/runtimes/exclude_php7.3.6.csv +++ b/test/runtimes/exclude/php7.3.6.csv diff --git a/test/runtimes/exclude_python3.7.3.csv b/test/runtimes/exclude/python3.7.3.csv index 8760f8951..8760f8951 100644 --- a/test/runtimes/exclude_python3.7.3.csv +++ b/test/runtimes/exclude/python3.7.3.csv diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index de753fc4e..e5d43cf2e 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -2430,6 +2430,7 @@ cc_library( ":socket_netlink_route_util", ":socket_test_util", "//test/util:capability_util", + "//test/util:cleanup", gtest, ], alwayslink = 1, diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 54fee2e82..11fcec443 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -1116,9 +1116,6 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) { TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; - // TODO(gvisor.dev/issue/1400): Remove this after SO_LINGER is fixed. - SKIP_IF(IsRunningOnGvisor()); - // Create the listening socket. const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); @@ -1178,12 +1175,20 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) { listen_fd.get(), reinterpret_cast<sockaddr*>(&accept_addr), &addrlen)); ASSERT_EQ(addrlen, listener.addr_len); - int err; - socklen_t optlen = sizeof(err); - ASSERT_THAT(getsockopt(accept_fd.get(), SOL_SOCKET, SO_ERROR, &err, &optlen), - SyscallSucceeds()); - ASSERT_EQ(err, ECONNRESET); - ASSERT_EQ(optlen, sizeof(err)); + // TODO(gvisor.dev/issue/3812): Remove after SO_ERROR is fixed. + if (IsRunningOnGvisor()) { + char buf[10]; + ASSERT_THAT(ReadFd(accept_fd.get(), buf, sizeof(buf)), + SyscallFailsWithErrno(ECONNRESET)); + } else { + int err; + socklen_t optlen = sizeof(err); + ASSERT_THAT( + getsockopt(accept_fd.get(), SOL_SOCKET, SO_ERROR, &err, &optlen), + SyscallSucceeds()); + ASSERT_EQ(err, ECONNRESET); + ASSERT_EQ(optlen, sizeof(err)); + } } // TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index 04356b780..f4b69c46c 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -1080,5 +1080,124 @@ TEST_P(TCPSocketPairTest, TCPResetDuringClose_NoRandomSave) { } } +// Test setsockopt and getsockopt for a socket with SO_LINGER option. +TEST_P(TCPSocketPairTest, SetAndGetLingerOption) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Check getsockopt before SO_LINGER option is set. + struct linger got_linger = {-1, -1}; + socklen_t got_len = sizeof(got_linger); + + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &got_len), + SyscallSucceeds()); + ASSERT_THAT(got_len, sizeof(got_linger)); + struct linger want_linger = {}; + EXPECT_EQ(0, memcmp(&want_linger, &got_linger, got_len)); + + // Set and get SO_LINGER with negative values. + struct linger sl; + sl.l_onoff = 1; + sl.l_linger = -3; + ASSERT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + SyscallSucceeds()); + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &got_len), + SyscallSucceeds()); + ASSERT_EQ(got_len, sizeof(got_linger)); + EXPECT_EQ(sl.l_onoff, got_linger.l_onoff); + // Linux returns a different value as it uses HZ to convert the seconds to + // jiffies which overflows for negative values. We want to be compatible with + // linux for getsockopt return value. + if (IsRunningOnGvisor()) { + EXPECT_EQ(sl.l_linger, got_linger.l_linger); + } + + // Set and get SO_LINGER option with positive values. + sl.l_onoff = 1; + sl.l_linger = 5; + ASSERT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + SyscallSucceeds()); + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &got_len), + SyscallSucceeds()); + ASSERT_EQ(got_len, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&sl, &got_linger, got_len)); +} + +// Test socket to disable SO_LINGER option. +TEST_P(TCPSocketPairTest, SetOffLingerOption) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Set the SO_LINGER option. + struct linger sl; + sl.l_onoff = 1; + sl.l_linger = 5; + ASSERT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + SyscallSucceeds()); + + // Check getsockopt after SO_LINGER option is set. + struct linger got_linger = {-1, -1}; + socklen_t got_len = sizeof(got_linger); + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &got_len), + SyscallSucceeds()); + ASSERT_EQ(got_len, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&sl, &got_linger, got_len)); + + sl.l_onoff = 0; + sl.l_linger = 5; + ASSERT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + SyscallSucceeds()); + + // Check getsockopt after SO_LINGER option is set to zero. + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &got_len), + SyscallSucceeds()); + ASSERT_EQ(got_len, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&sl, &got_linger, got_len)); +} + +// Test close on dup'd socket with SO_LINGER option set. +TEST_P(TCPSocketPairTest, CloseWithLingerOption) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Set the SO_LINGER option. + struct linger sl; + sl.l_onoff = 1; + sl.l_linger = 5; + ASSERT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + SyscallSucceeds()); + + // Check getsockopt after SO_LINGER option is set. + struct linger got_linger = {-1, -1}; + socklen_t got_len = sizeof(got_linger); + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &got_len), + SyscallSucceeds()); + ASSERT_EQ(got_len, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&sl, &got_linger, got_len)); + + FileDescriptor dupFd = FileDescriptor(dup(sockets->first_fd())); + ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); + char buf[10] = {}; + // Write on dupFd should succeed as socket will not be closed until + // all references are removed. + ASSERT_THAT(RetryEINTR(write)(dupFd.get(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)), + SyscallFailsWithErrno(EBADF)); + + // Close the socket. + dupFd.reset(); + // Write on dupFd should fail as all references for socket are removed. + ASSERT_THAT(RetryEINTR(write)(dupFd.get(), buf, sizeof(buf)), + SyscallFailsWithErrno(EBADF)); +} } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc index bbe356116..6e4ecd680 100644 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_udp_generic.cc @@ -450,5 +450,35 @@ TEST_P(UDPSocketPairTest, TClassRecvMismatch) { SyscallFailsWithErrno(EOPNOTSUPP)); } +// Test the SO_LINGER option can be set/get on udp socket. +TEST_P(UDPSocketPairTest, SoLingerFail) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + int level = SOL_SOCKET; + int type = SO_LINGER; + + struct linger sl; + sl.l_onoff = 1; + sl.l_linger = 5; + ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &sl, sizeof(sl)), + SyscallSucceedsWithValue(0)); + + struct linger got_linger = {}; + socklen_t length = sizeof(sl); + ASSERT_THAT( + getsockopt(sockets->first_fd(), level, type, &got_linger, &length), + SyscallSucceedsWithValue(0)); + + ASSERT_EQ(length, sizeof(got_linger)); + // Linux returns the values which are set in the SetSockOpt for SO_LINGER. + // In gVisor, we do not store the linger values for UDP as SO_LINGER for UDP + // is a no-op. + if (IsRunningOnGvisor()) { + struct linger want_linger = {}; + EXPECT_EQ(0, memcmp(&want_linger, &got_linger, length)); + } else { + EXPECT_EQ(0, memcmp(&sl, &got_linger, length)); + } +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc index 79eb48afa..49a0f06d9 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc @@ -15,10 +15,12 @@ #include "test/syscalls/linux/socket_ipv4_udp_unbound_netlink.h" #include <arpa/inet.h> +#include <poll.h> #include "gtest/gtest.h" #include "test/syscalls/linux/socket_netlink_route_util.h" #include "test/util/capability_util.h" +#include "test/util/cleanup.h" namespace gvisor { namespace testing { @@ -33,9 +35,23 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, JoinSubnet) { // Add an IP address to the loopback interface. Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); struct in_addr addr; - EXPECT_EQ(1, inet_pton(AF_INET, "192.0.2.1", &addr)); - EXPECT_NO_ERRNO(LinkAddLocalAddr(loopback_link.index, AF_INET, + ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.1", &addr)); + ASSERT_NO_ERRNO(LinkAddLocalAddr(loopback_link.index, AF_INET, /*prefixlen=*/24, &addr, sizeof(addr))); + Cleanup defer_addr_removal = Cleanup( + [loopback_link = std::move(loopback_link), addr = std::move(addr)] { + if (IsRunningOnGvisor()) { + // TODO(gvisor.dev/issue/3921): Remove this once deleting addresses + // via netlink is supported. + EXPECT_THAT(LinkDelLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, sizeof(addr)), + PosixErrorIs(EOPNOTSUPP, ::testing::_)); + } else { + EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, + sizeof(addr))); + } + }); auto snd_sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto rcv_sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -45,10 +61,10 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, JoinSubnet) { TestAddress sender_addr("V4NotAssignd1"); sender_addr.addr.ss_family = AF_INET; sender_addr.addr_len = sizeof(sockaddr_in); - EXPECT_EQ(1, inet_pton(AF_INET, "192.0.2.2", + ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.2", &(reinterpret_cast<sockaddr_in*>(&sender_addr.addr) ->sin_addr.s_addr))); - EXPECT_THAT( + ASSERT_THAT( bind(snd_sock->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), sender_addr.addr_len), SyscallSucceeds()); @@ -58,10 +74,10 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, JoinSubnet) { TestAddress receiver_addr("V4NotAssigned2"); receiver_addr.addr.ss_family = AF_INET; receiver_addr.addr_len = sizeof(sockaddr_in); - EXPECT_EQ(1, inet_pton(AF_INET, "192.0.2.254", + ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.254", &(reinterpret_cast<sockaddr_in*>(&receiver_addr.addr) ->sin_addr.s_addr))); - EXPECT_THAT( + ASSERT_THAT( bind(rcv_sock->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), receiver_addr.addr_len), SyscallSucceeds()); @@ -70,10 +86,10 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, JoinSubnet) { reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); - EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + ASSERT_EQ(receiver_addr_len, receiver_addr.addr_len); char send_buf[kSendBufSize]; RandomizeBuffer(send_buf, kSendBufSize); - EXPECT_THAT( + ASSERT_THAT( RetryEINTR(sendto)(snd_sock->get(), send_buf, kSendBufSize, 0, reinterpret_cast<sockaddr*>(&receiver_addr.addr), receiver_addr.addr_len), @@ -83,7 +99,126 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, JoinSubnet) { char recv_buf[kSendBufSize] = {}; ASSERT_THAT(RetryEINTR(recv)(rcv_sock->get(), recv_buf, kSendBufSize, 0), SyscallSucceedsWithValue(kSendBufSize)); - EXPECT_EQ(0, memcmp(send_buf, recv_buf, kSendBufSize)); + ASSERT_EQ(0, memcmp(send_buf, recv_buf, kSendBufSize)); +} + +// Tests that broadcast packets are delivered to all interested sockets +// (wildcard and broadcast address specified sockets). +// +// Note, we cannot test the IPv4 Broadcast (255.255.255.255) because we do +// not have a route to it. +TEST_P(IPv4UDPUnboundSocketNetlinkTest, ReuseAddrSubnetDirectedBroadcast) { + constexpr uint16_t kPort = 9876; + // Wait up to 20 seconds for the data. + constexpr int kPollTimeoutMs = 20000; + // Number of sockets per socket type. + constexpr int kNumSocketsPerType = 2; + + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + // Add an IP address to the loopback interface. + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + struct in_addr addr; + ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.1", &addr)); + ASSERT_NO_ERRNO(LinkAddLocalAddr(loopback_link.index, AF_INET, + 24 /* prefixlen */, &addr, sizeof(addr))); + Cleanup defer_addr_removal = Cleanup( + [loopback_link = std::move(loopback_link), addr = std::move(addr)] { + if (IsRunningOnGvisor()) { + // TODO(gvisor.dev/issue/3921): Remove this once deleting addresses + // via netlink is supported. + EXPECT_THAT(LinkDelLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, sizeof(addr)), + PosixErrorIs(EOPNOTSUPP, ::testing::_)); + } else { + EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, + sizeof(addr))); + } + }); + + TestAddress broadcast_address("SubnetBroadcastAddress"); + broadcast_address.addr.ss_family = AF_INET; + broadcast_address.addr_len = sizeof(sockaddr_in); + auto broadcast_address_in = + reinterpret_cast<sockaddr_in*>(&broadcast_address.addr); + ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.255", + &broadcast_address_in->sin_addr.s_addr)); + broadcast_address_in->sin_port = htons(kPort); + + TestAddress any_address = V4Any(); + reinterpret_cast<sockaddr_in*>(&any_address.addr)->sin_port = htons(kPort); + + // We create sockets bound to both the wildcard address and the broadcast + // address to make sure both of these types of "broadcast interested" sockets + // receive broadcast packets. + std::vector<std::unique_ptr<FileDescriptor>> socks; + for (bool bind_wildcard : {false, true}) { + // Create multiple sockets for each type of "broadcast interested" + // socket so we can test that all sockets receive the broadcast packet. + for (int i = 0; i < kNumSocketsPerType; i++) { + auto sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto idx = socks.size(); + + ASSERT_THAT(setsockopt(sock->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)) + << "socks[" << idx << "]"; + + ASSERT_THAT(setsockopt(sock->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)) + << "socks[" << idx << "]"; + + if (bind_wildcard) { + ASSERT_THAT( + bind(sock->get(), reinterpret_cast<sockaddr*>(&any_address.addr), + any_address.addr_len), + SyscallSucceeds()) + << "socks[" << idx << "]"; + } else { + ASSERT_THAT(bind(sock->get(), + reinterpret_cast<sockaddr*>(&broadcast_address.addr), + broadcast_address.addr_len), + SyscallSucceeds()) + << "socks[" << idx << "]"; + } + + socks.push_back(std::move(sock)); + } + } + + char send_buf[kSendBufSize]; + RandomizeBuffer(send_buf, kSendBufSize); + + // Broadcasts from each socket should be received by every socket (including + // the sending socket). + for (int w = 0; w < socks.size(); w++) { + auto& w_sock = socks[w]; + ASSERT_THAT( + RetryEINTR(sendto)(w_sock->get(), send_buf, kSendBufSize, 0, + reinterpret_cast<sockaddr*>(&broadcast_address.addr), + broadcast_address.addr_len), + SyscallSucceedsWithValue(kSendBufSize)) + << "write socks[" << w << "]"; + + // Check that we received the packet on all sockets. + for (int r = 0; r < socks.size(); r++) { + auto& r_sock = socks[r]; + + struct pollfd poll_fd = {r_sock->get(), POLLIN, 0}; + EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), + SyscallSucceedsWithValue(1)) + << "write socks[" << w << "] & read socks[" << r << "]"; + + char recv_buf[kSendBufSize] = {}; + EXPECT_THAT(RetryEINTR(recv)(r_sock->get(), recv_buf, kSendBufSize, 0), + SyscallSucceedsWithValue(kSendBufSize)) + << "write socks[" << w << "] & read socks[" << r << "]"; + EXPECT_EQ(0, memcmp(send_buf, recv_buf, kSendBufSize)) + << "write socks[" << w << "] & read socks[" << r << "]"; + } + } } } // namespace testing diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc index bde1dbb4d..a354f3f80 100644 --- a/test/syscalls/linux/socket_netlink_route_util.cc +++ b/test/syscalls/linux/socket_netlink_route_util.cc @@ -26,6 +26,62 @@ namespace { constexpr uint32_t kSeq = 12345; +// Types of address modifications that may be performed on an interface. +enum class LinkAddrModification { + kAdd, + kDelete, +}; + +// Populates |hdr| with appripriate values for the modification type. +PosixError PopulateNlmsghdr(LinkAddrModification modification, + struct nlmsghdr* hdr) { + switch (modification) { + case LinkAddrModification::kAdd: + hdr->nlmsg_type = RTM_NEWADDR; + hdr->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + return NoError(); + case LinkAddrModification::kDelete: + hdr->nlmsg_type = RTM_DELADDR; + hdr->nlmsg_flags = NLM_F_REQUEST; + return NoError(); + } + + return PosixError(EINVAL); +} + +// Adds or removes the specified address from the specified interface. +PosixError LinkModifyLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen, + LinkAddrModification modification) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifaddrmsg ifaddr; + char attrbuf[512]; + }; + + struct request req = {}; + PosixError err = PopulateNlmsghdr(modification, &req.hdr); + if (!err.ok()) { + return err; + } + req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifaddr)); + req.hdr.nlmsg_seq = kSeq; + req.ifaddr.ifa_index = index; + req.ifaddr.ifa_family = family; + req.ifaddr.ifa_prefixlen = prefixlen; + + struct rtattr* rta = reinterpret_cast<struct rtattr*>( + reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len)); + rta->rta_type = IFA_LOCAL; + rta->rta_len = RTA_LENGTH(addrlen); + req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen); + memcpy(RTA_DATA(rta), addr, addrlen); + + return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); +} + } // namespace PosixError DumpLinks( @@ -84,31 +140,14 @@ PosixErrorOr<Link> LoopbackLink() { PosixError LinkAddLocalAddr(int index, int family, int prefixlen, const void* addr, int addrlen) { - ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); - - struct request { - struct nlmsghdr hdr; - struct ifaddrmsg ifaddr; - char attrbuf[512]; - }; - - struct request req = {}; - req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifaddr)); - req.hdr.nlmsg_type = RTM_NEWADDR; - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; - req.hdr.nlmsg_seq = kSeq; - req.ifaddr.ifa_index = index; - req.ifaddr.ifa_family = family; - req.ifaddr.ifa_prefixlen = prefixlen; - - struct rtattr* rta = reinterpret_cast<struct rtattr*>( - reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len)); - rta->rta_type = IFA_LOCAL; - rta->rta_len = RTA_LENGTH(addrlen); - req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen); - memcpy(RTA_DATA(rta), addr, addrlen); + return LinkModifyLocalAddr(index, family, prefixlen, addr, addrlen, + LinkAddrModification::kAdd); +} - return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); +PosixError LinkDelLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen) { + return LinkModifyLocalAddr(index, family, prefixlen, addr, addrlen, + LinkAddrModification::kDelete); } PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change) { diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h index 149c4a7f6..e5badca70 100644 --- a/test/syscalls/linux/socket_netlink_route_util.h +++ b/test/syscalls/linux/socket_netlink_route_util.h @@ -43,6 +43,10 @@ PosixErrorOr<Link> LoopbackLink(); PosixError LinkAddLocalAddr(int index, int family, int prefixlen, const void* addr, int addrlen); +// LinkDelLocalAddr removes IFA_LOCAL attribute on the interface. +PosixError LinkDelLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen); + // LinkChangeFlags changes interface flags. E.g. IFF_UP. PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change); |