summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/p9/client.go4
-rw-r--r--pkg/p9/client_file.go25
-rw-r--r--pkg/sentry/fsimpl/fuse/dev.go2
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go2
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go61
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go85
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go7
-rw-r--r--pkg/sentry/platform/ring0/aarch64.go1
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s74
-rw-r--r--pkg/sentry/platform/ring0/kernel_arm64.go4
-rw-r--r--pkg/sentry/platform/ring0/offsets_arm64.go1
-rw-r--r--pkg/sentry/socket/netstack/netstack.go13
-rw-r--r--pkg/tcpip/adapters/gonet/BUILD1
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go3
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go58
-rw-r--r--pkg/tcpip/transport/tcp/connect.go13
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go47
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go10
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go14
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go29
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go50
-rw-r--r--pkg/test/testutil/testutil.go29
23 files changed, 354 insertions, 180 deletions
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
index 71e944c30..eadea390a 100644
--- a/pkg/p9/client.go
+++ b/pkg/p9/client.go
@@ -570,6 +570,8 @@ func (c *Client) Version() uint32 {
func (c *Client) Close() {
// unet.Socket.Shutdown() has no effect if unet.Socket.Close() has already
// been called (by c.watch()).
- c.socket.Shutdown()
+ if err := c.socket.Shutdown(); err != nil {
+ log.Warningf("Socket.Shutdown() failed (FD: %d): %v", c.socket.FD(), err)
+ }
c.closedWg.Wait()
}
diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go
index 28fe081d6..8b46a2987 100644
--- a/pkg/p9/client_file.go
+++ b/pkg/p9/client_file.go
@@ -478,28 +478,23 @@ func (r *ReadWriterFile) ReadAt(p []byte, offset int64) (int, error) {
}
// Write implements part of the io.ReadWriter interface.
+//
+// Note that this may return a short write with a nil error. This violates the
+// contract of io.Writer, but is more consistent with gVisor's pattern of
+// returning errors that correspond to Linux errnos. Since short writes without
+// error are common in Linux, returning a nil error is appropriate.
func (r *ReadWriterFile) Write(p []byte) (int, error) {
n, err := r.File.WriteAt(p, r.Offset)
r.Offset += uint64(n)
- if err != nil {
- return n, err
- }
- if n < len(p) {
- return n, io.ErrShortWrite
- }
- return n, nil
+ return n, err
}
// WriteAt implements the io.WriteAt interface.
+//
+// Note that this may return a short write with a nil error. This violates the
+// contract of io.WriterAt. See comment on Write for justification.
func (r *ReadWriterFile) WriteAt(p []byte, offset int64) (int, error) {
- n, err := r.File.WriteAt(p, uint64(offset))
- if err != nil {
- return n, err
- }
- if n < len(p) {
- return n, io.ErrShortWrite
- }
- return n, nil
+ return r.File.WriteAt(p, uint64(offset))
}
// Rename implements File.Rename.
diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go
index 89c3ef079..1bbe6fdb7 100644
--- a/pkg/sentry/fsimpl/fuse/dev.go
+++ b/pkg/sentry/fsimpl/fuse/dev.go
@@ -363,7 +363,7 @@ func (fd *DeviceFD) Readiness(mask waiter.EventMask) waiter.EventMask {
func (fd *DeviceFD) readinessLocked(mask waiter.EventMask) waiter.EventMask {
var ready waiter.EventMask
- if fd.fs.umounted {
+ if fd.fs == nil || fd.fs.umounted {
ready |= waiter.EventErr
return ready & mask
}
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index e39cd305b..61138a7a4 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -381,6 +381,8 @@ afterTrailingSymlink:
creds := rp.Credentials()
child := fs.newDentry(fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode))
parentDir.insertChildLocked(child, name)
+ child.IncRef()
+ defer child.DecRef(ctx)
unlock()
fd, err := child.open(ctx, rp, &opts, true)
if err != nil {
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
index add5dd48e..59fcff498 100644
--- a/pkg/sentry/fsimpl/verity/filesystem.go
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -107,8 +107,10 @@ func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*de
// Dentries which may have a reference count of zero, and which therefore
// should be dropped once traversal is complete, are appended to ds.
//
-// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
-// !rp.Done().
+// Preconditions:
+// * fs.renameMu must be locked.
+// * d.dirMu must be locked.
+// * !rp.Done().
func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) {
if !d.isDir() {
return nil, syserror.ENOTDIR
@@ -158,15 +160,19 @@ afterSymlink:
return child, nil
}
-// verifyChild verifies the hash of child against the already verified hash of
-// the parent to ensure the child is expected. verifyChild triggers a sentry
-// panic if unexpected modifications to the file system are detected. In
+// verifyChildLocked verifies the hash of child against the already verified
+// hash of the parent to ensure the child is expected. verifyChild triggers a
+// sentry panic if unexpected modifications to the file system are detected. In
// noCrashOnVerificationFailure mode it returns a syserror instead.
-// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
+//
+// Preconditions:
+// * fs.renameMu must be locked.
+// * d.dirMu must be locked.
+//
// TODO(b/166474175): Investigate all possible errors returned in this
// function, and make sure we differentiate all errors that indicate unexpected
// modifications to the file system from the ones that are not harmful.
-func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *dentry) (*dentry, error) {
+func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, child *dentry) (*dentry, error) {
vfsObj := fs.vfsfs.VirtualFilesystem()
// Get the path to the child dentry. This is only used to provide path
@@ -268,7 +274,8 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// contain the hash of the children in the parent Merkle tree when
// Verify returns with success.
var buf bytes.Buffer
- if _, err := merkletree.Verify(&merkletree.VerifyParams{
+ parent.hashMu.RLock()
+ _, err = merkletree.Verify(&merkletree.VerifyParams{
Out: &buf,
File: &fdReader,
Tree: &fdReader,
@@ -284,21 +291,27 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
ReadSize: int64(merkletree.DigestSize(fs.alg.toLinuxHashAlg())),
Expected: parent.hash,
DataAndTreeInSameFile: true,
- }); err != nil && err != io.EOF {
+ })
+ parent.hashMu.RUnlock()
+ if err != nil && err != io.EOF {
return nil, alertIntegrityViolation(fmt.Sprintf("Verification for %s failed: %v", childPath, err))
}
// Cache child hash when it's verified the first time.
+ child.hashMu.Lock()
if len(child.hash) == 0 {
child.hash = buf.Bytes()
}
+ child.hashMu.Unlock()
return child, nil
}
-// verifyStatAndChildren verifies the stat and children names against the
+// verifyStatAndChildrenLocked verifies the stat and children names against the
// verified hash. The mode/uid/gid and childrenNames of the file is cached
// after verified.
-func (fs *filesystem) verifyStatAndChildren(ctx context.Context, d *dentry, stat linux.Statx) error {
+//
+// Preconditions: d.dirMu must be locked.
+func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry, stat linux.Statx) error {
vfsObj := fs.vfsfs.VirtualFilesystem()
// Get the path to the child dentry. This is only used to provide path
@@ -390,6 +403,7 @@ func (fs *filesystem) verifyStatAndChildren(ctx context.Context, d *dentry, stat
}
var buf bytes.Buffer
+ d.hashMu.RLock()
params := &merkletree.VerifyParams{
Out: &buf,
Tree: &fdReader,
@@ -407,6 +421,7 @@ func (fs *filesystem) verifyStatAndChildren(ctx context.Context, d *dentry, stat
Expected: d.hash,
DataAndTreeInSameFile: false,
}
+ d.hashMu.RUnlock()
if atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFDIR {
params.DataAndTreeInSameFile = true
}
@@ -421,7 +436,9 @@ func (fs *filesystem) verifyStatAndChildren(ctx context.Context, d *dentry, stat
return nil
}
-// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
+// Preconditions:
+// * fs.renameMu must be locked.
+// * parent.dirMu must be locked.
func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) {
if child, ok := parent.children[name]; ok {
// If verity is enabled on child, we should check again whether
@@ -470,7 +487,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s
// be cached before enabled.
if fs.allowRuntimeEnable {
if parent.verityEnabled() {
- if _, err := fs.verifyChild(ctx, parent, child); err != nil {
+ if _, err := fs.verifyChildLocked(ctx, parent, child); err != nil {
return nil, err
}
}
@@ -486,7 +503,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s
if err != nil {
return nil, err
}
- if err := fs.verifyStatAndChildren(ctx, child, stat); err != nil {
+ if err := fs.verifyStatAndChildrenLocked(ctx, child, stat); err != nil {
return nil, err
}
}
@@ -506,7 +523,9 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s
return child, nil
}
-// Preconditions: fs.renameMu must be locked. parent.dirMu must be locked.
+// Preconditions:
+// * fs.renameMu must be locked.
+// * parent.dirMu must be locked.
func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) {
vfsObj := fs.vfsfs.VirtualFilesystem()
@@ -597,13 +616,13 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
// allowRuntimeEnable mode and the parent directory hasn't been enabled
// yet.
if parent.verityEnabled() {
- if _, err := fs.verifyChild(ctx, parent, child); err != nil {
+ if _, err := fs.verifyChildLocked(ctx, parent, child); err != nil {
child.destroyLocked(ctx)
return nil, err
}
}
if child.verityEnabled() {
- if err := fs.verifyStatAndChildren(ctx, child, stat); err != nil {
+ if err := fs.verifyStatAndChildrenLocked(ctx, child, stat); err != nil {
child.destroyLocked(ctx)
return nil, err
}
@@ -617,7 +636,9 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
// rp.Start().Impl().(*dentry)). It does not check that the returned directory
// is searchable by the provider of rp.
//
-// Preconditions: fs.renameMu must be locked. !rp.Done().
+// Preconditions:
+// * fs.renameMu must be locked.
+// * !rp.Done().
func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) {
for !rp.Final() {
d.dirMu.Lock()
@@ -958,11 +979,13 @@ func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if err != nil {
return linux.Statx{}, err
}
+ d.dirMu.Lock()
if d.verityEnabled() {
- if err := fs.verifyStatAndChildren(ctx, d, stat); err != nil {
+ if err := fs.verifyStatAndChildrenLocked(ctx, d, stat); err != nil {
return linux.Statx{}, err
}
}
+ d.dirMu.Unlock()
return stat, nil
}
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
index 87dabe038..46346e54d 100644
--- a/pkg/sentry/fsimpl/verity/verity.go
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -19,6 +19,18 @@
// The verity file system is read-only, except for one case: when
// allowRuntimeEnable is true, additional Merkle files can be generated using
// the FS_IOC_ENABLE_VERITY ioctl.
+//
+// Lock order:
+//
+// filesystem.renameMu
+// dentry.dirMu
+// fileDescription.mu
+// filesystem.verityMu
+// dentry.hashMu
+//
+// Locking dentry.dirMu in multiple dentries requires that parent dentries are
+// locked before child dentries, and that filesystem.renameMu is locked to
+// stabilize this relationship.
package verity
import (
@@ -372,12 +384,14 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames: %v", err))
}
- if err := fs.verifyStatAndChildren(ctx, d, stat); err != nil {
+ if err := fs.verifyStatAndChildrenLocked(ctx, d, stat); err != nil {
return nil, nil, err
}
}
+ d.hashMu.Lock()
copy(d.hash, iopts.RootHash)
+ d.hashMu.Unlock()
d.vfsd.Init(d)
fs.rootDentry = d
@@ -402,7 +416,8 @@ type dentry struct {
fs *filesystem
// mode, uid, gid and size are the file mode, owner, group, and size of
- // the file in the underlying file system.
+ // the file in the underlying file system. They are set when a dentry
+ // is initialized, and never modified.
mode uint32
uid uint32
gid uint32
@@ -425,18 +440,22 @@ type dentry struct {
// childrenNames stores the name of all children of the dentry. This is
// used by verity to check whether a child is expected. This is only
- // populated by enableVerity.
+ // populated by enableVerity. childrenNames is also protected by dirMu.
childrenNames map[string]struct{}
- // lowerVD is the VirtualDentry in the underlying file system.
+ // lowerVD is the VirtualDentry in the underlying file system. It is
+ // never modified after initialized.
lowerVD vfs.VirtualDentry
// lowerMerkleVD is the VirtualDentry of the corresponding Merkle tree
- // in the underlying file system.
+ // in the underlying file system. It is never modified after
+ // initialized.
lowerMerkleVD vfs.VirtualDentry
- // hash is the calculated hash for the current file or directory.
- hash []byte
+ // hash is the calculated hash for the current file or directory. hash
+ // is protected by hashMu.
+ hashMu sync.RWMutex `state:"nosave"`
+ hash []byte
}
// newDentry creates a new dentry representing the given verity file. The
@@ -519,7 +538,9 @@ func (d *dentry) checkDropLocked(ctx context.Context) {
// destroyLocked destroys the dentry.
//
-// Preconditions: d.fs.renameMu must be locked for writing. d.refs == 0.
+// Preconditions:
+// * d.fs.renameMu must be locked for writing.
+// * d.refs == 0.
func (d *dentry) destroyLocked(ctx context.Context) {
switch atomic.LoadInt64(&d.refs) {
case 0:
@@ -599,6 +620,8 @@ func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes)
// mode, it returns true if the target has been enabled with
// ioctl(FS_IOC_ENABLE_VERITY).
func (d *dentry) verityEnabled() bool {
+ d.hashMu.RLock()
+ defer d.hashMu.RUnlock()
return !d.fs.allowRuntimeEnable || len(d.hash) != 0
}
@@ -678,11 +701,13 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu
if err != nil {
return linux.Statx{}, err
}
+ fd.d.dirMu.Lock()
if fd.d.verityEnabled() {
- if err := fd.d.fs.verifyStatAndChildren(ctx, fd.d, stat); err != nil {
+ if err := fd.d.fs.verifyStatAndChildrenLocked(ctx, fd.d, stat); err != nil {
return linux.Statx{}, err
}
}
+ fd.d.dirMu.Unlock()
return stat, nil
}
@@ -718,13 +743,15 @@ func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32)
return offset, nil
}
-// generateMerkle generates a Merkle tree file for fd. If fd points to a file
-// /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The hash
-// of the generated Merkle tree and the data size is returned. If fd points to
-// a regular file, the data is the content of the file. If fd points to a
-// directory, the data is all hahes of its children, written to the Merkle tree
-// file.
-func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64, error) {
+// generateMerkleLocked generates a Merkle tree file for fd. If fd points to a
+// file /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The
+// hash of the generated Merkle tree and the data size is returned. If fd
+// points to a regular file, the data is the content of the file. If fd points
+// to a directory, the data is all hahes of its children, written to the Merkle
+// tree file.
+//
+// Preconditions: fd.d.fs.verityMu must be locked.
+func (fd *fileDescription) generateMerkleLocked(ctx context.Context) ([]byte, uint64, error) {
fdReader := vfs.FileReadWriteSeeker{
FD: fd.lowerFD,
Ctx: ctx,
@@ -793,11 +820,14 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64,
return hash, uint64(params.Size), err
}
-// recordChildren writes the names of fd's children into the corresponding
-// Merkle tree file, and saves the offset/size of the map into xattrs.
+// recordChildrenLocked writes the names of fd's children into the
+// corresponding Merkle tree file, and saves the offset/size of the map into
+// xattrs.
//
-// Preconditions: fd.d.isDir() == true
-func (fd *fileDescription) recordChildren(ctx context.Context) error {
+// Preconditions:
+// * fd.d.fs.verityMu must be locked.
+// * fd.d.isDir() == true.
+func (fd *fileDescription) recordChildrenLocked(ctx context.Context) error {
// Record the children names in the Merkle tree file.
childrenNames, err := json.Marshal(fd.d.childrenNames)
if err != nil {
@@ -847,7 +877,7 @@ func (fd *fileDescription) enableVerity(ctx context.Context) (uintptr, error) {
return 0, alertIntegrityViolation("Unexpected verity fd: missing expected underlying fds")
}
- hash, dataSize, err := fd.generateMerkle(ctx)
+ hash, dataSize, err := fd.generateMerkleLocked(ctx)
if err != nil {
return 0, err
}
@@ -888,11 +918,13 @@ func (fd *fileDescription) enableVerity(ctx context.Context) (uintptr, error) {
}
if fd.d.isDir() {
- if err := fd.recordChildren(ctx); err != nil {
+ if err := fd.recordChildrenLocked(ctx); err != nil {
return 0, err
}
}
- fd.d.hash = append(fd.d.hash, hash...)
+ fd.d.hashMu.Lock()
+ fd.d.hash = hash
+ fd.d.hashMu.Unlock()
return 0, nil
}
@@ -904,6 +936,9 @@ func (fd *fileDescription) measureVerity(ctx context.Context, verityDigest userm
}
var metadata linux.DigestMetadata
+ fd.d.hashMu.RLock()
+ defer fd.d.hashMu.RUnlock()
+
// If allowRuntimeEnable is true, an empty fd.d.hash indicates that
// verity is not enabled for the file. If allowRuntimeEnable is false,
// this is an integrity violation because all files should have verity
@@ -940,11 +975,13 @@ func (fd *fileDescription) measureVerity(ctx context.Context, verityDigest userm
func (fd *fileDescription) verityFlags(ctx context.Context, flags usermem.Addr) (uintptr, error) {
f := int32(0)
+ fd.d.hashMu.RLock()
// All enabled files should store a hash. This flag is not settable via
// FS_IOC_SETFLAGS.
if len(fd.d.hash) != 0 {
f |= linux.FS_VERITY_FL
}
+ fd.d.hashMu.RUnlock()
t := kernel.TaskFromContext(ctx)
if t == nil {
@@ -1023,6 +1060,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
Ctx: ctx,
}
+ fd.d.hashMu.RLock()
n, err := merkletree.Verify(&merkletree.VerifyParams{
Out: dst.Writer(ctx),
File: &dataReader,
@@ -1040,6 +1078,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
Expected: fd.d.hash,
DataAndTreeInSameFile: false,
})
+ fd.d.hashMu.RUnlock()
if err != nil {
return 0, alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err))
}
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index fd92c3873..3f5be276b 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -263,13 +263,6 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
return usermem.NoAccess, platform.ErrContextInterrupt
case ring0.El0SyncUndef:
return c.fault(int32(syscall.SIGILL), info)
- case ring0.El1SyncUndef:
- *info = arch.SignalInfo{
- Signo: int32(syscall.SIGILL),
- Code: 1, // ILL_ILLOPC (illegal opcode).
- }
- info.SetAddr(switchOpts.Registers.Pc) // Include address.
- return usermem.AccessType{}, platform.ErrContextSignal
default:
panic(fmt.Sprintf("unexpected vector: 0x%x", vector))
}
diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/sentry/platform/ring0/aarch64.go
index 327d48465..c51df2811 100644
--- a/pkg/sentry/platform/ring0/aarch64.go
+++ b/pkg/sentry/platform/ring0/aarch64.go
@@ -90,6 +90,7 @@ const (
El0SyncIa
El0SyncFpsimdAcc
El0SyncSveAcc
+ El0SyncFpsimdExc
El0SyncSys
El0SyncSpPc
El0SyncUndef
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
index f77bc72af..5f4b2105a 100644
--- a/pkg/sentry/platform/ring0/entry_arm64.s
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -382,12 +382,12 @@ TEXT ·DisableVFP(SB),NOSPLIT,$0
MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); \
LOAD_KERNEL_STACK(RSV_REG); // Load the temporary stack.
-// EXCEPTION_WITH_ERROR is a common exception handler function.
-#define EXCEPTION_WITH_ERROR(user, vector) \
+// EXCEPTION_EL0 is a common el0 exception handler function.
+#define EXCEPTION_EL0(vector) \
WORD $0xd538d092; \ //MRS TPIDR_EL1, R18
WORD $0xd538601a; \ //MRS FAR_EL1, R26
MOVD R26, CPU_FAULT_ADDR(RSV_REG); \
- MOVD $user, R3; \
+ MOVD $1, R3; \
MOVD R3, CPU_ERROR_TYPE(RSV_REG); \ // Set error type to user.
MOVD $vector, R3; \
MOVD R3, CPU_VECTOR_CODE(RSV_REG); \
@@ -395,6 +395,12 @@ TEXT ·DisableVFP(SB),NOSPLIT,$0
MOVD R3, CPU_ERROR_CODE(RSV_REG); \
B ·kernelExitToEl1(SB);
+// EXCEPTION_EL1 is a common el1 exception handler function.
+#define EXCEPTION_EL1(vector) \
+ MOVD $vector, R3; \
+ MOVD R3, 8(RSP); \
+ B ·HaltEl1ExceptionAndResume(SB);
+
// storeAppASID writes the application's asid value.
TEXT ·storeAppASID(SB),NOSPLIT,$0-8
MOVD asid+0(FP), R1
@@ -442,6 +448,16 @@ TEXT ·HaltEl1SvcAndResume(SB),NOSPLIT,$0
CALL ·kernelSyscall(SB) // Call the trampoline.
B ·kernelExitToEl1(SB) // Resume.
+// HaltEl1ExceptionAndResume calls Hooks.KernelException and resume.
+TEXT ·HaltEl1ExceptionAndResume(SB),NOSPLIT,$0-8
+ WORD $0xd538d092 // MRS TPIDR_EL1, R18
+ MOVD CPU_SELF(RSV_REG), R3 // Load vCPU.
+ MOVD R3, 8(RSP) // First argument (vCPU).
+ MOVD vector+0(FP), R3
+ MOVD R3, 16(RSP) // Second argument (vector).
+ CALL ·kernelException(SB) // Call the trampoline.
+ B ·kernelExitToEl1(SB) // Resume.
+
// Shutdown stops the guest.
TEXT ·Shutdown(SB),NOSPLIT,$0
// PSCI EVENT.
@@ -604,39 +620,22 @@ TEXT ·El1_sync(SB),NOSPLIT,$0
B el1_invalid
el1_da:
+ EXCEPTION_EL1(El1SyncDa)
el1_ia:
- WORD $0xd538d092 //MRS TPIDR_EL1, R18
- WORD $0xd538601a //MRS FAR_EL1, R26
-
- MOVD R26, CPU_FAULT_ADDR(RSV_REG)
-
- MOVD $0, CPU_ERROR_TYPE(RSV_REG)
-
- MOVD $PageFault, R3
- MOVD R3, CPU_VECTOR_CODE(RSV_REG)
-
- B ·HaltAndResume(SB)
-
+ EXCEPTION_EL1(El1SyncIa)
el1_sp_pc:
- B ·Shutdown(SB)
-
+ EXCEPTION_EL1(El1SyncSpPc)
el1_undef:
- B ·Shutdown(SB)
-
+ EXCEPTION_EL1(El1SyncUndef)
el1_svc:
- MOVD $0, CPU_ERROR_CODE(RSV_REG)
- MOVD $0, CPU_ERROR_TYPE(RSV_REG)
B ·HaltEl1SvcAndResume(SB)
-
el1_dbg:
- B ·Shutdown(SB)
-
+ EXCEPTION_EL1(El1SyncDbg)
el1_fpsimd_acc:
VFP_ENABLE
B ·kernelExitToEl1(SB) // Resume.
-
el1_invalid:
- B ·Shutdown(SB)
+ EXCEPTION_EL1(El1SyncInv)
// El1_irq is the handler for El1_irq.
TEXT ·El1_irq(SB),NOSPLIT,$0
@@ -692,28 +691,21 @@ el0_svc:
el0_da:
el0_ia:
- EXCEPTION_WITH_ERROR(1, PageFault)
-
+ EXCEPTION_EL0(PageFault)
el0_fpsimd_acc:
- B ·Shutdown(SB)
-
+ EXCEPTION_EL0(El0SyncFpsimdAcc)
el0_sve_acc:
- B ·Shutdown(SB)
-
+ EXCEPTION_EL0(El0SyncSveAcc)
el0_fpsimd_exc:
- B ·Shutdown(SB)
-
+ EXCEPTION_EL0(El0SyncFpsimdExc)
el0_sp_pc:
- B ·Shutdown(SB)
-
+ EXCEPTION_EL0(El0SyncSpPc)
el0_undef:
- EXCEPTION_WITH_ERROR(1, El0SyncUndef)
-
+ EXCEPTION_EL0(El0SyncUndef)
el0_dbg:
- B ·Shutdown(SB)
-
+ EXCEPTION_EL0(El0SyncDbg)
el0_invalid:
- B ·Shutdown(SB)
+ EXCEPTION_EL0(El0SyncInv)
TEXT ·El0_irq(SB),NOSPLIT,$0
B ·Shutdown(SB)
diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go
index ead598b24..90a7b8392 100644
--- a/pkg/sentry/platform/ring0/kernel_arm64.go
+++ b/pkg/sentry/platform/ring0/kernel_arm64.go
@@ -24,6 +24,10 @@ func HaltAndResume()
//go:nosplit
func HaltEl1SvcAndResume()
+// HaltEl1ExceptionAndResume calls Hooks.KernelException and resume.
+//go:nosplit
+func HaltEl1ExceptionAndResume()
+
// init initializes architecture-specific state.
func (k *Kernel) init(maxCPUs int) {
}
diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go
index 53bc3353c..b5652deb9 100644
--- a/pkg/sentry/platform/ring0/offsets_arm64.go
+++ b/pkg/sentry/platform/ring0/offsets_arm64.go
@@ -70,6 +70,7 @@ func Emit(w io.Writer) {
fmt.Fprintf(w, "#define El0SyncIa 0x%02x\n", El0SyncIa)
fmt.Fprintf(w, "#define El0SyncFpsimdAcc 0x%02x\n", El0SyncFpsimdAcc)
fmt.Fprintf(w, "#define El0SyncSveAcc 0x%02x\n", El0SyncSveAcc)
+ fmt.Fprintf(w, "#define El0SyncFpsimdExc 0x%02x\n", El0SyncFpsimdExc)
fmt.Fprintf(w, "#define El0SyncSys 0x%02x\n", El0SyncSys)
fmt.Fprintf(w, "#define El0SyncSpPc 0x%02x\n", El0SyncSpPc)
fmt.Fprintf(w, "#define El0SyncUndef 0x%02x\n", El0SyncUndef)
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 7d0ae15ca..5afe77858 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -2686,7 +2686,7 @@ func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequ
// Always do at least one fetchReadView, even if the number of bytes to
// read is 0.
err = s.fetchReadView()
- if err != nil {
+ if err != nil || len(s.readView) == 0 {
break
}
if dst.NumBytes() == 0 {
@@ -2709,15 +2709,20 @@ func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequ
}
copied += n
s.readView.TrimFront(n)
- if len(s.readView) == 0 {
- atomic.StoreUint32(&s.readViewHasData, 0)
- }
dst = dst.DropFirst(n)
if e != nil {
err = syserr.FromError(e)
break
}
+ // If we are done reading requested data then stop.
+ if dst.NumBytes() == 0 {
+ break
+ }
+ }
+
+ if len(s.readView) == 0 {
+ atomic.StoreUint32(&s.readViewHasData, 0)
}
// If we managed to copy something, we must deliver it.
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD
index a984f1712..d69b1e081 100644
--- a/pkg/tcpip/adapters/gonet/BUILD
+++ b/pkg/tcpip/adapters/gonet/BUILD
@@ -26,6 +26,7 @@ go_test(
"//pkg/tcpip",
"//pkg/tcpip/header",
"//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index b196324c7..4b6bf4bba 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -28,6 +28,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -65,7 +66,7 @@ func newLoopbackStack() (*stack.Stack, *tcpip.Error) {
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
})
- if err := s.CreateNIC(NICID, loopback.New()); err != nil {
+ if err := s.CreateNIC(NICID, sniffer.New(loopback.New())); err != nil {
return nil, err
}
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index b3e8c4b92..178e658df 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -53,16 +53,35 @@ type endpoint struct {
nested.Endpoint
writer io.Writer
maxPCAPLen uint32
+ logPrefix string
}
var _ stack.GSOEndpoint = (*endpoint)(nil)
var _ stack.LinkEndpoint = (*endpoint)(nil)
var _ stack.NetworkDispatcher = (*endpoint)(nil)
+type direction int
+
+const (
+ directionSend = iota
+ directionRecv
+)
+
// New creates a new sniffer link-layer endpoint. It wraps around another
// endpoint and logs packets and they traverse the endpoint.
func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
- sniffer := &endpoint{}
+ return NewWithPrefix(lower, "")
+}
+
+// NewWithPrefix creates a new sniffer link-layer endpoint. It wraps around
+// another endpoint and logs packets prefixed with logPrefix as they traverse
+// the endpoint.
+//
+// logPrefix is prepended to the log line without any separators.
+// E.g. logPrefix = "NIC:en0/" will produce log lines like
+// "NIC:en0/send udp [...]".
+func NewWithPrefix(lower stack.LinkEndpoint, logPrefix string) stack.LinkEndpoint {
+ sniffer := &endpoint{logPrefix: logPrefix}
sniffer.Endpoint.Init(lower, sniffer)
return sniffer
}
@@ -120,7 +139,7 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) (
// called by the link-layer endpoint being wrapped when a packet arrives, and
// logs the packet before forwarding to the actual dispatcher.
func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- e.dumpPacket("recv", nil, protocol, pkt)
+ e.dumpPacket(directionRecv, nil, protocol, pkt)
e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt)
}
@@ -129,10 +148,10 @@ func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protoc
e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt)
}
-func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+func (e *endpoint) dumpPacket(dir direction, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
writer := e.writer
if writer == nil && atomic.LoadUint32(&LogPackets) == 1 {
- logPacket(prefix, protocol, pkt, gso)
+ logPacket(e.logPrefix, dir, protocol, pkt, gso)
}
if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 {
totalLength := pkt.Size()
@@ -169,7 +188,7 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw
// higher-level protocols to write packets; it just logs the packet and
// forwards the request to the lower endpoint.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- e.dumpPacket("send", gso, protocol, pkt)
+ e.dumpPacket(directionSend, gso, protocol, pkt)
return e.Endpoint.WritePacket(r, gso, protocol, pkt)
}
@@ -178,20 +197,20 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
// forwards the request to the lower endpoint.
func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.dumpPacket("send", gso, protocol, pkt)
+ e.dumpPacket(directionSend, gso, protocol, pkt)
}
return e.Endpoint.WritePackets(r, gso, pkts, protocol)
}
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- e.dumpPacket("send", nil, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ e.dumpPacket(directionSend, nil, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: vv,
}))
return e.Endpoint.WriteRawPacket(vv)
}
-func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) {
+func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) {
// Figure out the network layer info.
var transProto uint8
src := tcpip.Address("unknown")
@@ -201,6 +220,16 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
var fragmentOffset uint16
var moreFragments bool
+ var directionPrefix string
+ switch dir {
+ case directionSend:
+ directionPrefix = "send"
+ case directionRecv:
+ directionPrefix = "recv"
+ default:
+ panic(fmt.Sprintf("unrecognized direction: %d", dir))
+ }
+
// Clone the packet buffer to not modify the original.
//
// We don't clone the original packet buffer so that the new packet buffer
@@ -248,15 +277,16 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
arp := header.ARP(pkt.NetworkHeader().View())
log.Infof(
- "%s arp %s (%s) -> %s (%s) valid:%t",
+ "%s%s arp %s (%s) -> %s (%s) valid:%t",
prefix,
+ directionPrefix,
tcpip.Address(arp.ProtocolAddressSender()), tcpip.LinkAddress(arp.HardwareAddressSender()),
tcpip.Address(arp.ProtocolAddressTarget()), tcpip.LinkAddress(arp.HardwareAddressTarget()),
arp.IsValid(),
)
return
default:
- log.Infof("%s unknown network protocol", prefix)
+ log.Infof("%s%s unknown network protocol", prefix, directionPrefix)
return
}
@@ -300,7 +330,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
icmpType = "info reply"
}
}
- log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
+ log.Infof("%s%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, directionPrefix, transName, src, dst, icmpType, size, id, icmp.Code())
return
case header.ICMPv6ProtocolNumber:
@@ -335,7 +365,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
case header.ICMPv6RedirectMsg:
icmpType = "redirect message"
}
- log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
+ log.Infof("%s%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, directionPrefix, transName, src, dst, icmpType, size, id, icmp.Code())
return
case header.UDPProtocolNumber:
@@ -391,7 +421,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
}
default:
- log.Infof("%s %s -> %s unknown transport protocol: %d", prefix, src, dst, transProto)
+ log.Infof("%s%s %s -> %s unknown transport protocol: %d", prefix, directionPrefix, src, dst, transProto)
return
}
@@ -399,5 +429,5 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
details += fmt.Sprintf(" gso: %+v", gso)
}
- log.Infof("%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details)
+ log.Infof("%s%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, directionPrefix, transName, src, srcPort, dst, dstPort, size, id, details)
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index ac6d879a7..6661e8915 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -496,7 +496,7 @@ func (h *handshake) resolveRoute() *tcpip.Error {
h.ep.mu.Lock()
}
if n&notifyError != 0 {
- return h.ep.LastError()
+ return h.ep.lastErrorLocked()
}
}
@@ -575,7 +575,6 @@ func (h *handshake) complete() *tcpip.Error {
return err
}
defer timer.stop()
-
for h.state != handshakeCompleted {
// Unlock before blocking, and reacquire again afterwards (h.ep.mu is held
// throughout handshake processing).
@@ -631,9 +630,8 @@ func (h *handshake) complete() *tcpip.Error {
h.ep.mu.Lock()
}
if n&notifyError != 0 {
- return h.ep.LastError()
+ return h.ep.lastErrorLocked()
}
-
case wakerForNewSegment:
if err := h.processSegments(); err != nil {
return err
@@ -1002,7 +1000,7 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
// Only send a reset if the connection is being aborted for a reason
// other than receiving a reset.
e.setEndpointState(StateError)
- e.HardError = err
+ e.hardError = err
if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout {
// The exact sequence number to be used for the RST is the same as the
// one used by Linux. We need to handle the case of window being shrunk
@@ -1141,7 +1139,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
// delete the TCB, and return.
case StateCloseWait:
e.transitionToStateCloseLocked()
- e.HardError = tcpip.ErrAborted
+ e.hardError = tcpip.ErrAborted
e.notifyProtocolGoroutine(notifyTickleWorker)
return false, nil
default:
@@ -1353,7 +1351,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
epilogue := func() {
// e.mu is expected to be hold upon entering this section.
-
if e.snd != nil {
e.snd.resendTimer.cleanup()
}
@@ -1383,7 +1380,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.lastErrorMu.Unlock()
e.setEndpointState(StateError)
- e.HardError = err
+ e.hardError = err
e.workerCleanup = true
// Lock released below.
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 4f4f4c65e..a2161e49d 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -315,11 +315,6 @@ func (*Stats) IsEndpointStats() {}
// +stateify savable
type EndpointInfo struct {
stack.TransportEndpointInfo
-
- // HardError is meaningful only when state is stateError. It stores the
- // error to be returned when read/write syscalls are called and the
- // endpoint is in this state. HardError is protected by endpoint mu.
- HardError *tcpip.Error `state:".(string)"`
}
// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
@@ -386,6 +381,11 @@ type endpoint struct {
waiterQueue *waiter.Queue `state:"wait"`
uniqueID uint64
+ // hardError is meaningful only when state is stateError. It stores the
+ // error to be returned when read/write syscalls are called and the
+ // endpoint is in this state. hardError is protected by endpoint mu.
+ hardError *tcpip.Error `state:".(string)"`
+
// lastError represents the last error that the endpoint reported;
// access to it is protected by the following mutex.
lastErrorMu sync.Mutex `state:"nosave"`
@@ -1283,7 +1283,15 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-func (e *endpoint) LastError() *tcpip.Error {
+// Preconditions: e.mu must be held to call this function.
+func (e *endpoint) hardErrorLocked() *tcpip.Error {
+ err := e.hardError
+ e.hardError = nil
+ return err
+}
+
+// Preconditions: e.mu must be held to call this function.
+func (e *endpoint) lastErrorLocked() *tcpip.Error {
e.lastErrorMu.Lock()
defer e.lastErrorMu.Unlock()
err := e.lastError
@@ -1291,6 +1299,15 @@ func (e *endpoint) LastError() *tcpip.Error {
return err
}
+func (e *endpoint) LastError() *tcpip.Error {
+ e.LockUser()
+ defer e.UnlockUser()
+ if err := e.hardErrorLocked(); err != nil {
+ return err
+ }
+ return e.lastErrorLocked()
+}
+
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.LockUser()
@@ -1312,9 +1329,8 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
bufUsed := e.rcvBufUsed
if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 {
e.rcvListMu.Unlock()
- he := e.HardError
if s == StateError {
- return buffer.View{}, tcpip.ControlMessages{}, he
+ return buffer.View{}, tcpip.ControlMessages{}, e.hardErrorLocked()
}
e.stats.ReadErrors.NotConnected.Increment()
return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected
@@ -1370,9 +1386,13 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
// indicating the reason why it's not writable.
// Caller must hold e.mu and e.sndBufMu
func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
+ // The endpoint cannot be written to if it's not connected.
switch s := e.EndpointState(); {
case s == StateError:
- return 0, e.HardError
+ if err := e.hardErrorLocked(); err != nil {
+ return 0, err
+ }
+ return 0, tcpip.ErrClosedForSend
case !s.connecting() && !s.connected():
return 0, tcpip.ErrClosedForSend
case s.connecting():
@@ -1486,7 +1506,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
// but has some pending unread data.
if s := e.EndpointState(); !s.connected() && s != StateClose {
if s == StateError {
- return 0, tcpip.ControlMessages{}, e.HardError
+ return 0, tcpip.ControlMessages{}, e.hardErrorLocked()
}
e.stats.ReadErrors.InvalidEndpointState.Increment()
return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
@@ -2243,7 +2263,10 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
return tcpip.ErrAlreadyConnecting
case StateError:
- return e.HardError
+ if err := e.hardErrorLocked(); err != nil {
+ return err
+ }
+ return tcpip.ErrConnectionAborted
default:
return tcpip.ErrInvalidEndpointState
@@ -2417,7 +2440,7 @@ func (e *endpoint) startMainLoop(handshake bool) *tcpip.Error {
e.lastErrorMu.Unlock()
e.setEndpointState(StateError)
- e.HardError = err
+ e.hardError = err
// Call cleanupLocked to free up any reservations.
e.cleanupLocked()
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index bb901c0f8..ba67176b5 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -321,21 +321,21 @@ func (e *endpoint) loadRecentTSTime(unix unixTime) {
}
// saveHardError is invoked by stateify.
-func (e *EndpointInfo) saveHardError() string {
- if e.HardError == nil {
+func (e *endpoint) saveHardError() string {
+ if e.hardError == nil {
return ""
}
- return e.HardError.String()
+ return e.hardError.String()
}
// loadHardError is invoked by stateify.
-func (e *EndpointInfo) loadHardError(s string) {
+func (e *endpoint) loadHardError(s string) {
if s == "" {
return
}
- e.HardError = tcpip.StringToError(s)
+ e.hardError = tcpip.StringToError(s)
}
// saveMeasureTime is invoked by stateify.
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 9f0fb41e3..c366a4cbc 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -75,9 +75,6 @@ func TestGiveUpConnect(t *testing.T) {
// Wait for ep to become writable.
<-notifyCh
- if err := ep.LastError(); err != tcpip.ErrAborted {
- t.Fatalf("got ep.LastError() = %s, want = %s", err, tcpip.ErrAborted)
- }
// Call Connect again to retreive the handshake failure status
// and stats updates.
@@ -3198,6 +3195,11 @@ loop:
case tcpip.ErrWouldBlock:
select {
case <-ch:
+ // Expect the state to be StateError and subsequent Reads to fail with HardError.
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
+ }
+ break loop
case <-time.After(1 * time.Second):
t.Fatalf("Timed out waiting for reset to arrive")
}
@@ -3207,14 +3209,10 @@ loop:
t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
}
}
- // Expect the state to be StateError and subsequent Reads to fail with HardError.
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
- }
+
if tcp.EndpointState(c.EP.State()) != tcp.StateError {
t.Fatalf("got EP state is not StateError")
}
-
if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got)
}
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index c78549424..7ebae63d8 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -56,6 +56,7 @@ go_test(
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
"//pkg/waiter",
],
)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 57976d4e3..835dcc54e 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -429,7 +429,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
to := opts.To
e.mu.RLock()
- defer e.mu.RUnlock()
+ lockReleased := false
+ defer func() {
+ if lockReleased {
+ return
+ }
+ e.mu.RUnlock()
+ }()
// If we've shutdown with SHUT_WR we are in an invalid state for sending.
if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
@@ -475,7 +481,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
if e.state != StateConnected {
err = tcpip.ErrInvalidEndpointState
}
- return
+ return ch, err
}
} else {
// Reject destination address if it goes through a different
@@ -541,7 +547,24 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
useDefaultTTL = false
}
- if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner, e.noChecksum); err != nil {
+ localPort := e.ID.LocalPort
+ sendTOS := e.sendTOS
+ owner := e.owner
+ noChecksum := e.noChecksum
+ lockReleased = true
+ e.mu.RUnlock()
+
+ // Do not hold lock when sending as loopback is synchronous and if the UDP
+ // datagram ends up generating an ICMP response then it can result in a
+ // deadlock where the ICMP response handling ends up acquiring this endpoint's
+ // mutex using e.mu.RLock() in endpoint.HandleControlPacket which can cause a
+ // deadlock if another caller is trying to acquire e.mu in exclusive mode w/
+ // e.mu.Lock(). Since e.mu.Lock() prevents any new read locks to ensure the
+ // lock can be eventually acquired.
+ //
+ // See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read
+ // locking is prohibited.
+ if err := sendUDP(route, buffer.View(v).ToVectorisedView(), localPort, dstPort, ttl, useDefaultTTL, sendTOS, owner, noChecksum); err != nil {
return 0, nil, err
}
return int64(len(v)), nil, nil
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 764ad0857..492e277a8 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -32,6 +32,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -54,6 +55,7 @@ const (
stackPort = 1234
testAddr = "\x0a\x00\x00\x02"
testPort = 4096
+ invalidPort = 8192
multicastAddr = "\xe8\x2b\xd3\xea"
multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
broadcastAddr = header.IPv4Broadcast
@@ -295,7 +297,8 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext {
t.Helper()
return newDualTestContextWithOptions(t, mtu, stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
+ HandleLocal: true,
})
}
@@ -972,7 +975,7 @@ func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) {
// provided.
func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
c.t.Helper()
- return testWriteInternal(c, flow, true, checkers...)
+ return testWriteAndVerifyInternal(c, flow, true, checkers...)
}
// testWriteWithoutDestination sends a packet of the given test flow from the
@@ -981,10 +984,10 @@ func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker
// checker functions provided.
func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
c.t.Helper()
- return testWriteInternal(c, flow, false, checkers...)
+ return testWriteAndVerifyInternal(c, flow, false, checkers...)
}
-func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 {
+func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View {
c.t.Helper()
// Take a snapshot of the stats to validate them at the end of the test.
epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
@@ -1006,6 +1009,12 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
}
c.checkEndpointWriteStats(1, epstats, err)
+ return payload
+}
+
+func testWriteAndVerifyInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 {
+ c.t.Helper()
+ payload := testWriteNoVerify(c, flow, setDest)
// Received the packet and check the payload.
b := c.getPacketAndVerify(flow, checkers...)
var udp header.UDP
@@ -1150,6 +1159,39 @@ func TestV4WriteOnConnected(t *testing.T) {
testWriteWithoutDestination(c, unicastV4)
}
+func TestWriteOnConnectedInvalidPort(t *testing.T) {
+ protocols := map[string]tcpip.NetworkProtocolNumber{
+ "ipv4": ipv4.ProtocolNumber,
+ "ipv6": ipv6.ProtocolNumber,
+ }
+ for name, pn := range protocols {
+ t.Run(name, func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(pn)
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}); err != nil {
+ c.t.Fatalf("Connect failed: %s", err)
+ }
+ writeOpts := tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort},
+ }
+ payload := buffer.View(newPayload())
+ n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
+ if err != nil {
+ c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err)
+ }
+ if got, want := n, int64(len(payload)); got != want {
+ c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want)
+ }
+
+ if err := c.ep.LastError(); err != tcpip.ErrConnectionRefused {
+ c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err)
+ }
+ })
+ }
+}
+
// TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket
// that is bound to a V4 multicast address.
func TestWriteOnBoundToV4Multicast(t *testing.T) {
diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go
index 49ab87c58..976331230 100644
--- a/pkg/test/testutil/testutil.go
+++ b/pkg/test/testutil/testutil.go
@@ -36,7 +36,6 @@ import (
"path/filepath"
"strconv"
"strings"
- "sync/atomic"
"syscall"
"testing"
"time"
@@ -417,33 +416,35 @@ func StartReaper() func() {
// WaitUntilRead reads from the given reader until the wanted string is found
// or until timeout.
-func WaitUntilRead(r io.Reader, want string, split bufio.SplitFunc, timeout time.Duration) error {
+func WaitUntilRead(r io.Reader, want string, timeout time.Duration) error {
sc := bufio.NewScanner(r)
- if split != nil {
- sc.Split(split)
- }
// done must be accessed atomically. A value greater than 0 indicates
// that the read loop can exit.
- var done uint32
- doneCh := make(chan struct{})
+ doneCh := make(chan bool)
+ defer close(doneCh)
go func() {
for sc.Scan() {
t := sc.Text()
if strings.Contains(t, want) {
- atomic.StoreUint32(&done, 1)
- close(doneCh)
- break
+ doneCh <- true
+ return
}
- if atomic.LoadUint32(&done) > 0 {
- break
+ select {
+ case <-doneCh:
+ return
+ default:
}
}
+ doneCh <- false
}()
+
select {
case <-time.After(timeout):
- atomic.StoreUint32(&done, 1)
return fmt.Errorf("timeout waiting to read %q", want)
- case <-doneCh:
+ case res := <-doneCh:
+ if !res {
+ return fmt.Errorf("reader closed while waiting to read %q", want)
+ }
return nil
}
}