summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--WORKSPACE18
-rw-r--r--kokoro/build.cfg1
-rw-r--r--kokoro/build_tests.cfg1
-rw-r--r--pkg/sentry/fs/fsutil/host_mappable.go2
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached.go22
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached_test.go4
-rw-r--r--pkg/sentry/fs/gofer/inode.go15
-rw-r--r--pkg/sentry/fs/host/inode.go4
-rw-r--r--pkg/sentry/fs/splice.go4
-rw-r--r--pkg/sentry/loader/elf.go18
-rw-r--r--pkg/sentry/sighandling/sighandling_unsafe.go2
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go20
-rw-r--r--pkg/sentry/socket/epsocket/provider.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_splice.go10
-rw-r--r--pkg/sentry/syscalls/linux/sys_time.go39
-rw-r--r--pkg/sentry/syscalls/linux/sys_utsname.go6
-rw-r--r--pkg/sentry/usermem/usermem.go8
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go14
-rw-r--r--pkg/tcpip/ports/ports.go161
-rw-r--r--pkg/tcpip/ports/ports_test.go169
-rw-r--r--pkg/tcpip/stack/BUILD5
-rw-r--r--pkg/tcpip/stack/nic.go15
-rw-r--r--pkg/tcpip/stack/stack.go83
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go227
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go352
-rw-r--r--pkg/tcpip/stack/transport_test.go5
-rw-r--r--pkg/tcpip/tcpip.go4
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go6
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/accept.go2
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go73
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go116
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go26
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go42
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go120
-rw-r--r--runsc/BUILD2
-rw-r--r--runsc/cmd/exec.go52
-rw-r--r--runsc/cmd/exec_test.go4
-rw-r--r--runsc/container/BUILD1
-rw-r--r--runsc/container/container_test.go25
-rw-r--r--runsc/container/test_app/test_app.go65
-rw-r--r--runsc/dockerutil/dockerutil.go9
-rw-r--r--runsc/fsgofer/fsgofer.go19
-rw-r--r--runsc/fsgofer/fsgofer_test.go2
-rw-r--r--runsc/specutils/BUILD1
-rw-r--r--runsc/specutils/specutils.go10
-rwxr-xr-xscripts/build.sh2
-rw-r--r--test/e2e/BUILD2
-rw-r--r--test/e2e/exec_test.go65
-rw-r--r--test/syscalls/linux/BUILD77
-rw-r--r--test/syscalls/linux/clock_nanosleep.cc86
-rw-r--r--test/syscalls/linux/exec_binary.cc19
-rw-r--r--test/syscalls/linux/packet_socket.cc29
-rw-r--r--test/syscalls/linux/packet_socket_raw.cc21
-rw-r--r--test/syscalls/linux/pty.cc25
-rw-r--r--test/syscalls/linux/raw_socket_hdrincl.cc43
-rw-r--r--test/syscalls/linux/raw_socket_icmp.cc13
-rw-r--r--test/syscalls/linux/raw_socket_ipv4.cc13
-rw-r--r--test/syscalls/linux/socket_bind_to_device.cc314
-rw-r--r--test/syscalls/linux/socket_bind_to_device_distribution.cc381
-rw-r--r--test/syscalls/linux/socket_bind_to_device_sequence.cc316
-rw-r--r--test/syscalls/linux/socket_bind_to_device_util.cc75
-rw-r--r--test/syscalls/linux/socket_bind_to_device_util.h67
-rw-r--r--test/syscalls/linux/uidgid.cc25
-rw-r--r--test/syscalls/linux/uname.cc14
-rw-r--r--test/util/BUILD11
-rw-r--r--test/util/proc_util.cc2
-rw-r--r--test/util/uid_util.cc44
-rw-r--r--test/util/uid_util.h29
72 files changed, 2913 insertions, 552 deletions
diff --git a/WORKSPACE b/WORKSPACE
index 77750f9e6..082e26ee9 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -3,10 +3,10 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
http_archive(
name = "io_bazel_rules_go",
- sha256 = "ae8c36ff6e565f674c7a3692d6a9ea1096e4c1ade497272c2108a810fb39acd2",
+ sha256 = "513c12397db1bc9aa46dd62f02dd94b49a9b5d17444d49b5a04c5a89f3053c1c",
urls = [
- "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/rules_go/releases/download/0.19.4/rules_go-0.19.4.tar.gz",
- "https://github.com/bazelbuild/rules_go/releases/download/0.19.4/rules_go-0.19.4.tar.gz",
+ "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/rules_go/releases/download/v0.19.5/rules_go-v0.19.5.tar.gz",
+ "https://github.com/bazelbuild/rules_go/releases/download/v0.19.5/rules_go-v0.19.5.tar.gz",
],
)
@@ -24,7 +24,7 @@ load("@io_bazel_rules_go//go:deps.bzl", "go_rules_dependencies", "go_register_to
go_rules_dependencies()
go_register_toolchains(
- go_version = "1.13",
+ go_version = "1.13.1",
nogo = "@//:nogo",
)
@@ -75,6 +75,16 @@ load("@bazel_toolchains//rules:rbe_repo.bzl", "rbe_autoconfig")
rbe_autoconfig(name = "rbe_default")
+http_archive(
+ name = "rules_pkg",
+ sha256 = "5bdc04987af79bd27bc5b00fe30f59a858f77ffa0bd2d8143d5b31ad8b1bd71c",
+ url = "https://github.com/bazelbuild/rules_pkg/releases/download/0.2.0/rules_pkg-0.2.0.tar.gz",
+)
+
+load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies")
+
+rules_pkg_dependencies()
+
# External repositories, in sorted order.
go_repository(
name = "com_github_cenkalti_backoff",
diff --git a/kokoro/build.cfg b/kokoro/build.cfg
index 084347dde..6c1d262d4 100644
--- a/kokoro/build.cfg
+++ b/kokoro/build.cfg
@@ -16,6 +16,7 @@ env_vars {
action {
define_artifacts {
+ regex: "**/runsc"
regex: "**/runsc.*"
regex: "**/dists/**"
}
diff --git a/kokoro/build_tests.cfg b/kokoro/build_tests.cfg
new file mode 100644
index 000000000..c64b7e679
--- /dev/null
+++ b/kokoro/build_tests.cfg
@@ -0,0 +1 @@
+build_file: "repo/scripts/build.sh"
diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go
index d2495cb83..693625ddc 100644
--- a/pkg/sentry/fs/fsutil/host_mappable.go
+++ b/pkg/sentry/fs/fsutil/host_mappable.go
@@ -144,7 +144,7 @@ func (h *HostMappable) Truncate(ctx context.Context, newSize int64) error {
mask := fs.AttrMask{Size: true}
attr := fs.UnstableAttr{Size: newSize}
- if err := h.backingFile.SetMaskedAttributes(ctx, mask, attr); err != nil {
+ if err := h.backingFile.SetMaskedAttributes(ctx, mask, attr, false); err != nil {
return err
}
diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go
index d404a79d4..dd80757dc 100644
--- a/pkg/sentry/fs/fsutil/inode_cached.go
+++ b/pkg/sentry/fs/fsutil/inode_cached.go
@@ -140,12 +140,16 @@ type CachedFileObject interface {
// WriteFromBlocksAt may return a partial write without an error.
WriteFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)
- // SetMaskedAttributes sets the attributes in attr that are true in mask
- // on the backing file.
+ // SetMaskedAttributes sets the attributes in attr that are true in
+ // mask on the backing file. If the mask contains only ATime or MTime
+ // and the CachedFileObject has an FD to the file, then this operation
+ // is a noop unless forceSetTimestamps is true. This avoids an extra
+ // RPC to the gofer in the open-read/write-close case, when the
+ // timestamps on the file will be updated by the host kernel for us.
//
// SetMaskedAttributes may be called at any point, regardless of whether
// the file was opened.
- SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error
+ SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, forceSetTimestamps bool) error
// Allocate allows the caller to reserve disk space for the inode.
// It's equivalent to fallocate(2) with 'mode=0'.
@@ -224,7 +228,7 @@ func (c *CachingInodeOperations) SetPermissions(ctx context.Context, inode *fs.I
now := ktime.NowFromContext(ctx)
masked := fs.AttrMask{Perms: true}
- if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Perms: perms}); err != nil {
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Perms: perms}, false); err != nil {
return false
}
c.attr.Perms = perms
@@ -246,7 +250,7 @@ func (c *CachingInodeOperations) SetOwner(ctx context.Context, inode *fs.Inode,
UID: owner.UID.Ok(),
GID: owner.GID.Ok(),
}
- if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Owner: owner}); err != nil {
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Owner: owner}, false); err != nil {
return err
}
if owner.UID.Ok() {
@@ -282,7 +286,9 @@ func (c *CachingInodeOperations) SetTimestamps(ctx context.Context, inode *fs.In
AccessTime: !ts.ATimeOmit,
ModificationTime: !ts.MTimeOmit,
}
- if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{AccessTime: ts.ATime, ModificationTime: ts.MTime}); err != nil {
+ // Call SetMaskedAttributes with forceSetTimestamps = true to make sure
+ // the timestamp is updated.
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{AccessTime: ts.ATime, ModificationTime: ts.MTime}, true); err != nil {
return err
}
if !ts.ATimeOmit {
@@ -305,7 +311,7 @@ func (c *CachingInodeOperations) Truncate(ctx context.Context, inode *fs.Inode,
now := ktime.NowFromContext(ctx)
masked := fs.AttrMask{Size: true}
attr := fs.UnstableAttr{Size: size}
- if err := c.backingFile.SetMaskedAttributes(ctx, masked, attr); err != nil {
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, attr, false); err != nil {
c.dataMu.Unlock()
return err
}
@@ -394,7 +400,7 @@ func (c *CachingInodeOperations) WriteOut(ctx context.Context, inode *fs.Inode)
c.dirtyAttr.Size = false
// Write out cached attributes.
- if err := c.backingFile.SetMaskedAttributes(ctx, c.dirtyAttr, c.attr); err != nil {
+ if err := c.backingFile.SetMaskedAttributes(ctx, c.dirtyAttr, c.attr, false); err != nil {
c.attrMu.Unlock()
return err
}
diff --git a/pkg/sentry/fs/fsutil/inode_cached_test.go b/pkg/sentry/fs/fsutil/inode_cached_test.go
index eb5730c35..129f314c8 100644
--- a/pkg/sentry/fs/fsutil/inode_cached_test.go
+++ b/pkg/sentry/fs/fsutil/inode_cached_test.go
@@ -39,7 +39,7 @@ func (noopBackingFile) WriteFromBlocksAt(ctx context.Context, srcs safemem.Block
return srcs.NumBytes(), nil
}
-func (noopBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr) error {
+func (noopBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr, bool) error {
return nil
}
@@ -230,7 +230,7 @@ func (f *sliceBackingFile) WriteFromBlocksAt(ctx context.Context, srcs safemem.B
return w.WriteFromBlocks(srcs)
}
-func (*sliceBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr) error {
+func (*sliceBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr, bool) error {
return nil
}
diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go
index 95b064aea..d918d6620 100644
--- a/pkg/sentry/fs/gofer/inode.go
+++ b/pkg/sentry/fs/gofer/inode.go
@@ -215,8 +215,8 @@ func (i *inodeFileState) WriteFromBlocksAt(ctx context.Context, srcs safemem.Blo
}
// SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes.
-func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error {
- if i.skipSetAttr(mask) {
+func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, forceSetTimestamps bool) error {
+ if i.skipSetAttr(mask, forceSetTimestamps) {
return nil
}
as, ans := attr.AccessTime.Unix()
@@ -251,13 +251,14 @@ func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMa
// when:
// - Mask is empty
// - Mask contains only attributes that cannot be set in the gofer
-// - Mask contains only atime and/or mtime, and host FD exists
+// - forceSetTimestamps is false and mask contains only atime and/or mtime
+// and host FD exists
//
// Updates to atime and mtime can be skipped because cached value will be
// "close enough" to host value, given that operation went directly to host FD.
// Skipping atime updates is particularly important to reduce the number of
// operations sent to the Gofer for readonly files.
-func (i *inodeFileState) skipSetAttr(mask fs.AttrMask) bool {
+func (i *inodeFileState) skipSetAttr(mask fs.AttrMask, forceSetTimestamps bool) bool {
// First remove attributes that cannot be updated.
cpy := mask
cpy.Type = false
@@ -277,6 +278,12 @@ func (i *inodeFileState) skipSetAttr(mask fs.AttrMask) bool {
return false
}
+ // If forceSetTimestamps was passed, then we cannot skip.
+ if forceSetTimestamps {
+ return false
+ }
+
+ // Skip if we have a host FD.
i.handlesMu.RLock()
defer i.handlesMu.RUnlock()
return (i.readHandles != nil && i.readHandles.Host != nil) ||
diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go
index 894ab01f0..a6e4a09e3 100644
--- a/pkg/sentry/fs/host/inode.go
+++ b/pkg/sentry/fs/host/inode.go
@@ -114,7 +114,7 @@ func (i *inodeFileState) WriteFromBlocksAt(ctx context.Context, srcs safemem.Blo
}
// SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes.
-func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error {
+func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, _ bool) error {
if mask.Empty() {
return nil
}
@@ -163,7 +163,7 @@ func (i *inodeFileState) unstableAttr(ctx context.Context) (fs.UnstableAttr, err
return unstableAttr(i.mops, &s), nil
}
-// SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes.
+// Allocate implements fsutil.CachedFileObject.Allocate.
func (i *inodeFileState) Allocate(_ context.Context, offset, length int64) error {
return syscall.Fallocate(i.FD(), 0, offset, length)
}
diff --git a/pkg/sentry/fs/splice.go b/pkg/sentry/fs/splice.go
index b03b7f836..311798811 100644
--- a/pkg/sentry/fs/splice.go
+++ b/pkg/sentry/fs/splice.go
@@ -139,7 +139,7 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64,
// Attempt to do a WriteTo; this is likely the most efficient.
n, err := src.FileOperations.WriteTo(ctx, src, w, opts.Length, opts.Dup)
- if n == 0 && err != nil && err != syserror.ErrWouldBlock && !opts.Dup {
+ if n == 0 && err == syserror.ENOSYS && !opts.Dup {
// Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also be
// more efficient than a copy if buffers are cached or readily
// available. (It's unlikely that they can actually be donated).
@@ -151,7 +151,7 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64,
// if we block at some point, we could lose data. If the source is
// not a pipe then reading is not destructive; if the destination
// is a regular file, then it is guaranteed not to block writing.
- if n == 0 && err != nil && err != syserror.ErrWouldBlock && !opts.Dup && (!dstPipe || !srcPipe) {
+ if n == 0 && err == syserror.ENOSYS && !opts.Dup && (!dstPipe || !srcPipe) {
// Fallback to an in-kernel copy.
n, err = io.Copy(w, &io.LimitedReader{
R: r,
diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go
index ba9c9ce12..2d9251e92 100644
--- a/pkg/sentry/loader/elf.go
+++ b/pkg/sentry/loader/elf.go
@@ -323,18 +323,22 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f *fs.File, phdr *elf.
return syserror.ENOEXEC
}
+ // N.B. Linux uses vm_brk_flags to map these pages, which only
+ // honors the X bit, always mapping at least RW. ignoring These
+ // pages are not included in the final brk region.
+ prot := usermem.ReadWrite
+ if phdr.Flags&elf.PF_X == elf.PF_X {
+ prot.Execute = true
+ }
+
if _, err := m.MMap(ctx, memmap.MMapOpts{
Length: uint64(anonSize),
Addr: anonAddr,
// Fixed without Unmap will fail the mmap if something is
// already at addr.
- Fixed: true,
- Private: true,
- // N.B. Linux uses vm_brk to map these pages, ignoring
- // the segment protections, instead always mapping RW.
- // These pages are not included in the final brk
- // region.
- Perms: usermem.ReadWrite,
+ Fixed: true,
+ Private: true,
+ Perms: prot,
MaxPerms: usermem.AnyAccess,
}); err != nil {
ctx.Infof("Error mapping PT_LOAD segment %v anonymous memory: %v", phdr, err)
diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sentry/sighandling/sighandling_unsafe.go
index eace3766d..c303435d5 100644
--- a/pkg/sentry/sighandling/sighandling_unsafe.go
+++ b/pkg/sentry/sighandling/sighandling_unsafe.go
@@ -23,7 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
)
-// TODO(b/34161764): Move to pkg/abi/linux along with definitions in
+// FIXME(gvisor.dev/issue/214): Move to pkg/abi/linux along with definitions in
// pkg/sentry/arch.
type sigaction struct {
handler uintptr
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index 3e66f9cbb..5812085fa 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -942,6 +942,19 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return int32(v), nil
+ case linux.SO_BINDTODEVICE:
+ var v tcpip.BindToDeviceOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ if len(v) == 0 {
+ return []byte{}, nil
+ }
+ if outLen < linux.IFNAMSIZ {
+ return nil, syserr.ErrInvalidArgument
+ }
+ return append([]byte(v), 0), nil
+
case linux.SO_BROADCAST:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
@@ -1305,6 +1318,13 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
v := usermem.ByteOrder.Uint32(optVal)
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v)))
+ case linux.SO_BINDTODEVICE:
+ n := bytes.IndexByte(optVal, 0)
+ if n == -1 {
+ n = len(optVal)
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n])))
+
case linux.SO_BROADCAST:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go
index 421f93dc4..0a9dfa6c3 100644
--- a/pkg/sentry/socket/epsocket/provider.go
+++ b/pkg/sentry/socket/epsocket/provider.go
@@ -65,7 +65,7 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in
// Raw sockets require CAP_NET_RAW.
creds := auth.CredentialsFromContext(ctx)
if !creds.HasCapability(linux.CAP_NET_RAW) {
- return 0, true, syserr.ErrPermissionDenied
+ return 0, true, syserr.ErrNotPermitted
}
switch protocol {
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 3bac4d90d..b5a72ce63 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -531,7 +531,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, syserror.ENOTSOCK
}
- if optLen <= 0 {
+ if optLen < 0 {
return 0, nil, syserror.EINVAL
}
if optLen > maxOptLen {
diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go
index f0a292f2f..9f705ebca 100644
--- a/pkg/sentry/syscalls/linux/sys_splice.go
+++ b/pkg/sentry/syscalls/linux/sys_splice.go
@@ -245,12 +245,12 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if inOffset != 0 || outOffset != 0 {
return 0, nil, syserror.ESPIPE
}
- default:
- return 0, nil, syserror.EINVAL
- }
- // We may not refer to the same pipe; otherwise it's a continuous loop.
- if inFile.Dirent.Inode.StableAttr.InodeID == outFile.Dirent.Inode.StableAttr.InodeID {
+ // We may not refer to the same pipe; otherwise it's a continuous loop.
+ if inFile.Dirent.Inode.StableAttr.InodeID == outFile.Dirent.Inode.StableAttr.InodeID {
+ return 0, nil, syserror.EINVAL
+ }
+ default:
return 0, nil, syserror.EINVAL
}
diff --git a/pkg/sentry/syscalls/linux/sys_time.go b/pkg/sentry/syscalls/linux/sys_time.go
index 4b3f043a2..b887fa9d7 100644
--- a/pkg/sentry/syscalls/linux/sys_time.go
+++ b/pkg/sentry/syscalls/linux/sys_time.go
@@ -15,6 +15,7 @@
package linux
import (
+ "fmt"
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -228,41 +229,35 @@ func clockNanosleepFor(t *kernel.Task, c ktime.Clock, dur time.Duration, rem use
timer.Destroy()
- var remaining time.Duration
- // Did we just block for the entire duration?
- if err == syserror.ETIMEDOUT {
- remaining = 0
- } else {
- remaining = dur - after.Sub(start)
+ switch err {
+ case syserror.ETIMEDOUT:
+ // Slept for entire timeout.
+ return nil
+ case syserror.ErrInterrupted:
+ // Interrupted.
+ remaining := dur - after.Sub(start)
if remaining < 0 {
remaining = time.Duration(0)
}
- }
- // Copy out remaining time.
- if err != nil && rem != usermem.Addr(0) {
- timeleft := linux.NsecToTimespec(remaining.Nanoseconds())
- if err := copyTimespecOut(t, rem, &timeleft); err != nil {
- return err
+ // Copy out remaining time.
+ if rem != 0 {
+ timeleft := linux.NsecToTimespec(remaining.Nanoseconds())
+ if err := copyTimespecOut(t, rem, &timeleft); err != nil {
+ return err
+ }
}
- }
-
- // Did we just block for the entire duration?
- if err == syserror.ETIMEDOUT {
- return nil
- }
- // If interrupted, arrange for a restart with the remaining duration.
- if err == syserror.ErrInterrupted {
+ // Arrange for a restart with the remaining duration.
t.SetSyscallRestartBlock(&clockNanosleepRestartBlock{
c: c,
duration: remaining,
rem: rem,
})
return kernel.ERESTART_RESTARTBLOCK
+ default:
+ panic(fmt.Sprintf("Impossible BlockWithTimer error %v", err))
}
-
- return err
}
// Nanosleep implements linux syscall Nanosleep(2).
diff --git a/pkg/sentry/syscalls/linux/sys_utsname.go b/pkg/sentry/syscalls/linux/sys_utsname.go
index 271ace08e..748e8dd8d 100644
--- a/pkg/sentry/syscalls/linux/sys_utsname.go
+++ b/pkg/sentry/syscalls/linux/sys_utsname.go
@@ -79,11 +79,11 @@ func Sethostname(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
return 0, nil, syserror.EINVAL
}
- name, err := t.CopyInString(nameAddr, int(size))
- if err != nil {
+ name := make([]byte, size)
+ if _, err := t.CopyInBytes(nameAddr, name); err != nil {
return 0, nil, err
}
- utsns.SetHostName(name)
+ utsns.SetHostName(string(name))
return 0, nil, nil
}
diff --git a/pkg/sentry/usermem/usermem.go b/pkg/sentry/usermem/usermem.go
index 6eced660a..7b1f312b1 100644
--- a/pkg/sentry/usermem/usermem.go
+++ b/pkg/sentry/usermem/usermem.go
@@ -16,6 +16,7 @@
package usermem
import (
+ "bytes"
"errors"
"io"
"strconv"
@@ -270,11 +271,10 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt
n, err := uio.CopyIn(ctx, addr, buf[done:done+readlen], opts)
// Look for the terminating zero byte, which may have occurred before
// hitting err.
- for i, c := range buf[done : done+n] {
- if c == 0 {
- return stringFromImmutableBytes(buf[:done+i]), nil
- }
+ if i := bytes.IndexByte(buf[done:done+n], byte(0)); i >= 0 {
+ return stringFromImmutableBytes(buf[:done+i]), nil
}
+
done += n
if err != nil {
return stringFromImmutableBytes(buf[:done]), err
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index 584db710e..7636418b1 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -82,7 +82,19 @@ const (
PacketMMap
)
-// An endpoint implements the link-layer using a message-oriented file descriptor.
+func (p PacketDispatchMode) String() string {
+ switch p {
+ case Readv:
+ return "Readv"
+ case RecvMMsg:
+ return "RecvMMsg"
+ case PacketMMap:
+ return "PacketMMap"
+ default:
+ return fmt.Sprintf("unknown packet dispatch mode %v", p)
+ }
+}
+
type endpoint struct {
// fds is the set of file descriptors each identifying one inbound/outbound
// channel. The endpoint will dispatch from all inbound channels as well as
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index 315780c0c..30cea8996 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -19,6 +19,7 @@ import (
"math"
"math/rand"
"sync"
+ "sync/atomic"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -27,6 +28,10 @@ const (
// FirstEphemeral is the first ephemeral port.
FirstEphemeral = 16000
+ // numEphemeralPorts it the mnumber of available ephemeral ports to
+ // Netstack.
+ numEphemeralPorts = math.MaxUint16 - FirstEphemeral + 1
+
anyIPAddress tcpip.Address = ""
)
@@ -40,6 +45,13 @@ type portDescriptor struct {
type PortManager struct {
mu sync.RWMutex
allocatedPorts map[portDescriptor]bindAddresses
+
+ // hint is used to pick ports ephemeral ports in a stable order for
+ // a given port offset.
+ //
+ // hint must be accessed using the portHint/incPortHint helpers.
+ // TODO(gvisor.dev/issue/940): S/R this field.
+ hint uint32
}
type portNode struct {
@@ -47,43 +59,76 @@ type portNode struct {
refs int
}
-// bindAddresses is a set of IP addresses.
-type bindAddresses map[tcpip.Address]portNode
+// deviceNode is never empty. When it has no elements, it is removed from the
+// map that references it.
+type deviceNode map[tcpip.NICID]portNode
-// isAvailable checks whether an IP address is available to bind to.
-func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool) bool {
- if addr == anyIPAddress {
- if len(b) == 0 {
- return true
- }
+// isAvailable checks whether binding is possible by device. If not binding to a
+// device, check against all portNodes. If binding to a specific device, check
+// against the unspecified device and the provided device.
+func (d deviceNode) isAvailable(reuse bool, bindToDevice tcpip.NICID) bool {
+ if bindToDevice == 0 {
+ // Trying to binding all devices.
if !reuse {
+ // Can't bind because the (addr,port) is already bound.
return false
}
- for _, n := range b {
- if !n.reuse {
+ for _, p := range d {
+ if !p.reuse {
+ // Can't bind because the (addr,port) was previously bound without reuse.
return false
}
}
return true
}
- // If all addresses for this portDescriptor are already bound, no
- // address is available.
- if n, ok := b[anyIPAddress]; ok {
- if !reuse {
+ if p, ok := d[0]; ok {
+ if !reuse || !p.reuse {
return false
}
- if !n.reuse {
+ }
+
+ if p, ok := d[bindToDevice]; ok {
+ if !reuse || !p.reuse {
return false
}
}
- if n, ok := b[addr]; ok {
- if !reuse {
+ return true
+}
+
+// bindAddresses is a set of IP addresses.
+type bindAddresses map[tcpip.Address]deviceNode
+
+// isAvailable checks whether an IP address is available to bind to. If the
+// address is the "any" address, check all other addresses. Otherwise, just
+// check against the "any" address and the provided address.
+func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice tcpip.NICID) bool {
+ if addr == anyIPAddress {
+ // If binding to the "any" address then check that there are no conflicts
+ // with all addresses.
+ for _, d := range b {
+ if !d.isAvailable(reuse, bindToDevice) {
+ return false
+ }
+ }
+ return true
+ }
+
+ // Check that there is no conflict with the "any" address.
+ if d, ok := b[anyIPAddress]; ok {
+ if !d.isAvailable(reuse, bindToDevice) {
return false
}
- return n.reuse
}
+
+ // Check that this is no conflict with the provided address.
+ if d, ok := b[addr]; ok {
+ if !d.isAvailable(reuse, bindToDevice) {
+ return false
+ }
+ }
+
return true
}
@@ -97,11 +142,40 @@ func NewPortManager() *PortManager {
// is suitable for its needs, and stopping when a port is found or an error
// occurs.
func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
- count := uint16(math.MaxUint16 - FirstEphemeral + 1)
- offset := uint16(rand.Int31n(int32(count)))
+ offset := uint32(rand.Int31n(numEphemeralPorts))
+ return s.pickEphemeralPort(offset, numEphemeralPorts, testPort)
+}
+
+// portHint atomically reads and returns the s.hint value.
+func (s *PortManager) portHint() uint32 {
+ return atomic.LoadUint32(&s.hint)
+}
+
+// incPortHint atomically increments s.hint by 1.
+func (s *PortManager) incPortHint() {
+ atomic.AddUint32(&s.hint, 1)
+}
- for i := uint16(0); i < count; i++ {
- port = FirstEphemeral + (offset+i)%count
+// PickEphemeralPortStable starts at the specified offset + s.portHint and
+// iterates over all ephemeral ports, allowing the caller to decide whether a
+// given port is suitable for its needs and stopping when a port is found or an
+// error occurs.
+func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
+ p, err := s.pickEphemeralPort(s.portHint()+offset, numEphemeralPorts, testPort)
+ if err == nil {
+ s.incPortHint()
+ }
+ return p, err
+
+}
+
+// pickEphemeralPort starts at the offset specified from the FirstEphemeral port
+// and iterates over the number of ports specified by count and allows the
+// caller to decide whether a given port is suitable for its needs, and stopping
+// when a port is found or an error occurs.
+func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
+ for i := uint32(0); i < count; i++ {
+ port = uint16(FirstEphemeral + (offset+i)%count)
ok, err := testPort(port)
if err != nil {
return 0, err
@@ -116,17 +190,17 @@ func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Er
}
// IsPortAvailable tests if the given port is available on all given protocols.
-func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
+func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
s.mu.Lock()
defer s.mu.Unlock()
- return s.isPortAvailableLocked(networks, transport, addr, port, reuse)
+ return s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice)
}
-func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
+func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
for _, network := range networks {
desc := portDescriptor{network, transport, port}
if addrs, ok := s.allocatedPorts[desc]; ok {
- if !addrs.isAvailable(addr, reuse) {
+ if !addrs.isAvailable(addr, reuse, bindToDevice) {
return false
}
}
@@ -138,14 +212,14 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb
// reserved by another endpoint. If port is zero, ReservePort will search for
// an unreserved ephemeral port and reserve it, returning its value in the
// "port" return value.
-func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) (reservedPort uint16, err *tcpip.Error) {
+func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) {
s.mu.Lock()
defer s.mu.Unlock()
// If a port is specified, just try to reserve it for all network
// protocols.
if port != 0 {
- if !s.reserveSpecificPort(networks, transport, addr, port, reuse) {
+ if !s.reserveSpecificPort(networks, transport, addr, port, reuse, bindToDevice) {
return 0, tcpip.ErrPortInUse
}
return port, nil
@@ -153,13 +227,13 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp
// A port wasn't specified, so try to find one.
return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
- return s.reserveSpecificPort(networks, transport, addr, p, reuse), nil
+ return s.reserveSpecificPort(networks, transport, addr, p, reuse, bindToDevice), nil
})
}
// reserveSpecificPort tries to reserve the given port on all given protocols.
-func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
- if !s.isPortAvailableLocked(networks, transport, addr, port, reuse) {
+func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
+ if !s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) {
return false
}
@@ -171,11 +245,16 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
m = make(bindAddresses)
s.allocatedPorts[desc] = m
}
- if n, ok := m[addr]; ok {
+ d, ok := m[addr]
+ if !ok {
+ d = make(deviceNode)
+ m[addr] = d
+ }
+ if n, ok := d[bindToDevice]; ok {
n.refs++
- m[addr] = n
+ d[bindToDevice] = n
} else {
- m[addr] = portNode{reuse: reuse, refs: 1}
+ d[bindToDevice] = portNode{reuse: reuse, refs: 1}
}
}
@@ -184,22 +263,28 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
// ReleasePort releases the reservation on a port/IP combination so that it can
// be reserved by other endpoints.
-func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) {
+func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, bindToDevice tcpip.NICID) {
s.mu.Lock()
defer s.mu.Unlock()
for _, network := range networks {
desc := portDescriptor{network, transport, port}
if m, ok := s.allocatedPorts[desc]; ok {
- n, ok := m[addr]
+ d, ok := m[addr]
+ if !ok {
+ continue
+ }
+ n, ok := d[bindToDevice]
if !ok {
continue
}
n.refs--
+ d[bindToDevice] = n
if n.refs == 0 {
+ delete(d, bindToDevice)
+ }
+ if len(d) == 0 {
delete(m, addr)
- } else {
- m[addr] = n
}
if len(m) == 0 {
delete(s.allocatedPorts, desc)
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
index 689401661..19f4833fc 100644
--- a/pkg/tcpip/ports/ports_test.go
+++ b/pkg/tcpip/ports/ports_test.go
@@ -15,6 +15,7 @@
package ports
import (
+ "math/rand"
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -34,6 +35,7 @@ type portReserveTestAction struct {
want *tcpip.Error
reuse bool
release bool
+ device tcpip.NICID
}
func TestPortReservation(t *testing.T) {
@@ -100,6 +102,112 @@ func TestPortReservation(t *testing.T) {
{port: 24, ip: anyIPAddress, release: true},
{port: 24, ip: anyIPAddress, reuse: false, want: nil},
},
+ }, {
+ tname: "bind twice with device fails",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 3, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 3, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind to device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 1, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 2, want: nil},
+ },
+ }, {
+ tname: "bind to device and then without device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind without device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, want: nil},
+ {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with reuse",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
+ },
+ }, {
+ tname: "binding with reuse and device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "mixing reuse and not reuse by binding to device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 999, want: nil},
+ },
+ }, {
+ tname: "can't bind to 0 after mixing reuse and not reuse",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind and release",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+
+ // Release the bind to device 0 and try again.
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil, release: true},
+ {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil},
+ },
+ }, {
+ tname: "bind twice with reuse once",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "release an unreserved device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil},
+ // The below don't exist.
+ {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil, release: true},
+ {port: 9999, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true},
+ // Release all.
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true},
+ {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil, release: true},
+ },
},
} {
t.Run(test.tname, func(t *testing.T) {
@@ -108,12 +216,12 @@ func TestPortReservation(t *testing.T) {
for _, test := range test.actions {
if test.release {
- pm.ReleasePort(net, fakeTransNumber, test.ip, test.port)
+ pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.device)
continue
}
- gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse)
+ gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse, test.device)
if err != test.want {
- t.Fatalf("ReservePort(.., .., %s, %d, %t) = %v, want %v", test.ip, test.port, test.release, err, test.want)
+ t.Fatalf("ReservePort(.., .., %s, %d, %t, %d) = %v, want %v", test.ip, test.port, test.reuse, test.device, err, test.want)
}
if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) {
t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral)
@@ -125,7 +233,6 @@ func TestPortReservation(t *testing.T) {
}
func TestPickEphemeralPort(t *testing.T) {
- pm := NewPortManager()
customErr := &tcpip.Error{}
for _, test := range []struct {
name string
@@ -169,9 +276,63 @@ func TestPickEphemeralPort(t *testing.T) {
},
} {
t.Run(test.name, func(t *testing.T) {
+ pm := NewPortManager()
if port, err := pm.PickEphemeralPort(test.f); port != test.wantPort || err != test.wantErr {
t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr)
}
})
}
}
+
+func TestPickEphemeralPortStable(t *testing.T) {
+ customErr := &tcpip.Error{}
+ for _, test := range []struct {
+ name string
+ f func(port uint16) (bool, *tcpip.Error)
+ wantErr *tcpip.Error
+ wantPort uint16
+ }{
+ {
+ name: "no-port-available",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ return false, nil
+ },
+ wantErr: tcpip.ErrNoPortAvailable,
+ },
+ {
+ name: "port-tester-error",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ return false, customErr
+ },
+ wantErr: customErr,
+ },
+ {
+ name: "only-port-16042-available",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ if port == FirstEphemeral+42 {
+ return true, nil
+ }
+ return false, nil
+ },
+ wantPort: FirstEphemeral + 42,
+ },
+ {
+ name: "only-port-under-16000-available",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ if port < FirstEphemeral {
+ return true, nil
+ }
+ return false, nil
+ },
+ wantErr: tcpip.ErrNoPortAvailable,
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ pm := NewPortManager()
+ portOffset := uint32(rand.Int31n(int32(numEphemeralPorts)))
+ if port, err := pm.PickEphemeralPortStable(portOffset, test.f); port != test.wantPort || err != test.wantErr {
+ t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 28c49e8ff..baf88bfab 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -36,6 +36,7 @@ go_library(
],
deps = [
"//pkg/ilist",
+ "//pkg/rand",
"//pkg/sleep",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
@@ -54,6 +55,7 @@ go_test(
size = "small",
srcs = [
"stack_test.go",
+ "transport_demuxer_test.go",
"transport_test.go",
],
deps = [
@@ -64,6 +66,9 @@ go_test(
"//pkg/tcpip/iptables",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/transport/udp",
"//pkg/waiter",
],
)
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 0e8a23f00..f6106f762 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -34,8 +34,6 @@ type NIC struct {
linkEP LinkEndpoint
loopback bool
- demux *transportDemuxer
-
mu sync.RWMutex
spoofing bool
promiscuous bool
@@ -85,7 +83,6 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback
name: name,
linkEP: ep,
loopback: loopback,
- demux: newTransportDemuxer(stack),
primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
mcastJoins: make(map[NetworkEndpointID]int32),
@@ -707,9 +704,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// Raw socket packets are delivered based solely on the transport
// protocol number. We do not inspect the payload to ensure it's
// validly formed.
- if !n.demux.deliverRawPacket(r, protocol, netHeader, vv) {
- n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv)
- }
+ n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv)
if len(vv.First()) < transProto.MinimumPacketSize() {
n.stack.stats.MalformedRcvdPackets.Increment()
@@ -723,9 +718,6 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
}
id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
- if n.demux.deliverPacket(r, protocol, netHeader, vv, id) {
- return
- }
if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) {
return
}
@@ -767,10 +759,7 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp
}
id := TransportEndpointID{srcPort, local, dstPort, remote}
- if n.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
- return
- }
- if n.stack.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
+ if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, vv, id) {
return
}
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 18d1704a5..90c2cf1be 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -20,10 +20,12 @@
package stack
import (
+ "encoding/binary"
"sync"
"time"
"golang.org/x/time/rate"
+ "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -388,6 +390,12 @@ type Stack struct {
// icmpRateLimiter is a global rate limiter for all ICMP messages generated
// by the stack.
icmpRateLimiter *ICMPRateLimiter
+
+ // portSeed is a one-time random value initialized at stack startup
+ // and is used to seed the TCP port picking on active connections
+ //
+ // TODO(gvisor.dev/issues/940): S/R this field.
+ portSeed uint32
}
// Options contains optional Stack configuration.
@@ -440,6 +448,7 @@ func New(opts Options) *Stack {
stats: opts.Stats.FillIn(),
handleLocal: opts.HandleLocal,
icmpRateLimiter: NewICMPRateLimiter(),
+ portSeed: generateRandUint32(),
}
// Add specified network protocols.
@@ -1033,73 +1042,27 @@ func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.
// transport dispatcher. Received packets that match the provided id will be
// delivered to the given endpoint; specifying a nic is optional, but
// nic-specific IDs have precedence over global ones.
-func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
- if nicID == 0 {
- return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort)
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic == nil {
- return tcpip.ErrUnknownNICID
- }
-
- return nic.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort)
+func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+ return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort, bindToDevice)
}
// UnregisterTransportEndpoint removes the endpoint with the given id from the
// stack transport dispatcher.
-func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) {
- if nicID == 0 {
- s.demux.unregisterEndpoint(netProtos, protocol, id, ep)
- return
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic != nil {
- nic.demux.unregisterEndpoint(netProtos, protocol, id, ep)
- }
+func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
+ s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice)
}
// RegisterRawTransportEndpoint registers the given endpoint with the stack
// transport dispatcher. Received packets that match the provided transport
// protocol will be delivered to the given endpoint.
func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
- if nicID == 0 {
- return s.demux.registerRawEndpoint(netProto, transProto, ep)
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic == nil {
- return tcpip.ErrUnknownNICID
- }
-
- return nic.demux.registerRawEndpoint(netProto, transProto, ep)
+ return s.demux.registerRawEndpoint(netProto, transProto, ep)
}
// UnregisterRawTransportEndpoint removes the endpoint for the transport
// protocol from the stack transport dispatcher.
func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
- if nicID == 0 {
- s.demux.unregisterRawEndpoint(netProto, transProto, ep)
- return
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic != nil {
- nic.demux.unregisterRawEndpoint(netProto, transProto, ep)
- }
+ s.demux.unregisterRawEndpoint(netProto, transProto, ep)
}
// RegisterRestoredEndpoint records e as an endpoint that has been restored on
@@ -1243,3 +1206,19 @@ func (s *Stack) SetICMPBurst(burst int) {
func (s *Stack) AllowICMPMessage() bool {
return s.icmpRateLimiter.Allow()
}
+
+// PortSeed returns a 32 bit value that can be used as a seed value for port
+// picking.
+//
+// NOTE: The seed is generated once during stack initialization only.
+func (s *Stack) PortSeed() uint32 {
+ return s.portSeed
+}
+
+func generateRandUint32() uint32 {
+ b := make([]byte, 4)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+ return binary.LittleEndian.Uint32(b)
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index cf8a6d129..8c768c299 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -35,25 +35,109 @@ type protocolIDs struct {
type transportEndpoints struct {
// mu protects all fields of the transportEndpoints.
mu sync.RWMutex
- endpoints map[TransportEndpointID]TransportEndpoint
+ endpoints map[TransportEndpointID]*endpointsByNic
// rawEndpoints contains endpoints for raw sockets, which receive all
// traffic of a given protocol regardless of port.
rawEndpoints []RawTransportEndpoint
}
+type endpointsByNic struct {
+ mu sync.RWMutex
+ endpoints map[tcpip.NICID]*multiPortEndpoint
+ // seed is a random secret for a jenkins hash.
+ seed uint32
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
+ epsByNic.mu.RLock()
+
+ mpep, ok := epsByNic.endpoints[r.ref.nic.ID()]
+ if !ok {
+ if mpep, ok = epsByNic.endpoints[0]; !ok {
+ epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ return
+ }
+ }
+
+ // If this is a broadcast or multicast datagram, deliver the datagram to all
+ // endpoints bound to the right device.
+ if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) {
+ mpep.handlePacketAll(r, id, vv)
+ epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ return
+ }
+
+ // multiPortEndpoints are guaranteed to have at least one element.
+ selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, vv)
+ epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) {
+ epsByNic.mu.RLock()
+ defer epsByNic.mu.RUnlock()
+
+ mpep, ok := epsByNic.endpoints[n.ID()]
+ if !ok {
+ mpep, ok = epsByNic.endpoints[0]
+ }
+ if !ok {
+ return
+ }
+
+ // TODO(eyalsoha): Why don't we look at id to see if this packet needs to
+ // broadcast like we are doing with handlePacket above?
+
+ // multiPortEndpoints are guaranteed to have at least one element.
+ selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, vv)
+}
+
+// registerEndpoint returns true if it succeeds. It fails and returns
+// false if ep already has an element with the same key.
+func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+ epsByNic.mu.Lock()
+ defer epsByNic.mu.Unlock()
+
+ if multiPortEp, ok := epsByNic.endpoints[bindToDevice]; ok {
+ // There was already a bind.
+ return multiPortEp.singleRegisterEndpoint(t, reusePort)
+ }
+
+ // This is a new binding.
+ multiPortEp := &multiPortEndpoint{}
+ multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
+ multiPortEp.reuse = reusePort
+ epsByNic.endpoints[bindToDevice] = multiPortEp
+ return multiPortEp.singleRegisterEndpoint(t, reusePort)
+}
+
+// unregisterEndpoint returns true if endpointsByNic has to be unregistered.
+func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool {
+ epsByNic.mu.Lock()
+ defer epsByNic.mu.Unlock()
+ multiPortEp, ok := epsByNic.endpoints[bindToDevice]
+ if !ok {
+ return false
+ }
+ if multiPortEp.unregisterEndpoint(t) {
+ delete(epsByNic.endpoints, bindToDevice)
+ }
+ return len(epsByNic.endpoints) == 0
+}
+
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
-func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint) {
+func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
eps.mu.Lock()
defer eps.mu.Unlock()
- e, ok := eps.endpoints[id]
+ epsByNic, ok := eps.endpoints[id]
if !ok {
return
}
- if multiPortEp, ok := e.(*multiPortEndpoint); ok {
- if !multiPortEp.unregisterEndpoint(ep) {
- return
- }
+ if !epsByNic.unregisterEndpoint(bindToDevice, ep) {
+ return
}
delete(eps.endpoints, id)
}
@@ -75,7 +159,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
for netProto := range stack.networkProtocols {
for proto := range stack.transportProtocols {
d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{
- endpoints: make(map[TransportEndpointID]TransportEndpoint),
+ endpoints: make(map[TransportEndpointID]*endpointsByNic),
}
}
}
@@ -85,10 +169,10 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
// registerEndpoint registers the given endpoint with the dispatcher such that
// packets that match the endpoint ID are delivered to it.
-func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
for i, n := range netProtos {
- if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort); err != nil {
- d.unregisterEndpoint(netProtos[:i], protocol, id, ep)
+ if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil {
+ d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice)
return err
}
}
@@ -97,13 +181,14 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum
}
// multiPortEndpoint is a container for TransportEndpoints which are bound to
-// the same pair of address and port.
+// the same pair of address and port. endpointsArr always has at least one
+// element.
type multiPortEndpoint struct {
mu sync.RWMutex
endpointsArr []TransportEndpoint
endpointsMap map[TransportEndpoint]int
- // seed is a random secret for a jenkins hash.
- seed uint32
+ // reuse indicates if more than one endpoint is allowed.
+ reuse bool
}
// reciprocalScale scales a value into range [0, n).
@@ -117,9 +202,10 @@ func reciprocalScale(val, n uint32) uint32 {
// selectEndpoint calculates a hash of destination and source addresses and
// ports then uses it to select a socket. In this case, all packets from one
// address will be sent to same endpoint.
-func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEndpoint {
- ep.mu.RLock()
- defer ep.mu.RUnlock()
+func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint {
+ if len(mpep.endpointsArr) == 1 {
+ return mpep.endpointsArr[0]
+ }
payload := []byte{
byte(id.LocalPort),
@@ -128,51 +214,50 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd
byte(id.RemotePort >> 8),
}
- h := jenkins.Sum32(ep.seed)
+ h := jenkins.Sum32(seed)
h.Write(payload)
h.Write([]byte(id.LocalAddress))
h.Write([]byte(id.RemoteAddress))
hash := h.Sum32()
- idx := reciprocalScale(hash, uint32(len(ep.endpointsArr)))
- return ep.endpointsArr[idx]
+ idx := reciprocalScale(hash, uint32(len(mpep.endpointsArr)))
+ return mpep.endpointsArr[idx]
}
-// HandlePacket is called by the stack when new packets arrive to this transport
-// endpoint.
-func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
- // If this is a broadcast or multicast datagram, deliver the datagram to all
- // endpoints managed by ep.
- if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) {
- for i, endpoint := range ep.endpointsArr {
- // HandlePacket modifies vv, so each endpoint needs its own copy.
- if i == len(ep.endpointsArr)-1 {
- endpoint.HandlePacket(r, id, vv)
- break
- }
- vvCopy := buffer.NewView(vv.Size())
- copy(vvCopy, vv.ToView())
- endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView())
+func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
+ ep.mu.RLock()
+ for i, endpoint := range ep.endpointsArr {
+ // HandlePacket modifies vv, so each endpoint needs its own copy except for
+ // the final one.
+ if i == len(ep.endpointsArr)-1 {
+ endpoint.HandlePacket(r, id, vv)
+ break
}
- } else {
- ep.selectEndpoint(id).HandlePacket(r, id, vv)
+ vvCopy := buffer.NewView(vv.Size())
+ copy(vvCopy, vv.ToView())
+ endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView())
}
+ ep.mu.RUnlock() // Don't use defer for performance reasons.
}
-// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (ep *multiPortEndpoint) HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) {
- ep.selectEndpoint(id).HandleControlPacket(id, typ, extra, vv)
-}
-
-func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint) {
+// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
+// list. The list might be empty already.
+func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- // A new endpoint is added into endpointsArr and its index there is
- // saved in endpointsMap. This will allows to remove endpoint from
- // the array fast.
+ if len(ep.endpointsArr) > 0 {
+ // If it was previously bound, we need to check if we can bind again.
+ if !ep.reuse || !reusePort {
+ return tcpip.ErrPortInUse
+ }
+ }
+
+ // A new endpoint is added into endpointsArr and its index there is saved in
+ // endpointsMap. This will allow us to remove endpoint from the array fast.
ep.endpointsMap[t] = len(ep.endpointsArr)
ep.endpointsArr = append(ep.endpointsArr, t)
+ return nil
}
// unregisterEndpoint returns true if multiPortEndpoint has to be unregistered.
@@ -197,53 +282,41 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool {
return true
}
-func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
if id.RemotePort != 0 {
+ // TODO(eyalsoha): Why?
reusePort = false
}
eps, ok := d.protocol[protocolIDs{netProto, protocol}]
if !ok {
- return nil
+ return tcpip.ErrUnknownProtocol
}
eps.mu.Lock()
defer eps.mu.Unlock()
- var multiPortEp *multiPortEndpoint
- if _, ok := eps.endpoints[id]; ok {
- if !reusePort {
- return tcpip.ErrPortInUse
- }
- multiPortEp, ok = eps.endpoints[id].(*multiPortEndpoint)
- if !ok {
- return tcpip.ErrPortInUse
- }
+ if epsByNic, ok := eps.endpoints[id]; ok {
+ // There was already a binding.
+ return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
}
- if reusePort {
- if multiPortEp == nil {
- multiPortEp = &multiPortEndpoint{}
- multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
- multiPortEp.seed = rand.Uint32()
- eps.endpoints[id] = multiPortEp
- }
-
- multiPortEp.singleRegisterEndpoint(ep)
-
- return nil
+ // This is a new binding.
+ epsByNic := &endpointsByNic{
+ endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
+ seed: rand.Uint32(),
}
- eps.endpoints[id] = ep
+ eps.endpoints[id] = epsByNic
- return nil
+ return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
-func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) {
+func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
for _, n := range netProtos {
if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
- eps.unregisterEndpoint(id, ep)
+ eps.unregisterEndpoint(id, ep, bindToDevice)
}
}
}
@@ -273,7 +346,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// If the packet is a broadcast, then find all matching transport endpoints.
// Otherwise, try to find a single matching transport endpoint.
- destEps := make([]TransportEndpoint, 0, 1)
+ destEps := make([]*endpointsByNic, 0, 1)
eps.mu.RLock()
if protocol == header.UDPProtocolNumber && id.LocalAddress == header.IPv4Broadcast {
@@ -299,7 +372,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// Deliver the packet.
for _, ep := range destEps {
- ep.HandlePacket(r, id, vv)
+ ep.handlePacket(r, id, vv)
}
return true
@@ -331,7 +404,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr
// deliverControlPacket attempts to deliver the given control packet. Returns
// true if it found an endpoint, false otherwise.
-func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{net, trans}]
if !ok {
return false
@@ -348,12 +421,12 @@ func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber,
}
// Deliver the packet.
- ep.HandleControlPacket(id, typ, extra, vv)
+ ep.handleControlPacket(n, id, typ, extra, vv)
return true
}
-func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint {
+func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic {
// Try to find a match with the id as provided.
if ep, ok := eps.endpoints[id]; ok {
return ep
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
new file mode 100644
index 000000000..210233dc0
--- /dev/null
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -0,0 +1,352 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack_test
+
+import (
+ "math"
+ "math/rand"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+
+ stackAddr = "\x0a\x00\x00\x01"
+ stackPort = 1234
+ testPort = 4096
+)
+
+type testContext struct {
+ t *testing.T
+ linkEPs map[string]*channel.Endpoint
+ s *stack.Stack
+
+ ep tcpip.Endpoint
+ wq waiter.Queue
+}
+
+func (c *testContext) cleanup() {
+ if c.ep != nil {
+ c.ep.Close()
+ }
+}
+
+func (c *testContext) createV6Endpoint(v6only bool) {
+ var err *tcpip.Error
+ c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ var v tcpip.V6OnlyOption
+ if v6only {
+ v = 1
+ }
+ if err := c.ep.SetSockOpt(v); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+}
+
+// newDualTestContextMultiNic creates the testing context and also linkEpNames
+// named NICs.
+func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
+ linkEPs := make(map[string]*channel.Endpoint)
+ for i, linkEpName := range linkEpNames {
+ channelEP := channel.New(256, mtu, "")
+ nicid := tcpip.NICID(i + 1)
+ if err := s.CreateNamedNIC(nicid, linkEpName, channelEP); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+ linkEPs[linkEpName] = channelEP
+
+ if err := s.AddAddress(nicid, ipv4.ProtocolNumber, stackAddr); err != nil {
+ t.Fatalf("AddAddress IPv4 failed: %v", err)
+ }
+
+ if err := s.AddAddress(nicid, ipv6.ProtocolNumber, stackV6Addr); err != nil {
+ t.Fatalf("AddAddress IPv6 failed: %v", err)
+ }
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: 1,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: 1,
+ },
+ })
+
+ return &testContext{
+ t: t,
+ s: s,
+ linkEPs: linkEPs,
+ }
+}
+
+type headers struct {
+ srcPort uint16
+ dstPort uint16
+}
+
+func newPayload() []byte {
+ b := make([]byte, 30+rand.Intn(100))
+ for i := range b {
+ b[i] = byte(rand.Intn(256))
+ }
+ return b
+}
+
+func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) {
+ // Allocate a buffer for data and headers.
+ buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
+ copy(buf[len(buf)-len(payload):], payload)
+
+ // Initialize the IP header.
+ ip := header.IPv6(buf)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: 65,
+ SrcAddr: testV6Addr,
+ DstAddr: stackV6Addr,
+ })
+
+ // Initialize the UDP header.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.Encode(&header.UDPFields{
+ SrcPort: h.srcPort,
+ DstPort: h.dstPort,
+ Length: uint16(header.UDPMinimumSize + len(payload)),
+ })
+
+ // Calculate the UDP pseudo-header checksum.
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u)))
+
+ // Calculate the UDP checksum and set it.
+ xsum = header.Checksum(payload, xsum)
+ u.SetChecksum(^u.CalculateChecksum(xsum))
+
+ // Inject packet.
+ c.linkEPs[linkEpName].Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
+}
+
+func TestTransportDemuxerRegister(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ want *tcpip.Error
+ }{
+ {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol},
+ {"success", ipv4.ProtocolNumber, nil},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
+ if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want {
+ t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want)
+ }
+ })
+ }
+}
+
+// TestReuseBindToDevice injects varied packets on input devices and checks that
+// the distribution of packets received matches expectations.
+func TestDistribution(t *testing.T) {
+ type endpointSockopts struct {
+ reuse int
+ bindToDevice string
+ }
+ for _, test := range []struct {
+ name string
+ // endpoints will received the inject packets.
+ endpoints []endpointSockopts
+ // wantedDistribution is the wanted ratio of packets received on each
+ // endpoint for each NIC on which packets are injected.
+ wantedDistributions map[string][]float64
+ }{
+ {
+ "BindPortReuse",
+ // 5 endpoints that all have reuse set.
+ []endpointSockopts{
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ },
+ map[string][]float64{
+ // Injected packets on dev0 get distributed evenly.
+ "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2},
+ },
+ },
+ {
+ "BindToDevice",
+ // 3 endpoints with various bindings.
+ []endpointSockopts{
+ endpointSockopts{0, "dev0"},
+ endpointSockopts{0, "dev1"},
+ endpointSockopts{0, "dev2"},
+ },
+ map[string][]float64{
+ // Injected packets on dev0 go only to the endpoint bound to dev0.
+ "dev0": []float64{1, 0, 0},
+ // Injected packets on dev1 go only to the endpoint bound to dev1.
+ "dev1": []float64{0, 1, 0},
+ // Injected packets on dev2 go only to the endpoint bound to dev2.
+ "dev2": []float64{0, 0, 1},
+ },
+ },
+ {
+ "ReuseAndBindToDevice",
+ // 6 endpoints with various bindings.
+ []endpointSockopts{
+ endpointSockopts{1, "dev0"},
+ endpointSockopts{1, "dev0"},
+ endpointSockopts{1, "dev1"},
+ endpointSockopts{1, "dev1"},
+ endpointSockopts{1, "dev1"},
+ endpointSockopts{1, ""},
+ },
+ map[string][]float64{
+ // Injected packets on dev0 get distributed among endpoints bound to
+ // dev0.
+ "dev0": []float64{0.5, 0.5, 0, 0, 0, 0},
+ // Injected packets on dev1 get distributed among endpoints bound to
+ // dev1 or unbound.
+ "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
+ // Injected packets on dev999 go only to the unbound.
+ "dev999": []float64{0, 0, 0, 0, 0, 1},
+ },
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ for device, wantedDistribution := range test.wantedDistributions {
+ t.Run(device, func(t *testing.T) {
+ var devices []string
+ for d := range test.wantedDistributions {
+ devices = append(devices, d)
+ }
+ c := newDualTestContextMultiNic(t, defaultMTU, devices)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ eps := make(map[tcpip.Endpoint]int)
+
+ pollChannel := make(chan tcpip.Endpoint)
+ for i, endpoint := range test.endpoints {
+ // Try to receive the data.
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+
+ var err *tcpip.Error
+ ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ eps[ep] = i
+
+ go func(ep tcpip.Endpoint) {
+ for range ch {
+ pollChannel <- ep
+ }
+ }(ep)
+
+ defer ep.Close()
+ reusePortOption := tcpip.ReusePortOption(endpoint.reuse)
+ if err := ep.SetSockOpt(reusePortOption); err != nil {
+ c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err)
+ }
+ bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
+ if err := ep.SetSockOpt(bindToDeviceOption); err != nil {
+ c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil {
+ t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err)
+ }
+ }
+
+ npackets := 100000
+ nports := 10000
+ if got, want := len(test.endpoints), len(wantedDistribution); got != want {
+ t.Fatalf("got len(test.endpoints) = %d, want %d", got, want)
+ }
+ ports := make(map[uint16]tcpip.Endpoint)
+ stats := make(map[tcpip.Endpoint]int)
+ for i := 0; i < npackets; i++ {
+ // Send a packet.
+ port := uint16(i % nports)
+ payload := newPayload()
+ c.sendV6Packet(payload,
+ &headers{
+ srcPort: testPort + port,
+ dstPort: stackPort},
+ device)
+
+ var addr tcpip.FullAddress
+ ep := <-pollChannel
+ _, _, err := ep.Read(&addr)
+ if err != nil {
+ c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err)
+ }
+ stats[ep]++
+ if i < nports {
+ ports[uint16(i)] = ep
+ } else {
+ // Check that all packets from one client are handled by the same
+ // socket.
+ if want, got := ports[port], ep; want != got {
+ t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got])
+ }
+ }
+ }
+
+ // Check that a packet distribution is as expected.
+ for ep, i := range eps {
+ wantedRatio := wantedDistribution[i]
+ wantedRecv := wantedRatio * float64(npackets)
+ actualRecv := stats[ep]
+ actualRatio := float64(stats[ep]) / float64(npackets)
+ // The deviation is less than 10%.
+ if math.Abs(actualRatio-wantedRatio) > 0.05 {
+ t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets)
+ }
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 56e8a5d9b..842a16277 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -127,7 +127,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Try to register so that we can start receiving packets.
f.id.RemoteAddress = addr.Addr
- err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false)
+ err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false /* reuse */, 0 /* bindToDevice */)
if err != nil {
return err
}
@@ -168,7 +168,8 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
fakeTransNumber,
stack.TransportEndpointID{LocalAddress: a.Addr},
f,
- false,
+ false, /* reuse */
+ 0, /* bindtoDevice */
); err != nil {
return err
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index c021c67ac..faaa4a4e3 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -495,6 +495,10 @@ type ReuseAddressOption int
// to be bound to an identical socket address.
type ReusePortOption int
+// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets
+// should bind only on a specific NIC.
+type BindToDeviceOption string
+
// QuickAckOption is stubbed out in SetSockOpt/GetSockOpt.
type QuickAckOption int
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index a111fdb2a..a3a910d41 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -104,7 +104,7 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e, 0 /* bindToDevice */)
}
// Close the receive list and drain it.
@@ -543,14 +543,14 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false /* reuse */, 0 /* bindToDevice */)
return id, err
}
// We need to find a port for the endpoint.
_, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false /* reuse */, 0 /* bindtodevice */)
switch err {
case nil:
return true, nil
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 39a839ab7..a42e1f4a2 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -49,6 +49,7 @@ go_library(
"//pkg/sleep",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
"//pkg/tcpip/iptables",
"//pkg/tcpip/seqnum",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 0802e984e..3ae4a5426 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -242,7 +242,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n.initGSO()
// Register new endpoint so that packets are routed to it.
- if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort); err != nil {
+ if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort, n.bindToDevice); err != nil {
n.Close()
return nil, err
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 35b489c68..f9d5e0085 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -15,6 +15,7 @@
package tcp
import (
+ "encoding/binary"
"fmt"
"math"
"strings"
@@ -26,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
@@ -280,6 +282,9 @@ type endpoint struct {
// reusePort is set to true if SO_REUSEPORT is enabled.
reusePort bool
+ // bindToDevice is set to the NIC on which to bind or disabled if 0.
+ bindToDevice tcpip.NICID
+
// delay enables Nagle's algorithm.
//
// delay is a boolean (0 is false) and must be accessed atomically.
@@ -564,11 +569,11 @@ func (e *endpoint) Close() {
// in Listen() when trying to register.
if e.state == StateListen && e.isPortReserved {
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
e.isRegistered = false
}
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice)
e.isPortReserved = false
}
@@ -625,12 +630,12 @@ func (e *endpoint) cleanupLocked() {
e.workerCleanup = false
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
e.isRegistered = false
}
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice)
e.isPortReserved = false
}
@@ -1060,6 +1065,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
return nil
+ case tcpip.BindToDeviceOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if v == "" {
+ e.bindToDevice = 0
+ return nil
+ }
+ for nicid, nic := range e.stack.NICInfo() {
+ if nic.Name == string(v) {
+ e.bindToDevice = nicid
+ return nil
+ }
+ }
+ return tcpip.ErrUnknownDevice
+
case tcpip.QuickAckOption:
if v == 0 {
atomic.StoreUint32(&e.slowAck, 1)
@@ -1260,6 +1280,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.BindToDeviceOption:
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
+ *o = tcpip.BindToDeviceOption(nic.Name)
+ return nil
+ }
+ *o = ""
+ return nil
+
case *tcpip.QuickAckOption:
*o = 1
if v := atomic.LoadUint32(&e.slowAck); v != 0 {
@@ -1466,7 +1496,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
if e.id.LocalPort != 0 {
// The endpoint is bound to a port, attempt to register it.
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort, e.bindToDevice)
if err != nil {
return err
}
@@ -1476,17 +1506,32 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
// address/port for both local and remote (otherwise this
// endpoint would be trying to connect to itself).
sameAddr := e.id.LocalAddress == e.id.RemoteAddress
- if _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
+
+ // Calculate a port offset based on the destination IP/port and
+ // src IP to ensure that for a given tuple (srcIP, destIP,
+ // destPort) the offset used as a starting point is the same to
+ // ensure that we can cycle through the port space effectively.
+ h := jenkins.Sum32(e.stack.PortSeed())
+ h.Write([]byte(e.id.LocalAddress))
+ h.Write([]byte(e.id.RemoteAddress))
+ portBuf := make([]byte, 2)
+ binary.LittleEndian.PutUint16(portBuf, e.id.RemotePort)
+ h.Write(portBuf)
+ portOffset := h.Sum32()
+
+ if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) {
if sameAddr && p == e.id.RemotePort {
return false, nil
}
- if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false) {
+ // reusePort is false below because connect cannot reuse a port even if
+ // reusePort was set.
+ if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false /* reusePort */, e.bindToDevice) {
return false, nil
}
id := e.id
id.LocalPort = p
- switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) {
+ switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) {
case nil:
e.id = id
return true, nil
@@ -1504,7 +1549,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
// before Connect: in such a case we don't want to hold on to
// reservations anymore.
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.bindToDevice)
e.isPortReserved = false
}
@@ -1648,7 +1693,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) {
}
// Register the endpoint.
- if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort); err != nil {
+ if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort, e.bindToDevice); err != nil {
return err
}
@@ -1729,7 +1774,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
}
}
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort, e.bindToDevice)
if err != nil {
return err
}
@@ -1739,16 +1784,16 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
e.id.LocalPort = port
// Any failures beyond this point must remove the port registration.
- defer func() {
+ defer func(bindToDevice tcpip.NICID) {
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, bindToDevice)
e.isPortReserved = false
e.effectiveNetProtos = nil
e.id.LocalPort = 0
e.id.LocalAddress = ""
e.boundNICID = 0
}
- }()
+ }(e.bindToDevice)
// If an address is specified, we must ensure that it's one of our
// local addresses.
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 2be094876..089826a88 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -465,6 +465,66 @@ func TestSimpleReceive(t *testing.T) {
)
}
+func TestConnectBindToDevice(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ device string
+ want tcp.EndpointState
+ }{
+ {"RightDevice", "nic1", tcp.StateEstablished},
+ {"WrongDevice", "nic2", tcp.StateSynSent},
+ {"AnyDevice", "", tcp.StateEstablished},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+ bindToDevice := tcpip.BindToDeviceOption(test.device)
+ c.EP.SetSockOpt(bindToDevice)
+ // Start connection attempt.
+ waitEntry, _ := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Receive SYN packet.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ iss := seqnum.Value(789)
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: tcpHdr.DestinationPort(),
+ DstPort: tcpHdr.SourcePort(),
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: rcvWnd,
+ TCPOpts: nil,
+ })
+
+ c.GetPacket()
+ if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want {
+ t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+ })
+ }
+}
+
func TestOutOfOrderReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -2970,6 +3030,62 @@ func TestMinMaxBufferSizes(t *testing.T) {
checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30)
}
+func TestBindToDeviceOption(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}})
+
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
+ }
+ defer ep.Close()
+
+ if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil {
+ t.Errorf("CreateNamedNIC failed: %v", err)
+ }
+
+ // Make an nameless NIC.
+ if err := s.CreateNIC(54321, loopback.New()); err != nil {
+ t.Errorf("CreateNIC failed: %v", err)
+ }
+
+ // strPtr is used instead of taking the address of string literals, which is
+ // a compiler error.
+ strPtr := func(s string) *string {
+ return &s
+ }
+
+ testActions := []struct {
+ name string
+ setBindToDevice *string
+ setBindToDeviceError *tcpip.Error
+ getBindToDevice tcpip.BindToDeviceOption
+ }{
+ {"GetDefaultValue", nil, nil, ""},
+ {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
+ {"BindToExistent", strPtr("my_device"), nil, "my_device"},
+ {"UnbindToDevice", strPtr(""), nil, ""},
+ }
+ for _, testAction := range testActions {
+ t.Run(testAction.name, func(t *testing.T) {
+ if testAction.setBindToDevice != nil {
+ bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
+ if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ }
+ }
+ bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
+ if ep.GetSockOpt(&bindToDevice) != nil {
+ t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ }
+ if got, want := bindToDevice, testAction.getBindToDevice; got != want {
+ t.Errorf("bindToDevice got %q, want %q", got, want)
+ }
+ })
+ }
+}
+
func makeStack() (*stack.Stack, *tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index d3f1d2cdf..ef823e4ae 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -158,7 +158,14 @@ func New(t *testing.T, mtu uint32) *Context {
if testing.Verbose() {
wep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, wep); err != nil {
+ if err := s.CreateNamedNIC(1, "nic1", wep); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+ wep2 := stack.LinkEndpoint(channel.New(1000, mtu, ""))
+ if testing.Verbose() {
+ wep2 = sniffer.New(channel.New(1000, mtu, ""))
+ }
+ if err := s.CreateNamedNIC(2, "nic2", wep2); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -588,12 +595,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte)
c.Port = tcpHdr.SourcePort()
}
-// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
-// the specified option bytes as the Option field in the initial SYN packet.
-//
-// It also sets the receive buffer for the endpoint to the specified
-// value in epRcvBuf.
-func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) {
+// Create creates a TCP endpoint.
+func (c *Context) Create(epRcvBuf int) {
// Create TCP endpoint.
var err *tcpip.Error
c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
@@ -606,6 +609,15 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.
c.t.Fatalf("SetSockOpt failed failed: %v", err)
}
}
+}
+
+// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
+// the specified option bytes as the Option field in the initial SYN packet.
+//
+// It also sets the receive buffer for the endpoint to the specified
+// value in epRcvBuf.
+func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) {
+ c.Create(epRcvBuf)
c.Connect(iss, rcvWnd, options)
}
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index c1ca22b35..7a635ab8d 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -52,6 +52,7 @@ go_test(
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/loopback",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 0bec7e62d..52f5af777 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -88,6 +88,7 @@ type endpoint struct {
multicastNICID tcpip.NICID
multicastLoop bool
reusePort bool
+ bindToDevice tcpip.NICID
broadcast bool
// shutdownFlags represent the current shutdown state of the endpoint.
@@ -144,8 +145,8 @@ func (e *endpoint) Close() {
switch e.state {
case StateBound, StateConnected:
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice)
}
for _, mem := range e.multicastMemberships {
@@ -551,6 +552,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.reusePort = v != 0
e.mu.Unlock()
+ case tcpip.BindToDeviceOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if v == "" {
+ e.bindToDevice = 0
+ return nil
+ }
+ for nicid, nic := range e.stack.NICInfo() {
+ if nic.Name == string(v) {
+ e.bindToDevice = nicid
+ return nil
+ }
+ }
+ return tcpip.ErrUnknownDevice
+
case tcpip.BroadcastOption:
e.mu.Lock()
e.broadcast = v != 0
@@ -646,6 +662,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.BindToDeviceOption:
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
+ *o = tcpip.BindToDeviceOption(nic.Name)
+ return nil
+ }
+ *o = tcpip.BindToDeviceOption("")
+ return nil
+
case *tcpip.KeepaliveEnabledOption:
*o = 0
return nil
@@ -753,12 +779,12 @@ func (e *endpoint) Disconnect() *tcpip.Error {
} else {
if e.id.LocalPort != 0 {
// Release the ephemeral port.
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice)
}
e.state = StateInitial
}
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
e.id = id
e.route.Release()
e.route = stack.Route{}
@@ -835,7 +861,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Remove the old registration.
if e.id.LocalPort != 0 {
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
}
e.id = id
@@ -898,16 +924,16 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
if e.id.LocalPort == 0 {
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice)
if err != nil {
return id, err
}
id.LocalPort = port
}
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice)
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.bindToDevice)
}
return id, err
}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index a9edc2c8d..2d0bc5221 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -74,7 +74,7 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID {
// CreateEndpoint creates a connected UDP endpoint for the session request.
func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
ep := newEndpoint(r.stack, r.route.NetProto, queue)
- if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort); err != nil {
+ if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort, ep.bindToDevice); err != nil {
ep.Close()
return nil, err
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 2ec27be4d..5059ca22d 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -17,7 +17,6 @@ package udp_test
import (
"bytes"
"fmt"
- "math"
"math/rand"
"testing"
"time"
@@ -27,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
@@ -476,87 +476,59 @@ func newMinPayload(minSize int) []byte {
return b
}
-func TestBindPortReuse(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- var eps [5]tcpip.Endpoint
- reusePortOpt := tcpip.ReusePortOption(1)
-
- pollChannel := make(chan tcpip.Endpoint)
- for i := 0; i < len(eps); i++ {
- // Try to receive the data.
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
- defer close(ch)
-
- var err *tcpip.Error
- eps[i], err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- go func(ep tcpip.Endpoint) {
- for range ch {
- pollChannel <- ep
- }
- }(eps[i])
+func TestBindToDeviceOption(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
- defer eps[i].Close()
- if err := eps[i].SetSockOpt(reusePortOpt); err != nil {
- c.t.Fatalf("SetSockOpt failed failed: %v", err)
- }
- if err := eps[i].Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
- }
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
}
+ defer ep.Close()
- npackets := 100000
- nports := 10000
- ports := make(map[uint16]tcpip.Endpoint)
- stats := make(map[tcpip.Endpoint]int)
- for i := 0; i < npackets; i++ {
- // Send a packet.
- port := uint16(i % nports)
- payload := newPayload()
- c.injectV6Packet(payload, &header4Tuple{
- srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort + port},
- dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
- })
+ if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil {
+ t.Errorf("CreateNamedNIC failed: %v", err)
+ }
- var addr tcpip.FullAddress
- ep := <-pollChannel
- _, _, err := ep.Read(&addr)
- if err != nil {
- c.t.Fatalf("Read failed: %v", err)
- }
- stats[ep]++
- if i < nports {
- ports[uint16(i)] = ep
- } else {
- // Check that all packets from one client are handled
- // by the same socket.
- if ports[port] != ep {
- t.Fatalf("Port mismatch")
- }
- }
+ // Make an nameless NIC.
+ if err := s.CreateNIC(54321, loopback.New()); err != nil {
+ t.Errorf("CreateNIC failed: %v", err)
}
- if len(stats) != len(eps) {
- t.Fatalf("Only %d(expected %d) sockets received packets", len(stats), len(eps))
+ // strPtr is used instead of taking the address of string literals, which is
+ // a compiler error.
+ strPtr := func(s string) *string {
+ return &s
}
- // Check that a packet distribution is fair between sockets.
- for _, c := range stats {
- n := float64(npackets) / float64(len(eps))
- // The deviation is less than 10%.
- if math.Abs(float64(c)-n) > n/10 {
- t.Fatal(c, n)
- }
+ testActions := []struct {
+ name string
+ setBindToDevice *string
+ setBindToDeviceError *tcpip.Error
+ getBindToDevice tcpip.BindToDeviceOption
+ }{
+ {"GetDefaultValue", nil, nil, ""},
+ {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
+ {"BindToExistent", strPtr("my_device"), nil, "my_device"},
+ {"UnbindToDevice", strPtr(""), nil, ""},
+ }
+ for _, testAction := range testActions {
+ t.Run(testAction.name, func(t *testing.T) {
+ if testAction.setBindToDevice != nil {
+ bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
+ if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ }
+ }
+ bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
+ if ep.GetSockOpt(&bindToDevice) != nil {
+ t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ }
+ if got, want := bindToDevice, testAction.getBindToDevice; got != want {
+ t.Errorf("bindToDevice got %q, want %q", got, want)
+ }
+ })
}
}
diff --git a/runsc/BUILD b/runsc/BUILD
index 5e7dacb87..a3a0d6730 100644
--- a/runsc/BUILD
+++ b/runsc/BUILD
@@ -1,7 +1,7 @@
package(licenses = ["notice"]) # Apache 2.0
load("@io_bazel_rules_go//go:def.bzl", "go_binary")
-load("@bazel_tools//tools/build_defs/pkg:pkg.bzl", "pkg_deb", "pkg_tar")
+load("@rules_pkg//:pkg.bzl", "pkg_deb", "pkg_tar")
go_binary(
name = "runsc",
diff --git a/runsc/cmd/exec.go b/runsc/cmd/exec.go
index bf1225e1c..d1e99243b 100644
--- a/runsc/cmd/exec.go
+++ b/runsc/cmd/exec.go
@@ -105,11 +105,11 @@ func (ex *Exec) SetFlags(f *flag.FlagSet) {
// Execute implements subcommands.Command.Execute. It starts a process in an
// already created container.
func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
- e, id, err := ex.parseArgs(f)
+ conf := args[0].(*boot.Config)
+ e, id, err := ex.parseArgs(f, conf.EnableRaw)
if err != nil {
Fatalf("parsing process spec: %v", err)
}
- conf := args[0].(*boot.Config)
waitStatus := args[1].(*syscall.WaitStatus)
c, err := container.Load(conf.RootDir, id)
@@ -117,6 +117,9 @@ func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
Fatalf("loading sandbox: %v", err)
}
+ log.Debugf("Exec arguments: %+v", e)
+ log.Debugf("Exec capablities: %+v", e.Capabilities)
+
// Replace empty settings with defaults from container.
if e.WorkingDirectory == "" {
e.WorkingDirectory = c.Spec.Process.Cwd
@@ -129,14 +132,11 @@ func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
}
if e.Capabilities == nil {
- // enableRaw is set to true to prevent the filtering out of
- // CAP_NET_RAW. This is the opposite of Create() because exec
- // requires the capability to be set explicitly, while 'docker
- // run' sets it by default.
- e.Capabilities, err = specutils.Capabilities(true /* enableRaw */, c.Spec.Process.Capabilities)
+ e.Capabilities, err = specutils.Capabilities(conf.EnableRaw, c.Spec.Process.Capabilities)
if err != nil {
Fatalf("creating capabilities: %v", err)
}
+ log.Infof("Using exec capabilities from container: %+v", e.Capabilities)
}
// containerd expects an actual process to represent the container being
@@ -283,14 +283,14 @@ func (ex *Exec) execChildAndWait(waitStatus *syscall.WaitStatus) subcommands.Exi
// parseArgs parses exec information from the command line or a JSON file
// depending on whether the --process flag was used. Returns an ExecArgs and
// the ID of the container to be used.
-func (ex *Exec) parseArgs(f *flag.FlagSet) (*control.ExecArgs, string, error) {
+func (ex *Exec) parseArgs(f *flag.FlagSet, enableRaw bool) (*control.ExecArgs, string, error) {
if ex.processPath == "" {
// Requires at least a container ID and command.
if f.NArg() < 2 {
f.Usage()
return nil, "", fmt.Errorf("both a container-id and command are required")
}
- e, err := ex.argsFromCLI(f.Args()[1:])
+ e, err := ex.argsFromCLI(f.Args()[1:], enableRaw)
return e, f.Arg(0), err
}
// Requires only the container ID.
@@ -298,11 +298,11 @@ func (ex *Exec) parseArgs(f *flag.FlagSet) (*control.ExecArgs, string, error) {
f.Usage()
return nil, "", fmt.Errorf("a container-id is required")
}
- e, err := ex.argsFromProcessFile()
+ e, err := ex.argsFromProcessFile(enableRaw)
return e, f.Arg(0), err
}
-func (ex *Exec) argsFromCLI(argv []string) (*control.ExecArgs, error) {
+func (ex *Exec) argsFromCLI(argv []string, enableRaw bool) (*control.ExecArgs, error) {
extraKGIDs := make([]auth.KGID, 0, len(ex.extraKGIDs))
for _, s := range ex.extraKGIDs {
kgid, err := strconv.Atoi(s)
@@ -315,7 +315,7 @@ func (ex *Exec) argsFromCLI(argv []string) (*control.ExecArgs, error) {
var caps *auth.TaskCapabilities
if len(ex.caps) > 0 {
var err error
- caps, err = capabilities(ex.caps)
+ caps, err = capabilities(ex.caps, enableRaw)
if err != nil {
return nil, fmt.Errorf("capabilities error: %v", err)
}
@@ -333,7 +333,7 @@ func (ex *Exec) argsFromCLI(argv []string) (*control.ExecArgs, error) {
}, nil
}
-func (ex *Exec) argsFromProcessFile() (*control.ExecArgs, error) {
+func (ex *Exec) argsFromProcessFile(enableRaw bool) (*control.ExecArgs, error) {
f, err := os.Open(ex.processPath)
if err != nil {
return nil, fmt.Errorf("error opening process file: %s, %v", ex.processPath, err)
@@ -343,21 +343,21 @@ func (ex *Exec) argsFromProcessFile() (*control.ExecArgs, error) {
if err := json.NewDecoder(f).Decode(&p); err != nil {
return nil, fmt.Errorf("error parsing process file: %s, %v", ex.processPath, err)
}
- return argsFromProcess(&p)
+ return argsFromProcess(&p, enableRaw)
}
// argsFromProcess performs all the non-IO conversion from the Process struct
// to ExecArgs.
-func argsFromProcess(p *specs.Process) (*control.ExecArgs, error) {
+func argsFromProcess(p *specs.Process, enableRaw bool) (*control.ExecArgs, error) {
// Create capabilities.
var caps *auth.TaskCapabilities
if p.Capabilities != nil {
var err error
- // enableRaw is set to true to prevent the filtering out of
- // CAP_NET_RAW. This is the opposite of Create() because exec
- // requires the capability to be set explicitly, while 'docker
- // run' sets it by default.
- caps, err = specutils.Capabilities(true /* enableRaw */, p.Capabilities)
+ // Starting from Docker 19, capabilities are explicitly set for exec (instead
+ // of nil like before). So we can't distinguish 'exec' from
+ // 'exec --privileged', as both specify CAP_NET_RAW. Therefore, filter
+ // CAP_NET_RAW in the same way as container start.
+ caps, err = specutils.Capabilities(enableRaw, p.Capabilities)
if err != nil {
return nil, fmt.Errorf("error creating capabilities: %v", err)
}
@@ -410,7 +410,7 @@ func resolveEnvs(envs ...[]string) ([]string, error) {
// capabilities takes a list of capabilities as strings and returns an
// auth.TaskCapabilities struct with those capabilities in every capability set.
// This mimics runc's behavior.
-func capabilities(cs []string) (*auth.TaskCapabilities, error) {
+func capabilities(cs []string, enableRaw bool) (*auth.TaskCapabilities, error) {
var specCaps specs.LinuxCapabilities
for _, cap := range cs {
specCaps.Ambient = append(specCaps.Ambient, cap)
@@ -419,11 +419,11 @@ func capabilities(cs []string) (*auth.TaskCapabilities, error) {
specCaps.Inheritable = append(specCaps.Inheritable, cap)
specCaps.Permitted = append(specCaps.Permitted, cap)
}
- // enableRaw is set to true to prevent the filtering out of
- // CAP_NET_RAW. This is the opposite of Create() because exec requires
- // the capability to be set explicitly, while 'docker run' sets it by
- // default.
- return specutils.Capabilities(true /* enableRaw */, &specCaps)
+ // Starting from Docker 19, capabilities are explicitly set for exec (instead
+ // of nil like before). So we can't distinguish 'exec' from
+ // 'exec --privileged', as both specify CAP_NET_RAW. Therefore, filter
+ // CAP_NET_RAW in the same way as container start.
+ return specutils.Capabilities(enableRaw, &specCaps)
}
// stringSlice allows a flag to be used multiple times, where each occurrence
diff --git a/runsc/cmd/exec_test.go b/runsc/cmd/exec_test.go
index eb38a431f..a1e980d08 100644
--- a/runsc/cmd/exec_test.go
+++ b/runsc/cmd/exec_test.go
@@ -91,7 +91,7 @@ func TestCLIArgs(t *testing.T) {
}
for _, tc := range testCases {
- e, err := tc.ex.argsFromCLI(tc.argv)
+ e, err := tc.ex.argsFromCLI(tc.argv, true)
if err != nil {
t.Errorf("argsFromCLI(%+v): got error: %+v", tc.ex, err)
} else if !cmp.Equal(*e, tc.expected, cmpopts.IgnoreUnexported(os.File{})) {
@@ -144,7 +144,7 @@ func TestJSONArgs(t *testing.T) {
}
for _, tc := range testCases {
- e, err := argsFromProcess(&tc.p)
+ e, err := argsFromProcess(&tc.p, true)
if err != nil {
t.Errorf("argsFromProcess(%+v): got error: %+v", tc.p, err)
} else if !cmp.Equal(*e, tc.expected, cmpopts.IgnoreUnexported(os.File{})) {
diff --git a/runsc/container/BUILD b/runsc/container/BUILD
index bc1fa25e3..26d1cd5ab 100644
--- a/runsc/container/BUILD
+++ b/runsc/container/BUILD
@@ -47,6 +47,7 @@ go_test(
],
deps = [
"//pkg/abi/linux",
+ "//pkg/bits",
"//pkg/log",
"//pkg/sentry/control",
"//pkg/sentry/kernel",
diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go
index 2ac12e5b6..519f5ed9b 100644
--- a/runsc/container/container_test.go
+++ b/runsc/container/container_test.go
@@ -34,6 +34,7 @@ import (
"github.com/cenkalti/backoff"
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -2049,6 +2050,30 @@ func TestMountSymlink(t *testing.T) {
}
}
+// Check that --net-raw disables the CAP_NET_RAW capability.
+func TestNetRaw(t *testing.T) {
+ capNetRaw := strconv.FormatUint(bits.MaskOf64(int(linux.CAP_NET_RAW)), 10)
+ app, err := testutil.FindFile("runsc/container/test_app/test_app")
+ if err != nil {
+ t.Fatal("error finding test_app:", err)
+ }
+
+ for _, enableRaw := range []bool{true, false} {
+ conf := testutil.TestConfig()
+ conf.EnableRaw = enableRaw
+
+ test := "--enabled"
+ if !enableRaw {
+ test = "--disabled"
+ }
+
+ spec := testutil.NewSpecWithArgs(app, "capability", test, capNetRaw)
+ if err := run(spec, conf); err != nil {
+ t.Fatalf("Error running container: %v", err)
+ }
+ }
+}
+
// executeSync synchronously executes a new process.
func (cont *Container) executeSync(args *control.ExecArgs) (syscall.WaitStatus, error) {
pid, err := cont.Execute(args)
diff --git a/runsc/container/test_app/test_app.go b/runsc/container/test_app/test_app.go
index 7f735c254..913d781c6 100644
--- a/runsc/container/test_app/test_app.go
+++ b/runsc/container/test_app/test_app.go
@@ -19,10 +19,12 @@ package main
import (
"context"
"fmt"
+ "io/ioutil"
"log"
"net"
"os"
"os/exec"
+ "regexp"
"strconv"
sys "syscall"
"time"
@@ -35,6 +37,7 @@ import (
func main() {
subcommands.Register(subcommands.HelpCommand(), "")
subcommands.Register(subcommands.FlagsCommand(), "")
+ subcommands.Register(new(capability), "")
subcommands.Register(new(fdReceiver), "")
subcommands.Register(new(fdSender), "")
subcommands.Register(new(forkBomb), "")
@@ -287,3 +290,65 @@ func (s *syscall) Execute(ctx context.Context, f *flag.FlagSet, args ...interfac
}
return subcommands.ExitSuccess
}
+
+type capability struct {
+ enabled uint64
+ disabled uint64
+}
+
+// Name implements subcommands.Command.
+func (*capability) Name() string {
+ return "capability"
+}
+
+// Synopsis implements subcommands.Command.
+func (*capability) Synopsis() string {
+ return "checks if effective capabilities are set/unset"
+}
+
+// Usage implements subcommands.Command.
+func (*capability) Usage() string {
+ return "capability [--enabled=number] [--disabled=number]"
+}
+
+// SetFlags implements subcommands.Command.
+func (c *capability) SetFlags(f *flag.FlagSet) {
+ f.Uint64Var(&c.enabled, "enabled", 0, "")
+ f.Uint64Var(&c.disabled, "disabled", 0, "")
+}
+
+// Execute implements subcommands.Command.
+func (c *capability) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if c.enabled == 0 && c.disabled == 0 {
+ fmt.Println("One of the flags must be set")
+ return subcommands.ExitUsageError
+ }
+
+ status, err := ioutil.ReadFile("/proc/self/status")
+ if err != nil {
+ fmt.Printf("Error reading %q: %v\n", "proc/self/status", err)
+ return subcommands.ExitFailure
+ }
+ re := regexp.MustCompile("CapEff:\t([0-9a-f]+)\n")
+ matches := re.FindStringSubmatch(string(status))
+ if matches == nil || len(matches) != 2 {
+ fmt.Printf("Effective capabilities not found in\n%s\n", status)
+ return subcommands.ExitFailure
+ }
+ caps, err := strconv.ParseUint(matches[1], 16, 64)
+ if err != nil {
+ fmt.Printf("failed to convert capabilities %q: %v\n", matches[1], err)
+ return subcommands.ExitFailure
+ }
+
+ if c.enabled != 0 && (caps&c.enabled) != c.enabled {
+ fmt.Printf("Missing capabilities, want: %#x: got: %#x\n", c.enabled, caps)
+ return subcommands.ExitFailure
+ }
+ if c.disabled != 0 && (caps&c.disabled) != 0 {
+ fmt.Printf("Extra capabilities found, dont_want: %#x: got: %#x\n", c.disabled, caps)
+ return subcommands.ExitFailure
+ }
+
+ return subcommands.ExitSuccess
+}
diff --git a/runsc/dockerutil/dockerutil.go b/runsc/dockerutil/dockerutil.go
index e37ec0ffd..57f6ae8de 100644
--- a/runsc/dockerutil/dockerutil.go
+++ b/runsc/dockerutil/dockerutil.go
@@ -282,7 +282,14 @@ func (d *Docker) Logs() (string, error) {
// Exec calls 'docker exec' with the arguments provided.
func (d *Docker) Exec(args ...string) (string, error) {
- a := []string{"exec", d.Name}
+ return d.ExecWithFlags(nil, args...)
+}
+
+// ExecWithFlags calls 'docker exec <flags> name <args>'.
+func (d *Docker) ExecWithFlags(flags []string, args ...string) (string, error) {
+ a := []string{"exec"}
+ a = append(a, flags...)
+ a = append(a, d.Name)
a = append(a, args...)
return do(a...)
}
diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go
index a570f1a41..29a82138e 100644
--- a/runsc/fsgofer/fsgofer.go
+++ b/runsc/fsgofer/fsgofer.go
@@ -136,6 +136,10 @@ func (a *attachPoint) Attach() (p9.File, error) {
a.attachedMu.Lock()
defer a.attachedMu.Unlock()
+ if a.attached {
+ return nil, fmt.Errorf("attach point already attached, prefix: %s", a.prefix)
+ }
+
// Hold the file descriptor we are converting into a p9.File.
var f *fd.FD
@@ -170,12 +174,6 @@ func (a *attachPoint) Attach() (p9.File, error) {
}
}
- // Close the connection if already attached.
- if a.attached {
- f.Close()
- return nil, fmt.Errorf("attach point already attached, prefix: %s", a.prefix)
- }
-
// Return a localFile object to the caller with the UDS FD included.
rv, err := newLocalFile(a, f, a.prefix, stat)
if err != nil {
@@ -330,7 +328,7 @@ func openAnyFile(path string, fn func(mode int) (*fd.FD, error)) (*fd.FD, error)
return file, nil
}
-func getSupportedFileType(stat syscall.Stat_t) (fileType, error) {
+func getSupportedFileType(stat syscall.Stat_t, permitSocket bool) (fileType, error) {
var ft fileType
switch stat.Mode & syscall.S_IFMT {
case syscall.S_IFREG:
@@ -340,6 +338,9 @@ func getSupportedFileType(stat syscall.Stat_t) (fileType, error) {
case syscall.S_IFLNK:
ft = symlink
case syscall.S_IFSOCK:
+ if !permitSocket {
+ return unknown, syscall.EPERM
+ }
ft = socket
default:
return unknown, syscall.EPERM
@@ -348,7 +349,7 @@ func getSupportedFileType(stat syscall.Stat_t) (fileType, error) {
}
func newLocalFile(a *attachPoint, file *fd.FD, path string, stat syscall.Stat_t) (*localFile, error) {
- ft, err := getSupportedFileType(stat)
+ ft, err := getSupportedFileType(stat, a.conf.HostUDS)
if err != nil {
return nil, err
}
@@ -1065,7 +1066,7 @@ func (l *localFile) Flush() error {
func (l *localFile) Connect(p9.ConnectFlags) (*fd.FD, error) {
// Check to see if the CLI option has been set to allow the UDS mount.
if !l.attachPoint.conf.HostUDS {
- return nil, errors.New("host UDS support is disabled")
+ return nil, syscall.ECONNREFUSED
}
return fd.DialUnix(l.hostPath)
}
diff --git a/runsc/fsgofer/fsgofer_test.go b/runsc/fsgofer/fsgofer_test.go
index cbbe71019..05af7e397 100644
--- a/runsc/fsgofer/fsgofer_test.go
+++ b/runsc/fsgofer/fsgofer_test.go
@@ -665,7 +665,7 @@ func TestAttachInvalidType(t *testing.T) {
}
f, err := a.Attach()
if f != nil || err == nil {
- t.Fatalf("Attach should have failed, got (%v, nil)", f)
+ t.Fatalf("Attach should have failed, got (%v, %v)", f, err)
}
})
}
diff --git a/runsc/specutils/BUILD b/runsc/specutils/BUILD
index fbfb8e2f8..fa58313a0 100644
--- a/runsc/specutils/BUILD
+++ b/runsc/specutils/BUILD
@@ -13,6 +13,7 @@ go_library(
visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
+ "//pkg/bits",
"//pkg/log",
"//pkg/sentry/kernel/auth",
"@com_github_cenkalti_backoff//:go_default_library",
diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go
index cb9e58dfb..591abe458 100644
--- a/runsc/specutils/specutils.go
+++ b/runsc/specutils/specutils.go
@@ -31,6 +31,7 @@ import (
"github.com/cenkalti/backoff"
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
)
@@ -241,6 +242,15 @@ func AllCapabilities() *specs.LinuxCapabilities {
}
}
+// AllCapabilitiesUint64 returns a bitmask containing all capabilities set.
+func AllCapabilitiesUint64() uint64 {
+ var rv uint64
+ for _, cap := range capFromName {
+ rv |= bits.MaskOf64(int(cap))
+ }
+ return rv
+}
+
var capFromName = map[string]linux.Capability{
"CAP_CHOWN": linux.CAP_CHOWN,
"CAP_DAC_OVERRIDE": linux.CAP_DAC_OVERRIDE,
diff --git a/scripts/build.sh b/scripts/build.sh
index b3a6e4e7a..0b3d1b316 100755
--- a/scripts/build.sh
+++ b/scripts/build.sh
@@ -23,7 +23,7 @@ sudo apt-get update && sudo apt-get install -y dpkg-sig coreutils apt-utils
runsc=$(build -c opt //runsc)
# Build packages.
-pkg=$(build -c opt --host_force_python=py2 //runsc:runsc-debian)
+pkg=$(build -c opt //runsc:runsc-debian)
# Build a repository, if the key is available.
if [[ -v KOKORO_REPO_KEY ]]; then
diff --git a/test/e2e/BUILD b/test/e2e/BUILD
index 99442cffb..4fe03a220 100644
--- a/test/e2e/BUILD
+++ b/test/e2e/BUILD
@@ -19,7 +19,9 @@ go_test(
visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
+ "//pkg/bits",
"//runsc/dockerutil",
+ "//runsc/specutils",
"//runsc/testutil",
],
)
diff --git a/test/e2e/exec_test.go b/test/e2e/exec_test.go
index 7238c2afe..88d26e865 100644
--- a/test/e2e/exec_test.go
+++ b/test/e2e/exec_test.go
@@ -30,14 +30,17 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/runsc/dockerutil"
+ "gvisor.dev/gvisor/runsc/specutils"
)
+// Test that exec uses the exact same capability set as the container.
func TestExecCapabilities(t *testing.T) {
if err := dockerutil.Pull("alpine"); err != nil {
t.Fatalf("docker pull failed: %v", err)
}
- d := dockerutil.MakeDocker("exec-test")
+ d := dockerutil.MakeDocker("exec-capabilities-test")
// Start the container.
if err := d.Run("alpine", "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil {
@@ -52,27 +55,59 @@ func TestExecCapabilities(t *testing.T) {
if len(matches) != 2 {
t.Fatalf("There should be a match for the whole line and the capability bitmask")
}
- capString := matches[1]
- t.Log("Root capabilities:", capString)
+ want := fmt.Sprintf("CapEff:\t%s\n", matches[1])
+ t.Log("Root capabilities:", want)
- // CAP_NET_RAW was in the capability set for the container, but was
- // removed. However, `exec` does not remove it. Verify that it's not
- // set in the container, then re-add it for comparison.
- caps, err := strconv.ParseUint(capString, 16, 64)
+ // Now check that exec'd process capabilities match the root.
+ got, err := d.Exec("grep", "CapEff:", "/proc/self/status")
if err != nil {
- t.Fatalf("failed to convert capabilities %q: %v", capString, err)
+ t.Fatalf("docker exec failed: %v", err)
}
- if caps&(1<<uint64(linux.CAP_NET_RAW)) != 0 {
- t.Fatalf("CAP_NET_RAW should be filtered, but is set in the container: %x", caps)
+ t.Logf("CapEff: %v", got)
+ if got != want {
+ t.Errorf("wrong capabilities, got: %q, want: %q", got, want)
}
- caps |= 1 << uint64(linux.CAP_NET_RAW)
- want := fmt.Sprintf("CapEff:\t%016x\n", caps)
+}
- // Now check that exec'd process capabilities match the root.
- got, err := d.Exec("grep", "CapEff:", "/proc/self/status")
+// Test that 'exec --privileged' adds all capabilities, except for CAP_NET_RAW
+// which is removed from the container when --net-raw=false.
+func TestExecPrivileged(t *testing.T) {
+ if err := dockerutil.Pull("alpine"); err != nil {
+ t.Fatalf("docker pull failed: %v", err)
+ }
+ d := dockerutil.MakeDocker("exec-privileged-test")
+
+ // Start the container with all capabilities dropped.
+ if err := d.Run("--cap-drop=all", "alpine", "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ defer d.CleanUp()
+
+ // Check that all capabilities where dropped from container.
+ matches, err := d.WaitForOutputSubmatch("CapEff:\t([0-9a-f]+)\n", 5*time.Second)
+ if err != nil {
+ t.Fatalf("WaitForOutputSubmatch() timeout: %v", err)
+ }
+ if len(matches) != 2 {
+ t.Fatalf("There should be a match for the whole line and the capability bitmask")
+ }
+ containerCaps, err := strconv.ParseUint(matches[1], 16, 64)
+ if err != nil {
+ t.Fatalf("failed to convert capabilities %q: %v", matches[1], err)
+ }
+ t.Logf("Container capabilities: %#x", containerCaps)
+ if containerCaps != 0 {
+ t.Fatalf("Container should have no capabilities: %x", containerCaps)
+ }
+
+ // Check that 'exec --privileged' adds all capabilities, except
+ // for CAP_NET_RAW.
+ got, err := d.ExecWithFlags([]string{"--privileged"}, "grep", "CapEff:", "/proc/self/status")
if err != nil {
t.Fatalf("docker exec failed: %v", err)
}
+ t.Logf("Exec CapEff: %v", got)
+ want := fmt.Sprintf("CapEff:\t%016x\n", specutils.AllCapabilitiesUint64()&^bits.MaskOf64(int(linux.CAP_NET_RAW)))
if got != want {
t.Errorf("wrong capabilities, got: %q, want: %q", got, want)
}
@@ -184,7 +219,7 @@ func TestExecEnvHasHome(t *testing.T) {
if err := dockerutil.Pull("alpine"); err != nil {
t.Fatalf("docker pull failed: %v", err)
}
- d := dockerutil.MakeDocker("exec-env-test")
+ d := dockerutil.MakeDocker("exec-env-home-test")
// We will check that HOME is set for root user, and also for a new
// non-root user we will create.
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 28b23ce58..d5a2b7725 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -2464,6 +2464,63 @@ cc_binary(
)
cc_binary(
+ name = "socket_bind_to_device_test",
+ testonly = 1,
+ srcs = [
+ "socket_bind_to_device.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_bind_to_device_util",
+ ":socket_test_util",
+ "//test/util:capability_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_binary(
+ name = "socket_bind_to_device_sequence_test",
+ testonly = 1,
+ srcs = [
+ "socket_bind_to_device_sequence.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_bind_to_device_util",
+ ":socket_test_util",
+ "//test/util:capability_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_binary(
+ name = "socket_bind_to_device_distribution_test",
+ testonly = 1,
+ srcs = [
+ "socket_bind_to_device_distribution.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_bind_to_device_util",
+ ":socket_test_util",
+ "//test/util:capability_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_binary(
name = "socket_ip_udp_loopback_non_blocking_test",
testonly = 1,
srcs = [
@@ -2740,6 +2797,23 @@ cc_library(
alwayslink = 1,
)
+cc_library(
+ name = "socket_bind_to_device_util",
+ testonly = 1,
+ srcs = [
+ "socket_bind_to_device_util.cc",
+ ],
+ hdrs = [
+ "socket_bind_to_device_util.h",
+ ],
+ deps = [
+ "//test/util:test_util",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ ],
+ alwayslink = 1,
+)
+
cc_binary(
name = "socket_stream_local_test",
testonly = 1,
@@ -3163,8 +3237,6 @@ cc_binary(
testonly = 1,
srcs = ["timers.cc"],
linkstatic = 1,
- # FIXME(b/136599201)
- tags = ["flaky"],
deps = [
"//test/util:cleanup",
"//test/util:logging",
@@ -3253,6 +3325,7 @@ cc_binary(
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ "//test/util:uid_util",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest",
diff --git a/test/syscalls/linux/clock_nanosleep.cc b/test/syscalls/linux/clock_nanosleep.cc
index 52a69d230..b55cddc52 100644
--- a/test/syscalls/linux/clock_nanosleep.cc
+++ b/test/syscalls/linux/clock_nanosleep.cc
@@ -43,7 +43,7 @@ int sys_clock_nanosleep(clockid_t clkid, int flags,
PosixErrorOr<absl::Time> GetTime(clockid_t clk) {
struct timespec ts = {};
- int rc = clock_gettime(clk, &ts);
+ const int rc = clock_gettime(clk, &ts);
MaybeSave();
if (rc < 0) {
return PosixError(errno, "clock_gettime");
@@ -67,31 +67,32 @@ TEST_P(WallClockNanosleepTest, InvalidValues) {
}
TEST_P(WallClockNanosleepTest, SleepOneSecond) {
- absl::Duration const duration = absl::Seconds(1);
- struct timespec dur = absl::ToTimespec(duration);
+ constexpr absl::Duration kSleepDuration = absl::Seconds(1);
+ struct timespec duration = absl::ToTimespec(kSleepDuration);
- absl::Time const before = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
- EXPECT_THAT(RetryEINTR(sys_clock_nanosleep)(GetParam(), 0, &dur, &dur),
- SyscallSucceeds());
- absl::Time const after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+ const absl::Time before = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+ EXPECT_THAT(
+ RetryEINTR(sys_clock_nanosleep)(GetParam(), 0, &duration, &duration),
+ SyscallSucceeds());
+ const absl::Time after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
- EXPECT_GE(after - before, duration);
+ EXPECT_GE(after - before, kSleepDuration);
}
TEST_P(WallClockNanosleepTest, InterruptedNanosleep) {
- absl::Duration const duration = absl::Seconds(60);
- struct timespec dur = absl::ToTimespec(duration);
+ constexpr absl::Duration kSleepDuration = absl::Seconds(60);
+ struct timespec duration = absl::ToTimespec(kSleepDuration);
// Install no-op signal handler for SIGALRM.
struct sigaction sa = {};
sigfillset(&sa.sa_mask);
sa.sa_handler = +[](int signo) {};
- auto const cleanup_sa =
+ const auto cleanup_sa =
ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGALRM, sa));
// Measure time since setting the alarm, since the alarm will interrupt the
// sleep and hence determine how long we sleep.
- absl::Time const before = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+ const absl::Time before = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
// Set an alarm to go off while sleeping.
struct itimerval timer = {};
@@ -99,26 +100,51 @@ TEST_P(WallClockNanosleepTest, InterruptedNanosleep) {
timer.it_value.tv_usec = 0;
timer.it_interval.tv_sec = 1;
timer.it_interval.tv_usec = 0;
- auto const cleanup =
+ const auto cleanup =
ASSERT_NO_ERRNO_AND_VALUE(ScopedItimer(ITIMER_REAL, timer));
- EXPECT_THAT(sys_clock_nanosleep(GetParam(), 0, &dur, &dur),
+ EXPECT_THAT(sys_clock_nanosleep(GetParam(), 0, &duration, &duration),
SyscallFailsWithErrno(EINTR));
- absl::Time const after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+ const absl::Time after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
- absl::Duration const remaining = absl::DurationFromTimespec(dur);
- EXPECT_GE(after - before + remaining, duration);
+ // Remaining time updated.
+ const absl::Duration remaining = absl::DurationFromTimespec(duration);
+ EXPECT_GE(after - before + remaining, kSleepDuration);
+}
+
+// Remaining time is *not* updated if nanosleep completes uninterrupted.
+TEST_P(WallClockNanosleepTest, UninterruptedNanosleep) {
+ constexpr absl::Duration kSleepDuration = absl::Milliseconds(10);
+ const struct timespec duration = absl::ToTimespec(kSleepDuration);
+
+ while (true) {
+ constexpr int kRemainingMagic = 42;
+ struct timespec remaining;
+ remaining.tv_sec = kRemainingMagic;
+ remaining.tv_nsec = kRemainingMagic;
+
+ int ret = sys_clock_nanosleep(GetParam(), 0, &duration, &remaining);
+ if (ret == EINTR) {
+ // Retry from beginning. We want a single uninterrupted call.
+ continue;
+ }
+
+ EXPECT_THAT(ret, SyscallSucceeds());
+ EXPECT_EQ(remaining.tv_sec, kRemainingMagic);
+ EXPECT_EQ(remaining.tv_nsec, kRemainingMagic);
+ break;
+ }
}
TEST_P(WallClockNanosleepTest, SleepUntil) {
- absl::Time const now = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
- absl::Time const until = now + absl::Seconds(2);
- struct timespec ts = absl::ToTimespec(until);
+ const absl::Time now = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+ const absl::Time until = now + absl::Seconds(2);
+ const struct timespec ts = absl::ToTimespec(until);
EXPECT_THAT(
RetryEINTR(sys_clock_nanosleep)(GetParam(), TIMER_ABSTIME, &ts, nullptr),
SyscallSucceeds());
- absl::Time const after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+ const absl::Time after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
EXPECT_GE(after, until);
}
@@ -127,8 +153,8 @@ INSTANTIATE_TEST_SUITE_P(Sleepers, WallClockNanosleepTest,
::testing::Values(CLOCK_REALTIME, CLOCK_MONOTONIC));
TEST(ClockNanosleepProcessTest, SleepFiveSeconds) {
- absl::Duration const kDuration = absl::Seconds(5);
- struct timespec dur = absl::ToTimespec(kDuration);
+ const absl::Duration kSleepDuration = absl::Seconds(5);
+ struct timespec duration = absl::ToTimespec(kSleepDuration);
// Ensure that CLOCK_PROCESS_CPUTIME_ID advances.
std::atomic<bool> done(false);
@@ -136,16 +162,16 @@ TEST(ClockNanosleepProcessTest, SleepFiveSeconds) {
while (!done.load()) {
}
});
- auto const cleanup_done = Cleanup([&] { done.store(true); });
+ const auto cleanup_done = Cleanup([&] { done.store(true); });
- absl::Time const before =
+ const absl::Time before =
ASSERT_NO_ERRNO_AND_VALUE(GetTime(CLOCK_PROCESS_CPUTIME_ID));
- EXPECT_THAT(
- RetryEINTR(sys_clock_nanosleep)(CLOCK_PROCESS_CPUTIME_ID, 0, &dur, &dur),
- SyscallSucceeds());
- absl::Time const after =
+ EXPECT_THAT(RetryEINTR(sys_clock_nanosleep)(CLOCK_PROCESS_CPUTIME_ID, 0,
+ &duration, &duration),
+ SyscallSucceeds());
+ const absl::Time after =
ASSERT_NO_ERRNO_AND_VALUE(GetTime(CLOCK_PROCESS_CPUTIME_ID));
- EXPECT_GE(after - before, kDuration);
+ EXPECT_GE(after - before, kSleepDuration);
}
} // namespace
diff --git a/test/syscalls/linux/exec_binary.cc b/test/syscalls/linux/exec_binary.cc
index 91b55015c..68af882bb 100644
--- a/test/syscalls/linux/exec_binary.cc
+++ b/test/syscalls/linux/exec_binary.cc
@@ -401,12 +401,17 @@ TEST(ElfTest, DataSegment) {
})));
}
-// Additonal pages beyond filesz are always RW.
+// Additonal pages beyond filesz honor (only) execute protections.
//
-// N.B. Linux uses set_brk -> vm_brk to additional pages beyond filesz (even
-// though start_brk itself will always be beyond memsz). As a result, the
-// segment permissions don't apply; the mapping is always RW.
+// N.B. Linux changed this in 4.11 (16e72e9b30986 "powerpc: do not make the
+// entire heap executable"). Previously, extra pages were always RW.
TEST(ElfTest, ExtraMemPages) {
+ // gVisor has the newer behavior.
+ if (!IsRunningOnGvisor()) {
+ auto version = ASSERT_NO_ERRNO_AND_VALUE(GetKernelVersion());
+ SKIP_IF(version.major < 4 || (version.major == 4 && version.minor < 11));
+ }
+
ElfBinary<64> elf = StandardElf();
// Create a standard ELF, but extend to 1.5 pages. The second page will be the
@@ -415,7 +420,7 @@ TEST(ElfTest, ExtraMemPages) {
decltype(elf)::ElfPhdr phdr = {};
phdr.p_type = PT_LOAD;
- // RWX segment. The extra anon page will be RW anyways.
+ // RWX segment. The extra anon page will also be RWX.
//
// N.B. Linux uses clear_user to clear the end of the file-mapped page, which
// respects the mapping protections. Thus if we map this RO with memsz >
@@ -454,7 +459,7 @@ TEST(ElfTest, ExtraMemPages) {
{0x41000, 0x42000, true, true, true, true, kPageSize, 0, 0, 0,
file.path().c_str()},
// extra page from anon.
- {0x42000, 0x43000, true, true, false, true, 0, 0, 0, 0, ""},
+ {0x42000, 0x43000, true, true, true, true, 0, 0, 0, 0, ""},
})));
}
@@ -469,7 +474,7 @@ TEST(ElfTest, AnonOnlySegment) {
phdr.p_offset = 0;
phdr.p_vaddr = 0x41000;
phdr.p_filesz = 0;
- phdr.p_memsz = kPageSize - 0xe8;
+ phdr.p_memsz = kPageSize;
elf.phdrs.push_back(phdr);
elf.UpdateOffsets();
diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc
index 7a3379b9e..37b4e6575 100644
--- a/test/syscalls/linux/packet_socket.cc
+++ b/test/syscalls/linux/packet_socket.cc
@@ -83,9 +83,15 @@ void SendUDPMessage(int sock) {
// Send an IP packet and make sure ETH_P_<something else> doesn't pick it up.
TEST(BasicCookedPacketTest, WrongType) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ // (b/129292371): Remove once we support packet sockets.
SKIP_IF(IsRunningOnGvisor());
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_PACKET, SOCK_DGRAM, ETH_P_PUP),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
+
FileDescriptor sock =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_PACKET, SOCK_DGRAM, ETH_P_PUP));
@@ -118,18 +124,27 @@ class CookedPacketTest : public ::testing::TestWithParam<int> {
};
void CookedPacketTest::SetUp() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ // (b/129292371): Remove once we support packet sockets.
SKIP_IF(IsRunningOnGvisor());
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
+
ASSERT_THAT(socket_ = socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())),
SyscallSucceeds());
}
void CookedPacketTest::TearDown() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ // (b/129292371): Remove once we support packet sockets.
SKIP_IF(IsRunningOnGvisor());
- EXPECT_THAT(close(socket_), SyscallSucceeds());
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ EXPECT_THAT(close(socket_), SyscallSucceeds());
+ }
}
int CookedPacketTest::GetLoopbackIndex() {
@@ -142,9 +157,6 @@ int CookedPacketTest::GetLoopbackIndex() {
// Receive via a packet socket.
TEST_P(CookedPacketTest, Receive) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
- SKIP_IF(IsRunningOnGvisor());
-
// Let's use a simple IP payload: a UDP datagram.
FileDescriptor udp_sock =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
@@ -201,9 +213,6 @@ TEST_P(CookedPacketTest, Receive) {
// Send via a packet socket.
TEST_P(CookedPacketTest, Send) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
- SKIP_IF(IsRunningOnGvisor());
-
// Let's send a UDP packet and receive it using a regular UDP socket.
FileDescriptor udp_sock =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc
index 9e96460ee..6491453b6 100644
--- a/test/syscalls/linux/packet_socket_raw.cc
+++ b/test/syscalls/linux/packet_socket_raw.cc
@@ -97,9 +97,15 @@ class RawPacketTest : public ::testing::TestWithParam<int> {
};
void RawPacketTest::SetUp() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ // (b/129292371): Remove once we support packet sockets.
SKIP_IF(IsRunningOnGvisor());
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_PACKET, SOCK_RAW, htons(GetParam())),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
+
if (!IsRunningOnGvisor()) {
FileDescriptor acceptLocal = ASSERT_NO_ERRNO_AND_VALUE(
Open("/proc/sys/net/ipv4/conf/lo/accept_local", O_RDONLY));
@@ -119,10 +125,13 @@ void RawPacketTest::SetUp() {
}
void RawPacketTest::TearDown() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ // (b/129292371): Remove once we support packet sockets.
SKIP_IF(IsRunningOnGvisor());
- EXPECT_THAT(close(socket_), SyscallSucceeds());
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ EXPECT_THAT(close(socket_), SyscallSucceeds());
+ }
}
int RawPacketTest::GetLoopbackIndex() {
@@ -135,9 +144,6 @@ int RawPacketTest::GetLoopbackIndex() {
// Receive via a packet socket.
TEST_P(RawPacketTest, Receive) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
- SKIP_IF(IsRunningOnGvisor());
-
// Let's use a simple IP payload: a UDP datagram.
FileDescriptor udp_sock =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
@@ -208,9 +214,6 @@ TEST_P(RawPacketTest, Receive) {
// Send via a packet socket.
TEST_P(RawPacketTest, Send) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
- SKIP_IF(IsRunningOnGvisor());
-
// Let's send a UDP packet and receive it using a regular UDP socket.
FileDescriptor udp_sock =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc
index bf32efe1e..99a0df235 100644
--- a/test/syscalls/linux/pty.cc
+++ b/test/syscalls/linux/pty.cc
@@ -1527,27 +1527,36 @@ TEST_F(JobControlTest, SetForegroundProcessGroupEmptyProcessGroup) {
TEST_F(JobControlTest, SetForegroundProcessGroupDifferentSession) {
ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+ int sync_setsid[2];
+ int sync_exit[2];
+ ASSERT_THAT(pipe(sync_setsid), SyscallSucceeds());
+ ASSERT_THAT(pipe(sync_exit), SyscallSucceeds());
+
// Create a new process and put it in a new session.
pid_t child = fork();
if (!child) {
TEST_PCHECK(setsid() >= 0);
// Tell the parent we're in a new session.
- TEST_PCHECK(!raise(SIGSTOP));
- TEST_PCHECK(!pause());
- _exit(1);
+ char c = 'c';
+ TEST_PCHECK(WriteFd(sync_setsid[1], &c, 1) == 1);
+ TEST_PCHECK(ReadFd(sync_exit[0], &c, 1) == 1);
+ _exit(0);
}
// Wait for the child to tell us it's in a new session.
- int wstatus;
- EXPECT_THAT(waitpid(child, &wstatus, WUNTRACED),
- SyscallSucceedsWithValue(child));
- EXPECT_TRUE(WSTOPSIG(wstatus));
+ char c = 'c';
+ ASSERT_THAT(ReadFd(sync_setsid[0], &c, 1), SyscallSucceedsWithValue(1));
// Child is in a new session, so we can't make it the foregroup process group.
EXPECT_THAT(ioctl(slave_.get(), TIOCSPGRP, &child),
SyscallFailsWithErrno(EPERM));
- EXPECT_THAT(kill(child, SIGKILL), SyscallSucceeds());
+ EXPECT_THAT(WriteFd(sync_exit[1], &c, 1), SyscallSucceedsWithValue(1));
+
+ int wstatus;
+ EXPECT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child));
+ EXPECT_TRUE(WIFEXITED(wstatus));
+ EXPECT_EQ(WEXITSTATUS(wstatus), 0);
}
// Verify that we don't hang when creating a new session from an orphaned
diff --git a/test/syscalls/linux/raw_socket_hdrincl.cc b/test/syscalls/linux/raw_socket_hdrincl.cc
index a070817eb..0a27506aa 100644
--- a/test/syscalls/linux/raw_socket_hdrincl.cc
+++ b/test/syscalls/linux/raw_socket_hdrincl.cc
@@ -63,7 +63,11 @@ class RawHDRINCL : public ::testing::Test {
};
void RawHDRINCL::SetUp() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_INET, SOCK_RAW, IPPROTO_RAW),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
ASSERT_THAT(socket_ = socket(AF_INET, SOCK_RAW, IPPROTO_RAW),
SyscallSucceeds());
@@ -76,9 +80,10 @@ void RawHDRINCL::SetUp() {
}
void RawHDRINCL::TearDown() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- EXPECT_THAT(close(socket_), SyscallSucceeds());
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ EXPECT_THAT(close(socket_), SyscallSucceeds());
+ }
}
struct iphdr RawHDRINCL::LoopbackHeader() {
@@ -123,8 +128,6 @@ bool RawHDRINCL::FillPacket(char* buf, size_t buf_size, int port,
// We should be able to create multiple IPPROTO_RAW sockets. RawHDRINCL::Setup
// creates the first one, so we only have to create one more here.
TEST_F(RawHDRINCL, MultipleCreation) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
int s2;
ASSERT_THAT(s2 = socket(AF_INET, SOCK_RAW, IPPROTO_RAW), SyscallSucceeds());
@@ -133,23 +136,17 @@ TEST_F(RawHDRINCL, MultipleCreation) {
// Test that shutting down an unconnected socket fails.
TEST_F(RawHDRINCL, FailShutdownWithoutConnect) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
ASSERT_THAT(shutdown(socket_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN));
ASSERT_THAT(shutdown(socket_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN));
}
// Test that listen() fails.
TEST_F(RawHDRINCL, FailListen) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
ASSERT_THAT(listen(socket_, 1), SyscallFailsWithErrno(ENOTSUP));
}
// Test that accept() fails.
TEST_F(RawHDRINCL, FailAccept) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
struct sockaddr saddr;
socklen_t addrlen;
ASSERT_THAT(accept(socket_, &saddr, &addrlen),
@@ -158,8 +155,6 @@ TEST_F(RawHDRINCL, FailAccept) {
// Test that the socket is writable immediately.
TEST_F(RawHDRINCL, PollWritableImmediately) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
struct pollfd pfd = {};
pfd.fd = socket_;
pfd.events = POLLOUT;
@@ -168,8 +163,6 @@ TEST_F(RawHDRINCL, PollWritableImmediately) {
// Test that the socket isn't readable.
TEST_F(RawHDRINCL, NotReadable) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
// Try to receive data with MSG_DONTWAIT, which returns immediately if there's
// nothing to be read.
char buf[117];
@@ -179,16 +172,12 @@ TEST_F(RawHDRINCL, NotReadable) {
// Test that we can connect() to a valid IP (loopback).
TEST_F(RawHDRINCL, ConnectToLoopback) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_),
sizeof(addr_)),
SyscallSucceeds());
}
TEST_F(RawHDRINCL, SendWithoutConnectSucceeds) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
struct iphdr hdr = LoopbackHeader();
ASSERT_THAT(send(socket_, &hdr, sizeof(hdr), 0),
SyscallSucceedsWithValue(sizeof(hdr)));
@@ -197,8 +186,6 @@ TEST_F(RawHDRINCL, SendWithoutConnectSucceeds) {
// HDRINCL implies write-only. Verify that we can't read a packet sent to
// loopback.
TEST_F(RawHDRINCL, NotReadableAfterWrite) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_),
sizeof(addr_)),
SyscallSucceeds());
@@ -221,8 +208,6 @@ TEST_F(RawHDRINCL, NotReadableAfterWrite) {
}
TEST_F(RawHDRINCL, WriteTooSmall) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_),
sizeof(addr_)),
SyscallSucceeds());
@@ -235,8 +220,6 @@ TEST_F(RawHDRINCL, WriteTooSmall) {
// Bind to localhost.
TEST_F(RawHDRINCL, BindToLocalhost) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
ASSERT_THAT(
bind(socket_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
SyscallSucceeds());
@@ -244,8 +227,6 @@ TEST_F(RawHDRINCL, BindToLocalhost) {
// Bind to a different address.
TEST_F(RawHDRINCL, BindToInvalid) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
struct sockaddr_in bind_addr = {};
bind_addr.sin_family = AF_INET;
bind_addr.sin_addr = {1}; // 1.0.0.0 - An address that we can't bind to.
@@ -256,8 +237,6 @@ TEST_F(RawHDRINCL, BindToInvalid) {
// Send and receive a packet.
TEST_F(RawHDRINCL, SendAndReceive) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
int port = 40000;
if (!IsRunningOnGvisor()) {
port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE(
@@ -302,8 +281,6 @@ TEST_F(RawHDRINCL, SendAndReceive) {
// Send and receive a packet with nonzero IP ID.
TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
int port = 40000;
if (!IsRunningOnGvisor()) {
port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE(
@@ -349,8 +326,6 @@ TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) {
// Send and receive a packet where the sendto address is not the same as the
// provided destination.
TEST_F(RawHDRINCL, SendAndReceiveDifferentAddress) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
int port = 40000;
if (!IsRunningOnGvisor()) {
port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE(
diff --git a/test/syscalls/linux/raw_socket_icmp.cc b/test/syscalls/linux/raw_socket_icmp.cc
index 971592d7d..8bcaba6f1 100644
--- a/test/syscalls/linux/raw_socket_icmp.cc
+++ b/test/syscalls/linux/raw_socket_icmp.cc
@@ -77,7 +77,11 @@ class RawSocketICMPTest : public ::testing::Test {
};
void RawSocketICMPTest::SetUp() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_INET, SOCK_RAW, IPPROTO_ICMP),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), SyscallSucceeds());
@@ -90,9 +94,10 @@ void RawSocketICMPTest::SetUp() {
}
void RawSocketICMPTest::TearDown() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- EXPECT_THAT(close(s_), SyscallSucceeds());
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ EXPECT_THAT(close(s_), SyscallSucceeds());
+ }
}
// We'll only read an echo in this case, as the kernel won't respond to the
diff --git a/test/syscalls/linux/raw_socket_ipv4.cc b/test/syscalls/linux/raw_socket_ipv4.cc
index 352037c88..cde2f07c9 100644
--- a/test/syscalls/linux/raw_socket_ipv4.cc
+++ b/test/syscalls/linux/raw_socket_ipv4.cc
@@ -67,7 +67,11 @@ class RawSocketTest : public ::testing::TestWithParam<int> {
};
void RawSocketTest::SetUp() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_INET, SOCK_RAW, Protocol()),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, Protocol()), SyscallSucceeds());
@@ -79,9 +83,10 @@ void RawSocketTest::SetUp() {
}
void RawSocketTest::TearDown() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- EXPECT_THAT(close(s_), SyscallSucceeds());
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ EXPECT_THAT(close(s_), SyscallSucceeds());
+ }
}
// We should be able to create multiple raw sockets for the same protocol.
diff --git a/test/syscalls/linux/socket_bind_to_device.cc b/test/syscalls/linux/socket_bind_to_device.cc
new file mode 100644
index 000000000..d20821cac
--- /dev/null
+++ b/test/syscalls/linux/socket_bind_to_device.cc
@@ -0,0 +1,314 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <linux/if_tun.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <cstdio>
+#include <cstring>
+#include <map>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_bind_to_device_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+using std::string;
+
+// Test fixture for SO_BINDTODEVICE tests.
+class BindToDeviceTest : public ::testing::TestWithParam<SocketKind> {
+ protected:
+ void SetUp() override {
+ printf("Testing case: %s\n", GetParam().description.c_str());
+ ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)))
+ << "CAP_NET_RAW is required to use SO_BINDTODEVICE";
+
+ interface_name_ = "eth1";
+ auto interface_names = GetInterfaceNames();
+ if (interface_names.find(interface_name_) == interface_names.end()) {
+ // Need a tunnel.
+ tunnel_ = ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New());
+ interface_name_ = tunnel_->GetName();
+ ASSERT_FALSE(interface_name_.empty());
+ }
+ socket_ = ASSERT_NO_ERRNO_AND_VALUE(GetParam().Create());
+ }
+
+ string interface_name() const { return interface_name_; }
+
+ int socket_fd() const { return socket_->get(); }
+
+ private:
+ std::unique_ptr<Tunnel> tunnel_;
+ string interface_name_;
+ std::unique_ptr<FileDescriptor> socket_;
+};
+
+constexpr char kIllegalIfnameChar = '/';
+
+// Tests getsockopt of the default value.
+TEST_P(BindToDeviceTest, GetsockoptDefault) {
+ char name_buffer[IFNAMSIZ * 2];
+ char original_name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Read the default SO_BINDTODEVICE.
+ memset(original_name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ for (size_t i = 0; i <= sizeof(name_buffer); i++) {
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = i;
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, &name_buffer_size),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(name_buffer_size, 0);
+ EXPECT_EQ(memcmp(name_buffer, original_name_buffer, sizeof(name_buffer)),
+ 0);
+ }
+}
+
+// Tests setsockopt of invalid device name.
+TEST_P(BindToDeviceTest, SetsockoptInvalidDeviceName) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Set an invalid device name.
+ memset(name_buffer, kIllegalIfnameChar, 5);
+ name_buffer_size = 5;
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ name_buffer_size),
+ SyscallFailsWithErrno(ENODEV));
+}
+
+// Tests setsockopt of a buffer with a valid device name but not
+// null-terminated, with different sizes of buffer.
+TEST_P(BindToDeviceTest, SetsockoptValidDeviceNameWithoutNullTermination) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ strncpy(name_buffer, interface_name().c_str(), interface_name().size() + 1);
+ // Intentionally overwrite the null at the end.
+ memset(name_buffer + interface_name().size(), kIllegalIfnameChar,
+ sizeof(name_buffer) - interface_name().size());
+ for (size_t i = 1; i <= sizeof(name_buffer); i++) {
+ name_buffer_size = i;
+ SCOPED_TRACE(absl::StrCat("Buffer size: ", i));
+ // It should only work if the size provided is exactly right.
+ if (name_buffer_size == interface_name().size()) {
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, name_buffer_size),
+ SyscallSucceeds());
+ } else {
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, name_buffer_size),
+ SyscallFailsWithErrno(ENODEV));
+ }
+ }
+}
+
+// Tests setsockopt of a buffer with a valid device name and null-terminated,
+// with different sizes of buffer.
+TEST_P(BindToDeviceTest, SetsockoptValidDeviceNameWithNullTermination) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ strncpy(name_buffer, interface_name().c_str(), interface_name().size() + 1);
+ // Don't overwrite the null at the end.
+ memset(name_buffer + interface_name().size() + 1, kIllegalIfnameChar,
+ sizeof(name_buffer) - interface_name().size() - 1);
+ for (size_t i = 1; i <= sizeof(name_buffer); i++) {
+ name_buffer_size = i;
+ SCOPED_TRACE(absl::StrCat("Buffer size: ", i));
+ // It should only work if the size provided is at least the right size.
+ if (name_buffer_size >= interface_name().size()) {
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, name_buffer_size),
+ SyscallSucceeds());
+ } else {
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, name_buffer_size),
+ SyscallFailsWithErrno(ENODEV));
+ }
+ }
+}
+
+// Tests that setsockopt of an invalid device name doesn't unset the previous
+// valid setsockopt.
+TEST_P(BindToDeviceTest, SetsockoptValidThenInvalid) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Write successfully.
+ strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer));
+ ASSERT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ sizeof(name_buffer)),
+ SyscallSucceeds());
+
+ // Read it back successfully.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, interface_name().size() + 1);
+ EXPECT_STREQ(name_buffer, interface_name().c_str());
+
+ // Write unsuccessfully.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = 5;
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ sizeof(name_buffer)),
+ SyscallFailsWithErrno(ENODEV));
+
+ // Read it back successfully, it's unchanged.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, interface_name().size() + 1);
+ EXPECT_STREQ(name_buffer, interface_name().c_str());
+}
+
+// Tests that setsockopt of zero-length string correctly unsets the previous
+// value.
+TEST_P(BindToDeviceTest, SetsockoptValidThenClear) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Write successfully.
+ strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer));
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ sizeof(name_buffer)),
+ SyscallSucceeds());
+
+ // Read it back successfully.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, interface_name().size() + 1);
+ EXPECT_STREQ(name_buffer, interface_name().c_str());
+
+ // Clear it successfully.
+ name_buffer_size = 0;
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ name_buffer_size),
+ SyscallSucceeds());
+
+ // Read it back successfully, it's cleared.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, 0);
+}
+
+// Tests that setsockopt of empty string correctly unsets the previous
+// value.
+TEST_P(BindToDeviceTest, SetsockoptValidThenClearWithNull) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Write successfully.
+ strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer));
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ sizeof(name_buffer)),
+ SyscallSucceeds());
+
+ // Read it back successfully.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, interface_name().size() + 1);
+ EXPECT_STREQ(name_buffer, interface_name().c_str());
+
+ // Clear it successfully.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer[0] = 0;
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ name_buffer_size),
+ SyscallSucceeds());
+
+ // Read it back successfully, it's cleared.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, 0);
+}
+
+// Tests getsockopt with different buffer sizes.
+TEST_P(BindToDeviceTest, GetsockoptDevice) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Write successfully.
+ strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer));
+ ASSERT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ sizeof(name_buffer)),
+ SyscallSucceeds());
+
+ // Read it back at various buffer sizes.
+ for (size_t i = 0; i <= sizeof(name_buffer); i++) {
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = i;
+ SCOPED_TRACE(absl::StrCat("Buffer size: ", i));
+ // Linux only allows a buffer at least IFNAMSIZ, even if less would suffice
+ // for this interface name.
+ if (name_buffer_size >= IFNAMSIZ) {
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, interface_name().size() + 1);
+ EXPECT_STREQ(name_buffer, interface_name().c_str());
+ } else {
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, &name_buffer_size),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_EQ(name_buffer_size, i);
+ }
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(BindToDeviceTest, BindToDeviceTest,
+ ::testing::Values(IPv4UDPUnboundSocket(0),
+ IPv4TCPUnboundSocket(0)));
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_bind_to_device_distribution.cc b/test/syscalls/linux/socket_bind_to_device_distribution.cc
new file mode 100644
index 000000000..4d2400328
--- /dev/null
+++ b/test/syscalls/linux/socket_bind_to_device_distribution.cc
@@ -0,0 +1,381 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <linux/if_tun.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <atomic>
+#include <cstdio>
+#include <cstring>
+#include <map>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_bind_to_device_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+using std::string;
+using std::vector;
+
+struct EndpointConfig {
+ std::string bind_to_device;
+ double expected_ratio;
+};
+
+struct DistributionTestCase {
+ std::string name;
+ std::vector<EndpointConfig> endpoints;
+};
+
+struct ListenerConnector {
+ TestAddress listener;
+ TestAddress connector;
+};
+
+// Test fixture for SO_BINDTODEVICE tests the distribution of packets received
+// with varying SO_BINDTODEVICE settings.
+class BindToDeviceDistributionTest
+ : public ::testing::TestWithParam<
+ ::testing::tuple<ListenerConnector, DistributionTestCase>> {
+ protected:
+ void SetUp() override {
+ printf("Testing case: %s, listener=%s, connector=%s\n",
+ ::testing::get<1>(GetParam()).name.c_str(),
+ ::testing::get<0>(GetParam()).listener.description.c_str(),
+ ::testing::get<0>(GetParam()).connector.description.c_str());
+ ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)))
+ << "CAP_NET_RAW is required to use SO_BINDTODEVICE";
+ }
+};
+
+PosixErrorOr<uint16_t> AddrPort(int family, sockaddr_storage const& addr) {
+ switch (family) {
+ case AF_INET:
+ return static_cast<uint16_t>(
+ reinterpret_cast<sockaddr_in const*>(&addr)->sin_port);
+ case AF_INET6:
+ return static_cast<uint16_t>(
+ reinterpret_cast<sockaddr_in6 const*>(&addr)->sin6_port);
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+}
+
+PosixError SetAddrPort(int family, sockaddr_storage* addr, uint16_t port) {
+ switch (family) {
+ case AF_INET:
+ reinterpret_cast<sockaddr_in*>(addr)->sin_port = port;
+ return NoError();
+ case AF_INET6:
+ reinterpret_cast<sockaddr_in6*>(addr)->sin6_port = port;
+ return NoError();
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+}
+
+// Binds sockets to different devices and then creates many TCP connections.
+// Checks that the distribution of connections received on the sockets matches
+// the expectation.
+TEST_P(BindToDeviceDistributionTest, Tcp) {
+ auto const& [listener_connector, test] = GetParam();
+
+ TestAddress const& listener = listener_connector.listener;
+ TestAddress const& connector = listener_connector.connector;
+ sockaddr_storage listen_addr = listener.addr;
+ sockaddr_storage conn_addr = connector.addr;
+
+ auto interface_names = GetInterfaceNames();
+
+ // Create the listening sockets.
+ std::vector<FileDescriptor> listener_fds;
+ std::vector<std::unique_ptr<Tunnel>> all_tunnels;
+ for (auto const& endpoint : test.endpoints) {
+ if (!endpoint.bind_to_device.empty() &&
+ interface_names.find(endpoint.bind_to_device) ==
+ interface_names.end()) {
+ all_tunnels.push_back(
+ ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New(endpoint.bind_to_device)));
+ interface_names.insert(endpoint.bind_to_device);
+ }
+
+ listener_fds.push_back(ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)));
+ int fd = listener_fds.back().get();
+
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE,
+ endpoint.bind_to_device.c_str(),
+ endpoint.bind_to_device.size() + 1),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(fd, 40), SyscallSucceeds());
+
+ // On the first bind we need to determine which port was bound.
+ if (listener_fds.size() > 1) {
+ continue;
+ }
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(
+ getsockname(listener_fds[0].get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ }
+
+ constexpr int kConnectAttempts = 10000;
+ std::atomic<int> connects_received = ATOMIC_VAR_INIT(0);
+ std::vector<int> accept_counts(listener_fds.size(), 0);
+ std::vector<std::unique_ptr<ScopedThread>> listen_threads(
+ listener_fds.size());
+
+ for (int i = 0; i < listener_fds.size(); i++) {
+ listen_threads[i] = absl::make_unique<ScopedThread>(
+ [&listener_fds, &accept_counts, &connects_received, i,
+ kConnectAttempts]() {
+ do {
+ auto fd = Accept(listener_fds[i].get(), nullptr, nullptr);
+ if (!fd.ok()) {
+ // Another thread has shutdown our read side causing the accept to
+ // fail.
+ ASSERT_GE(connects_received, kConnectAttempts)
+ << "errno = " << fd.error();
+ return;
+ }
+ // Receive some data from a socket to be sure that the connect()
+ // system call has been completed on another side.
+ int data;
+ EXPECT_THAT(
+ RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0),
+ SyscallSucceedsWithValue(sizeof(data)));
+ accept_counts[i]++;
+ } while (++connects_received < kConnectAttempts);
+
+ // Shutdown all sockets to wake up other threads.
+ for (auto const& listener_fd : listener_fds) {
+ shutdown(listener_fd.get(), SHUT_RDWR);
+ }
+ });
+ }
+
+ for (int i = 0; i < kConnectAttempts; i++) {
+ FileDescriptor const fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(
+ RetryEINTR(connect)(fd.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
+ SyscallSucceedsWithValue(sizeof(i)));
+ }
+
+ // Join threads to be sure that all connections have been counted.
+ for (auto const& listen_thread : listen_threads) {
+ listen_thread->Join();
+ }
+ // Check that connections are distributed correctly among listening sockets.
+ for (int i = 0; i < accept_counts.size(); i++) {
+ EXPECT_THAT(
+ accept_counts[i],
+ EquivalentWithin(static_cast<int>(kConnectAttempts *
+ test.endpoints[i].expected_ratio),
+ 0.10))
+ << "endpoint " << i << " got the wrong number of packets";
+ }
+}
+
+// Binds sockets to different devices and then sends many UDP packets. Checks
+// that the distribution of packets received on the sockets matches the
+// expectation.
+TEST_P(BindToDeviceDistributionTest, Udp) {
+ auto const& [listener_connector, test] = GetParam();
+
+ TestAddress const& listener = listener_connector.listener;
+ TestAddress const& connector = listener_connector.connector;
+ sockaddr_storage listen_addr = listener.addr;
+ sockaddr_storage conn_addr = connector.addr;
+
+ auto interface_names = GetInterfaceNames();
+
+ // Create the listening socket.
+ std::vector<FileDescriptor> listener_fds;
+ std::vector<std::unique_ptr<Tunnel>> all_tunnels;
+ for (auto const& endpoint : test.endpoints) {
+ if (!endpoint.bind_to_device.empty() &&
+ interface_names.find(endpoint.bind_to_device) ==
+ interface_names.end()) {
+ all_tunnels.push_back(
+ ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New(endpoint.bind_to_device)));
+ interface_names.insert(endpoint.bind_to_device);
+ }
+
+ listener_fds.push_back(
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0)));
+ int fd = listener_fds.back().get();
+
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE,
+ endpoint.bind_to_device.c_str(),
+ endpoint.bind_to_device.size() + 1),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
+
+ // On the first bind we need to determine which port was bound.
+ if (listener_fds.size() > 1) {
+ continue;
+ }
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(
+ getsockname(listener_fds[0].get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+ ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port));
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ }
+
+ constexpr int kConnectAttempts = 10000;
+ std::atomic<int> packets_received = ATOMIC_VAR_INIT(0);
+ std::vector<int> packets_per_socket(listener_fds.size(), 0);
+ std::vector<std::unique_ptr<ScopedThread>> receiver_threads(
+ listener_fds.size());
+
+ for (int i = 0; i < listener_fds.size(); i++) {
+ receiver_threads[i] = absl::make_unique<ScopedThread>(
+ [&listener_fds, &packets_per_socket, &packets_received, i]() {
+ do {
+ struct sockaddr_storage addr = {};
+ socklen_t addrlen = sizeof(addr);
+ int data;
+
+ auto ret = RetryEINTR(recvfrom)(
+ listener_fds[i].get(), &data, sizeof(data), 0,
+ reinterpret_cast<struct sockaddr*>(&addr), &addrlen);
+
+ if (packets_received < kConnectAttempts) {
+ ASSERT_THAT(ret, SyscallSucceedsWithValue(sizeof(data)));
+ }
+
+ if (ret != sizeof(data)) {
+ // Another thread may have shutdown our read side causing the
+ // recvfrom to fail.
+ break;
+ }
+
+ packets_received++;
+ packets_per_socket[i]++;
+
+ // A response is required to synchronize with the main thread,
+ // otherwise the main thread can send more than can fit into receive
+ // queues.
+ EXPECT_THAT(RetryEINTR(sendto)(
+ listener_fds[i].get(), &data, sizeof(data), 0,
+ reinterpret_cast<sockaddr*>(&addr), addrlen),
+ SyscallSucceedsWithValue(sizeof(data)));
+ } while (packets_received < kConnectAttempts);
+
+ // Shutdown all sockets to wake up other threads.
+ for (auto const& listener_fd : listener_fds) {
+ shutdown(listener_fd.get(), SHUT_RDWR);
+ }
+ });
+ }
+
+ for (int i = 0; i < kConnectAttempts; i++) {
+ FileDescriptor const fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0));
+ EXPECT_THAT(RetryEINTR(sendto)(fd.get(), &i, sizeof(i), 0,
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceedsWithValue(sizeof(i)));
+ int data;
+ EXPECT_THAT(RetryEINTR(recv)(fd.get(), &data, sizeof(data), 0),
+ SyscallSucceedsWithValue(sizeof(data)));
+ }
+
+ // Join threads to be sure that all connections have been counted.
+ for (auto const& receiver_thread : receiver_threads) {
+ receiver_thread->Join();
+ }
+ // Check that packets are distributed correctly among listening sockets.
+ for (int i = 0; i < packets_per_socket.size(); i++) {
+ EXPECT_THAT(
+ packets_per_socket[i],
+ EquivalentWithin(static_cast<int>(kConnectAttempts *
+ test.endpoints[i].expected_ratio),
+ 0.10))
+ << "endpoint " << i << " got the wrong number of packets";
+ }
+}
+
+std::vector<DistributionTestCase> GetDistributionTestCases() {
+ return std::vector<DistributionTestCase>{
+ {"Even distribution among sockets not bound to device",
+ {{"", 1. / 3}, {"", 1. / 3}, {"", 1. / 3}}},
+ {"Sockets bound to other interfaces get no packets",
+ {{"eth1", 0}, {"", 1. / 2}, {"", 1. / 2}}},
+ {"Bound has priority over unbound", {{"eth1", 0}, {"", 0}, {"lo", 1}}},
+ {"Even distribution among sockets bound to device",
+ {{"eth1", 0}, {"lo", 1. / 2}, {"lo", 1. / 2}}},
+ };
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ BindToDeviceTest, BindToDeviceDistributionTest,
+ ::testing::Combine(::testing::Values(
+ // Listeners bound to IPv4 addresses refuse
+ // connections using IPv6 addresses.
+ ListenerConnector{V4Any(), V4Loopback()},
+ ListenerConnector{V4Loopback(), V4MappedLoopback()}),
+ ::testing::ValuesIn(GetDistributionTestCases())));
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_bind_to_device_sequence.cc b/test/syscalls/linux/socket_bind_to_device_sequence.cc
new file mode 100644
index 000000000..a7365d139
--- /dev/null
+++ b/test/syscalls/linux/socket_bind_to_device_sequence.cc
@@ -0,0 +1,316 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <linux/capability.h>
+#include <linux/if_tun.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <cstdio>
+#include <cstring>
+#include <map>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_bind_to_device_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+using std::string;
+using std::vector;
+
+// Test fixture for SO_BINDTODEVICE tests the results of sequences of socket
+// binding.
+class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> {
+ protected:
+ void SetUp() override {
+ printf("Testing case: %s\n", GetParam().description.c_str());
+ ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)))
+ << "CAP_NET_RAW is required to use SO_BINDTODEVICE";
+ socket_factory_ = GetParam();
+
+ interface_names_ = GetInterfaceNames();
+ }
+
+ PosixErrorOr<std::unique_ptr<FileDescriptor>> NewSocket() const {
+ return socket_factory_.Create();
+ }
+
+ // Gets a device by device_id. If the device_id has been seen before, returns
+ // the previously returned device. If not, finds or creates a new device.
+ // Returns an empty string on failure.
+ void GetDevice(int device_id, string *device_name) {
+ auto device = devices_.find(device_id);
+ if (device != devices_.end()) {
+ *device_name = device->second;
+ return;
+ }
+
+ // Need to pick a new device. Try ethernet first.
+ *device_name = absl::StrCat("eth", next_unused_eth_);
+ if (interface_names_.find(*device_name) != interface_names_.end()) {
+ devices_[device_id] = *device_name;
+ next_unused_eth_++;
+ return;
+ }
+
+ // Need to make a new tunnel device. gVisor tests should have enough
+ // ethernet devices to never reach here.
+ ASSERT_FALSE(IsRunningOnGvisor());
+ // Need a tunnel.
+ tunnels_.push_back(ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New()));
+ devices_[device_id] = tunnels_.back()->GetName();
+ *device_name = devices_[device_id];
+ }
+
+ // Release the socket
+ void ReleaseSocket(int socket_id) {
+ // Close the socket that was made in a previous action. The socket_id
+ // indicates which socket to close based on index into the list of actions.
+ sockets_to_close_.erase(socket_id);
+ }
+
+ // Bind a socket with the reuse option and bind_to_device options. Checks
+ // that all steps succeed and that the bind command's error matches want.
+ // Sets the socket_id to uniquely identify the socket bound if it is not
+ // nullptr.
+ void BindSocket(bool reuse, int device_id = 0, int want = 0,
+ int *socket_id = nullptr) {
+ next_socket_id_++;
+ sockets_to_close_[next_socket_id_] = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket_fd = sockets_to_close_[next_socket_id_]->get();
+ if (socket_id != nullptr) {
+ *socket_id = next_socket_id_;
+ }
+
+ // If reuse is indicated, do that.
+ if (reuse) {
+ EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+ }
+
+ // If the device is non-zero, bind to that device.
+ if (device_id != 0) {
+ string device_name;
+ ASSERT_NO_FATAL_FAILURE(GetDevice(device_id, &device_name));
+ EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE,
+ device_name.c_str(), device_name.size() + 1),
+ SyscallSucceedsWithValue(0));
+ char get_device[100];
+ socklen_t get_device_size = 100;
+ EXPECT_THAT(getsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE, get_device,
+ &get_device_size),
+ SyscallSucceedsWithValue(0));
+ }
+
+ struct sockaddr_in addr = {};
+ addr.sin_family = AF_INET;
+ addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ addr.sin_port = port_;
+ if (want == 0) {
+ ASSERT_THAT(
+ bind(socket_fd, reinterpret_cast<const struct sockaddr *>(&addr),
+ sizeof(addr)),
+ SyscallSucceeds());
+ } else {
+ ASSERT_THAT(
+ bind(socket_fd, reinterpret_cast<const struct sockaddr *>(&addr),
+ sizeof(addr)),
+ SyscallFailsWithErrno(want));
+ }
+
+ if (port_ == 0) {
+ // We don't yet know what port we'll be using so we need to fetch it and
+ // remember it for future commands.
+ socklen_t addr_size = sizeof(addr);
+ ASSERT_THAT(
+ getsockname(socket_fd, reinterpret_cast<struct sockaddr *>(&addr),
+ &addr_size),
+ SyscallSucceeds());
+ port_ = addr.sin_port;
+ }
+ }
+
+ private:
+ SocketKind socket_factory_;
+ // devices maps from the device id in the test case to the name of the device.
+ std::unordered_map<int, string> devices_;
+ // These are the tunnels that were created for the test and will be destroyed
+ // by the destructor.
+ vector<std::unique_ptr<Tunnel>> tunnels_;
+ // A list of all interface names before the test started.
+ std::unordered_set<string> interface_names_;
+ // The next ethernet device to use when requested a device.
+ int next_unused_eth_ = 1;
+ // The port for all tests. Originally 0 (any) and later set to the port that
+ // all further commands will use.
+ in_port_t port_ = 0;
+ // sockets_to_close_ is a map from action index to the socket that was
+ // created.
+ std::unordered_map<int,
+ std::unique_ptr<gvisor::testing::FileDescriptor>>
+ sockets_to_close_;
+ int next_socket_id_ = 0;
+};
+
+TEST_P(BindToDeviceSequenceTest, BindTwiceWithDeviceFails) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 3));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 3, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindToDevice) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 1));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 2));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindToDeviceAndThenWithoutDevice) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindWithoutDevice) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ false));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindWithDevice) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123, 0));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 456, 0));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 789, 0));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindWithReuse) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true, /* bind_to_device */ 0));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindingWithReuseAndDevice) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 456));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 789));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 999, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, MixingReuseAndNotReuseByBindingToDevice) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123, 0));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 456, 0));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 789, 0));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 999, 0));
+}
+
+TEST_P(BindToDeviceSequenceTest, CannotBindTo0AfterMixingReuseAndNotReuse) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 456));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindAndRelease) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123));
+ int to_release;
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 0, 0, &to_release));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 345, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 789));
+ // Release the bind to device 0 and try again.
+ ASSERT_NO_FATAL_FAILURE(ReleaseSocket(to_release));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 345));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindTwiceWithReuseOnce) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE));
+}
+
+INSTANTIATE_TEST_SUITE_P(BindToDeviceTest, BindToDeviceSequenceTest,
+ ::testing::Values(IPv4UDPUnboundSocket(0),
+ IPv4TCPUnboundSocket(0)));
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_bind_to_device_util.cc b/test/syscalls/linux/socket_bind_to_device_util.cc
new file mode 100644
index 000000000..f4ee775bd
--- /dev/null
+++ b/test/syscalls/linux/socket_bind_to_device_util.cc
@@ -0,0 +1,75 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_bind_to_device_util.h"
+
+#include <arpa/inet.h>
+#include <fcntl.h>
+#include <linux/if_tun.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+#include <unistd.h>
+
+#include <cstdio>
+#include <cstring>
+#include <map>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+using std::string;
+
+PosixErrorOr<std::unique_ptr<Tunnel>> Tunnel::New(string tunnel_name) {
+ int fd;
+ RETURN_ERROR_IF_SYSCALL_FAIL(fd = open("/dev/net/tun", O_RDWR));
+
+ // Using `new` to access a non-public constructor.
+ auto new_tunnel = absl::WrapUnique(new Tunnel(fd));
+
+ ifreq ifr = {};
+ ifr.ifr_flags = IFF_TUN;
+ strncpy(ifr.ifr_name, tunnel_name.c_str(), sizeof(ifr.ifr_name));
+
+ RETURN_ERROR_IF_SYSCALL_FAIL(ioctl(fd, TUNSETIFF, &ifr));
+ new_tunnel->name_ = ifr.ifr_name;
+ return new_tunnel;
+}
+
+std::unordered_set<string> GetInterfaceNames() {
+ struct if_nameindex* interfaces = if_nameindex();
+ std::unordered_set<string> names;
+ if (interfaces == nullptr) {
+ return names;
+ }
+ for (auto interface = interfaces;
+ interface->if_index != 0 || interface->if_name != nullptr; interface++) {
+ names.insert(interface->if_name);
+ }
+ if_freenameindex(interfaces);
+ return names;
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_bind_to_device_util.h b/test/syscalls/linux/socket_bind_to_device_util.h
new file mode 100644
index 000000000..f941ccc86
--- /dev/null
+++ b/test/syscalls/linux/socket_bind_to_device_util.h
@@ -0,0 +1,67 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_
+#define GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_
+
+#include <arpa/inet.h>
+#include <linux/if_tun.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+#include <unistd.h>
+
+#include <cstdio>
+#include <cstring>
+#include <map>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+class Tunnel {
+ public:
+ static PosixErrorOr<std::unique_ptr<Tunnel>> New(
+ std::string tunnel_name = "");
+ const std::string& GetName() const { return name_; }
+
+ ~Tunnel() {
+ if (fd_ != -1) {
+ close(fd_);
+ }
+ }
+
+ private:
+ Tunnel(int fd) : fd_(fd) {}
+ int fd_ = -1;
+ std::string name_;
+};
+
+std::unordered_set<std::string> GetInterfaceNames();
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_
diff --git a/test/syscalls/linux/uidgid.cc b/test/syscalls/linux/uidgid.cc
index d48453a93..6218fbce1 100644
--- a/test/syscalls/linux/uidgid.cc
+++ b/test/syscalls/linux/uidgid.cc
@@ -25,6 +25,7 @@
#include "test/util/posix_error.h"
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
+#include "test/util/uid_util.h"
ABSL_FLAG(int32_t, scratch_uid1, 65534, "first scratch UID");
ABSL_FLAG(int32_t, scratch_uid2, 65533, "second scratch UID");
@@ -68,30 +69,6 @@ TEST(UidGidTest, Getgroups) {
// here; see the setgroups test below.
}
-// If the caller's real/effective/saved user/group IDs are all 0, IsRoot returns
-// true. Otherwise IsRoot logs an explanatory message and returns false.
-PosixErrorOr<bool> IsRoot() {
- uid_t ruid, euid, suid;
- int rc = getresuid(&ruid, &euid, &suid);
- MaybeSave();
- if (rc < 0) {
- return PosixError(errno, "getresuid");
- }
- if (ruid != 0 || euid != 0 || suid != 0) {
- return false;
- }
- gid_t rgid, egid, sgid;
- rc = getresgid(&rgid, &egid, &sgid);
- MaybeSave();
- if (rc < 0) {
- return PosixError(errno, "getresgid");
- }
- if (rgid != 0 || egid != 0 || sgid != 0) {
- return false;
- }
- return true;
-}
-
// Checks that the calling process' real/effective/saved user IDs are
// ruid/euid/suid respectively.
PosixError CheckUIDs(uid_t ruid, uid_t euid, uid_t suid) {
diff --git a/test/syscalls/linux/uname.cc b/test/syscalls/linux/uname.cc
index 0a5d91017..d8824b171 100644
--- a/test/syscalls/linux/uname.cc
+++ b/test/syscalls/linux/uname.cc
@@ -41,6 +41,19 @@ TEST(UnameTest, Sanity) {
TEST(UnameTest, SetNames) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+ char hostname[65];
+ ASSERT_THAT(sethostname("0123456789", 3), SyscallSucceeds());
+ EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds());
+ EXPECT_EQ(absl::string_view(hostname), "012");
+
+ ASSERT_THAT(sethostname("0123456789\0xxx", 11), SyscallSucceeds());
+ EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds());
+ EXPECT_EQ(absl::string_view(hostname), "0123456789");
+
+ ASSERT_THAT(sethostname("0123456789\0xxx", 12), SyscallSucceeds());
+ EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds());
+ EXPECT_EQ(absl::string_view(hostname), "0123456789");
+
constexpr char kHostname[] = "wubbalubba";
ASSERT_THAT(sethostname(kHostname, sizeof(kHostname)), SyscallSucceeds());
@@ -54,7 +67,6 @@ TEST(UnameTest, SetNames) {
EXPECT_EQ(absl::string_view(buf.domainname), kDomainname);
// These should just be glibc wrappers that also call uname(2).
- char hostname[65];
EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds());
EXPECT_EQ(absl::string_view(hostname), kHostname);
diff --git a/test/util/BUILD b/test/util/BUILD
index 25ed9c944..5d2a9cc2c 100644
--- a/test/util/BUILD
+++ b/test/util/BUILD
@@ -324,3 +324,14 @@ cc_library(
":test_util",
],
)
+
+cc_library(
+ name = "uid_util",
+ testonly = 1,
+ srcs = ["uid_util.cc"],
+ hdrs = ["uid_util.h"],
+ deps = [
+ ":posix_error",
+ ":save_util",
+ ],
+)
diff --git a/test/util/proc_util.cc b/test/util/proc_util.cc
index 75b24da37..34d636ba9 100644
--- a/test/util/proc_util.cc
+++ b/test/util/proc_util.cc
@@ -88,7 +88,7 @@ PosixErrorOr<std::vector<ProcMapsEntry>> ParseProcMaps(
std::vector<ProcMapsEntry> entries;
auto lines = absl::StrSplit(contents, '\n', absl::SkipEmpty());
for (const auto& l : lines) {
- std::cout << "line: " << l;
+ std::cout << "line: " << l << std::endl;
ASSIGN_OR_RETURN_ERRNO(auto entry, ParseProcMapsLine(l));
entries.push_back(entry);
}
diff --git a/test/util/uid_util.cc b/test/util/uid_util.cc
new file mode 100644
index 000000000..b131b4b99
--- /dev/null
+++ b/test/util/uid_util.cc
@@ -0,0 +1,44 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
+
+namespace gvisor {
+namespace testing {
+
+PosixErrorOr<bool> IsRoot() {
+ uid_t ruid, euid, suid;
+ int rc = getresuid(&ruid, &euid, &suid);
+ MaybeSave();
+ if (rc < 0) {
+ return PosixError(errno, "getresuid");
+ }
+ if (ruid != 0 || euid != 0 || suid != 0) {
+ return false;
+ }
+ gid_t rgid, egid, sgid;
+ rc = getresgid(&rgid, &egid, &sgid);
+ MaybeSave();
+ if (rc < 0) {
+ return PosixError(errno, "getresgid");
+ }
+ if (rgid != 0 || egid != 0 || sgid != 0) {
+ return false;
+ }
+ return true;
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/uid_util.h b/test/util/uid_util.h
new file mode 100644
index 000000000..2cd387fb0
--- /dev/null
+++ b/test/util/uid_util.h
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_UID_UTIL_H_
+#define GVISOR_TEST_SYSCALLS_UID_UTIL_H_
+
+#include "test/util/posix_error.h"
+
+namespace gvisor {
+namespace testing {
+
+// Returns true if the caller's real/effective/saved user/group IDs are all 0.
+PosixErrorOr<bool> IsRoot();
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_UID_UTIL_H_