summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/file.go5
-rw-r--r--pkg/sentry/fsimpl/ext/inode.go2
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go22
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go7
-rw-r--r--pkg/sentry/fsimpl/host/host.go6
-rw-r--r--pkg/sentry/fsimpl/kernfs/fd_impl_util.go3
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go6
-rw-r--r--pkg/sentry/fsimpl/proc/task.go9
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go24
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go8
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s18
-rw-r--r--pkg/sentry/vfs/anonfs.go2
-rw-r--r--pkg/sentry/vfs/permissions.go23
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go239
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go222
15 files changed, 320 insertions, 276 deletions
diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go
index dbe58acbe..055ac1d7c 100644
--- a/pkg/abi/linux/file.go
+++ b/pkg/abi/linux/file.go
@@ -287,6 +287,11 @@ func (m FileMode) ExtraBits() FileMode {
return m &^ (PermissionsMask | FileTypeMask)
}
+// IsDir returns true if file type represents a directory.
+func (m FileMode) IsDir() bool {
+ return m.FileType() == S_IFDIR
+}
+
// String returns a string representation of m.
func (m FileMode) String() string {
var s []string
diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go
index 6962083f5..a39a37318 100644
--- a/pkg/sentry/fsimpl/ext/inode.go
+++ b/pkg/sentry/fsimpl/ext/inode.go
@@ -186,7 +186,7 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts *vfs.OpenOpt
}
func (in *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
- return vfs.GenericCheckPermissions(creds, ats, in.isDir(), uint16(in.diskInode.Mode()), in.diskInode.UID(), in.diskInode.GID())
+ return vfs.GenericCheckPermissions(creds, ats, in.diskInode.Mode(), in.diskInode.UID(), in.diskInode.GID())
}
// statTo writes the statx fields to the output parameter.
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 26b492185..1e43df9ec 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -119,7 +119,7 @@ func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *
if !d.isDir() {
return nil, syserror.ENOTDIR
}
- if err := d.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil {
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
afterSymlink:
@@ -314,7 +314,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if err != nil {
return err
}
- if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil {
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
if parent.isDeleted() {
@@ -378,7 +378,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
if err != nil {
return err
}
- if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil {
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
if err := rp.Mount().CheckBeginWrite(); err != nil {
@@ -512,7 +512,7 @@ func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds
if err != nil {
return err
}
- return d.checkPermissions(creds, ats, d.isDir())
+ return d.checkPermissions(creds, ats)
}
// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
@@ -528,7 +528,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op
if !d.isDir() {
return nil, syserror.ENOTDIR
}
- if err := d.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil {
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
}
@@ -624,7 +624,7 @@ afterTrailingSymlink:
return nil, err
}
// Check for search permission in the parent directory.
- if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil {
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
// Determine whether or not we need to create a file.
@@ -661,7 +661,7 @@ afterTrailingSymlink:
// Preconditions: fs.renameMu must be locked.
func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
ats := vfs.AccessTypesForOpenFlags(opts)
- if err := d.checkPermissions(rp.Credentials(), ats, d.isDir()); err != nil {
+ if err := d.checkPermissions(rp.Credentials(), ats); err != nil {
return nil, err
}
mnt := rp.Mount()
@@ -722,7 +722,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
// Preconditions: d.fs.renameMu must be locked. d.dirMu must be locked.
func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
- if err := d.checkPermissions(rp.Credentials(), vfs.MayWrite, true); err != nil {
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
return nil, err
}
if d.isDeleted() {
@@ -884,7 +884,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
return err
}
}
- if err := oldParent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil {
+ if err := oldParent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
vfsObj := rp.VirtualFilesystem()
@@ -904,7 +904,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
return syserror.EINVAL
}
if oldParent != newParent {
- if err := renamed.checkPermissions(rp.Credentials(), vfs.MayWrite, true); err != nil {
+ if err := renamed.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
return err
}
}
@@ -915,7 +915,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
}
if oldParent != newParent {
- if err := newParent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil {
+ if err := newParent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
newParent.dirMu.Lock()
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 13928ce36..cf276a417 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -721,7 +721,8 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin
if stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_SIZE) != 0 {
return syserror.EPERM
}
- if err := vfs.CheckSetStat(ctx, creds, stat, uint16(atomic.LoadUint32(&d.mode))&^linux.S_IFMT, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
+ mode := linux.FileMode(atomic.LoadUint32(&d.mode))
+ if err := vfs.CheckSetStat(ctx, creds, stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
return err
}
if err := mnt.CheckBeginWrite(); err != nil {
@@ -843,8 +844,8 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin
return nil
}
-func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes, isDir bool) error {
- return vfs.GenericCheckPermissions(creds, ats, isDir, uint16(atomic.LoadUint32(&d.mode))&0777, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid)))
+func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
+ return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid)))
}
// IncRef implements vfs.DentryImpl.IncRef.
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index 1f735628f..a54985ef5 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -167,8 +167,8 @@ func fileFlagsFromHostFD(fd int) (int, error) {
}
// CheckPermissions implements kernfs.Inode.
-func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, atx vfs.AccessTypes) error {
- return vfs.GenericCheckPermissions(creds, atx, false /* isDir */, uint16(i.mode), i.uid, i.gid)
+func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ return vfs.GenericCheckPermissions(creds, ats, i.mode, i.uid, i.gid)
}
// Mode implements kernfs.Inode.
@@ -306,7 +306,7 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
if m&^(linux.STATX_MODE|linux.STATX_SIZE|linux.STATX_ATIME|linux.STATX_MTIME) != 0 {
return syserror.EPERM
}
- if err := vfs.CheckSetStat(ctx, creds, &s, uint16(i.Mode().Permissions()), i.uid, i.gid); err != nil {
+ if err := vfs.CheckSetStat(ctx, creds, &s, i.Mode(), i.uid, i.gid); err != nil {
return err
}
diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
index 75c4bab1a..bfa786c88 100644
--- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
@@ -206,8 +206,7 @@ func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (l
// SetStat implements vfs.FileDescriptionImpl.SetStat.
func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
- fs := fd.filesystem()
creds := auth.CredentialsFromContext(ctx)
inode := fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode
- return inode.SetStat(ctx, fs, creds, opts)
+ return inode.SetStat(ctx, fd.filesystem(), creds, opts)
}
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index c612dcf07..5c84b10c9 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -241,7 +241,7 @@ func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *aut
if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID) != 0 {
return syserror.EPERM
}
- if err := vfs.CheckSetStat(ctx, creds, &opts.Stat, uint16(a.Mode().Permissions()), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil {
+ if err := vfs.CheckSetStat(ctx, creds, &opts.Stat, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil {
return err
}
@@ -273,12 +273,10 @@ func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *aut
// CheckPermissions implements Inode.CheckPermissions.
func (a *InodeAttrs) CheckPermissions(_ context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
- mode := a.Mode()
return vfs.GenericCheckPermissions(
creds,
ats,
- mode.FileType() == linux.ModeDirectory,
- uint16(mode),
+ a.Mode(),
auth.KUID(atomic.LoadUint32(&a.uid)),
auth.KGID(atomic.LoadUint32(&a.gid)),
)
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
index 49d6efb0e..aee2a4392 100644
--- a/pkg/sentry/fsimpl/proc/task.go
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -172,14 +172,7 @@ func (i *taskOwnedInode) Stat(fs *vfs.Filesystem, opts vfs.StatOptions) (linux.S
func (i *taskOwnedInode) CheckPermissions(_ context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
mode := i.Mode()
uid, gid := i.getOwner(mode)
- return vfs.GenericCheckPermissions(
- creds,
- ats,
- mode.FileType() == linux.ModeDirectory,
- uint16(mode),
- uid,
- gid,
- )
+ return vfs.GenericCheckPermissions(creds, ats, mode, uid, gid)
}
func (i *taskOwnedInode) getOwner(mode linux.FileMode) (auth.KUID, auth.KGID) {
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index 75d01b853..12cc64385 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -41,7 +41,7 @@ func stepLocked(rp *vfs.ResolvingPath, d *dentry) (*dentry, error) {
if !d.inode.isDir() {
return nil, syserror.ENOTDIR
}
- if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil {
+ if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
afterSymlink:
@@ -125,7 +125,7 @@ func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(pa
if err != nil {
return err
}
- if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil {
+ if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
name := rp.Component()
@@ -163,7 +163,7 @@ func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds
if err != nil {
return err
}
- return d.inode.checkPermissions(creds, ats, d.inode.isDir())
+ return d.inode.checkPermissions(creds, ats)
}
// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
@@ -178,7 +178,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op
if !d.inode.isDir() {
return nil, syserror.ENOTDIR
}
- if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayExec, true /* isDir */); err != nil {
+ if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
}
@@ -301,7 +301,7 @@ afterTrailingSymlink:
return nil, err
}
// Check for search permission in the parent directory.
- if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil {
+ if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
// Reject attempts to open directories with O_CREAT.
@@ -316,7 +316,7 @@ afterTrailingSymlink:
child, err := stepLocked(rp, parent)
if err == syserror.ENOENT {
// Already checked for searchability above; now check for writability.
- if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite, true); err != nil {
+ if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
return nil, err
}
if err := rp.Mount().CheckBeginWrite(); err != nil {
@@ -347,7 +347,7 @@ afterTrailingSymlink:
func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions, afterCreate bool) (*vfs.FileDescription, error) {
ats := vfs.AccessTypesForOpenFlags(opts)
if !afterCreate {
- if err := d.inode.checkPermissions(rp.Credentials(), ats, d.inode.isDir()); err != nil {
+ if err := d.inode.checkPermissions(rp.Credentials(), ats); err != nil {
return nil, err
}
}
@@ -428,7 +428,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
defer mnt.EndWrite()
oldParent := oldParentVD.Dentry().Impl().(*dentry)
- if err := oldParent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil {
+ if err := oldParent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
// Call vfs.Dentry.Child() instead of stepLocked() or rp.ResolveChild(),
@@ -445,7 +445,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
}
if oldParent != newParent {
// Writability is needed to change renamed's "..".
- if err := renamed.inode.checkPermissions(rp.Credentials(), vfs.MayWrite, true /* isDir */); err != nil {
+ if err := renamed.inode.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
return err
}
}
@@ -455,7 +455,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
}
}
- if err := newParent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil {
+ if err := newParent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
replacedVFSD := newParent.vfsd.Child(newName)
@@ -528,7 +528,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
if err != nil {
return err
}
- if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil {
+ if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
name := rp.Component()
@@ -621,7 +621,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
if err != nil {
return err
}
- if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil {
+ if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
name := rp.Component()
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index 2d5070a46..2f9e6c876 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -245,8 +245,9 @@ func (i *inode) decRef() {
}
}
-func (i *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes, isDir bool) error {
- return vfs.GenericCheckPermissions(creds, ats, isDir, uint16(atomic.LoadUint32(&i.mode)), auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid)))
+func (i *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
+ mode := linux.FileMode(atomic.LoadUint32(&i.mode))
+ return vfs.GenericCheckPermissions(creds, ats, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid)))
}
// Go won't inline this function, and returning linux.Statx (which is quite
@@ -299,7 +300,8 @@ func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linu
if stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME|linux.STATX_SIZE) != 0 {
return syserror.EPERM
}
- if err := vfs.CheckSetStat(ctx, creds, stat, uint16(atomic.LoadUint32(&i.mode))&^linux.S_IFMT, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil {
+ mode := linux.FileMode(atomic.LoadUint32(&i.mode))
+ if err := vfs.CheckSetStat(ctx, creds, stat, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil {
return err
}
i.mu.Lock()
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
index d42eda37b..db6465663 100644
--- a/pkg/sentry/platform/ring0/entry_arm64.s
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -394,6 +394,8 @@ TEXT ·Current(SB),NOSPLIT,$0-8
#define STACK_FRAME_SIZE 16
+// kernelExitToEl0 is the entrypoint for application in guest_el0.
+// Prepare the vcpu environment for container application.
TEXT ·kernelExitToEl0(SB),NOSPLIT,$0
// Step1, save sentry context into memory.
REGISTERS_SAVE(RSV_REG, CPU_REGISTERS)
@@ -464,7 +466,23 @@ TEXT ·kernelExitToEl0(SB),NOSPLIT,$0
ERET()
+// kernelExitToEl1 is the entrypoint for sentry in guest_el1.
+// Prepare the vcpu environment for sentry.
TEXT ·kernelExitToEl1(SB),NOSPLIT,$0
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+
+ MOVD CPU_REGISTERS+PTRACE_PSTATE(RSV_REG), R1
+ WORD $0xd5184001 //MSR R1, SPSR_EL1
+
+ MOVD CPU_REGISTERS+PTRACE_PC(RSV_REG), R1
+ MSR R1, ELR_EL1
+
+ MOVD CPU_REGISTERS+PTRACE_SP(RSV_REG), R1
+ MOVD R1, RSP
+
+ REGISTERS_LOAD(RSV_REG, CPU_REGISTERS)
+ MOVD CPU_REGISTERS+PTRACE_R9(RSV_REG), RSV_REG_APP
+
ERET()
// Start is the CPU entrypoint.
diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go
index 925996517..a62e43589 100644
--- a/pkg/sentry/vfs/anonfs.go
+++ b/pkg/sentry/vfs/anonfs.go
@@ -83,7 +83,7 @@ func (fs *anonFilesystem) AccessAt(ctx context.Context, rp *ResolvingPath, creds
if !rp.Done() {
return syserror.ENOTDIR
}
- return GenericCheckPermissions(creds, ats, false /* isDir */, anonFileMode, anonFileUID, anonFileGID)
+ return GenericCheckPermissions(creds, ats, anonFileMode, anonFileUID, anonFileGID)
}
// GetDentryAt implements FilesystemImpl.GetDentryAt.
diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go
index 2c8f23f55..f9647f90e 100644
--- a/pkg/sentry/vfs/permissions.go
+++ b/pkg/sentry/vfs/permissions.go
@@ -29,9 +29,9 @@ type AccessTypes uint16
// Bits in AccessTypes.
const (
+ MayExec AccessTypes = 1
+ MayWrite AccessTypes = 2
MayRead AccessTypes = 4
- MayWrite = 2
- MayExec = 1
)
// OnlyRead returns true if access _only_ allows read.
@@ -56,16 +56,17 @@ func (a AccessTypes) MayExec() bool {
// GenericCheckPermissions checks that creds has the given access rights on a
// file with the given permissions, UID, and GID, subject to the rules of
-// fs/namei.c:generic_permission(). isDir is true if the file is a directory.
-func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, isDir bool, mode uint16, kuid auth.KUID, kgid auth.KGID) error {
+// fs/namei.c:generic_permission().
+func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error {
// Check permission bits.
- perms := mode
+ perms := uint16(mode.Permissions())
if creds.EffectiveKUID == kuid {
perms >>= 6
} else if creds.InGroup(kgid) {
perms >>= 3
}
if uint16(ats)&perms == uint16(ats) {
+ // All permission bits match, access granted.
return nil
}
@@ -77,7 +78,7 @@ func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, isDir boo
}
// CAP_DAC_READ_SEARCH allows the caller to read and search arbitrary
// directories, and read arbitrary non-directory files.
- if (isDir && !ats.MayWrite()) || ats.OnlyRead() {
+ if (mode.IsDir() && !ats.MayWrite()) || ats.OnlyRead() {
if creds.HasCapability(linux.CAP_DAC_READ_SEARCH) {
return nil
}
@@ -85,7 +86,7 @@ func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, isDir boo
// CAP_DAC_OVERRIDE allows arbitrary access to directories, read/write
// access to non-directory files, and execute access to non-directory files
// for which at least one execute bit is set.
- if isDir || !ats.MayExec() || (mode&0111 != 0) {
+ if mode.IsDir() || !ats.MayExec() || (mode.Permissions()&0111 != 0) {
if creds.HasCapability(linux.CAP_DAC_OVERRIDE) {
return nil
}
@@ -151,7 +152,7 @@ func MayWriteFileWithOpenFlags(flags uint32) bool {
// CheckSetStat checks that creds has permission to change the metadata of a
// file with the given permissions, UID, and GID as specified by stat, subject
// to the rules of Linux's fs/attr.c:setattr_prepare().
-func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mode uint16, kuid auth.KUID, kgid auth.KGID) error {
+func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error {
if stat.Mask&linux.STATX_SIZE != 0 {
limit, err := CheckLimit(ctx, 0, int64(stat.Size))
if err != nil {
@@ -190,11 +191,7 @@ func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Stat
(stat.Mask&linux.STATX_CTIME != 0 && stat.Ctime.Nsec != linux.UTIME_NOW) {
return syserror.EPERM
}
- // isDir is irrelevant in the following call to
- // GenericCheckPermissions since ats == MayWrite means that
- // CAP_DAC_READ_SEARCH does not apply, and CAP_DAC_OVERRIDE
- // applies, regardless of isDir.
- if err := GenericCheckPermissions(creds, MayWrite, false /* isDir */, mode, kuid, kgid); err != nil {
+ if err := GenericCheckPermissions(creds, MayWrite, mode, kuid, kgid); err != nil {
return err
}
}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index c55e3e8bc..9a33ed375 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -35,7 +35,7 @@ type protocolIDs struct {
type transportEndpoints struct {
// mu protects all fields of the transportEndpoints.
mu sync.RWMutex
- endpoints map[TransportEndpointID]*endpointsByNic
+ endpoints map[TransportEndpointID]*endpointsByNIC
// rawEndpoints contains endpoints for raw sockets, which receive all
// traffic of a given protocol regardless of port.
rawEndpoints []RawTransportEndpoint
@@ -46,11 +46,11 @@ type transportEndpoints struct {
func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
eps.mu.Lock()
defer eps.mu.Unlock()
- epsByNic, ok := eps.endpoints[id]
+ epsByNIC, ok := eps.endpoints[id]
if !ok {
return
}
- if !epsByNic.unregisterEndpoint(bindToDevice, ep) {
+ if !epsByNIC.unregisterEndpoint(bindToDevice, ep) {
return
}
delete(eps.endpoints, id)
@@ -66,18 +66,85 @@ func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint {
return es
}
-type endpointsByNic struct {
+// iterEndpointsLocked yields all endpointsByNIC in eps that match id, in
+// descending order of match quality. If a call to yield returns false,
+// iterEndpointsLocked stops iteration and returns immediately.
+//
+// Preconditions: eps.mu must be locked.
+func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) {
+ // Try to find a match with the id as provided.
+ if ep, ok := eps.endpoints[id]; ok {
+ if !yield(ep) {
+ return
+ }
+ }
+
+ // Try to find a match with the id minus the local address.
+ nid := id
+
+ nid.LocalAddress = ""
+ if ep, ok := eps.endpoints[nid]; ok {
+ if !yield(ep) {
+ return
+ }
+ }
+
+ // Try to find a match with the id minus the remote part.
+ nid.LocalAddress = id.LocalAddress
+ nid.RemoteAddress = ""
+ nid.RemotePort = 0
+ if ep, ok := eps.endpoints[nid]; ok {
+ if !yield(ep) {
+ return
+ }
+ }
+
+ // Try to find a match with only the local port.
+ nid.LocalAddress = ""
+ if ep, ok := eps.endpoints[nid]; ok {
+ if !yield(ep) {
+ return
+ }
+ }
+}
+
+// findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in
+// descending order of match quality.
+//
+// Preconditions: eps.mu must be locked.
+func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC {
+ var matchedEPs []*endpointsByNIC
+ eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
+ matchedEPs = append(matchedEPs, ep)
+ return true
+ })
+ return matchedEPs
+}
+
+// findEndpointLocked returns the endpoint that most closely matches the given id.
+//
+// Preconditions: eps.mu must be locked.
+func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC {
+ var matchedEP *endpointsByNIC
+ eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
+ matchedEP = ep
+ return false
+ })
+ return matchedEP
+}
+
+type endpointsByNIC struct {
mu sync.RWMutex
endpoints map[tcpip.NICID]*multiPortEndpoint
// seed is a random secret for a jenkins hash.
seed uint32
}
-func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint {
- epsByNic.mu.RLock()
- defer epsByNic.mu.RUnlock()
+func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
+ epsByNIC.mu.RLock()
+ defer epsByNIC.mu.RUnlock()
var eps []TransportEndpoint
- for _, ep := range epsByNic.endpoints {
+ for _, ep := range epsByNIC.endpoints {
eps = append(eps, ep.transportEndpoints()...)
}
return eps
@@ -85,13 +152,13 @@ func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) {
- epsByNic.mu.RLock()
+func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) {
+ epsByNIC.mu.RLock()
- mpep, ok := epsByNic.endpoints[r.ref.nic.ID()]
+ mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()]
if !ok {
- if mpep, ok = epsByNic.endpoints[0]; !ok {
- epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ if mpep, ok = epsByNIC.endpoints[0]; !ok {
+ epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return
}
}
@@ -100,29 +167,29 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, p
// endpoints bound to the right device.
if isMulticastOrBroadcast(id.LocalAddress) {
mpep.handlePacketAll(r, id, pkt)
- epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return
}
// multiPortEndpoints are guaranteed to have at least one element.
- transEP := selectEndpoint(id, mpep, epsByNic.seed)
+ transEP := selectEndpoint(id, mpep, epsByNIC.seed)
if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
queuedProtocol.QueuePacket(r, transEP, id, pkt)
- epsByNic.mu.RUnlock()
+ epsByNIC.mu.RUnlock()
return
}
transEP.HandlePacket(r, id, pkt)
- epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) {
- epsByNic.mu.RLock()
- defer epsByNic.mu.RUnlock()
+func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) {
+ epsByNIC.mu.RLock()
+ defer epsByNIC.mu.RUnlock()
- mpep, ok := epsByNic.endpoints[n.ID()]
+ mpep, ok := epsByNIC.endpoints[n.ID()]
if !ok {
- mpep, ok = epsByNic.endpoints[0]
+ mpep, ok = epsByNIC.endpoints[0]
}
if !ok {
return
@@ -132,16 +199,16 @@ func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpoint
// broadcast like we are doing with handlePacket above?
// multiPortEndpoints are guaranteed to have at least one element.
- selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, pkt)
+ selectEndpoint(id, mpep, epsByNIC.seed).HandleControlPacket(id, typ, extra, pkt)
}
// registerEndpoint returns true if it succeeds. It fails and returns
// false if ep already has an element with the same key.
-func (epsByNic *endpointsByNic) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
- epsByNic.mu.Lock()
- defer epsByNic.mu.Unlock()
+func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+ epsByNIC.mu.Lock()
+ defer epsByNIC.mu.Unlock()
- multiPortEp, ok := epsByNic.endpoints[bindToDevice]
+ multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
if !ok {
multiPortEp = &multiPortEndpoint{
demux: d,
@@ -149,24 +216,24 @@ func (epsByNic *endpointsByNic) registerEndpoint(d *transportDemuxer, netProto t
transProto: transProto,
reuse: reusePort,
}
- epsByNic.endpoints[bindToDevice] = multiPortEp
+ epsByNIC.endpoints[bindToDevice] = multiPortEp
}
return multiPortEp.singleRegisterEndpoint(t, reusePort)
}
-// unregisterEndpoint returns true if endpointsByNic has to be unregistered.
-func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool {
- epsByNic.mu.Lock()
- defer epsByNic.mu.Unlock()
- multiPortEp, ok := epsByNic.endpoints[bindToDevice]
+// unregisterEndpoint returns true if endpointsByNIC has to be unregistered.
+func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool {
+ epsByNIC.mu.Lock()
+ defer epsByNIC.mu.Unlock()
+ multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
if !ok {
return false
}
if multiPortEp.unregisterEndpoint(t) {
- delete(epsByNic.endpoints, bindToDevice)
+ delete(epsByNIC.endpoints, bindToDevice)
}
- return len(epsByNic.endpoints) == 0
+ return len(epsByNIC.endpoints) == 0
}
// transportDemuxer demultiplexes packets targeted at a transport endpoint
@@ -198,7 +265,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
for proto := range stack.transportProtocols {
protoIDs := protocolIDs{netProto, proto}
d.protocol[protoIDs] = &transportEndpoints{
- endpoints: make(map[TransportEndpointID]*endpointsByNic),
+ endpoints: make(map[TransportEndpointID]*endpointsByNIC),
}
qTransProto, isQueued := (stack.transportProtocols[proto].proto).(queuedTransportProtocol)
if isQueued {
@@ -378,16 +445,16 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
eps.mu.Lock()
defer eps.mu.Unlock()
- epsByNic, ok := eps.endpoints[id]
+ epsByNIC, ok := eps.endpoints[id]
if !ok {
- epsByNic = &endpointsByNic{
+ epsByNIC = &endpointsByNIC{
endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
seed: rand.Uint32(),
}
- eps.endpoints[id] = epsByNic
+ eps.endpoints[id] = epsByNIC
}
- return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice)
+ return epsByNIC.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice)
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
@@ -413,7 +480,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// transport endpoints.
if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
eps.mu.RLock()
- destEPs := d.findAllEndpointsLocked(eps, id)
+ destEPs := eps.findAllEndpointsLocked(id)
eps.mu.RUnlock()
// Fail if we didn't find at least one matching transport endpoint.
if len(destEPs) == 0 {
@@ -439,7 +506,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
}
eps.mu.RLock()
- ep := d.findEndpointLocked(eps, id)
+ ep := eps.findEndpointLocked(id)
eps.mu.RUnlock()
if ep == nil {
if protocol == header.UDPProtocolNumber {
@@ -483,115 +550,47 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco
return false
}
- // Try to find the endpoint.
eps.mu.RLock()
- ep := d.findEndpointLocked(eps, id)
+ ep := eps.findEndpointLocked(id)
eps.mu.RUnlock()
-
- // Fail if we didn't find one.
if ep == nil {
return false
}
- // Deliver the packet.
ep.handleControlPacket(n, id, typ, extra, pkt)
-
return true
}
-// iterEndpointsLocked yields all endpointsByNic in eps that match id, in
-// descending order of match quality. If a call to yield returns false,
-// iterEndpointsLocked stops iteration and returns immediately.
-//
-// Preconditions: eps.mu must be locked.
-func (d *transportDemuxer) iterEndpointsLocked(eps *transportEndpoints, id TransportEndpointID, yield func(*endpointsByNic) bool) {
- // Try to find a match with the id as provided.
- if ep, ok := eps.endpoints[id]; ok {
- if !yield(ep) {
- return
- }
- }
-
- // Try to find a match with the id minus the local address.
- nid := id
-
- nid.LocalAddress = ""
- if ep, ok := eps.endpoints[nid]; ok {
- if !yield(ep) {
- return
- }
- }
-
- // Try to find a match with the id minus the remote part.
- nid.LocalAddress = id.LocalAddress
- nid.RemoteAddress = ""
- nid.RemotePort = 0
- if ep, ok := eps.endpoints[nid]; ok {
- if !yield(ep) {
- return
- }
- }
-
- // Try to find a match with only the local port.
- nid.LocalAddress = ""
- if ep, ok := eps.endpoints[nid]; ok {
- if !yield(ep) {
- return
- }
- }
-}
-
-func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic {
- var matchedEPs []*endpointsByNic
- d.iterEndpointsLocked(eps, id, func(ep *endpointsByNic) bool {
- matchedEPs = append(matchedEPs, ep)
- return true
- })
- return matchedEPs
-}
-
// findTransportEndpoint find a single endpoint that most closely matches the provided id.
func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint {
eps, ok := d.protocol[protocolIDs{netProto, transProto}]
if !ok {
return nil
}
- // Try to find the endpoint.
+
eps.mu.RLock()
- epsByNic := d.findEndpointLocked(eps, id)
- // Fail if we didn't find one.
- if epsByNic == nil {
+ epsByNIC := eps.findEndpointLocked(id)
+ if epsByNIC == nil {
eps.mu.RUnlock()
return nil
}
- epsByNic.mu.RLock()
+ epsByNIC.mu.RLock()
eps.mu.RUnlock()
- mpep, ok := epsByNic.endpoints[r.ref.nic.ID()]
+ mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()]
if !ok {
- if mpep, ok = epsByNic.endpoints[0]; !ok {
- epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ if mpep, ok = epsByNIC.endpoints[0]; !ok {
+ epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return nil
}
}
- ep := selectEndpoint(id, mpep, epsByNic.seed)
- epsByNic.mu.RUnlock()
+ ep := selectEndpoint(id, mpep, epsByNIC.seed)
+ epsByNIC.mu.RUnlock()
return ep
}
-// findEndpointLocked returns the endpoint that most closely matches the given
-// id.
-func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic {
- var matchedEP *endpointsByNic
- d.iterEndpointsLocked(eps, id, func(ep *endpointsByNic) bool {
- matchedEP = ep
- return false
- })
- return matchedEP
-}
-
// registerRawEndpoint registers the given endpoint with the dispatcher such
// that packets of the appropriate protocol are delivered to it. A single
// packet can be sent to one or more raw endpoints along with a non-raw
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 84311bcc8..c65b0c632 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -31,84 +31,58 @@ import (
)
const (
- stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ testSrcAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ testDstAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- stackAddr = "\x0a\x00\x00\x01"
- stackPort = 1234
- testPort = 4096
+ testSrcAddrV4 = "\x0a\x00\x00\x01"
+ testDstAddrV4 = "\x0a\x00\x00\x02"
+
+ testDstPort = 1234
+ testSrcPort = 4096
)
type testContext struct {
- t *testing.T
linkEps map[tcpip.NICID]*channel.Endpoint
s *stack.Stack
-
- ep tcpip.Endpoint
- wq waiter.Queue
-}
-
-func (c *testContext) cleanup() {
- if c.ep != nil {
- c.ep.Close()
- }
-}
-
-func (c *testContext) createV6Endpoint(v6only bool) {
- var err *tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
- }
+ wq waiter.Queue
}
// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs.
func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
linkEps := make(map[tcpip.NICID]*channel.Endpoint)
for _, linkEpID := range linkEpIDs {
channelEp := channel.New(256, mtu, "")
if err := s.CreateNIC(linkEpID, channelEp); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ t.Fatalf("CreateNIC failed: %s", err)
}
linkEps[linkEpID] = channelEp
- if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress IPv4 failed: %v", err)
+ if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil {
+ t.Fatalf("AddAddress IPv4 failed: %s", err)
}
- if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, stackV6Addr); err != nil {
- t.Fatalf("AddAddress IPv6 failed: %v", err)
+ if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil {
+ t.Fatalf("AddAddress IPv6 failed: %s", err)
}
}
s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: 1,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: 1,
- },
+ {Destination: header.IPv4EmptySubnet, NIC: 1},
+ {Destination: header.IPv6EmptySubnet, NIC: 1},
})
return &testContext{
- t: t,
s: s,
linkEps: linkEps,
}
}
type headers struct {
- srcPort uint16
- dstPort uint16
+ srcPort, dstPort uint16
}
func newPayload() []byte {
@@ -119,6 +93,47 @@ func newPayload() []byte {
return b
}
+func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
+ buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
+ payloadStart := len(buf) - len(payload)
+ copy(buf[payloadStart:], payload)
+
+ // Initialize the IP header.
+ ip := header.IPv4(buf)
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: 0x80,
+ TotalLength: uint16(len(buf)),
+ TTL: 65,
+ Protocol: uint8(udp.ProtocolNumber),
+ SrcAddr: testSrcAddrV4,
+ DstAddr: testDstAddrV4,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ // Initialize the UDP header.
+ u := header.UDP(buf[header.IPv4MinimumSize:])
+ u.Encode(&header.UDPFields{
+ SrcPort: h.srcPort,
+ DstPort: h.dstPort,
+ Length: uint16(header.UDPMinimumSize + len(payload)),
+ })
+
+ // Calculate the UDP pseudo-header checksum.
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV4, testDstAddrV4, uint16(len(u)))
+
+ // Calculate the UDP checksum and set it.
+ xsum = header.Checksum(payload, xsum)
+ u.SetChecksum(^u.CalculateChecksum(xsum))
+
+ // Inject packet.
+ c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ NetworkHeader: buffer.View(ip),
+ TransportHeader: buffer.View(u),
+ })
+}
+
func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
@@ -130,8 +145,8 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI
PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
NextHeader: uint8(udp.ProtocolNumber),
HopLimit: 65,
- SrcAddr: testV6Addr,
- DstAddr: stackV6Addr,
+ SrcAddr: testSrcAddrV6,
+ DstAddr: testDstAddrV6,
})
// Initialize the UDP header.
@@ -143,7 +158,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI
})
// Calculate the UDP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u)))
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV6, testDstAddrV6, uint16(len(u)))
// Calculate the UDP checksum and set it.
xsum = header.Checksum(payload, xsum)
@@ -151,7 +166,9 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI
// Inject packet.
c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{
- Data: buf.ToVectorisedView(),
+ Data: buf.ToVectorisedView(),
+ NetworkHeader: buffer.View(ip),
+ TransportHeader: buffer.View(u),
})
}
@@ -179,15 +196,15 @@ func TestTransportDemuxerRegister(t *testing.T) {
t.Fatalf("%T does not implement stack.TransportEndpoint", ep)
}
if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, false, 0), test.want; got != want {
- t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want)
+ t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want)
}
})
}
}
-// TestReuseBindToDevice injects varied packets on input devices and checks that
+// TestBindToDeviceDistribution injects varied packets on input devices and checks that
// the distribution of packets received matches expectations.
-func TestDistribution(t *testing.T) {
+func TestBindToDeviceDistribution(t *testing.T) {
type endpointSockopts struct {
reuse int
bindToDevice tcpip.NICID
@@ -196,19 +213,19 @@ func TestDistribution(t *testing.T) {
name string
// endpoints will received the inject packets.
endpoints []endpointSockopts
- // wantedDistribution is the wanted ratio of packets received on each
+ // wantDistributions is the want ratio of packets received on each
// endpoint for each NIC on which packets are injected.
- wantedDistributions map[tcpip.NICID][]float64
+ wantDistributions map[tcpip.NICID][]float64
}{
{
"BindPortReuse",
// 5 endpoints that all have reuse set.
[]endpointSockopts{
- {1, 0},
- {1, 0},
- {1, 0},
- {1, 0},
- {1, 0},
+ {reuse: 1, bindToDevice: 0},
+ {reuse: 1, bindToDevice: 0},
+ {reuse: 1, bindToDevice: 0},
+ {reuse: 1, bindToDevice: 0},
+ {reuse: 1, bindToDevice: 0},
},
map[tcpip.NICID][]float64{
// Injected packets on dev0 get distributed evenly.
@@ -219,9 +236,9 @@ func TestDistribution(t *testing.T) {
"BindToDevice",
// 3 endpoints with various bindings.
[]endpointSockopts{
- {0, 1},
- {0, 2},
- {0, 3},
+ {reuse: 0, bindToDevice: 1},
+ {reuse: 0, bindToDevice: 2},
+ {reuse: 0, bindToDevice: 3},
},
map[tcpip.NICID][]float64{
// Injected packets on dev0 go only to the endpoint bound to dev0.
@@ -236,12 +253,12 @@ func TestDistribution(t *testing.T) {
"ReuseAndBindToDevice",
// 6 endpoints with various bindings.
[]endpointSockopts{
- {1, 1},
- {1, 1},
- {1, 2},
- {1, 2},
- {1, 2},
- {1, 0},
+ {reuse: 1, bindToDevice: 1},
+ {reuse: 1, bindToDevice: 1},
+ {reuse: 1, bindToDevice: 2},
+ {reuse: 1, bindToDevice: 2},
+ {reuse: 1, bindToDevice: 2},
+ {reuse: 1, bindToDevice: 0},
},
map[tcpip.NICID][]float64{
// Injected packets on dev0 get distributed among endpoints bound to
@@ -255,17 +272,17 @@ func TestDistribution(t *testing.T) {
},
},
} {
- t.Run(test.name, func(t *testing.T) {
- for device, wantedDistribution := range test.wantedDistributions {
- t.Run(string(device), func(t *testing.T) {
+ for protoName, netProtoNum := range map[string]tcpip.NetworkProtocolNumber{
+ "IPv4": ipv4.ProtocolNumber,
+ "IPv6": ipv6.ProtocolNumber,
+ } {
+ for device, wantDistribution := range test.wantDistributions {
+ t.Run(test.name+protoName+string(device), func(t *testing.T) {
var devices []tcpip.NICID
- for d := range test.wantedDistributions {
+ for d := range test.wantDistributions {
devices = append(devices, d)
}
c := newDualTestContextMultiNIC(t, defaultMTU, devices)
- defer c.cleanup()
-
- c.createV6Endpoint(false)
eps := make(map[tcpip.Endpoint]int)
@@ -279,9 +296,9 @@ func TestDistribution(t *testing.T) {
defer close(ch)
var err *tcpip.Error
- ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
+ ep, err := c.s.NewEndpoint(udp.ProtocolNumber, netProtoNum, &wq)
if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
eps[ep] = i
@@ -294,20 +311,30 @@ func TestDistribution(t *testing.T) {
defer ep.Close()
reusePortOption := tcpip.ReusePortOption(endpoint.reuse)
if err := ep.SetSockOpt(reusePortOption); err != nil {
- c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err)
+ t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", reusePortOption, i, err)
}
bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
if err := ep.SetSockOpt(bindToDeviceOption); err != nil {
- c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err)
+ t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", bindToDeviceOption, i, err)
+ }
+
+ var dstAddr tcpip.Address
+ switch netProtoNum {
+ case ipv4.ProtocolNumber:
+ dstAddr = testDstAddrV4
+ case ipv6.ProtocolNumber:
+ dstAddr = testDstAddrV6
+ default:
+ t.Fatalf("unexpected protocol number: %d", netProtoNum)
}
- if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil {
- t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err)
+ if err := ep.Bind(tcpip.FullAddress{Addr: dstAddr, Port: testDstPort}); err != nil {
+ t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err)
}
}
npackets := 100000
nports := 10000
- if got, want := len(test.endpoints), len(wantedDistribution); got != want {
+ if got, want := len(test.endpoints), len(wantDistribution); got != want {
t.Fatalf("got len(test.endpoints) = %d, want %d", got, want)
}
ports := make(map[uint16]tcpip.Endpoint)
@@ -316,17 +343,22 @@ func TestDistribution(t *testing.T) {
// Send a packet.
port := uint16(i % nports)
payload := newPayload()
- c.sendV6Packet(payload,
- &headers{
- srcPort: testPort + port,
- dstPort: stackPort},
- device)
+ hdrs := &headers{
+ srcPort: testSrcPort + port,
+ dstPort: testDstPort,
+ }
+ switch netProtoNum {
+ case ipv4.ProtocolNumber:
+ c.sendV4Packet(payload, hdrs, device)
+ case ipv6.ProtocolNumber:
+ c.sendV6Packet(payload, hdrs, device)
+ default:
+ t.Fatalf("unexpected protocol number: %d", netProtoNum)
+ }
- var addr tcpip.FullAddress
ep := <-pollChannel
- _, _, err := ep.Read(&addr)
- if err != nil {
- c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err)
+ if _, _, err := ep.Read(nil); err != nil {
+ t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err)
}
stats[ep]++
if i < nports {
@@ -342,17 +374,17 @@ func TestDistribution(t *testing.T) {
// Check that a packet distribution is as expected.
for ep, i := range eps {
- wantedRatio := wantedDistribution[i]
- wantedRecv := wantedRatio * float64(npackets)
+ wantRatio := wantDistribution[i]
+ wantRecv := wantRatio * float64(npackets)
actualRecv := stats[ep]
actualRatio := float64(stats[ep]) / float64(npackets)
// The deviation is less than 10%.
- if math.Abs(actualRatio-wantedRatio) > 0.05 {
- t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets)
+ if math.Abs(actualRatio-wantRatio) > 0.05 {
+ t.Errorf("want about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantRatio*100, wantRecv, npackets, i, actualRatio*100, actualRecv, npackets)
}
}
})
}
- })
+ }
}
}