summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/merkletree/merkletree.go25
-rw-r--r--pkg/merkletree/merkletree_test.go42
-rw-r--r--pkg/sentry/fs/g3doc/fuse.md9
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go222
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go77
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go4
-rw-r--r--pkg/sentry/platform/ring0/defs_arm64.go3
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s48
-rw-r--r--pkg/sentry/platform/ring0/kernel_arm64.go5
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.go9
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.s14
-rw-r--r--pkg/sentry/platform/ring0/offsets_arm64.go1
-rw-r--r--pkg/sentry/socket/netstack/netstack.go45
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go10
-rw-r--r--pkg/tcpip/link/rawfile/BUILD13
-rw-r--r--pkg/tcpip/link/rawfile/errors.go8
-rw-r--r--pkg/tcpip/link/rawfile/errors_test.go53
-rw-r--r--pkg/tcpip/stack/nic.go44
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go18
-rw-r--r--pkg/tcpip/tcpip.go13
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go61
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go120
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go33
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go68
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go32
25 files changed, 796 insertions, 181 deletions
diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go
index 36832ec86..4b4f9bd52 100644
--- a/pkg/merkletree/merkletree.go
+++ b/pkg/merkletree/merkletree.go
@@ -225,9 +225,9 @@ func Generate(data io.ReadSeeker, dataSize int64, treeReader io.ReadSeeker, tree
// Verify will modify the cursor for data, but always restores it to its
// original position upon exit. The cursor for tree is modified and not
// restored.
-func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset int64, readSize int64, expectedRoot []byte, dataAndTreeInSameFile bool) error {
+func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset int64, readSize int64, expectedRoot []byte, dataAndTreeInSameFile bool) (int64, error) {
if readSize <= 0 {
- return fmt.Errorf("Unexpected read size: %d", readSize)
+ return 0, fmt.Errorf("Unexpected read size: %d", readSize)
}
layout := InitLayout(int64(dataSize), dataAndTreeInSameFile)
@@ -240,29 +240,30 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in
// finishes.
origOffset, err := data.Seek(0, io.SeekCurrent)
if err != nil {
- return fmt.Errorf("Find current data offset failed: %v", err)
+ return 0, fmt.Errorf("Find current data offset failed: %v", err)
}
defer data.Seek(origOffset, io.SeekStart)
// Move to the first block that contains target data.
if _, err := data.Seek(firstDataBlock*layout.blockSize, io.SeekStart); err != nil {
- return fmt.Errorf("Seek to datablock start failed: %v", err)
+ return 0, fmt.Errorf("Seek to datablock start failed: %v", err)
}
buf := make([]byte, layout.blockSize)
var readErr error
- bytesRead := 0
+ total := int64(0)
for i := firstDataBlock; i <= lastDataBlock; i++ {
// Read a block that includes all or part of target range in
// input data.
- bytesRead, readErr = data.Read(buf)
+ bytesRead, err := data.Read(buf)
+ readErr = err
// If at the end of input data and all previous blocks are
// verified, return the verified input data and EOF.
if readErr == io.EOF && bytesRead == 0 {
break
}
if readErr != nil && readErr != io.EOF {
- return fmt.Errorf("Read from data failed: %v", err)
+ return 0, fmt.Errorf("Read from data failed: %v", err)
}
// If this is the end of file, zero the remaining bytes in buf,
// otherwise they are still from the previous block.
@@ -274,7 +275,7 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in
}
}
if err := verifyBlock(tree, layout, buf, i, expectedRoot); err != nil {
- return err
+ return 0, err
}
// startOff is the beginning of the read range within the
// current data block. Note that for all blocks other than the
@@ -298,10 +299,14 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in
if endOff > int64(bytesRead) {
endOff = int64(bytesRead)
}
- w.Write(buf[startOff:endOff])
+ n, err := w.Write(buf[startOff:endOff])
+ if err != nil {
+ return total, err
+ }
+ total += int64(n)
}
- return readErr
+ return total, readErr
}
// verifyBlock verifies a block against tree. index is the number of block in
diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go
index ad50ba5f6..daaca759a 100644
--- a/pkg/merkletree/merkletree_test.go
+++ b/pkg/merkletree/merkletree_test.go
@@ -67,17 +67,17 @@ func TestLayout(t *testing.T) {
t.Run(fmt.Sprintf("%d", tc.dataSize), func(t *testing.T) {
l := InitLayout(tc.dataSize, tc.dataAndTreeInSameFile)
if l.blockSize != int64(usermem.PageSize) {
- t.Errorf("got blockSize %d, want %d", l.blockSize, usermem.PageSize)
+ t.Errorf("Got blockSize %d, want %d", l.blockSize, usermem.PageSize)
}
if l.digestSize != sha256DigestSize {
- t.Errorf("got digestSize %d, want %d", l.digestSize, sha256DigestSize)
+ t.Errorf("Got digestSize %d, want %d", l.digestSize, sha256DigestSize)
}
if l.numLevels() != len(tc.expectedLevelOffset) {
- t.Errorf("got levels %d, want %d", l.numLevels(), len(tc.expectedLevelOffset))
+ t.Errorf("Got levels %d, want %d", l.numLevels(), len(tc.expectedLevelOffset))
}
for i := 0; i < l.numLevels() && i < len(tc.expectedLevelOffset); i++ {
if l.levelOffset[i] != tc.expectedLevelOffset[i] {
- t.Errorf("got levelStart[%d] %d, want %d", i, l.levelOffset[i], tc.expectedLevelOffset[i])
+ t.Errorf("Got levelStart[%d] %d, want %d", i, l.levelOffset[i], tc.expectedLevelOffset[i])
}
}
})
@@ -169,11 +169,11 @@ func TestGenerate(t *testing.T) {
}, int64(len(tc.data)), &tree, &tree, dataAndTreeInSameFile)
}
if err != nil {
- t.Fatalf("got err: %v, want nil", err)
+ t.Fatalf("Got err: %v, want nil", err)
}
if !bytes.Equal(root, tc.expectedRoot) {
- t.Errorf("got root: %v, want %v", root, tc.expectedRoot)
+ t.Errorf("Got root: %v, want %v", root, tc.expectedRoot)
}
}
})
@@ -334,14 +334,21 @@ func TestVerify(t *testing.T) {
var buf bytes.Buffer
data[tc.modifyByte] ^= 1
if tc.shouldSucceed {
- if err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile); err != nil && err != io.EOF {
+ n, err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile)
+ if err != nil && err != io.EOF {
t.Errorf("Verification failed when expected to succeed: %v", err)
}
- if int64(buf.Len()) != tc.verifySize || !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) {
- t.Errorf("Incorrect output from Verify")
+ if n != tc.verifySize {
+ t.Errorf("Got Verify output size %d, want %d", n, tc.verifySize)
+ }
+ if int64(buf.Len()) != tc.verifySize {
+ t.Errorf("Got Verify output buf size %d, want %d,", buf.Len(), tc.verifySize)
+ }
+ if !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) {
+ t.Errorf("Incorrect output buf from Verify")
}
} else {
- if err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile); err == nil {
+ if _, err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile); err == nil {
t.Errorf("Verification succeeded when expected to fail")
}
}
@@ -382,14 +389,21 @@ func TestVerifyRandom(t *testing.T) {
var buf bytes.Buffer
// Checks that the random portion of data from the original data is
// verified successfully.
- if err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile); err != nil && err != io.EOF {
+ n, err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile)
+ if err != nil && err != io.EOF {
t.Errorf("Verification failed for correct data: %v", err)
}
if size > dataSize-start {
size = dataSize - start
}
- if int64(buf.Len()) != size || !bytes.Equal(data[start:start+size], buf.Bytes()) {
- t.Errorf("Incorrect output from Verify")
+ if n != size {
+ t.Errorf("Got Verify output size %d, want %d", n, size)
+ }
+ if int64(buf.Len()) != size {
+ t.Errorf("Got Verify output buf size %d, want %d", buf.Len(), size)
+ }
+ if !bytes.Equal(data[start:start+size], buf.Bytes()) {
+ t.Errorf("Incorrect output buf from Verify")
}
buf.Reset()
@@ -397,7 +411,7 @@ func TestVerifyRandom(t *testing.T) {
randBytePos := rand.Int63n(size)
data[start+randBytePos] ^= 1
- if err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile); err == nil {
+ if _, err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile); err == nil {
t.Errorf("Verification succeeded for modified data")
}
}
diff --git a/pkg/sentry/fs/g3doc/fuse.md b/pkg/sentry/fs/g3doc/fuse.md
index b102c9b34..eccb1fb2f 100644
--- a/pkg/sentry/fs/g3doc/fuse.md
+++ b/pkg/sentry/fs/g3doc/fuse.md
@@ -119,8 +119,8 @@ ops can be implemented in parallel.
`connection.initializedChan`
-- a channel that the requests issued before connection initialization
- blocks on.
+- a channel that the requests issued before connection initialization blocks
+ on.
`fd.queue`
@@ -128,9 +128,8 @@ ops can be implemented in parallel.
`fd.completions`
-- a map of the requests that have been prepared but
- not yet received a response, including the ones on the
- `fd.queue`.
+- a map of the requests that have been prepared but not yet received a
+ response, including the ones on the `fd.queue`.
`fd.waitQueue`
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
index 2cf0a38c9..f86a6a0b2 100644
--- a/pkg/sentry/fsimpl/verity/filesystem.go
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -19,6 +19,7 @@ import (
"fmt"
"io"
"strconv"
+ "strings"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -192,10 +193,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// contains the expected xattrs. If the file or the xattr does not
// exist, it indicates unexpected modifications to the file system.
if err == syserror.ENOENT || err == syserror.ENODATA {
- if noCrashOnVerificationFailure {
- return nil, err
- }
- panic(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err))
+ return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err))
}
if err != nil {
return nil, err
@@ -204,10 +202,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// unexpected modifications to the file system.
offset, err := strconv.Atoi(off)
if err != nil {
- if noCrashOnVerificationFailure {
- return nil, syserror.EINVAL
- }
- panic(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err))
+ return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err))
}
// Open parent Merkle tree file to read and verify child's root hash.
@@ -221,10 +216,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// The parent Merkle tree file should have been created. If it's
// missing, it indicates an unexpected modification to the file system.
if err == syserror.ENOENT {
- if noCrashOnVerificationFailure {
- return nil, err
- }
- panic(fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err))
+ return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err))
}
if err != nil {
return nil, err
@@ -242,10 +234,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// contains the expected xattrs. If the file or the xattr does not
// exist, it indicates unexpected modifications to the file system.
if err == syserror.ENOENT || err == syserror.ENODATA {
- if noCrashOnVerificationFailure {
- return nil, err
- }
- panic(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err))
+ return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err))
}
if err != nil {
return nil, err
@@ -255,10 +244,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// unexpected modifications to the file system.
parentSize, err := strconv.Atoi(dataSize)
if err != nil {
- if noCrashOnVerificationFailure {
- return nil, syserror.EINVAL
- }
- panic(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
+ return nil, alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
}
fdReader := vfs.FileReadWriteSeeker{
@@ -270,11 +256,8 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// contain the root hash of the children in the parent Merkle tree when
// Verify returns with success.
var buf bytes.Buffer
- if err := merkletree.Verify(&buf, &fdReader, &fdReader, int64(parentSize), int64(offset), int64(merkletree.DigestSize()), parent.rootHash, true /* dataAndTreeInSameFile */); err != nil && err != io.EOF {
- if noCrashOnVerificationFailure {
- return nil, syserror.EIO
- }
- panic(fmt.Sprintf("Verification for %s failed: %v", childPath, err))
+ if _, err := merkletree.Verify(&buf, &fdReader, &fdReader, int64(parentSize), int64(offset), int64(merkletree.DigestSize()), parent.rootHash, true /* dataAndTreeInSameFile */); err != nil && err != io.EOF {
+ return nil, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification for %s failed: %v", childPath, err))
}
// Cache child root hash when it's verified the first time.
@@ -370,10 +353,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
// corresponding Merkle tree is found. This indicates an
// unexpected modification to the file system that
// removed/renamed the child.
- if noCrashOnVerificationFailure {
- return nil, childErr
- }
- panic(fmt.Sprintf("Target file %s is expected but missing", parentPath+"/"+name))
+ return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Target file %s is expected but missing", parentPath+"/"+name))
} else if childErr == nil && childMerkleErr == syserror.ENOENT {
// If in allowRuntimeEnable mode, and the Merkle tree file is
// not created yet, we create an empty Merkle tree file, so that
@@ -409,10 +389,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
// If runtime enable is not allowed. This indicates an
// unexpected modification to the file system that
// removed/renamed the Merkle tree file.
- if noCrashOnVerificationFailure {
- return nil, childMerkleErr
- }
- panic(fmt.Sprintf("Expected Merkle file for target %s but none found", parentPath+"/"+name))
+ return nil, alertIntegrityViolation(childMerkleErr, fmt.Sprintf("Expected Merkle file for target %s but none found", parentPath+"/"+name))
}
} else if childErr == syserror.ENOENT && childMerkleErr == syserror.ENOENT {
// Both the child and the corresponding Merkle tree are missing.
@@ -421,7 +398,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
// TODO(b/167752508): Investigate possible ways to differentiate
// cases that both files are deleted from cases that they never
// exist in the file system.
- panic(fmt.Sprintf("Failed to find file %s", parentPath+"/"+name))
+ return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Failed to find file %s", parentPath+"/"+name))
}
mask := uint32(linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID)
@@ -580,8 +557,181 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
// OpenAt implements vfs.FilesystemImpl.OpenAt.
func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- //TODO(b/159261227): Implement OpenAt.
- return nil, nil
+ // Verity fs is read-only.
+ if opts.Flags&(linux.O_WRONLY|linux.O_CREAT) != 0 {
+ return nil, syserror.EROFS
+ }
+
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+
+ start := rp.Start().Impl().(*dentry)
+ if rp.Done() {
+ return start.openLocked(ctx, rp, &opts)
+ }
+
+afterTrailingSymlink:
+ parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return nil, err
+ }
+
+ // Check for search permission in the parent directory.
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+
+ // Open existing child or follow symlink.
+ parent.dirMu.Lock()
+ child, err := fs.stepLocked(ctx, rp, parent, false /*mayFollowSymlinks*/, &ds)
+ parent.dirMu.Unlock()
+ if err != nil {
+ return nil, err
+ }
+ if child.isSymlink() && rp.ShouldFollowSymlink() {
+ target, err := child.readlink(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if err := rp.HandleSymlink(target); err != nil {
+ return nil, err
+ }
+ start = parent
+ goto afterTrailingSymlink
+ }
+ return child.openLocked(ctx, rp, &opts)
+}
+
+// Preconditions: fs.renameMu must be locked.
+func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
+ // Users should not open the Merkle tree files. Those are for verity fs
+ // use only.
+ if strings.Contains(d.name, merklePrefix) {
+ return nil, syserror.EPERM
+ }
+ ats := vfs.AccessTypesForOpenFlags(opts)
+ if err := d.checkPermissions(rp.Credentials(), ats); err != nil {
+ return nil, err
+ }
+
+ // Verity fs is read-only.
+ if ats&vfs.MayWrite != 0 {
+ return nil, syserror.EROFS
+ }
+
+ // Get the path to the target file. This is only used to provide path
+ // information in failure case.
+ path, err := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.lowerVD)
+ if err != nil {
+ return nil, err
+ }
+
+ // Open the file in the underlying file system.
+ lowerFD, err := rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.lowerVD,
+ Start: d.lowerVD,
+ }, opts)
+
+ // The file should exist, as we succeeded in finding its dentry. If it's
+ // missing, it indicates an unexpected modification to the file system.
+ if err != nil {
+ if err == syserror.ENOENT {
+ return nil, alertIntegrityViolation(err, fmt.Sprintf("File %s expected but not found", path))
+ }
+ return nil, err
+ }
+
+ // lowerFD needs to be cleaned up if any error occurs. IncRef will be
+ // called if a verity FD is successfully created.
+ defer lowerFD.DecRef(ctx)
+
+ // Open the Merkle tree file corresponding to the current file/directory
+ // to be used later for verifying Read/Walk.
+ merkleReader, err := rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.lowerMerkleVD,
+ Start: d.lowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ })
+
+ // The Merkle tree file should exist, as we succeeded in finding its
+ // dentry. If it's missing, it indicates an unexpected modification to
+ // the file system.
+ if err != nil {
+ if err == syserror.ENOENT {
+ return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path))
+ }
+ return nil, err
+ }
+
+ // merkleReader needs to be cleaned up if any error occurs. IncRef will
+ // be called if a verity FD is successfully created.
+ defer merkleReader.DecRef(ctx)
+
+ lowerFlags := lowerFD.StatusFlags()
+ lowerFDOpts := lowerFD.Options()
+ var merkleWriter *vfs.FileDescription
+ var parentMerkleWriter *vfs.FileDescription
+
+ // Only open the Merkle tree files for write if in allowRuntimeEnable
+ // mode.
+ if d.fs.allowRuntimeEnable {
+ merkleWriter, err = rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.lowerMerkleVD,
+ Start: d.lowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_WRONLY | linux.O_APPEND,
+ })
+ if err != nil {
+ if err == syserror.ENOENT {
+ return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path))
+ }
+ return nil, err
+ }
+ // merkleWriter is cleaned up if any error occurs. IncRef will
+ // be called if a verity FD is created successfully.
+ defer merkleWriter.DecRef(ctx)
+
+ parentMerkleWriter, err = rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.parent.lowerMerkleVD,
+ Start: d.parent.lowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_WRONLY | linux.O_APPEND,
+ })
+ if err != nil {
+ if err == syserror.ENOENT {
+ parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD)
+ return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", parentPath))
+ }
+ return nil, err
+ }
+ // parentMerkleWriter is cleaned up if any error occurs. IncRef
+ // will be called if a verity FD is created successfully.
+ defer parentMerkleWriter.DecRef(ctx)
+ }
+
+ fd := &fileDescription{
+ d: d,
+ lowerFD: lowerFD,
+ merkleReader: merkleReader,
+ merkleWriter: merkleWriter,
+ parentMerkleWriter: parentMerkleWriter,
+ isDir: d.isDir(),
+ }
+
+ if err := fd.vfsfd.Init(fd, lowerFlags, rp.Mount(), &d.vfsd, &lowerFDOpts); err != nil {
+ return nil, err
+ }
+ lowerFD.IncRef()
+ merkleReader.IncRef()
+ if merkleWriter != nil {
+ merkleWriter.IncRef()
+ }
+ if parentMerkleWriter != nil {
+ parentMerkleWriter.IncRef()
+ }
+ return &fd.vfsfd, err
}
// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
index 249cc1341..3e0bcd02b 100644
--- a/pkg/sentry/fsimpl/verity/verity.go
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -22,6 +22,7 @@
package verity
import (
+ "fmt"
"strconv"
"sync/atomic"
@@ -29,7 +30,6 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/marshal/primitive"
-
"gvisor.dev/gvisor/pkg/merkletree"
"gvisor.dev/gvisor/pkg/sentry/arch"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
@@ -135,6 +135,16 @@ func (FilesystemType) Name() string {
return Name
}
+// alertIntegrityViolation alerts a violation of integrity, which usually means
+// unexpected modification to the file system is detected. In
+// noCrashOnVerificationFailure mode, it returns an error, otherwise it panic.
+func alertIntegrityViolation(err error, msg string) error {
+ if noCrashOnVerificationFailure {
+ return err
+ }
+ panic(msg)
+}
+
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
iopts, ok := opts.InternalData.(InternalFilesystemOptions)
@@ -204,15 +214,12 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
return nil, nil, err
}
} else if err != nil {
- // Failed to get dentry for the root Merkle file. This indicates
- // an attack that removed/renamed the root Merkle file, or it's
- // never generated.
- if noCrashOnVerificationFailure {
- fs.vfsfs.DecRef(ctx)
- d.DecRef(ctx)
- return nil, nil, err
- }
- panic("Failed to find root Merkle file")
+ // Failed to get dentry for the root Merkle file. This
+ // indicates an unexpected modification that removed/renamed
+ // the root Merkle file, or it's never generated.
+ fs.vfsfs.DecRef(ctx)
+ d.DecRef(ctx)
+ return nil, nil, alertIntegrityViolation(err, "Failed to find root Merkle file")
}
d.lowerMerkleVD = lowerMerkleVD
@@ -550,7 +557,7 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO, arg
defer verityMu.Unlock()
if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || fd.parentMerkleWriter == nil {
- panic("Unexpected verity fd: missing expected underlying fds")
+ return 0, alertIntegrityViolation(syserror.EIO, "Unexpected verity fd: missing expected underlying fds")
}
rootHash, dataSize, err := fd.generateMerkle(ctx)
@@ -618,6 +625,54 @@ func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.
}
}
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ // No need to verify if the file is not enabled yet in
+ // allowRuntimeEnable mode.
+ if fd.d.fs.allowRuntimeEnable && len(fd.d.rootHash) == 0 {
+ return fd.lowerFD.PRead(ctx, dst, offset, opts)
+ }
+
+ // dataSize is the size of the whole file.
+ dataSize, err := fd.merkleReader.GetXattr(ctx, &vfs.GetXattrOptions{
+ Name: merkleSizeXattr,
+ Size: sizeOfInt32,
+ })
+
+ // The Merkle tree file for the child should have been created and
+ // contains the expected xattrs. If the xattr does not exist, it
+ // indicates unexpected modifications to the file system.
+ if err == syserror.ENODATA {
+ return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err))
+ }
+ if err != nil {
+ return 0, err
+ }
+
+ // The dataSize xattr should be an integer. If it's not, it indicates
+ // unexpected modifications to the file system.
+ size, err := strconv.Atoi(dataSize)
+ if err != nil {
+ return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err))
+ }
+
+ dataReader := vfs.FileReadWriteSeeker{
+ FD: fd.lowerFD,
+ Ctx: ctx,
+ }
+
+ merkleReader := vfs.FileReadWriteSeeker{
+ FD: fd.merkleReader,
+ Ctx: ctx,
+ }
+
+ n, err := merkletree.Verify(dst.Writer(ctx), &dataReader, &merkleReader, int64(size), offset, dst.NumBytes(), fd.d.rootHash, false /* dataAndTreeInSameFile */)
+ if err != nil {
+ return 0, alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Verification failed: %v", err))
+ }
+ return n, err
+}
+
// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index 905712076..537419657 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -79,7 +79,7 @@ func (c *vCPU) initArchState() error {
}
// tcr_el1
- data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS
+ data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS | _TCR_A1
reg.id = _KVM_ARM64_REGS_TCR_EL1
if err := c.setOneRegister(&reg); err != nil {
return err
@@ -103,7 +103,7 @@ func (c *vCPU) initArchState() error {
c.SetTtbr0Kvm(uintptr(data))
// ttbr1_el1
- data = c.machine.kernel.PageTables.TTBR1_EL1(false, 0)
+ data = c.machine.kernel.PageTables.TTBR1_EL1(false, 1)
reg.id = _KVM_ARM64_REGS_TTBR1_EL1
if err := c.setOneRegister(&reg); err != nil {
diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go
index 0e2ab716c..508236e46 100644
--- a/pkg/sentry/platform/ring0/defs_arm64.go
+++ b/pkg/sentry/platform/ring0/defs_arm64.go
@@ -77,6 +77,9 @@ type CPUArchState struct {
// lazyVFP is the value of cpacr_el1.
lazyVFP uintptr
+
+ // appASID is the asid value of guest application.
+ appASID uintptr
}
// ErrorCode returns the last error code.
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
index 1e477cc49..5f63cbd45 100644
--- a/pkg/sentry/platform/ring0/entry_arm64.s
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -302,17 +302,23 @@
// SWITCH_TO_APP_PAGETABLE sets a new pagetable for a container application.
#define SWITCH_TO_APP_PAGETABLE(from) \
- MOVD CPU_TTBR0_APP(from), RSV_REG; \
- WORD $0xd5182012; \ // MSR R18, TTBR0_EL1
+ MRS TTBR1_EL1, R0; \
+ MOVD CPU_APP_ASID(from), R1; \
+ BFI $48, R1, $16, R0; \
+ MSR R0, TTBR1_EL1; \ // set the ASID in TTBR1_EL1 (since TCR.A1 is set)
ISB $15; \
- DSB $15;
+ MOVD CPU_TTBR0_APP(from), RSV_REG; \
+ MSR RSV_REG, TTBR0_EL1;
// SWITCH_TO_KVM_PAGETABLE sets the kvm pagetable.
#define SWITCH_TO_KVM_PAGETABLE(from) \
- MOVD CPU_TTBR0_KVM(from), RSV_REG; \
- WORD $0xd5182012; \ // MSR R18, TTBR0_EL1
+ MRS TTBR1_EL1, R0; \
+ MOVD $1, R1; \
+ BFI $48, R1, $16, R0; \
+ MSR R0, TTBR1_EL1; \
ISB $15; \
- DSB $15;
+ MOVD CPU_TTBR0_KVM(from), RSV_REG; \
+ MSR RSV_REG, TTBR0_EL1;
#define VFP_ENABLE \
MOVD $FPEN_ENABLE, R0; \
@@ -328,23 +334,20 @@
#define KERNEL_ENTRY_FROM_EL0 \
SUB $16, RSP, RSP; \ // step1, save r18, r9 into kernel temporary stack.
STP (RSV_REG, RSV_REG_APP), 16*0(RSP); \
- WORD $0xd538d092; \ //MRS TPIDR_EL1, R18, step2, switch user pagetable.
- SWITCH_TO_KVM_PAGETABLE(RSV_REG); \
- WORD $0xd538d092; \ //MRS TPIDR_EL1, R18
- MOVD CPU_APP_ADDR(RSV_REG), RSV_REG_APP; \ // step3, load app context pointer.
- REGISTERS_SAVE(RSV_REG_APP, 0); \ // step4, save app context.
+ WORD $0xd538d092; \ // MRS TPIDR_EL1, R18
+ MOVD CPU_APP_ADDR(RSV_REG), RSV_REG_APP; \ // step2, load app context pointer.
+ REGISTERS_SAVE(RSV_REG_APP, 0); \ // step3, save app context.
MOVD RSV_REG_APP, R20; \
LDP 16*0(RSP), (RSV_REG, RSV_REG_APP); \
ADD $16, RSP, RSP; \
MOVD RSV_REG, PTRACE_R18(R20); \
MOVD RSV_REG_APP, PTRACE_R9(R20); \
- MOVD R20, RSV_REG_APP; \
WORD $0xd5384003; \ // MRS SPSR_EL1, R3
- MOVD R3, PTRACE_PSTATE(RSV_REG_APP); \
+ MOVD R3, PTRACE_PSTATE(R20); \
MRS ELR_EL1, R3; \
- MOVD R3, PTRACE_PC(RSV_REG_APP); \
+ MOVD R3, PTRACE_PC(R20); \
WORD $0xd5384103; \ // MRS SP_EL0, R3
- MOVD R3, PTRACE_SP(RSV_REG_APP);
+ MOVD R3, PTRACE_SP(R20);
// KERNEL_ENTRY_FROM_EL1 is the entry code of the vcpu from el1 to el1.
#define KERNEL_ENTRY_FROM_EL1 \
@@ -359,6 +362,13 @@
MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); \
LOAD_KERNEL_STACK(RSV_REG); // Load the temporary stack.
+// storeAppASID writes the application's asid value.
+TEXT ·storeAppASID(SB),NOSPLIT,$0-8
+ MOVD asid+0(FP), R1
+ MRS TPIDR_EL1, RSV_REG
+ MOVD R1, CPU_APP_ASID(RSV_REG)
+ RET
+
// Halt halts execution.
TEXT ·Halt(SB),NOSPLIT,$0
// Clear bluepill.
@@ -416,7 +426,7 @@ TEXT ·Current(SB),NOSPLIT,$0-8
MOVD R8, ret+0(FP)
RET
-#define STACK_FRAME_SIZE 16
+#define STACK_FRAME_SIZE 32
// kernelExitToEl0 is the entrypoint for application in guest_el0.
// Prepare the vcpu environment for container application.
@@ -460,15 +470,16 @@ TEXT ·kernelExitToEl0(SB),NOSPLIT,$0
SUB $STACK_FRAME_SIZE, RSP, RSP
STP (RSV_REG, RSV_REG_APP), 16*0(RSP)
+ STP (R0, R1), 16*1(RSP)
WORD $0xd538d092 //MRS TPIDR_EL1, R18
SWITCH_TO_APP_PAGETABLE(RSV_REG)
+ LDP 16*1(RSP), (R0, R1)
LDP 16*0(RSP), (RSV_REG, RSV_REG_APP)
ADD $STACK_FRAME_SIZE, RSP, RSP
- ISB $15
ERET()
// kernelExitToEl1 is the entrypoint for sentry in guest_el1.
@@ -484,6 +495,9 @@ TEXT ·kernelExitToEl1(SB),NOSPLIT,$0
MOVD CPU_REGISTERS+PTRACE_SP(RSV_REG), R1
MOVD R1, RSP
+ SWITCH_TO_KVM_PAGETABLE(RSV_REG)
+ MRS TPIDR_EL1, RSV_REG
+
REGISTERS_LOAD(RSV_REG, CPU_REGISTERS)
MOVD CPU_REGISTERS+PTRACE_R9(RSV_REG), RSV_REG_APP
diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go
index d0afa1aaa..14774c5db 100644
--- a/pkg/sentry/platform/ring0/kernel_arm64.go
+++ b/pkg/sentry/platform/ring0/kernel_arm64.go
@@ -53,6 +53,11 @@ func IsCanonical(addr uint64) bool {
//go:nosplit
func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
+ storeAppASID(uintptr(switchOpts.UserASID))
+ if switchOpts.Flush {
+ FlushTlbAll()
+ }
+
regs := switchOpts.Registers
regs.Pstate &= ^uint64(PsrFlagsClear)
diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go
index 00e52c8af..2f1abcb0f 100644
--- a/pkg/sentry/platform/ring0/lib_arm64.go
+++ b/pkg/sentry/platform/ring0/lib_arm64.go
@@ -16,6 +16,15 @@
package ring0
+// storeAppASID writes the application's asid value.
+func storeAppASID(asid uintptr)
+
+// LocalFlushTlbAll same as FlushTlbAll, but only applies to the calling CPU.
+func LocalFlushTlbAll()
+
+// FlushTlbAll flush all tlb.
+func FlushTlbAll()
+
// CPACREL1 returns the value of the CPACR_EL1 register.
func CPACREL1() (value uintptr)
diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s
index 86bfbe46f..8aabf7d0e 100644
--- a/pkg/sentry/platform/ring0/lib_arm64.s
+++ b/pkg/sentry/platform/ring0/lib_arm64.s
@@ -15,6 +15,20 @@
#include "funcdata.h"
#include "textflag.h"
+TEXT ·LocalFlushTlbAll(SB),NOSPLIT,$0
+ DSB $6 // dsb(nshst)
+ WORD $0xd508871f // __tlbi(vmalle1)
+ DSB $7 // dsb(nsh)
+ ISB $15
+ RET
+
+TEXT ·FlushTlbAll(SB),NOSPLIT,$0
+ DSB $10 // dsb(ishst)
+ WORD $0xd508831f // __tlbi(vmalle1is)
+ DSB $11 // dsb(ish)
+ ISB $15
+ RET
+
TEXT ·GetTLS(SB),NOSPLIT,$0-8
MRS TPIDR_EL0, R1
MOVD R1, ret+0(FP)
diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go
index f3de962f0..1d86b4bcf 100644
--- a/pkg/sentry/platform/ring0/offsets_arm64.go
+++ b/pkg/sentry/platform/ring0/offsets_arm64.go
@@ -41,6 +41,7 @@ func Emit(w io.Writer) {
fmt.Fprintf(w, "#define CPU_VECTOR_CODE 0x%02x\n", reflect.ValueOf(&c.vecCode).Pointer()-reflect.ValueOf(c).Pointer())
fmt.Fprintf(w, "#define CPU_APP_ADDR 0x%02x\n", reflect.ValueOf(&c.appAddr).Pointer()-reflect.ValueOf(c).Pointer())
fmt.Fprintf(w, "#define CPU_LAZY_VFP 0x%02x\n", reflect.ValueOf(&c.lazyVFP).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_APP_ASID 0x%02x\n", reflect.ValueOf(&c.appASID).Pointer()-reflect.ValueOf(c).Pointer())
fmt.Fprintf(w, "\n// Bits.\n")
fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet)
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 2e568bc3d..b462924af 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -482,8 +482,35 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error {
}
// Release implements fs.FileOperations.Release.
-func (s *socketOpsCommon) Release(context.Context) {
+func (s *socketOpsCommon) Release(ctx context.Context) {
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventHUp|waiter.EventErr)
+ defer s.EventUnregister(&e)
+
s.Endpoint.Close()
+
+ // SO_LINGER option is valid only for TCP. For other socket types
+ // return after endpoint close.
+ if family, skType, _ := s.Type(); skType != linux.SOCK_STREAM || (family != linux.AF_INET && family != linux.AF_INET6) {
+ return
+ }
+
+ var v tcpip.LingerOption
+ if err := s.Endpoint.GetSockOpt(&v); err != nil {
+ return
+ }
+
+ // The case for zero timeout is handled in tcp endpoint close function.
+ // Close is blocked until either:
+ // 1. The endpoint state is not in any of the states: FIN-WAIT1,
+ // CLOSING and LAST_ACK.
+ // 2. Timeout is reached.
+ if v.Enabled && v.Timeout != 0 {
+ t := kernel.TaskFromContext(ctx)
+ start := t.Kernel().MonotonicClock().Now()
+ deadline := start.Add(v.Timeout)
+ t.BlockWithDeadline(ch, true, deadline)
+ }
}
// Read implements fs.FileOperations.Read.
@@ -1155,7 +1182,16 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- linger := linux.Linger{}
+ var v tcpip.LingerOption
+ var linger linux.Linger
+ if err := ep.GetSockOpt(&v); err != nil {
+ return &linger, nil
+ }
+
+ if v.Enabled {
+ linger.OnOff = 1
+ }
+ linger.Linger = int32(v.Timeout.Seconds())
return &linger, nil
case linux.SO_SNDTIMEO:
@@ -1884,7 +1920,10 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
socket.SetSockOptEmitUnimplementedEvent(t, name)
}
- return nil
+ return syserr.TranslateNetstackError(
+ ep.SetSockOpt(&tcpip.LingerOption{
+ Enabled: v.OnOff != 0,
+ Timeout: time.Second * time.Duration(v.Linger)}))
case linux.SO_DETACH_FILTER:
// optval is ignored.
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index cbbdd000f..08504560c 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -945,8 +945,14 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- log.Warningf("Unsupported socket option: %T", opt)
- return tcpip.ErrUnknownProtocolOption
+ switch opt.(type) {
+ case *tcpip.LingerOption:
+ return nil
+
+ default:
+ log.Warningf("Unsupported socket option: %T", opt)
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// LastError implements Endpoint.LastError.
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
index 14b527bc2..6c410c5a6 100644
--- a/pkg/tcpip/link/rawfile/BUILD
+++ b/pkg/tcpip/link/rawfile/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -18,3 +18,14 @@ go_library(
"@org_golang_x_sys//unix:go_default_library",
],
)
+
+go_test(
+ name = "rawfile_test",
+ srcs = [
+ "errors_test.go",
+ ],
+ library = "rawfile",
+ deps = [
+ "//pkg/tcpip",
+ ],
+)
diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go
index a0a873c84..604868fd8 100644
--- a/pkg/tcpip/link/rawfile/errors.go
+++ b/pkg/tcpip/link/rawfile/errors.go
@@ -31,10 +31,12 @@ var translations [maxErrno]*tcpip.Error
// *tcpip.Error.
//
// Valid, but unrecognized errnos will be translated to
-// tcpip.ErrInvalidEndpointState (EINVAL). Panics on invalid errnos.
+// tcpip.ErrInvalidEndpointState (EINVAL).
func TranslateErrno(e syscall.Errno) *tcpip.Error {
- if err := translations[e]; err != nil {
- return err
+ if e > 0 && e < syscall.Errno(len(translations)) {
+ if err := translations[e]; err != nil {
+ return err
+ }
}
return tcpip.ErrInvalidEndpointState
}
diff --git a/pkg/tcpip/link/rawfile/errors_test.go b/pkg/tcpip/link/rawfile/errors_test.go
new file mode 100644
index 000000000..e4cdc66bd
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/errors_test.go
@@ -0,0 +1,53 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package rawfile
+
+import (
+ "syscall"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+func TestTranslateErrno(t *testing.T) {
+ for _, test := range []struct {
+ errno syscall.Errno
+ translated *tcpip.Error
+ }{
+ {
+ errno: syscall.Errno(0),
+ translated: tcpip.ErrInvalidEndpointState,
+ },
+ {
+ errno: syscall.Errno(maxErrno),
+ translated: tcpip.ErrInvalidEndpointState,
+ },
+ {
+ errno: syscall.Errno(514),
+ translated: tcpip.ErrInvalidEndpointState,
+ },
+ {
+ errno: syscall.EEXIST,
+ translated: tcpip.ErrDuplicateAddress,
+ },
+ } {
+ got := TranslateErrno(test.errno)
+ if got != test.translated {
+ t.Errorf("TranslateErrno(%q) = %q, want %q", test.errno, got, test.translated)
+ }
+ }
+}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 863ef6bee..1f1a1426b 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -665,33 +665,15 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
}
}
- // Check if address is a broadcast address for the endpoint's network.
- //
- // Only IPv4 has a notion of broadcast addresses.
if protocol == header.IPv4ProtocolNumber {
- if ref := n.getRefForBroadcastRLocked(address); ref != nil {
+ if ref := n.getIPv4RefForBroadcastOrLoopbackRLocked(address); ref != nil {
n.mu.RUnlock()
return ref
}
}
-
- // A usable reference was not found, create a temporary one if requested by
- // the caller or if the IPv4 address is found in the NIC's subnets and the NIC
- // is a loopback interface.
- createTempEP := spoofingOrPromiscuous
- if !createTempEP && n.isLoopback() && protocol == header.IPv4ProtocolNumber {
- for _, r := range n.mu.endpoints {
- addr := r.addrWithPrefix()
- subnet := addr.Subnet()
- if subnet.Contains(address) {
- createTempEP = true
- break
- }
- }
- }
n.mu.RUnlock()
- if !createTempEP {
+ if !spoofingOrPromiscuous {
return nil
}
@@ -704,20 +686,21 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
return ref
}
-// getRefForBroadcastLocked returns an endpoint where address is the IPv4
-// broadcast address for the endpoint's network.
+// getRefForBroadcastOrLoopbackRLocked returns an endpoint whose address is the
+// broadcast address for the endpoint's network or an address in the endpoint's
+// subnet if the NIC is a loopback interface. This matches linux behaviour.
//
-// n.mu MUST be read locked.
-func (n *NIC) getRefForBroadcastRLocked(address tcpip.Address) *referencedNetworkEndpoint {
+// n.mu MUST be read or write locked.
+func (n *NIC) getIPv4RefForBroadcastOrLoopbackRLocked(address tcpip.Address) *referencedNetworkEndpoint {
for _, ref := range n.mu.endpoints {
- // Only IPv4 has a notion of broadcast addresses.
+ // Only IPv4 has a notion of broadcast addresses or considers the loopback
+ // interface bound to an address's whole subnet (on linux).
if ref.protocol != header.IPv4ProtocolNumber {
continue
}
- addr := ref.addrWithPrefix()
- subnet := addr.Subnet()
- if subnet.IsBroadcast(address) && ref.tryIncRef() {
+ subnet := ref.addrWithPrefix().Subnet()
+ if (subnet.IsBroadcast(address) || (n.isLoopback() && subnet.Contains(address))) && ref.isValidForOutgoingRLocked() && ref.tryIncRef() {
return ref
}
}
@@ -745,11 +728,8 @@ func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, add
n.removeEndpointLocked(ref)
}
- // Check if address is a broadcast address for an endpoint's network.
- //
- // Only IPv4 has a notion of broadcast addresses.
if protocol == header.IPv4ProtocolNumber {
- if ref := n.getRefForBroadcastRLocked(address); ref != nil {
+ if ref := n.getIPv4RefForBroadcastOrLoopbackRLocked(address); ref != nil {
return ref
}
}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index b902c6ca9..0774b5382 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -165,7 +165,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p
// If this is a broadcast or multicast datagram, deliver the datagram to all
// endpoints bound to the right device.
- if isMulticastOrBroadcast(id.LocalAddress) {
+ if isInboundMulticastOrBroadcast(r) {
mpep.handlePacketAll(r, id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return
@@ -526,7 +526,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// If the packet is a UDP broadcast or multicast, then find all matching
// transport endpoints.
- if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
+ if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) {
eps.mu.RLock()
destEPs := eps.findAllEndpointsLocked(id)
eps.mu.RUnlock()
@@ -546,7 +546,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// If the packet is a TCP packet with a non-unicast source or destination
// address, then do nothing further and instruct the caller to do the same.
- if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) {
+ if protocol == header.TCPProtocolNumber && (!isInboundUnicast(r) || !isOutboundUnicast(r)) {
// TCP can only be used to communicate between a single source and a
// single destination; the addresses must be unicast.
r.Stats().TCP.InvalidSegmentsReceived.Increment()
@@ -677,10 +677,14 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN
eps.mu.Unlock()
}
-func isMulticastOrBroadcast(addr tcpip.Address) bool {
- return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr)
+func isInboundMulticastOrBroadcast(r *Route) bool {
+ return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress)
}
-func isUnicast(addr tcpip.Address) bool {
- return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr)
+func isInboundUnicast(r *Route) bool {
+ return r.LocalAddress != header.IPv4Any && r.LocalAddress != header.IPv6Any && !isInboundMulticastOrBroadcast(r)
+}
+
+func isOutboundUnicast(r *Route) bool {
+ return r.RemoteAddress != header.IPv4Any && r.RemoteAddress != header.IPv6Any && !r.IsOutboundBroadcast() && !header.IsV4MulticastAddress(r.RemoteAddress) && !header.IsV6MulticastAddress(r.RemoteAddress)
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 5e34e27ba..b2ddb24ec 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -1194,6 +1194,19 @@ const (
TCPTimeWaitReuseLoopbackOnly
)
+// LingerOption is used by SetSockOpt/GetSockOpt to set/get the
+// duration for which a socket lingers before returning from Close.
+//
+// +stateify savable
+type LingerOption struct {
+ Enabled bool
+ Timeout time.Duration
+}
+
+func (*LingerOption) isGettableSocketOption() {}
+
+func (*LingerOption) isSettableSocketOption() {}
+
// IPPacketInfo is the message structure for IP_PKTINFO.
//
// +stateify savable
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index 1b18023c5..fecbe7ba7 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -187,3 +187,64 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) {
})
}
}
+
+// TestLoopbackSubnetLifetimeBoundToAddr tests that the lifetime of an address
+// in a loopback interface's associated subnet is bound to the permanently bound
+// address.
+func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) {
+ const nicID = 1
+
+ protoAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ }
+ addrBytes := []byte(ipv4Addr.Address)
+ addrBytes[len(addrBytes)-1]++
+ otherAddr := tcpip.Address(addrBytes)
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ r, err := s.FindRoute(nicID, otherAddr, remoteIPv4Addr, ipv4.ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, otherAddr, remoteIPv4Addr, ipv4.ProtocolNumber, err)
+ }
+ defer r.Release()
+
+ params := stack.NetworkHeaderParams{
+ Protocol: 111,
+ TTL: 64,
+ TOS: stack.DefaultTOS,
+ }
+ data := buffer.View([]byte{1, 2, 3, 4})
+ if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: data.ToVectorisedView(),
+ })); err != nil {
+ t.Fatalf("r.WritePacket(nil, %#v, _): %s", params, err)
+ }
+
+ // Removing the address should make the endpoint invalid.
+ if err := s.RemoveAddress(nicID, protoAddr.AddressWithPrefix.Address); err != nil {
+ t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, protoAddr.AddressWithPrefix.Address, err)
+ }
+ if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: data.ToVectorisedView(),
+ })); err != tcpip.ErrInvalidEndpointState {
+ t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, tcpip.ErrInvalidEndpointState)
+ }
+}
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 52c27e045..659acbc7a 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -436,3 +437,122 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
})
}
}
+
+// TestReuseAddrAndBroadcast makes sure broadcast packets are received by all
+// interested endpoints.
+func TestReuseAddrAndBroadcast(t *testing.T) {
+ const (
+ nicID = 1
+ localPort = 9000
+ loopbackBroadcast = tcpip.Address("\x7f\xff\xff\xff")
+ )
+
+ data := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+
+ tests := []struct {
+ name string
+ broadcastAddr tcpip.Address
+ }{
+ {
+ name: "Subnet directed broadcast",
+ broadcastAddr: loopbackBroadcast,
+ },
+ {
+ name: "IPv4 broadcast",
+ broadcastAddr: header.IPv4Broadcast,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ protoAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x7f\x00\x00\x01",
+ PrefixLen: 8,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, protoAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ // We use the empty subnet instead of just the loopback subnet so we
+ // also have a route to the IPv4 Broadcast address.
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ // We create endpoints that bind to both the wildcard address and the
+ // broadcast address to make sure both of these types of "broadcast
+ // interested" endpoints receive broadcast packets.
+ wq := waiter.Queue{}
+ var eps []tcpip.Endpoint
+ for _, bindWildcard := range []bool{false, true} {
+ // Create multiple endpoints for each type of "broadcast interested"
+ // endpoint so we can test that all endpoints receive the broadcast
+ // packet.
+ for i := 0; i < 2; i++ {
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("(eps[%d]) NewEndpoint(%d, %d, _): %s", len(eps), udp.ProtocolNumber, ipv4.ProtocolNumber, err)
+ }
+ defer ep.Close()
+
+ if err := ep.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("eps[%d].SetSockOptBool(tcpip.ReuseAddressOption, true): %s", len(eps), err)
+ }
+
+ if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil {
+ t.Fatalf("eps[%d].SetSockOptBool(tcpip.BroadcastOption, true): %s", len(eps), err)
+ }
+
+ bindAddr := tcpip.FullAddress{Port: localPort}
+ if bindWildcard {
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err)
+ }
+ } else {
+ bindAddr.Addr = test.broadcastAddr
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err)
+ }
+ }
+
+ eps = append(eps, ep)
+ }
+ }
+
+ for i, wep := range eps {
+ writeOpts := tcpip.WriteOptions{
+ To: &tcpip.FullAddress{
+ Addr: test.broadcastAddr,
+ Port: localPort,
+ },
+ }
+ if n, _, err := wep.Write(data, writeOpts); err != nil {
+ t.Fatalf("eps[%d].Write(_, _): %s", i, err)
+ } else if want := int64(len(data)); n != want {
+ t.Fatalf("got eps[%d].Write(_, _) = (%d, nil, nil), want = (%d, nil, nil)", i, n, want)
+ }
+
+ for j, rep := range eps {
+ if gotPayload, _, err := rep.Read(nil); err != nil {
+ t.Errorf("(eps[%d] write) eps[%d].Read(nil): %s", i, j, err)
+ } else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" {
+ t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff)
+ }
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 6d5046a3d..faea7f2bb 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -654,6 +654,9 @@ type endpoint struct {
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
+
+ // linger is used for SO_LINGER socket option.
+ linger tcpip.LingerOption
}
// UniqueID implements stack.TransportEndpoint.UniqueID.
@@ -1007,6 +1010,26 @@ func (e *endpoint) Close() {
return
}
+ if e.linger.Enabled && e.linger.Timeout == 0 {
+ s := e.EndpointState()
+ isResetState := s == StateEstablished || s == StateCloseWait || s == StateFinWait1 || s == StateFinWait2 || s == StateSynRecv
+ if isResetState {
+ // Close the endpoint without doing full shutdown and
+ // send a RST.
+ e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ e.closeNoShutdownLocked()
+
+ // Wake up worker to close the endpoint.
+ switch s {
+ case StateSynRecv:
+ e.notifyProtocolGoroutine(notifyClose)
+ default:
+ e.notifyProtocolGoroutine(notifyTickleWorker)
+ }
+ return
+ }
+ }
+
// Issue a shutdown so that the peer knows we won't send any more data
// if we're connected, or stop accepting if we're listening.
e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead)
@@ -1807,6 +1830,11 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
case *tcpip.SocketDetachFilterOption:
return nil
+ case *tcpip.LingerOption:
+ e.LockUser()
+ e.linger = *v
+ e.UnlockUser()
+
default:
return nil
}
@@ -2031,6 +2059,11 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
Port: port,
}
+ case *tcpip.LingerOption:
+ e.LockUser()
+ *o = e.linger
+ e.UnlockUser()
+
default:
return tcpip.ErrUnknownProtocolOption
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 0d13e1efd..b1e5f1b24 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -5214,6 +5214,8 @@ func TestListenBacklogFull(t *testing.T) {
func TestListenNoAcceptNonUnicastV4(t *testing.T) {
multicastAddr := tcpip.Address("\xe0\x00\x01\x02")
otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03")
+ subnet := context.StackAddrWithPrefix.Subnet()
+ subnetBroadcastAddr := subnet.Broadcast()
tests := []struct {
name string
@@ -5221,53 +5223,59 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) {
dstAddr tcpip.Address
}{
{
- "SourceUnspecified",
- header.IPv4Any,
- context.StackAddr,
+ name: "SourceUnspecified",
+ srcAddr: header.IPv4Any,
+ dstAddr: context.StackAddr,
},
{
- "SourceBroadcast",
- header.IPv4Broadcast,
- context.StackAddr,
+ name: "SourceBroadcast",
+ srcAddr: header.IPv4Broadcast,
+ dstAddr: context.StackAddr,
},
{
- "SourceOurMulticast",
- multicastAddr,
- context.StackAddr,
+ name: "SourceOurMulticast",
+ srcAddr: multicastAddr,
+ dstAddr: context.StackAddr,
},
{
- "SourceOtherMulticast",
- otherMulticastAddr,
- context.StackAddr,
+ name: "SourceOtherMulticast",
+ srcAddr: otherMulticastAddr,
+ dstAddr: context.StackAddr,
},
{
- "DestUnspecified",
- context.TestAddr,
- header.IPv4Any,
+ name: "DestUnspecified",
+ srcAddr: context.TestAddr,
+ dstAddr: header.IPv4Any,
},
{
- "DestBroadcast",
- context.TestAddr,
- header.IPv4Broadcast,
+ name: "DestBroadcast",
+ srcAddr: context.TestAddr,
+ dstAddr: header.IPv4Broadcast,
},
{
- "DestOurMulticast",
- context.TestAddr,
- multicastAddr,
+ name: "DestOurMulticast",
+ srcAddr: context.TestAddr,
+ dstAddr: multicastAddr,
},
{
- "DestOtherMulticast",
- context.TestAddr,
- otherMulticastAddr,
+ name: "DestOtherMulticast",
+ srcAddr: context.TestAddr,
+ dstAddr: otherMulticastAddr,
+ },
+ {
+ name: "SrcSubnetBroadcast",
+ srcAddr: subnetBroadcastAddr,
+ dstAddr: context.StackAddr,
+ },
+ {
+ name: "DestSubnetBroadcast",
+ srcAddr: context.TestAddr,
+ dstAddr: subnetBroadcastAddr,
},
}
for _, test := range tests {
- test := test // capture range variable
-
t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -5367,11 +5375,7 @@ func TestListenNoAcceptNonUnicastV6(t *testing.T) {
}
for _, test := range tests {
- test := test // capture range variable
-
t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
c := context.New(t, defaultMTU)
defer c.Cleanup()
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index baf7df197..85e8c1c75 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -53,11 +53,11 @@ const (
TestPort = 4096
// StackV6Addr is the IPv6 address assigned to the stack.
- StackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ StackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
// TestV6Addr is the source address for packets sent to the stack via
// the link layer endpoint.
- TestV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ TestV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
// StackV4MappedAddr is StackAddr as a mapped v6 address.
StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr
@@ -73,6 +73,18 @@ const (
testInitialSequenceNumber = 789
)
+// StackAddrWithPrefix is StackAddr with its associated prefix length.
+var StackAddrWithPrefix = tcpip.AddressWithPrefix{
+ Address: StackAddr,
+ PrefixLen: 24,
+}
+
+// StackV6AddrWithPrefix is StackV6Addr with its associated prefix length.
+var StackV6AddrWithPrefix = tcpip.AddressWithPrefix{
+ Address: StackV6Addr,
+ PrefixLen: header.IIDOffsetInIPv6Address * 8,
+}
+
// Headers is used to represent the TCP header fields when building a
// new packet.
type Headers struct {
@@ -184,12 +196,20 @@ func New(t *testing.T, mtu uint32) *Context {
t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ v4ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: StackAddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err)
}
- if err := s.AddAddress(1, ipv6.ProtocolNumber, StackV6Addr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ v6ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: StackV6AddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{