diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/abi/linux/file.go | 5 | ||||
-rw-r--r-- | pkg/sentry/fsimpl/ext/inode.go | 2 | ||||
-rw-r--r-- | pkg/sentry/fsimpl/gofer/filesystem.go | 22 | ||||
-rw-r--r-- | pkg/sentry/fsimpl/gofer/gofer.go | 7 | ||||
-rw-r--r-- | pkg/sentry/fsimpl/host/host.go | 6 | ||||
-rw-r--r-- | pkg/sentry/fsimpl/kernfs/fd_impl_util.go | 3 | ||||
-rw-r--r-- | pkg/sentry/fsimpl/kernfs/inode_impl_util.go | 6 | ||||
-rw-r--r-- | pkg/sentry/fsimpl/proc/task.go | 9 | ||||
-rw-r--r-- | pkg/sentry/fsimpl/tmpfs/filesystem.go | 24 | ||||
-rw-r--r-- | pkg/sentry/fsimpl/tmpfs/tmpfs.go | 8 | ||||
-rw-r--r-- | pkg/sentry/platform/ring0/entry_arm64.s | 18 | ||||
-rw-r--r-- | pkg/sentry/vfs/anonfs.go | 2 | ||||
-rw-r--r-- | pkg/sentry/vfs/permissions.go | 23 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 239 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer_test.go | 222 |
15 files changed, 320 insertions, 276 deletions
diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go index dbe58acbe..055ac1d7c 100644 --- a/pkg/abi/linux/file.go +++ b/pkg/abi/linux/file.go @@ -287,6 +287,11 @@ func (m FileMode) ExtraBits() FileMode { return m &^ (PermissionsMask | FileTypeMask) } +// IsDir returns true if file type represents a directory. +func (m FileMode) IsDir() bool { + return m.FileType() == S_IFDIR +} + // String returns a string representation of m. func (m FileMode) String() string { var s []string diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go index 6962083f5..a39a37318 100644 --- a/pkg/sentry/fsimpl/ext/inode.go +++ b/pkg/sentry/fsimpl/ext/inode.go @@ -186,7 +186,7 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts *vfs.OpenOpt } func (in *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error { - return vfs.GenericCheckPermissions(creds, ats, in.isDir(), uint16(in.diskInode.Mode()), in.diskInode.UID(), in.diskInode.GID()) + return vfs.GenericCheckPermissions(creds, ats, in.diskInode.Mode(), in.diskInode.UID(), in.diskInode.GID()) } // statTo writes the statx fields to the output parameter. diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 26b492185..1e43df9ec 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -119,7 +119,7 @@ func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d * if !d.isDir() { return nil, syserror.ENOTDIR } - if err := d.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil { + if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { return nil, err } afterSymlink: @@ -314,7 +314,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if err != nil { return err } - if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil { + if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { return err } if parent.isDeleted() { @@ -378,7 +378,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b if err != nil { return err } - if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil { + if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { return err } if err := rp.Mount().CheckBeginWrite(); err != nil { @@ -512,7 +512,7 @@ func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds if err != nil { return err } - return d.checkPermissions(creds, ats, d.isDir()) + return d.checkPermissions(creds, ats) } // GetDentryAt implements vfs.FilesystemImpl.GetDentryAt. @@ -528,7 +528,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op if !d.isDir() { return nil, syserror.ENOTDIR } - if err := d.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil { + if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { return nil, err } } @@ -624,7 +624,7 @@ afterTrailingSymlink: return nil, err } // Check for search permission in the parent directory. - if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil { + if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { return nil, err } // Determine whether or not we need to create a file. @@ -661,7 +661,7 @@ afterTrailingSymlink: // Preconditions: fs.renameMu must be locked. func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { ats := vfs.AccessTypesForOpenFlags(opts) - if err := d.checkPermissions(rp.Credentials(), ats, d.isDir()); err != nil { + if err := d.checkPermissions(rp.Credentials(), ats); err != nil { return nil, err } mnt := rp.Mount() @@ -722,7 +722,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf // Preconditions: d.fs.renameMu must be locked. d.dirMu must be locked. func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { - if err := d.checkPermissions(rp.Credentials(), vfs.MayWrite, true); err != nil { + if err := d.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil { return nil, err } if d.isDeleted() { @@ -884,7 +884,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa return err } } - if err := oldParent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil { + if err := oldParent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { return err } vfsObj := rp.VirtualFilesystem() @@ -904,7 +904,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa return syserror.EINVAL } if oldParent != newParent { - if err := renamed.checkPermissions(rp.Credentials(), vfs.MayWrite, true); err != nil { + if err := renamed.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil { return err } } @@ -915,7 +915,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa } if oldParent != newParent { - if err := newParent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil { + if err := newParent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { return err } newParent.dirMu.Lock() diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 13928ce36..cf276a417 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -721,7 +721,8 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin if stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_SIZE) != 0 { return syserror.EPERM } - if err := vfs.CheckSetStat(ctx, creds, stat, uint16(atomic.LoadUint32(&d.mode))&^linux.S_IFMT, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { + mode := linux.FileMode(atomic.LoadUint32(&d.mode)) + if err := vfs.CheckSetStat(ctx, creds, stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { return err } if err := mnt.CheckBeginWrite(); err != nil { @@ -843,8 +844,8 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin return nil } -func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes, isDir bool) error { - return vfs.GenericCheckPermissions(creds, ats, isDir, uint16(atomic.LoadUint32(&d.mode))&0777, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))) +func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error { + return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))) } // IncRef implements vfs.DentryImpl.IncRef. diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 1f735628f..a54985ef5 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -167,8 +167,8 @@ func fileFlagsFromHostFD(fd int) (int, error) { } // CheckPermissions implements kernfs.Inode. -func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, atx vfs.AccessTypes) error { - return vfs.GenericCheckPermissions(creds, atx, false /* isDir */, uint16(i.mode), i.uid, i.gid) +func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error { + return vfs.GenericCheckPermissions(creds, ats, i.mode, i.uid, i.gid) } // Mode implements kernfs.Inode. @@ -306,7 +306,7 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre if m&^(linux.STATX_MODE|linux.STATX_SIZE|linux.STATX_ATIME|linux.STATX_MTIME) != 0 { return syserror.EPERM } - if err := vfs.CheckSetStat(ctx, creds, &s, uint16(i.Mode().Permissions()), i.uid, i.gid); err != nil { + if err := vfs.CheckSetStat(ctx, creds, &s, i.Mode(), i.uid, i.gid); err != nil { return err } diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go index 75c4bab1a..bfa786c88 100644 --- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go @@ -206,8 +206,7 @@ func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (l // SetStat implements vfs.FileDescriptionImpl.SetStat. func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { - fs := fd.filesystem() creds := auth.CredentialsFromContext(ctx) inode := fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode - return inode.SetStat(ctx, fs, creds, opts) + return inode.SetStat(ctx, fd.filesystem(), creds, opts) } diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index c612dcf07..5c84b10c9 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -241,7 +241,7 @@ func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *aut if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID) != 0 { return syserror.EPERM } - if err := vfs.CheckSetStat(ctx, creds, &opts.Stat, uint16(a.Mode().Permissions()), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil { + if err := vfs.CheckSetStat(ctx, creds, &opts.Stat, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil { return err } @@ -273,12 +273,10 @@ func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *aut // CheckPermissions implements Inode.CheckPermissions. func (a *InodeAttrs) CheckPermissions(_ context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error { - mode := a.Mode() return vfs.GenericCheckPermissions( creds, ats, - mode.FileType() == linux.ModeDirectory, - uint16(mode), + a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid)), ) diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index 49d6efb0e..aee2a4392 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -172,14 +172,7 @@ func (i *taskOwnedInode) Stat(fs *vfs.Filesystem, opts vfs.StatOptions) (linux.S func (i *taskOwnedInode) CheckPermissions(_ context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error { mode := i.Mode() uid, gid := i.getOwner(mode) - return vfs.GenericCheckPermissions( - creds, - ats, - mode.FileType() == linux.ModeDirectory, - uint16(mode), - uid, - gid, - ) + return vfs.GenericCheckPermissions(creds, ats, mode, uid, gid) } func (i *taskOwnedInode) getOwner(mode linux.FileMode) (auth.KUID, auth.KGID) { diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index 75d01b853..12cc64385 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -41,7 +41,7 @@ func stepLocked(rp *vfs.ResolvingPath, d *dentry) (*dentry, error) { if !d.inode.isDir() { return nil, syserror.ENOTDIR } - if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil { + if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { return nil, err } afterSymlink: @@ -125,7 +125,7 @@ func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(pa if err != nil { return err } - if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil { + if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { return err } name := rp.Component() @@ -163,7 +163,7 @@ func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds if err != nil { return err } - return d.inode.checkPermissions(creds, ats, d.inode.isDir()) + return d.inode.checkPermissions(creds, ats) } // GetDentryAt implements vfs.FilesystemImpl.GetDentryAt. @@ -178,7 +178,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op if !d.inode.isDir() { return nil, syserror.ENOTDIR } - if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayExec, true /* isDir */); err != nil { + if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { return nil, err } } @@ -301,7 +301,7 @@ afterTrailingSymlink: return nil, err } // Check for search permission in the parent directory. - if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil { + if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { return nil, err } // Reject attempts to open directories with O_CREAT. @@ -316,7 +316,7 @@ afterTrailingSymlink: child, err := stepLocked(rp, parent) if err == syserror.ENOENT { // Already checked for searchability above; now check for writability. - if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite, true); err != nil { + if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil { return nil, err } if err := rp.Mount().CheckBeginWrite(); err != nil { @@ -347,7 +347,7 @@ afterTrailingSymlink: func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions, afterCreate bool) (*vfs.FileDescription, error) { ats := vfs.AccessTypesForOpenFlags(opts) if !afterCreate { - if err := d.inode.checkPermissions(rp.Credentials(), ats, d.inode.isDir()); err != nil { + if err := d.inode.checkPermissions(rp.Credentials(), ats); err != nil { return nil, err } } @@ -428,7 +428,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa defer mnt.EndWrite() oldParent := oldParentVD.Dentry().Impl().(*dentry) - if err := oldParent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil { + if err := oldParent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { return err } // Call vfs.Dentry.Child() instead of stepLocked() or rp.ResolveChild(), @@ -445,7 +445,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa } if oldParent != newParent { // Writability is needed to change renamed's "..". - if err := renamed.inode.checkPermissions(rp.Credentials(), vfs.MayWrite, true /* isDir */); err != nil { + if err := renamed.inode.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil { return err } } @@ -455,7 +455,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa } } - if err := newParent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil { + if err := newParent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { return err } replacedVFSD := newParent.vfsd.Child(newName) @@ -528,7 +528,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error if err != nil { return err } - if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil { + if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { return err } name := rp.Component() @@ -621,7 +621,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error if err != nil { return err } - if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil { + if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { return err } name := rp.Component() diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index 2d5070a46..2f9e6c876 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -245,8 +245,9 @@ func (i *inode) decRef() { } } -func (i *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes, isDir bool) error { - return vfs.GenericCheckPermissions(creds, ats, isDir, uint16(atomic.LoadUint32(&i.mode)), auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))) +func (i *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error { + mode := linux.FileMode(atomic.LoadUint32(&i.mode)) + return vfs.GenericCheckPermissions(creds, ats, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))) } // Go won't inline this function, and returning linux.Statx (which is quite @@ -299,7 +300,8 @@ func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linu if stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME|linux.STATX_SIZE) != 0 { return syserror.EPERM } - if err := vfs.CheckSetStat(ctx, creds, stat, uint16(atomic.LoadUint32(&i.mode))&^linux.S_IFMT, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil { + mode := linux.FileMode(atomic.LoadUint32(&i.mode)) + if err := vfs.CheckSetStat(ctx, creds, stat, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil { return err } i.mu.Lock() diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s index d42eda37b..db6465663 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -394,6 +394,8 @@ TEXT ·Current(SB),NOSPLIT,$0-8 #define STACK_FRAME_SIZE 16 +// kernelExitToEl0 is the entrypoint for application in guest_el0. +// Prepare the vcpu environment for container application. TEXT ·kernelExitToEl0(SB),NOSPLIT,$0 // Step1, save sentry context into memory. REGISTERS_SAVE(RSV_REG, CPU_REGISTERS) @@ -464,7 +466,23 @@ TEXT ·kernelExitToEl0(SB),NOSPLIT,$0 ERET() +// kernelExitToEl1 is the entrypoint for sentry in guest_el1. +// Prepare the vcpu environment for sentry. TEXT ·kernelExitToEl1(SB),NOSPLIT,$0 + WORD $0xd538d092 //MRS TPIDR_EL1, R18 + + MOVD CPU_REGISTERS+PTRACE_PSTATE(RSV_REG), R1 + WORD $0xd5184001 //MSR R1, SPSR_EL1 + + MOVD CPU_REGISTERS+PTRACE_PC(RSV_REG), R1 + MSR R1, ELR_EL1 + + MOVD CPU_REGISTERS+PTRACE_SP(RSV_REG), R1 + MOVD R1, RSP + + REGISTERS_LOAD(RSV_REG, CPU_REGISTERS) + MOVD CPU_REGISTERS+PTRACE_R9(RSV_REG), RSV_REG_APP + ERET() // Start is the CPU entrypoint. diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go index 925996517..a62e43589 100644 --- a/pkg/sentry/vfs/anonfs.go +++ b/pkg/sentry/vfs/anonfs.go @@ -83,7 +83,7 @@ func (fs *anonFilesystem) AccessAt(ctx context.Context, rp *ResolvingPath, creds if !rp.Done() { return syserror.ENOTDIR } - return GenericCheckPermissions(creds, ats, false /* isDir */, anonFileMode, anonFileUID, anonFileGID) + return GenericCheckPermissions(creds, ats, anonFileMode, anonFileUID, anonFileGID) } // GetDentryAt implements FilesystemImpl.GetDentryAt. diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go index 2c8f23f55..f9647f90e 100644 --- a/pkg/sentry/vfs/permissions.go +++ b/pkg/sentry/vfs/permissions.go @@ -29,9 +29,9 @@ type AccessTypes uint16 // Bits in AccessTypes. const ( + MayExec AccessTypes = 1 + MayWrite AccessTypes = 2 MayRead AccessTypes = 4 - MayWrite = 2 - MayExec = 1 ) // OnlyRead returns true if access _only_ allows read. @@ -56,16 +56,17 @@ func (a AccessTypes) MayExec() bool { // GenericCheckPermissions checks that creds has the given access rights on a // file with the given permissions, UID, and GID, subject to the rules of -// fs/namei.c:generic_permission(). isDir is true if the file is a directory. -func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, isDir bool, mode uint16, kuid auth.KUID, kgid auth.KGID) error { +// fs/namei.c:generic_permission(). +func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error { // Check permission bits. - perms := mode + perms := uint16(mode.Permissions()) if creds.EffectiveKUID == kuid { perms >>= 6 } else if creds.InGroup(kgid) { perms >>= 3 } if uint16(ats)&perms == uint16(ats) { + // All permission bits match, access granted. return nil } @@ -77,7 +78,7 @@ func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, isDir boo } // CAP_DAC_READ_SEARCH allows the caller to read and search arbitrary // directories, and read arbitrary non-directory files. - if (isDir && !ats.MayWrite()) || ats.OnlyRead() { + if (mode.IsDir() && !ats.MayWrite()) || ats.OnlyRead() { if creds.HasCapability(linux.CAP_DAC_READ_SEARCH) { return nil } @@ -85,7 +86,7 @@ func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, isDir boo // CAP_DAC_OVERRIDE allows arbitrary access to directories, read/write // access to non-directory files, and execute access to non-directory files // for which at least one execute bit is set. - if isDir || !ats.MayExec() || (mode&0111 != 0) { + if mode.IsDir() || !ats.MayExec() || (mode.Permissions()&0111 != 0) { if creds.HasCapability(linux.CAP_DAC_OVERRIDE) { return nil } @@ -151,7 +152,7 @@ func MayWriteFileWithOpenFlags(flags uint32) bool { // CheckSetStat checks that creds has permission to change the metadata of a // file with the given permissions, UID, and GID as specified by stat, subject // to the rules of Linux's fs/attr.c:setattr_prepare(). -func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mode uint16, kuid auth.KUID, kgid auth.KGID) error { +func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error { if stat.Mask&linux.STATX_SIZE != 0 { limit, err := CheckLimit(ctx, 0, int64(stat.Size)) if err != nil { @@ -190,11 +191,7 @@ func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Stat (stat.Mask&linux.STATX_CTIME != 0 && stat.Ctime.Nsec != linux.UTIME_NOW) { return syserror.EPERM } - // isDir is irrelevant in the following call to - // GenericCheckPermissions since ats == MayWrite means that - // CAP_DAC_READ_SEARCH does not apply, and CAP_DAC_OVERRIDE - // applies, regardless of isDir. - if err := GenericCheckPermissions(creds, MayWrite, false /* isDir */, mode, kuid, kgid); err != nil { + if err := GenericCheckPermissions(creds, MayWrite, mode, kuid, kgid); err != nil { return err } } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index c55e3e8bc..9a33ed375 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -35,7 +35,7 @@ type protocolIDs struct { type transportEndpoints struct { // mu protects all fields of the transportEndpoints. mu sync.RWMutex - endpoints map[TransportEndpointID]*endpointsByNic + endpoints map[TransportEndpointID]*endpointsByNIC // rawEndpoints contains endpoints for raw sockets, which receive all // traffic of a given protocol regardless of port. rawEndpoints []RawTransportEndpoint @@ -46,11 +46,11 @@ type transportEndpoints struct { func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { eps.mu.Lock() defer eps.mu.Unlock() - epsByNic, ok := eps.endpoints[id] + epsByNIC, ok := eps.endpoints[id] if !ok { return } - if !epsByNic.unregisterEndpoint(bindToDevice, ep) { + if !epsByNIC.unregisterEndpoint(bindToDevice, ep) { return } delete(eps.endpoints, id) @@ -66,18 +66,85 @@ func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint { return es } -type endpointsByNic struct { +// iterEndpointsLocked yields all endpointsByNIC in eps that match id, in +// descending order of match quality. If a call to yield returns false, +// iterEndpointsLocked stops iteration and returns immediately. +// +// Preconditions: eps.mu must be locked. +func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) { + // Try to find a match with the id as provided. + if ep, ok := eps.endpoints[id]; ok { + if !yield(ep) { + return + } + } + + // Try to find a match with the id minus the local address. + nid := id + + nid.LocalAddress = "" + if ep, ok := eps.endpoints[nid]; ok { + if !yield(ep) { + return + } + } + + // Try to find a match with the id minus the remote part. + nid.LocalAddress = id.LocalAddress + nid.RemoteAddress = "" + nid.RemotePort = 0 + if ep, ok := eps.endpoints[nid]; ok { + if !yield(ep) { + return + } + } + + // Try to find a match with only the local port. + nid.LocalAddress = "" + if ep, ok := eps.endpoints[nid]; ok { + if !yield(ep) { + return + } + } +} + +// findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in +// descending order of match quality. +// +// Preconditions: eps.mu must be locked. +func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC { + var matchedEPs []*endpointsByNIC + eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool { + matchedEPs = append(matchedEPs, ep) + return true + }) + return matchedEPs +} + +// findEndpointLocked returns the endpoint that most closely matches the given id. +// +// Preconditions: eps.mu must be locked. +func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC { + var matchedEP *endpointsByNIC + eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool { + matchedEP = ep + return false + }) + return matchedEP +} + +type endpointsByNIC struct { mu sync.RWMutex endpoints map[tcpip.NICID]*multiPortEndpoint // seed is a random secret for a jenkins hash. seed uint32 } -func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint { - epsByNic.mu.RLock() - defer epsByNic.mu.RUnlock() +func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { + epsByNIC.mu.RLock() + defer epsByNIC.mu.RUnlock() var eps []TransportEndpoint - for _, ep := range epsByNic.endpoints { + for _, ep := range epsByNIC.endpoints { eps = append(eps, ep.transportEndpoints()...) } return eps @@ -85,13 +152,13 @@ func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) { - epsByNic.mu.RLock() +func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) { + epsByNIC.mu.RLock() - mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] + mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()] if !ok { - if mpep, ok = epsByNic.endpoints[0]; !ok { - epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + if mpep, ok = epsByNIC.endpoints[0]; !ok { + epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return } } @@ -100,29 +167,29 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, p // endpoints bound to the right device. if isMulticastOrBroadcast(id.LocalAddress) { mpep.handlePacketAll(r, id, pkt) - epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return } // multiPortEndpoints are guaranteed to have at least one element. - transEP := selectEndpoint(id, mpep, epsByNic.seed) + transEP := selectEndpoint(id, mpep, epsByNIC.seed) if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { queuedProtocol.QueuePacket(r, transEP, id, pkt) - epsByNic.mu.RUnlock() + epsByNIC.mu.RUnlock() return } transEP.HandlePacket(r, id, pkt) - epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) { - epsByNic.mu.RLock() - defer epsByNic.mu.RUnlock() +func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) { + epsByNIC.mu.RLock() + defer epsByNIC.mu.RUnlock() - mpep, ok := epsByNic.endpoints[n.ID()] + mpep, ok := epsByNIC.endpoints[n.ID()] if !ok { - mpep, ok = epsByNic.endpoints[0] + mpep, ok = epsByNIC.endpoints[0] } if !ok { return @@ -132,16 +199,16 @@ func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpoint // broadcast like we are doing with handlePacket above? // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, pkt) + selectEndpoint(id, mpep, epsByNIC.seed).HandleControlPacket(id, typ, extra, pkt) } // registerEndpoint returns true if it succeeds. It fails and returns // false if ep already has an element with the same key. -func (epsByNic *endpointsByNic) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { - epsByNic.mu.Lock() - defer epsByNic.mu.Unlock() +func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { + epsByNIC.mu.Lock() + defer epsByNIC.mu.Unlock() - multiPortEp, ok := epsByNic.endpoints[bindToDevice] + multiPortEp, ok := epsByNIC.endpoints[bindToDevice] if !ok { multiPortEp = &multiPortEndpoint{ demux: d, @@ -149,24 +216,24 @@ func (epsByNic *endpointsByNic) registerEndpoint(d *transportDemuxer, netProto t transProto: transProto, reuse: reusePort, } - epsByNic.endpoints[bindToDevice] = multiPortEp + epsByNIC.endpoints[bindToDevice] = multiPortEp } return multiPortEp.singleRegisterEndpoint(t, reusePort) } -// unregisterEndpoint returns true if endpointsByNic has to be unregistered. -func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool { - epsByNic.mu.Lock() - defer epsByNic.mu.Unlock() - multiPortEp, ok := epsByNic.endpoints[bindToDevice] +// unregisterEndpoint returns true if endpointsByNIC has to be unregistered. +func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool { + epsByNIC.mu.Lock() + defer epsByNIC.mu.Unlock() + multiPortEp, ok := epsByNIC.endpoints[bindToDevice] if !ok { return false } if multiPortEp.unregisterEndpoint(t) { - delete(epsByNic.endpoints, bindToDevice) + delete(epsByNIC.endpoints, bindToDevice) } - return len(epsByNic.endpoints) == 0 + return len(epsByNIC.endpoints) == 0 } // transportDemuxer demultiplexes packets targeted at a transport endpoint @@ -198,7 +265,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { for proto := range stack.transportProtocols { protoIDs := protocolIDs{netProto, proto} d.protocol[protoIDs] = &transportEndpoints{ - endpoints: make(map[TransportEndpointID]*endpointsByNic), + endpoints: make(map[TransportEndpointID]*endpointsByNIC), } qTransProto, isQueued := (stack.transportProtocols[proto].proto).(queuedTransportProtocol) if isQueued { @@ -378,16 +445,16 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol eps.mu.Lock() defer eps.mu.Unlock() - epsByNic, ok := eps.endpoints[id] + epsByNIC, ok := eps.endpoints[id] if !ok { - epsByNic = &endpointsByNic{ + epsByNIC = &endpointsByNIC{ endpoints: make(map[tcpip.NICID]*multiPortEndpoint), seed: rand.Uint32(), } - eps.endpoints[id] = epsByNic + eps.endpoints[id] = epsByNIC } - return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) + return epsByNIC.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it @@ -413,7 +480,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // transport endpoints. if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { eps.mu.RLock() - destEPs := d.findAllEndpointsLocked(eps, id) + destEPs := eps.findAllEndpointsLocked(id) eps.mu.RUnlock() // Fail if we didn't find at least one matching transport endpoint. if len(destEPs) == 0 { @@ -439,7 +506,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto } eps.mu.RLock() - ep := d.findEndpointLocked(eps, id) + ep := eps.findEndpointLocked(id) eps.mu.RUnlock() if ep == nil { if protocol == header.UDPProtocolNumber { @@ -483,115 +550,47 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco return false } - // Try to find the endpoint. eps.mu.RLock() - ep := d.findEndpointLocked(eps, id) + ep := eps.findEndpointLocked(id) eps.mu.RUnlock() - - // Fail if we didn't find one. if ep == nil { return false } - // Deliver the packet. ep.handleControlPacket(n, id, typ, extra, pkt) - return true } -// iterEndpointsLocked yields all endpointsByNic in eps that match id, in -// descending order of match quality. If a call to yield returns false, -// iterEndpointsLocked stops iteration and returns immediately. -// -// Preconditions: eps.mu must be locked. -func (d *transportDemuxer) iterEndpointsLocked(eps *transportEndpoints, id TransportEndpointID, yield func(*endpointsByNic) bool) { - // Try to find a match with the id as provided. - if ep, ok := eps.endpoints[id]; ok { - if !yield(ep) { - return - } - } - - // Try to find a match with the id minus the local address. - nid := id - - nid.LocalAddress = "" - if ep, ok := eps.endpoints[nid]; ok { - if !yield(ep) { - return - } - } - - // Try to find a match with the id minus the remote part. - nid.LocalAddress = id.LocalAddress - nid.RemoteAddress = "" - nid.RemotePort = 0 - if ep, ok := eps.endpoints[nid]; ok { - if !yield(ep) { - return - } - } - - // Try to find a match with only the local port. - nid.LocalAddress = "" - if ep, ok := eps.endpoints[nid]; ok { - if !yield(ep) { - return - } - } -} - -func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic { - var matchedEPs []*endpointsByNic - d.iterEndpointsLocked(eps, id, func(ep *endpointsByNic) bool { - matchedEPs = append(matchedEPs, ep) - return true - }) - return matchedEPs -} - // findTransportEndpoint find a single endpoint that most closely matches the provided id. func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { eps, ok := d.protocol[protocolIDs{netProto, transProto}] if !ok { return nil } - // Try to find the endpoint. + eps.mu.RLock() - epsByNic := d.findEndpointLocked(eps, id) - // Fail if we didn't find one. - if epsByNic == nil { + epsByNIC := eps.findEndpointLocked(id) + if epsByNIC == nil { eps.mu.RUnlock() return nil } - epsByNic.mu.RLock() + epsByNIC.mu.RLock() eps.mu.RUnlock() - mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] + mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()] if !ok { - if mpep, ok = epsByNic.endpoints[0]; !ok { - epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + if mpep, ok = epsByNIC.endpoints[0]; !ok { + epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return nil } } - ep := selectEndpoint(id, mpep, epsByNic.seed) - epsByNic.mu.RUnlock() + ep := selectEndpoint(id, mpep, epsByNIC.seed) + epsByNIC.mu.RUnlock() return ep } -// findEndpointLocked returns the endpoint that most closely matches the given -// id. -func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic { - var matchedEP *endpointsByNic - d.iterEndpointsLocked(eps, id, func(ep *endpointsByNic) bool { - matchedEP = ep - return false - }) - return matchedEP -} - // registerRawEndpoint registers the given endpoint with the dispatcher such // that packets of the appropriate protocol are delivered to it. A single // packet can be sent to one or more raw endpoints along with a non-raw diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 84311bcc8..c65b0c632 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -31,84 +31,58 @@ import ( ) const ( - stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + testSrcAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + testDstAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - stackAddr = "\x0a\x00\x00\x01" - stackPort = 1234 - testPort = 4096 + testSrcAddrV4 = "\x0a\x00\x00\x01" + testDstAddrV4 = "\x0a\x00\x00\x02" + + testDstPort = 1234 + testSrcPort = 4096 ) type testContext struct { - t *testing.T linkEps map[tcpip.NICID]*channel.Endpoint s *stack.Stack - - ep tcpip.Endpoint - wq waiter.Queue -} - -func (c *testContext) cleanup() { - if c.ep != nil { - c.ep.Close() - } -} - -func (c *testContext) createV6Endpoint(v6only bool) { - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } + wq waiter.Queue } // newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs. func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) linkEps := make(map[tcpip.NICID]*channel.Endpoint) for _, linkEpID := range linkEpIDs { channelEp := channel.New(256, mtu, "") if err := s.CreateNIC(linkEpID, channelEp); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatalf("CreateNIC failed: %s", err) } linkEps[linkEpID] = channelEp - if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress IPv4 failed: %v", err) + if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil { + t.Fatalf("AddAddress IPv4 failed: %s", err) } - if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, stackV6Addr); err != nil { - t.Fatalf("AddAddress IPv6 failed: %v", err) + if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil { + t.Fatalf("AddAddress IPv6 failed: %s", err) } } s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: 1, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: 1, - }, + {Destination: header.IPv4EmptySubnet, NIC: 1}, + {Destination: header.IPv6EmptySubnet, NIC: 1}, }) return &testContext{ - t: t, s: s, linkEps: linkEps, } } type headers struct { - srcPort uint16 - dstPort uint16 + srcPort, dstPort uint16 } func newPayload() []byte { @@ -119,6 +93,47 @@ func newPayload() []byte { return b } +func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { + buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) + payloadStart := len(buf) - len(payload) + copy(buf[payloadStart:], payload) + + // Initialize the IP header. + ip := header.IPv4(buf) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: 0x80, + TotalLength: uint16(len(buf)), + TTL: 65, + Protocol: uint8(udp.ProtocolNumber), + SrcAddr: testSrcAddrV4, + DstAddr: testDstAddrV4, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + // Initialize the UDP header. + u := header.UDP(buf[header.IPv4MinimumSize:]) + u.Encode(&header.UDPFields{ + SrcPort: h.srcPort, + DstPort: h.dstPort, + Length: uint16(header.UDPMinimumSize + len(payload)), + }) + + // Calculate the UDP pseudo-header checksum. + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV4, testDstAddrV4, uint16(len(u))) + + // Calculate the UDP checksum and set it. + xsum = header.Checksum(payload, xsum) + u.SetChecksum(^u.CalculateChecksum(xsum)) + + // Inject packet. + c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + NetworkHeader: buffer.View(ip), + TransportHeader: buffer.View(u), + }) +} + func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) @@ -130,8 +145,8 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI PayloadLength: uint16(header.UDPMinimumSize + len(payload)), NextHeader: uint8(udp.ProtocolNumber), HopLimit: 65, - SrcAddr: testV6Addr, - DstAddr: stackV6Addr, + SrcAddr: testSrcAddrV6, + DstAddr: testDstAddrV6, }) // Initialize the UDP header. @@ -143,7 +158,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI }) // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u))) + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV6, testDstAddrV6, uint16(len(u))) // Calculate the UDP checksum and set it. xsum = header.Checksum(payload, xsum) @@ -151,7 +166,9 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI // Inject packet. c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{ - Data: buf.ToVectorisedView(), + Data: buf.ToVectorisedView(), + NetworkHeader: buffer.View(ip), + TransportHeader: buffer.View(u), }) } @@ -179,15 +196,15 @@ func TestTransportDemuxerRegister(t *testing.T) { t.Fatalf("%T does not implement stack.TransportEndpoint", ep) } if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, false, 0), test.want; got != want { - t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want) + t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want) } }) } } -// TestReuseBindToDevice injects varied packets on input devices and checks that +// TestBindToDeviceDistribution injects varied packets on input devices and checks that // the distribution of packets received matches expectations. -func TestDistribution(t *testing.T) { +func TestBindToDeviceDistribution(t *testing.T) { type endpointSockopts struct { reuse int bindToDevice tcpip.NICID @@ -196,19 +213,19 @@ func TestDistribution(t *testing.T) { name string // endpoints will received the inject packets. endpoints []endpointSockopts - // wantedDistribution is the wanted ratio of packets received on each + // wantDistributions is the want ratio of packets received on each // endpoint for each NIC on which packets are injected. - wantedDistributions map[tcpip.NICID][]float64 + wantDistributions map[tcpip.NICID][]float64 }{ { "BindPortReuse", // 5 endpoints that all have reuse set. []endpointSockopts{ - {1, 0}, - {1, 0}, - {1, 0}, - {1, 0}, - {1, 0}, + {reuse: 1, bindToDevice: 0}, + {reuse: 1, bindToDevice: 0}, + {reuse: 1, bindToDevice: 0}, + {reuse: 1, bindToDevice: 0}, + {reuse: 1, bindToDevice: 0}, }, map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed evenly. @@ -219,9 +236,9 @@ func TestDistribution(t *testing.T) { "BindToDevice", // 3 endpoints with various bindings. []endpointSockopts{ - {0, 1}, - {0, 2}, - {0, 3}, + {reuse: 0, bindToDevice: 1}, + {reuse: 0, bindToDevice: 2}, + {reuse: 0, bindToDevice: 3}, }, map[tcpip.NICID][]float64{ // Injected packets on dev0 go only to the endpoint bound to dev0. @@ -236,12 +253,12 @@ func TestDistribution(t *testing.T) { "ReuseAndBindToDevice", // 6 endpoints with various bindings. []endpointSockopts{ - {1, 1}, - {1, 1}, - {1, 2}, - {1, 2}, - {1, 2}, - {1, 0}, + {reuse: 1, bindToDevice: 1}, + {reuse: 1, bindToDevice: 1}, + {reuse: 1, bindToDevice: 2}, + {reuse: 1, bindToDevice: 2}, + {reuse: 1, bindToDevice: 2}, + {reuse: 1, bindToDevice: 0}, }, map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed among endpoints bound to @@ -255,17 +272,17 @@ func TestDistribution(t *testing.T) { }, }, } { - t.Run(test.name, func(t *testing.T) { - for device, wantedDistribution := range test.wantedDistributions { - t.Run(string(device), func(t *testing.T) { + for protoName, netProtoNum := range map[string]tcpip.NetworkProtocolNumber{ + "IPv4": ipv4.ProtocolNumber, + "IPv6": ipv6.ProtocolNumber, + } { + for device, wantDistribution := range test.wantDistributions { + t.Run(test.name+protoName+string(device), func(t *testing.T) { var devices []tcpip.NICID - for d := range test.wantedDistributions { + for d := range test.wantDistributions { devices = append(devices, d) } c := newDualTestContextMultiNIC(t, defaultMTU, devices) - defer c.cleanup() - - c.createV6Endpoint(false) eps := make(map[tcpip.Endpoint]int) @@ -279,9 +296,9 @@ func TestDistribution(t *testing.T) { defer close(ch) var err *tcpip.Error - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) + ep, err := c.s.NewEndpoint(udp.ProtocolNumber, netProtoNum, &wq) if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } eps[ep] = i @@ -294,20 +311,30 @@ func TestDistribution(t *testing.T) { defer ep.Close() reusePortOption := tcpip.ReusePortOption(endpoint.reuse) if err := ep.SetSockOpt(reusePortOption); err != nil { - c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err) + t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", reusePortOption, i, err) } bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) if err := ep.SetSockOpt(bindToDeviceOption); err != nil { - c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err) + t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", bindToDeviceOption, i, err) + } + + var dstAddr tcpip.Address + switch netProtoNum { + case ipv4.ProtocolNumber: + dstAddr = testDstAddrV4 + case ipv6.ProtocolNumber: + dstAddr = testDstAddrV6 + default: + t.Fatalf("unexpected protocol number: %d", netProtoNum) } - if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil { - t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err) + if err := ep.Bind(tcpip.FullAddress{Addr: dstAddr, Port: testDstPort}); err != nil { + t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err) } } npackets := 100000 nports := 10000 - if got, want := len(test.endpoints), len(wantedDistribution); got != want { + if got, want := len(test.endpoints), len(wantDistribution); got != want { t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) } ports := make(map[uint16]tcpip.Endpoint) @@ -316,17 +343,22 @@ func TestDistribution(t *testing.T) { // Send a packet. port := uint16(i % nports) payload := newPayload() - c.sendV6Packet(payload, - &headers{ - srcPort: testPort + port, - dstPort: stackPort}, - device) + hdrs := &headers{ + srcPort: testSrcPort + port, + dstPort: testDstPort, + } + switch netProtoNum { + case ipv4.ProtocolNumber: + c.sendV4Packet(payload, hdrs, device) + case ipv6.ProtocolNumber: + c.sendV6Packet(payload, hdrs, device) + default: + t.Fatalf("unexpected protocol number: %d", netProtoNum) + } - var addr tcpip.FullAddress ep := <-pollChannel - _, _, err := ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err) + if _, _, err := ep.Read(nil); err != nil { + t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err) } stats[ep]++ if i < nports { @@ -342,17 +374,17 @@ func TestDistribution(t *testing.T) { // Check that a packet distribution is as expected. for ep, i := range eps { - wantedRatio := wantedDistribution[i] - wantedRecv := wantedRatio * float64(npackets) + wantRatio := wantDistribution[i] + wantRecv := wantRatio * float64(npackets) actualRecv := stats[ep] actualRatio := float64(stats[ep]) / float64(npackets) // The deviation is less than 10%. - if math.Abs(actualRatio-wantedRatio) > 0.05 { - t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets) + if math.Abs(actualRatio-wantRatio) > 0.05 { + t.Errorf("want about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantRatio*100, wantRecv, npackets, i, actualRatio*100, actualRecv, npackets) } } }) } - }) + } } } |