From 00f8663887cbf9057d93e8848eb9538cf1c0cff4 Mon Sep 17 00:00:00 2001 From: Andrei Vagin Date: Mon, 3 Jun 2019 21:24:56 -0700 Subject: gvisor/fs: return a proper error from FileWriter.Write in case of a short-write The io.Writer contract requires that Write writes all available bytes and does not return short writes. This causes errors with io.Copy, since our own Write interface does not have this same contract. PiperOrigin-RevId: 251368730 --- pkg/sentry/fs/file.go | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) (limited to 'pkg/sentry/fs') diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go index 8c1307235..f64954457 100644 --- a/pkg/sentry/fs/file.go +++ b/pkg/sentry/fs/file.go @@ -545,12 +545,28 @@ type lockedWriter struct { // Write implements io.Writer.Write. func (w *lockedWriter) Write(buf []byte) (int, error) { - n, err := w.File.FileOperations.Write(w.Ctx, w.File, usermem.BytesIOSequence(buf), w.File.offset) - return int(n), err + return w.WriteAt(buf, w.File.offset) } // WriteAt implements io.Writer.WriteAt. func (w *lockedWriter) WriteAt(buf []byte, offset int64) (int, error) { - n, err := w.File.FileOperations.Write(w.Ctx, w.File, usermem.BytesIOSequence(buf), offset) - return int(n), err + var ( + written int + err error + ) + // The io.Writer contract requires that Write writes all available + // bytes and does not return short writes. This causes errors with + // io.Copy, since our own Write interface does not have this same + // contract. Enforce that here. + for written < len(buf) { + var n int64 + n, err = w.File.FileOperations.Write(w.Ctx, w.File, usermem.BytesIOSequence(buf[written:]), offset+int64(written)) + if n > 0 { + written += int(n) + } + if err != nil { + break + } + } + return written, err } -- cgit v1.2.3 From 90a116890fcea9fd39911bae854e4e67608a141d Mon Sep 17 00:00:00 2001 From: Andrei Vagin Date: Mon, 3 Jun 2019 21:47:09 -0700 Subject: gvisor/sock/unix: pass creds when a message is sent between unconnected sockets and don't report a sender address if it doesn't have one PiperOrigin-RevId: 251371284 --- pkg/sentry/fs/gofer/socket.go | 5 +++++ pkg/sentry/socket/control/control.go | 12 ++++++++++-- pkg/sentry/socket/unix/transport/unix.go | 4 ++++ pkg/sentry/socket/unix/unix.go | 6 +++++- test/syscalls/linux/accept_bind.cc | 14 +------------- test/syscalls/linux/socket_unix_unbound_dgram.cc | 24 ++++++++++++++++++++++++ 6 files changed, 49 insertions(+), 16 deletions(-) (limited to 'pkg/sentry/fs') diff --git a/pkg/sentry/fs/gofer/socket.go b/pkg/sentry/fs/gofer/socket.go index cbd5b9a84..7376fd76f 100644 --- a/pkg/sentry/fs/gofer/socket.go +++ b/pkg/sentry/fs/gofer/socket.go @@ -139,3 +139,8 @@ func (e *endpoint) UnidirectionalConnect() (transport.ConnectedEndpoint, *syserr func (e *endpoint) Release() { e.inode.DecRef() } + +// Passcred implements transport.BoundEndpoint.Passcred. +func (e *endpoint) Passcred() bool { + return false +} diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index c0238691d..434d7ca2e 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -406,12 +406,20 @@ func makeCreds(t *kernel.Task, socketOrEndpoint interface{}) SCMCredentials { return nil } if cr, ok := socketOrEndpoint.(transport.Credentialer); ok && (cr.Passcred() || cr.ConnectedPasscred()) { - tcred := t.Credentials() - return &scmCredentials{t, tcred.EffectiveKUID, tcred.EffectiveKGID} + return MakeCreds(t) } return nil } +// MakeCreds creates default SCMCredentials. +func MakeCreds(t *kernel.Task) SCMCredentials { + if t == nil { + return nil + } + tcred := t.Credentials() + return &scmCredentials{t, tcred.EffectiveKUID, tcred.EffectiveKGID} +} + // New creates default control messages if needed. func New(t *kernel.Task, socketOrEndpoint interface{}, rights SCMRights) transport.ControlMessages { return transport.ControlMessages{ diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index b734b4c20..37d82bb6b 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -237,6 +237,10 @@ type BoundEndpoint interface { // endpoint. UnidirectionalConnect() (ConnectedEndpoint, *syserr.Error) + // Passcred returns whether or not the SO_PASSCRED socket option is + // enabled on this end. + Passcred() bool + // Release releases any resources held by the BoundEndpoint. It must be // called before dropping all references to a BoundEndpoint returned by a // function. diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 1414be0c6..388cc0d8b 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -385,6 +385,10 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] } defer ep.Release() w.To = ep + + if ep.Passcred() && w.Control.Credentials == nil { + w.Control.Credentials = control.MakeCreds(t) + } } n, err := src.CopyInTo(t, &w) @@ -516,7 +520,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || dontWait { var from interface{} var fromLen uint32 - if r.From != nil { + if r.From != nil && len([]byte(r.From.Addr)) != 0 { from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From) } diff --git a/test/syscalls/linux/accept_bind.cc b/test/syscalls/linux/accept_bind.cc index 56377feab..1122ea240 100644 --- a/test/syscalls/linux/accept_bind.cc +++ b/test/syscalls/linux/accept_bind.cc @@ -448,19 +448,7 @@ TEST_P(AllSocketPairTest, UnboundSenderAddr) { RetryEINTR(recvfrom)(accepted_fd.get(), &i, sizeof(i), 0, reinterpret_cast(&addr), &addr_len), SyscallSucceedsWithValue(sizeof(i))); - if (!IsRunningOnGvisor()) { - // Linux returns a zero length for addresses from recvfrom(2) and - // recvmsg(2). This differs from the behavior of getpeername(2) and - // getsockname(2). For simplicity, we use the getpeername(2) and - // getsockname(2) behavior for recvfrom(2) and recvmsg(2). - EXPECT_EQ(addr_len, 0); - return; - } - EXPECT_EQ(addr_len, 2); - EXPECT_EQ( - memcmp(&addr, sockets->second_addr(), - std::min((size_t)addr_len, (size_t)sockets->second_addr_len())), - 0); + EXPECT_EQ(addr_len, 0); } TEST_P(AllSocketPairTest, BoundSenderAddr) { diff --git a/test/syscalls/linux/socket_unix_unbound_dgram.cc b/test/syscalls/linux/socket_unix_unbound_dgram.cc index 2ddc5c11f..52aef891f 100644 --- a/test/syscalls/linux/socket_unix_unbound_dgram.cc +++ b/test/syscalls/linux/socket_unix_unbound_dgram.cc @@ -13,7 +13,9 @@ // limitations under the License. #include +#include #include + #include "gtest/gtest.h" #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" @@ -142,6 +144,28 @@ TEST_P(UnboundDgramUnixSocketPairTest, SendtoWithoutConnect) { SyscallSucceedsWithValue(sizeof(data))); } +TEST_P(UnboundDgramUnixSocketPairTest, SendtoWithoutConnectPassCreds) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + SetSoPassCred(sockets->first_fd()); + char data = 'a'; + ASSERT_THAT( + RetryEINTR(sendto)(sockets->second_fd(), &data, sizeof(data), 0, + sockets->first_addr(), sockets->first_addr_size()), + SyscallSucceedsWithValue(sizeof(data))); + ucred creds; + creds.pid = -1; + char buf[sizeof(data) + 1]; + ASSERT_NO_FATAL_FAILURE( + RecvCreds(sockets->first_fd(), &creds, buf, sizeof(buf), sizeof(data))); + EXPECT_EQ(0, memcmp(&data, buf, sizeof(data))); + EXPECT_THAT(getpid(), SyscallSucceedsWithValue(creds.pid)); +} + INSTANTIATE_TEST_SUITE_P( AllUnixDomainSockets, UnboundDgramUnixSocketPairTest, ::testing::ValuesIn(VecCat( -- cgit v1.2.3 From 7398f013f043cfe43b5fc615bd24b641df17e6bc Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 4 Jun 2019 15:39:24 -0700 Subject: Drop one dirent reference after referenced by file When pipe is created, a dirent of pipe will be created and its initial reference is set as 0. Cause all dirent will only be destroyed when the reference decreased to -1, so there is already a 'initial reference' of dirent after it created. For destroying dirent after all reference released, the correct way is to drop the 'initial reference' once someone hold a reference to the dirent, such as fs.NewFile, otherwise the reference of dirent will stay 0 all the time, and will cause memory leak of dirent. Except pipe, timerfd/eventfd/epoll has the same problem Here is a simple case to create memory leak of dirent for pipe/timerfd/eventfd/epoll in C langange, after run the case, pprof the runsc process, you will find lots dirents of pipe/timerfd/eventfd/epoll not freed: int main(int argc, char *argv[]) { int i; int n; int pipefd[2]; if (argc != 3) { printf("Usage: %s epoll|timerfd|eventfd|pipe \n", argv[0]); } n = strtol(argv[2], NULL, 10); if (strcmp(argv[1], "epoll") == 0) { for (i = 0; i < n; ++i) close(epoll_create(1)); } else if (strcmp(argv[1], "timerfd") == 0) { for (i = 0; i < n; ++i) close(timerfd_create(CLOCK_REALTIME, 0)); } else if (strcmp(argv[1], "eventfd") == 0) { for (i = 0; i < n; ++i) close(eventfd(0, 0)); } else if (strcmp(argv[1], "pipe") == 0) { for (i = 0; i < n; ++i) if (pipe(pipefd) == 0) { close(pipefd[0]); close(pipefd[1]); } } printf("%s %s test finished\r\n",argv[1],argv[2]); return 0; } Change-Id: Ia1b8a1fb9142edb00c040e44ec644d007f81f5d2 PiperOrigin-RevId: 251531096 --- pkg/sentry/fs/timerfd/timerfd.go | 2 ++ pkg/sentry/kernel/epoll/epoll.go | 2 ++ pkg/sentry/kernel/eventfd/eventfd.go | 2 ++ 3 files changed, 6 insertions(+) (limited to 'pkg/sentry/fs') diff --git a/pkg/sentry/fs/timerfd/timerfd.go b/pkg/sentry/fs/timerfd/timerfd.go index bce5f091d..c1721f434 100644 --- a/pkg/sentry/fs/timerfd/timerfd.go +++ b/pkg/sentry/fs/timerfd/timerfd.go @@ -54,6 +54,8 @@ type TimerOperations struct { // NewFile returns a timerfd File that receives time from c. func NewFile(ctx context.Context, c ktime.Clock) *fs.File { dirent := fs.NewDirent(anon.NewInode(ctx), "anon_inode:[timerfd]") + // Release the initial dirent reference after NewFile takes a reference. + defer dirent.DecRef() tops := &TimerOperations{} tops.timer = ktime.NewTimer(c, tops) // Timerfds reject writes, but the Write flag must be set in order to diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go index bbacba1f4..43ae22a5d 100644 --- a/pkg/sentry/kernel/epoll/epoll.go +++ b/pkg/sentry/kernel/epoll/epoll.go @@ -156,6 +156,8 @@ var cycleMu sync.Mutex func NewEventPoll(ctx context.Context) *fs.File { // name matches fs/eventpoll.c:epoll_create1. dirent := fs.NewDirent(anon.NewInode(ctx), fmt.Sprintf("anon_inode:[eventpoll]")) + // Release the initial dirent reference after NewFile takes a reference. + defer dirent.DecRef() return fs.NewFile(ctx, dirent, fs.FileFlags{}, &EventPoll{ files: make(map[FileIdentifier]*pollEntry), }) diff --git a/pkg/sentry/kernel/eventfd/eventfd.go b/pkg/sentry/kernel/eventfd/eventfd.go index 2f900be38..fe474cbf0 100644 --- a/pkg/sentry/kernel/eventfd/eventfd.go +++ b/pkg/sentry/kernel/eventfd/eventfd.go @@ -69,6 +69,8 @@ type EventOperations struct { func New(ctx context.Context, initVal uint64, semMode bool) *fs.File { // name matches fs/eventfd.c:eventfd_file_create. dirent := fs.NewDirent(anon.NewInode(ctx), "anon_inode:[eventfd]") + // Release the initial dirent reference after NewFile takes a reference. + defer dirent.DecRef() return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &EventOperations{ val: initVal, semMode: semMode, -- cgit v1.2.3 From d3ed9baac0dc967eaf6d3e3f986cafe60604121a Mon Sep 17 00:00:00 2001 From: Michael Pratt Date: Wed, 5 Jun 2019 13:59:01 -0700 Subject: Implement dumpability tracking and checks We don't actually support core dumps, but some applications want to get/set dumpability, which still has an effect in procfs. Lack of support for set-uid binaries or fs creds simplifies things a bit. As-is, processes started via CreateProcess (i.e., init and sentryctl exec) have normal dumpability. I'm a bit torn on whether sentryctl exec tasks should be dumpable, but at least since they have no parent normal UID/GID checks should protect them. PiperOrigin-RevId: 251712714 --- pkg/abi/linux/prctl.go | 7 +++++ pkg/sentry/fs/proc/inode.go | 40 ++++++++++++++++++++++-- pkg/sentry/fs/proc/task.go | 17 +++++++++- pkg/sentry/kernel/ptrace.go | 17 +++++++++- pkg/sentry/kernel/task_exec.go | 7 +++++ pkg/sentry/kernel/task_identity.go | 24 ++++++++++++-- pkg/sentry/mm/lifecycle.go | 6 ++-- pkg/sentry/mm/metadata.go | 30 ++++++++++++++++++ pkg/sentry/mm/mm.go | 6 ++++ pkg/sentry/syscalls/linux/sys_prctl.go | 33 ++++++++++++++++++-- test/syscalls/linux/BUILD | 1 + test/syscalls/linux/prctl.cc | 34 ++++++++++++++++++++ test/syscalls/linux/proc.cc | 57 ++++++++++++++++++++++++++++++++++ 13 files changed, 268 insertions(+), 11 deletions(-) (limited to 'pkg/sentry/fs') diff --git a/pkg/abi/linux/prctl.go b/pkg/abi/linux/prctl.go index 0428282dd..391cfaa1c 100644 --- a/pkg/abi/linux/prctl.go +++ b/pkg/abi/linux/prctl.go @@ -155,3 +155,10 @@ const ( ARCH_GET_GS = 0x1004 ARCH_SET_CPUID = 0x1012 ) + +// Flags for prctl(PR_SET_DUMPABLE), defined in include/linux/sched/coredump.h. +const ( + SUID_DUMP_DISABLE = 0 + SUID_DUMP_USER = 1 + SUID_DUMP_ROOT = 2 +) diff --git a/pkg/sentry/fs/proc/inode.go b/pkg/sentry/fs/proc/inode.go index 379569823..986bc0a45 100644 --- a/pkg/sentry/fs/proc/inode.go +++ b/pkg/sentry/fs/proc/inode.go @@ -21,11 +21,14 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil" "gvisor.googlesource.com/gvisor/pkg/sentry/fs/proc/device" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" + "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth" + "gvisor.googlesource.com/gvisor/pkg/sentry/mm" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" ) // taskOwnedInodeOps wraps an fs.InodeOperations and overrides the UnstableAttr -// method to return the task as the owner. +// method to return either the task or root as the owner, depending on the +// task's dumpability. // // +stateify savable type taskOwnedInodeOps struct { @@ -41,9 +44,42 @@ func (i *taskOwnedInodeOps) UnstableAttr(ctx context.Context, inode *fs.Inode) ( if err != nil { return fs.UnstableAttr{}, err } - // Set the task owner as the file owner. + + // By default, set the task owner as the file owner. creds := i.t.Credentials() uattr.Owner = fs.FileOwner{creds.EffectiveKUID, creds.EffectiveKGID} + + // Linux doesn't apply dumpability adjustments to world + // readable/executable directories so that applications can stat + // /proc/PID to determine the effective UID of a process. See + // fs/proc/base.c:task_dump_owner. + if fs.IsDir(inode.StableAttr) && uattr.Perms == fs.FilePermsFromMode(0555) { + return uattr, nil + } + + // If the task is not dumpable, then root (in the namespace preferred) + // owns the file. + var m *mm.MemoryManager + i.t.WithMuLocked(func(t *kernel.Task) { + m = t.MemoryManager() + }) + + if m == nil { + uattr.Owner.UID = auth.RootKUID + uattr.Owner.GID = auth.RootKGID + } else if m.Dumpability() != mm.UserDumpable { + if kuid := creds.UserNamespace.MapToKUID(auth.RootUID); kuid.Ok() { + uattr.Owner.UID = kuid + } else { + uattr.Owner.UID = auth.RootKUID + } + if kgid := creds.UserNamespace.MapToKGID(auth.RootGID); kgid.Ok() { + uattr.Owner.GID = kgid + } else { + uattr.Owner.GID = auth.RootKGID + } + } + return uattr, nil } diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index 77e03d349..21a965f90 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -96,7 +96,7 @@ func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, showSubtasks boo contents["cgroup"] = newCGroupInode(t, msrc, p.cgroupControllers) } - // TODO(b/31916171): Set EUID/EGID based on dumpability. + // N.B. taskOwnedInodeOps enforces dumpability-based ownership. d := &taskDir{ Dir: *ramfs.NewDir(t, contents, fs.RootOwner, fs.FilePermsFromMode(0555)), t: t, @@ -667,6 +667,21 @@ func newComm(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { return newProcInode(c, msrc, fs.SpecialFile, t) } +// Check implements fs.InodeOperations.Check. +func (c *comm) Check(ctx context.Context, inode *fs.Inode, p fs.PermMask) bool { + // This file can always be read or written by members of the same + // thread group. See fs/proc/base.c:proc_tid_comm_permission. + // + // N.B. This check is currently a no-op as we don't yet support writing + // and this file is world-readable anyways. + t := kernel.TaskFromContext(ctx) + if t != nil && t.ThreadGroup() == c.t.ThreadGroup() && !p.Execute { + return true + } + + return fs.ContextCanAccessFile(ctx, inode, p) +} + // GetFile implements fs.InodeOperations.GetFile. func (c *comm) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { return fs.NewFile(ctx, dirent, flags, &commFile{t: c.t}), nil diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go index 4423e7efd..193447b17 100644 --- a/pkg/sentry/kernel/ptrace.go +++ b/pkg/sentry/kernel/ptrace.go @@ -19,6 +19,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/sentry/arch" + "gvisor.googlesource.com/gvisor/pkg/sentry/mm" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserror" ) @@ -92,6 +93,14 @@ const ( // ptrace(2), subsection "Ptrace access mode checking". If attach is true, it // checks for access mode PTRACE_MODE_ATTACH; otherwise, it checks for access // mode PTRACE_MODE_READ. +// +// NOTE(b/30815691): The result of CanTrace is immediately stale (e.g., a +// racing setuid(2) may change traceability). This may pose a risk when a task +// changes from traceable to not traceable. This is only problematic across +// execve, where privileges may increase. +// +// We currently do not implement privileged executables (set-user/group-ID bits +// and file capabilities), so that case is not reachable. func (t *Task) CanTrace(target *Task, attach bool) bool { // "1. If the calling thread and the target thread are in the same thread // group, access is always allowed." - ptrace(2) @@ -162,7 +171,13 @@ func (t *Task) CanTrace(target *Task, attach bool) bool { if cgid := callerCreds.RealKGID; cgid != targetCreds.RealKGID || cgid != targetCreds.EffectiveKGID || cgid != targetCreds.SavedKGID { return false } - // TODO(b/31916171): dumpability check + var targetMM *mm.MemoryManager + target.WithMuLocked(func(t *Task) { + targetMM = t.MemoryManager() + }) + if targetMM != nil && targetMM.Dumpability() != mm.UserDumpable { + return false + } if callerCreds.UserNamespace != targetCreds.UserNamespace { return false } diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go index 5d1425d5c..35d5cb90c 100644 --- a/pkg/sentry/kernel/task_exec.go +++ b/pkg/sentry/kernel/task_exec.go @@ -68,6 +68,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/sentry/arch" "gvisor.googlesource.com/gvisor/pkg/sentry/fs" + "gvisor.googlesource.com/gvisor/pkg/sentry/mm" "gvisor.googlesource.com/gvisor/pkg/syserror" ) @@ -198,6 +199,12 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { return flags.CloseOnExec }) + // NOTE(b/30815691): We currently do not implement privileged + // executables (set-user/group-ID bits and file capabilities). This + // allows us to unconditionally enable user dumpability on the new mm. + // See fs/exec.c:setup_new_exec. + r.tc.MemoryManager.SetDumpability(mm.UserDumpable) + // Switch to the new process. t.MemoryManager().Deactivate() t.mu.Lock() diff --git a/pkg/sentry/kernel/task_identity.go b/pkg/sentry/kernel/task_identity.go index 17f08729a..ec95f78d0 100644 --- a/pkg/sentry/kernel/task_identity.go +++ b/pkg/sentry/kernel/task_identity.go @@ -17,6 +17,7 @@ package kernel import ( "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth" + "gvisor.googlesource.com/gvisor/pkg/sentry/mm" "gvisor.googlesource.com/gvisor/pkg/syserror" ) @@ -206,8 +207,17 @@ func (t *Task) setKUIDsUncheckedLocked(newR, newE, newS auth.KUID) { // (filesystem UIDs aren't implemented, nor are any of the capabilities in // question) - // Not documented, but compare Linux's kernel/cred.c:commit_creds(). if oldE != newE { + // "[dumpability] is reset to the current value contained in + // the file /proc/sys/fs/suid_dumpable (which by default has + // the value 0), in the following circumstances: The process's + // effective user or group ID is changed." - prctl(2) + // + // (suid_dumpable isn't implemented, so we just use the + // default. + t.MemoryManager().SetDumpability(mm.NotDumpable) + + // Not documented, but compare Linux's kernel/cred.c:commit_creds(). t.parentDeathSignal = 0 } } @@ -303,8 +313,18 @@ func (t *Task) setKGIDsUncheckedLocked(newR, newE, newS auth.KGID) { t.creds = t.creds.Fork() // See doc for creds. t.creds.RealKGID, t.creds.EffectiveKGID, t.creds.SavedKGID = newR, newE, newS - // Not documented, but compare Linux's kernel/cred.c:commit_creds(). if oldE != newE { + // "[dumpability] is reset to the current value contained in + // the file /proc/sys/fs/suid_dumpable (which by default has + // the value 0), in the following circumstances: The process's + // effective user or group ID is changed." - prctl(2) + // + // (suid_dumpable isn't implemented, so we just use the + // default. + t.MemoryManager().SetDumpability(mm.NotDumpable) + + // Not documented, but compare Linux's + // kernel/cred.c:commit_creds(). t.parentDeathSignal = 0 } } diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go index 7a65a62a2..7646d5ab2 100644 --- a/pkg/sentry/mm/lifecycle.go +++ b/pkg/sentry/mm/lifecycle.go @@ -37,6 +37,7 @@ func NewMemoryManager(p platform.Platform, mfp pgalloc.MemoryFileProvider) *Memo privateRefs: &privateRefs{}, users: 1, auxv: arch.Auxv{}, + dumpability: UserDumpable, aioManager: aioManager{contexts: make(map[uint64]*AIOContext)}, } } @@ -79,8 +80,9 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) { envv: mm.envv, auxv: append(arch.Auxv(nil), mm.auxv...), // IncRef'd below, once we know that there isn't an error. - executable: mm.executable, - aioManager: aioManager{contexts: make(map[uint64]*AIOContext)}, + executable: mm.executable, + dumpability: mm.dumpability, + aioManager: aioManager{contexts: make(map[uint64]*AIOContext)}, } // Copy vmas. diff --git a/pkg/sentry/mm/metadata.go b/pkg/sentry/mm/metadata.go index 9768e51f1..c218006ee 100644 --- a/pkg/sentry/mm/metadata.go +++ b/pkg/sentry/mm/metadata.go @@ -20,6 +20,36 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" ) +// Dumpability describes if and how core dumps should be created. +type Dumpability int + +const ( + // NotDumpable indicates that core dumps should never be created. + NotDumpable Dumpability = iota + + // UserDumpable indicates that core dumps should be created, owned by + // the current user. + UserDumpable + + // RootDumpable indicates that core dumps should be created, owned by + // root. + RootDumpable +) + +// Dumpability returns the dumpability. +func (mm *MemoryManager) Dumpability() Dumpability { + mm.metadataMu.Lock() + defer mm.metadataMu.Unlock() + return mm.dumpability +} + +// SetDumpability sets the dumpability. +func (mm *MemoryManager) SetDumpability(d Dumpability) { + mm.metadataMu.Lock() + defer mm.metadataMu.Unlock() + mm.dumpability = d +} + // ArgvStart returns the start of the application argument vector. // // There is no guarantee that this value is sensible w.r.t. ArgvEnd. diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index eb6defa2b..0a026ff8c 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -219,6 +219,12 @@ type MemoryManager struct { // executable is protected by metadataMu. executable *fs.Dirent + // dumpability describes if and how this MemoryManager may be dumped to + // userspace. + // + // dumpability is protected by metadataMu. + dumpability Dumpability + // aioManager keeps track of AIOContexts used for async IOs. AIOManager // must be cloned when CLONE_VM is used. aioManager aioManager diff --git a/pkg/sentry/syscalls/linux/sys_prctl.go b/pkg/sentry/syscalls/linux/sys_prctl.go index 117ae1a0e..1b7e5616b 100644 --- a/pkg/sentry/syscalls/linux/sys_prctl.go +++ b/pkg/sentry/syscalls/linux/sys_prctl.go @@ -15,6 +15,7 @@ package linux import ( + "fmt" "syscall" "gvisor.googlesource.com/gvisor/pkg/abi/linux" @@ -23,6 +24,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs" + "gvisor.googlesource.com/gvisor/pkg/sentry/mm" ) // Prctl implements linux syscall prctl(2). @@ -44,6 +46,33 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall _, err := t.CopyOut(args[1].Pointer(), int32(t.ParentDeathSignal())) return 0, nil, err + case linux.PR_GET_DUMPABLE: + d := t.MemoryManager().Dumpability() + switch d { + case mm.NotDumpable: + return linux.SUID_DUMP_DISABLE, nil, nil + case mm.UserDumpable: + return linux.SUID_DUMP_USER, nil, nil + case mm.RootDumpable: + return linux.SUID_DUMP_ROOT, nil, nil + default: + panic(fmt.Sprintf("Unknown dumpability %v", d)) + } + + case linux.PR_SET_DUMPABLE: + var d mm.Dumpability + switch args[1].Int() { + case linux.SUID_DUMP_DISABLE: + d = mm.NotDumpable + case linux.SUID_DUMP_USER: + d = mm.UserDumpable + default: + // N.B. Userspace may not pass SUID_DUMP_ROOT. + return 0, nil, syscall.EINVAL + } + t.MemoryManager().SetDumpability(d) + return 0, nil, nil + case linux.PR_GET_KEEPCAPS: if t.Credentials().KeepCaps { return 1, nil, nil @@ -171,9 +200,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } return 0, nil, t.DropBoundingCapability(cp) - case linux.PR_GET_DUMPABLE, - linux.PR_SET_DUMPABLE, - linux.PR_GET_TIMING, + case linux.PR_GET_TIMING, linux.PR_SET_TIMING, linux.PR_GET_TSC, linux.PR_SET_TSC, diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index ba9fd6d1f..7633ab162 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1317,6 +1317,7 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:capability_util", + "//test/util:cleanup", "//test/util:multiprocess_util", "//test/util:posix_error", "//test/util:test_util", diff --git a/test/syscalls/linux/prctl.cc b/test/syscalls/linux/prctl.cc index bce42dc74..bd1779557 100644 --- a/test/syscalls/linux/prctl.cc +++ b/test/syscalls/linux/prctl.cc @@ -17,10 +17,12 @@ #include #include #include + #include #include "gtest/gtest.h" #include "test/util/capability_util.h" +#include "test/util/cleanup.h" #include "test/util/multiprocess_util.h" #include "test/util/posix_error.h" #include "test/util/test_util.h" @@ -35,6 +37,16 @@ namespace testing { namespace { +#ifndef SUID_DUMP_DISABLE +#define SUID_DUMP_DISABLE 0 +#endif /* SUID_DUMP_DISABLE */ +#ifndef SUID_DUMP_USER +#define SUID_DUMP_USER 1 +#endif /* SUID_DUMP_USER */ +#ifndef SUID_DUMP_ROOT +#define SUID_DUMP_ROOT 2 +#endif /* SUID_DUMP_ROOT */ + TEST(PrctlTest, NameInitialized) { const size_t name_length = 20; char name[name_length] = {}; @@ -178,6 +190,28 @@ TEST(PrctlTest, InvalidPrSetMM) { ASSERT_THAT(prctl(PR_SET_MM, 0, 0, 0, 0), SyscallFailsWithErrno(EPERM)); } +// Sanity check that dumpability is remembered. +TEST(PrctlTest, SetGetDumpability) { + int before; + ASSERT_THAT(before = prctl(PR_GET_DUMPABLE), SyscallSucceeds()); + auto cleanup = Cleanup([before] { + ASSERT_THAT(prctl(PR_SET_DUMPABLE, before), SyscallSucceeds()); + }); + + EXPECT_THAT(prctl(PR_SET_DUMPABLE, SUID_DUMP_DISABLE), SyscallSucceeds()); + EXPECT_THAT(prctl(PR_GET_DUMPABLE), + SyscallSucceedsWithValue(SUID_DUMP_DISABLE)); + + EXPECT_THAT(prctl(PR_SET_DUMPABLE, SUID_DUMP_USER), SyscallSucceeds()); + EXPECT_THAT(prctl(PR_GET_DUMPABLE), SyscallSucceedsWithValue(SUID_DUMP_USER)); +} + +// SUID_DUMP_ROOT cannot be set via PR_SET_DUMPABLE. +TEST(PrctlTest, RootDumpability) { + EXPECT_THAT(prctl(PR_SET_DUMPABLE, SUID_DUMP_ROOT), + SyscallFailsWithErrno(EINVAL)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index ede6fb860..924b98e3a 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -69,9 +69,11 @@ // way to get it tested on both gVisor, PTrace and Linux. using ::testing::AllOf; +using ::testing::AnyOf; using ::testing::ContainerEq; using ::testing::Contains; using ::testing::ContainsRegex; +using ::testing::Eq; using ::testing::Gt; using ::testing::HasSubstr; using ::testing::IsSupersetOf; @@ -86,6 +88,16 @@ namespace gvisor { namespace testing { namespace { +#ifndef SUID_DUMP_DISABLE +#define SUID_DUMP_DISABLE 0 +#endif /* SUID_DUMP_DISABLE */ +#ifndef SUID_DUMP_USER +#define SUID_DUMP_USER 1 +#endif /* SUID_DUMP_USER */ +#ifndef SUID_DUMP_ROOT +#define SUID_DUMP_ROOT 2 +#endif /* SUID_DUMP_ROOT */ + // O_LARGEFILE as defined by Linux. glibc tries to be clever by setting it to 0 // because "it isn't needed", even though Linux can return it via F_GETFL. constexpr int kOLargeFile = 00100000; @@ -1896,6 +1908,51 @@ void CheckDuplicatesRecursively(std::string path) { TEST(Proc, NoDuplicates) { CheckDuplicatesRecursively("/proc"); } +// Most /proc/PID files are owned by the task user with SUID_DUMP_USER. +TEST(ProcPid, UserDumpableOwner) { + int before; + ASSERT_THAT(before = prctl(PR_GET_DUMPABLE), SyscallSucceeds()); + auto cleanup = Cleanup([before] { + ASSERT_THAT(prctl(PR_SET_DUMPABLE, before), SyscallSucceeds()); + }); + + EXPECT_THAT(prctl(PR_SET_DUMPABLE, SUID_DUMP_USER), SyscallSucceeds()); + + // This applies to the task directory itself and files inside. + struct stat st; + ASSERT_THAT(stat("/proc/self/", &st), SyscallSucceeds()); + EXPECT_EQ(st.st_uid, geteuid()); + EXPECT_EQ(st.st_gid, getegid()); + + ASSERT_THAT(stat("/proc/self/stat", &st), SyscallSucceeds()); + EXPECT_EQ(st.st_uid, geteuid()); + EXPECT_EQ(st.st_gid, getegid()); +} + +// /proc/PID files are owned by root with SUID_DUMP_DISABLE. +TEST(ProcPid, RootDumpableOwner) { + int before; + ASSERT_THAT(before = prctl(PR_GET_DUMPABLE), SyscallSucceeds()); + auto cleanup = Cleanup([before] { + ASSERT_THAT(prctl(PR_SET_DUMPABLE, before), SyscallSucceeds()); + }); + + EXPECT_THAT(prctl(PR_SET_DUMPABLE, SUID_DUMP_DISABLE), SyscallSucceeds()); + + // This *does not* applies to the task directory itself (or other 0555 + // directories), but does to files inside. + struct stat st; + ASSERT_THAT(stat("/proc/self/", &st), SyscallSucceeds()); + EXPECT_EQ(st.st_uid, geteuid()); + EXPECT_EQ(st.st_gid, getegid()); + + // This file is owned by root. Also allow nobody in case this test is running + // in a userns without root mapped. + ASSERT_THAT(stat("/proc/self/stat", &st), SyscallSucceeds()); + EXPECT_THAT(st.st_uid, AnyOf(Eq(0), Eq(65534))); + EXPECT_THAT(st.st_gid, AnyOf(Eq(0), Eq(65534))); +} + } // namespace } // namespace testing } // namespace gvisor -- cgit v1.2.3 From 57772db2e7351511de422baeecf807785709ee5d Mon Sep 17 00:00:00 2001 From: Michael Pratt Date: Wed, 5 Jun 2019 18:39:30 -0700 Subject: Shutdown host sockets on internal shutdown This is required to make the shutdown visible to peers outside the sandbox. The readClosed / writeClosed fields were dropped, as they were preventing a shutdown socket from reading the remainder of queued bytes. The host syscalls will return the appropriate errors for shutdown. The control message tests have been split out of socket_unix.cc to make the (few) remaining tests accessible to testing inherited host UDS, which don't support sending control messages. Updates #273 PiperOrigin-RevId: 251763060 --- pkg/sentry/fs/host/socket.go | 62 +- pkg/sentry/fs/host/socket_test.go | 156 --- runsc/boot/filter/config.go | 4 + test/syscalls/linux/BUILD | 23 + test/syscalls/linux/socket_abstract.cc | 5 + test/syscalls/linux/socket_filesystem.cc | 5 + test/syscalls/linux/socket_unix.cc | 1518 ++---------------------------- test/syscalls/linux/socket_unix_cmsg.cc | 1473 +++++++++++++++++++++++++++++ test/syscalls/linux/socket_unix_cmsg.h | 30 + test/syscalls/linux/socket_unix_pair.cc | 5 + 10 files changed, 1655 insertions(+), 1626 deletions(-) create mode 100644 test/syscalls/linux/socket_unix_cmsg.cc create mode 100644 test/syscalls/linux/socket_unix_cmsg.h (limited to 'pkg/sentry/fs') diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 3ed137006..e4ec0f62c 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -15,6 +15,7 @@ package host import ( + "fmt" "sync" "syscall" @@ -51,20 +52,6 @@ type ConnectedEndpoint struct { // ref keeps track of references to a connectedEndpoint. ref refs.AtomicRefCount - // mu protects fd, readClosed and writeClosed. - mu sync.RWMutex `state:"nosave"` - - // file is an *fd.FD containing the FD backing this endpoint. It must be - // set to nil if it has been closed. - file *fd.FD `state:"nosave"` - - // readClosed is true if the FD has read shutdown or if it has been closed. - readClosed bool - - // writeClosed is true if the FD has write shutdown or if it has been - // closed. - writeClosed bool - // If srfd >= 0, it is the host FD that file was imported from. srfd int `state:"wait"` @@ -78,6 +65,13 @@ type ConnectedEndpoint struct { // prevent lots of small messages from filling the real send buffer // size on the host. sndbuf int `state:"nosave"` + + // mu protects the fields below. + mu sync.RWMutex `state:"nosave"` + + // file is an *fd.FD containing the FD backing this endpoint. It must be + // set to nil if it has been closed. + file *fd.FD `state:"nosave"` } // init performs initialization required for creating new ConnectedEndpoints and @@ -208,9 +202,6 @@ func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error) func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (uintptr, bool, *syserr.Error) { c.mu.RLock() defer c.mu.RUnlock() - if c.writeClosed { - return 0, false, syserr.ErrClosedForSend - } if !controlMessages.Empty() { return 0, false, syserr.ErrInvalidEndpointState @@ -244,8 +235,13 @@ func (c *ConnectedEndpoint) SendNotify() {} // CloseSend implements transport.ConnectedEndpoint.CloseSend. func (c *ConnectedEndpoint) CloseSend() { c.mu.Lock() - c.writeClosed = true - c.mu.Unlock() + defer c.mu.Unlock() + + if err := syscall.Shutdown(c.file.FD(), syscall.SHUT_WR); err != nil { + // A well-formed UDS shutdown can't fail. See + // net/unix/af_unix.c:unix_shutdown. + panic(fmt.Sprintf("failed write shutdown on host socket %+v: %v", c, err)) + } } // CloseNotify implements transport.ConnectedEndpoint.CloseNotify. @@ -255,9 +251,7 @@ func (c *ConnectedEndpoint) CloseNotify() {} func (c *ConnectedEndpoint) Writable() bool { c.mu.RLock() defer c.mu.RUnlock() - if c.writeClosed { - return true - } + return fdnotifier.NonBlockingPoll(int32(c.file.FD()), waiter.EventOut)&waiter.EventOut != 0 } @@ -285,9 +279,6 @@ func (c *ConnectedEndpoint) EventUpdate() { func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { c.mu.RLock() defer c.mu.RUnlock() - if c.readClosed { - return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.ErrClosedForReceive - } var cm unet.ControlMessage if numRights > 0 { @@ -344,31 +335,34 @@ func (c *ConnectedEndpoint) RecvNotify() {} // CloseRecv implements transport.Receiver.CloseRecv. func (c *ConnectedEndpoint) CloseRecv() { c.mu.Lock() - c.readClosed = true - c.mu.Unlock() + defer c.mu.Unlock() + + if err := syscall.Shutdown(c.file.FD(), syscall.SHUT_RD); err != nil { + // A well-formed UDS shutdown can't fail. See + // net/unix/af_unix.c:unix_shutdown. + panic(fmt.Sprintf("failed read shutdown on host socket %+v: %v", c, err)) + } } // Readable implements transport.Receiver.Readable. func (c *ConnectedEndpoint) Readable() bool { c.mu.RLock() defer c.mu.RUnlock() - if c.readClosed { - return true - } + return fdnotifier.NonBlockingPoll(int32(c.file.FD()), waiter.EventIn)&waiter.EventIn != 0 } // SendQueuedSize implements transport.Receiver.SendQueuedSize. func (c *ConnectedEndpoint) SendQueuedSize() int64 { - // SendQueuedSize isn't supported for host sockets because we don't allow the - // sentry to call ioctl(2). + // TODO(gvisor.dev/issue/273): SendQueuedSize isn't supported for host + // sockets because we don't allow the sentry to call ioctl(2). return -1 } // RecvQueuedSize implements transport.Receiver.RecvQueuedSize. func (c *ConnectedEndpoint) RecvQueuedSize() int64 { - // RecvQueuedSize isn't supported for host sockets because we don't allow the - // sentry to call ioctl(2). + // TODO(gvisor.dev/issue/273): RecvQueuedSize isn't supported for host + // sockets because we don't allow the sentry to call ioctl(2). return -1 } diff --git a/pkg/sentry/fs/host/socket_test.go b/pkg/sentry/fs/host/socket_test.go index 06392a65a..bc3ce5627 100644 --- a/pkg/sentry/fs/host/socket_test.go +++ b/pkg/sentry/fs/host/socket_test.go @@ -198,20 +198,6 @@ func TestListen(t *testing.T) { } } -func TestSend(t *testing.T) { - e := ConnectedEndpoint{writeClosed: true} - if _, _, err := e.Send(nil, transport.ControlMessages{}, tcpip.FullAddress{}); err != syserr.ErrClosedForSend { - t.Errorf("Got %#v.Send() = %v, want = %v", e, err, syserr.ErrClosedForSend) - } -} - -func TestRecv(t *testing.T) { - e := ConnectedEndpoint{readClosed: true} - if _, _, _, _, _, _, err := e.Recv(nil, false, 0, false); err != syserr.ErrClosedForReceive { - t.Errorf("Got %#v.Recv() = %v, want = %v", e, err, syserr.ErrClosedForReceive) - } -} - func TestPasscred(t *testing.T) { e := ConnectedEndpoint{} if got, want := e.Passcred(), false; got != want { @@ -244,20 +230,6 @@ func TestQueuedSize(t *testing.T) { } } -func TestReadable(t *testing.T) { - e := ConnectedEndpoint{readClosed: true} - if got, want := e.Readable(), true; got != want { - t.Errorf("Got %#v.Readable() = %t, want = %t", e, got, want) - } -} - -func TestWritable(t *testing.T) { - e := ConnectedEndpoint{writeClosed: true} - if got, want := e.Writable(), true; got != want { - t.Errorf("Got %#v.Writable() = %t, want = %t", e, got, want) - } -} - func TestRelease(t *testing.T) { f, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) if err != nil { @@ -272,131 +244,3 @@ func TestRelease(t *testing.T) { t.Errorf("got = %#v, want = %#v", c, want) } } - -func TestClose(t *testing.T) { - type testCase struct { - name string - cep *ConnectedEndpoint - addFD bool - f func() - want *ConnectedEndpoint - } - - var tests []testCase - - // nil is the value used by ConnectedEndpoint to indicate a closed file. - // Non-nil files are used to check if the file gets closed. - - f, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("Creating socket:", err) - } - c := &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)} - tests = append(tests, testCase{ - name: "First CloseRecv", - cep: c, - addFD: false, - f: c.CloseRecv, - want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true}, - }) - - f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("Creating socket:", err) - } - c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true} - tests = append(tests, testCase{ - name: "Second CloseRecv", - cep: c, - addFD: false, - f: c.CloseRecv, - want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true}, - }) - - f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("Creating socket:", err) - } - c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)} - tests = append(tests, testCase{ - name: "First CloseSend", - cep: c, - addFD: false, - f: c.CloseSend, - want: &ConnectedEndpoint{queue: c.queue, file: c.file, writeClosed: true}, - }) - - f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("Creating socket:", err) - } - c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), writeClosed: true} - tests = append(tests, testCase{ - name: "Second CloseSend", - cep: c, - addFD: false, - f: c.CloseSend, - want: &ConnectedEndpoint{queue: c.queue, file: c.file, writeClosed: true}, - }) - - f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("Creating socket:", err) - } - c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), writeClosed: true} - tests = append(tests, testCase{ - name: "CloseSend then CloseRecv", - cep: c, - addFD: true, - f: c.CloseRecv, - want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true}, - }) - - f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("Creating socket:", err) - } - c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true} - tests = append(tests, testCase{ - name: "CloseRecv then CloseSend", - cep: c, - addFD: true, - f: c.CloseSend, - want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true}, - }) - - f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("Creating socket:", err) - } - c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true, writeClosed: true} - tests = append(tests, testCase{ - name: "Full close then CloseRecv", - cep: c, - addFD: false, - f: c.CloseRecv, - want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true}, - }) - - f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("Creating socket:", err) - } - c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true, writeClosed: true} - tests = append(tests, testCase{ - name: "Full close then CloseSend", - cep: c, - addFD: false, - f: c.CloseSend, - want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true}, - }) - - for _, test := range tests { - if test.addFD { - fdnotifier.AddFD(int32(test.cep.file.FD()), nil) - } - if test.f(); !reflect.DeepEqual(test.cep, test.want) { - t.Errorf("%s: got = %#v, want = %#v", test.name, test.cep, test.want) - } - } -} diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go index 652da1cef..ef2dbfad2 100644 --- a/runsc/boot/filter/config.go +++ b/runsc/boot/filter/config.go @@ -246,6 +246,10 @@ var allowedSyscalls = seccomp.SyscallRules{ }, syscall.SYS_SETITIMER: {}, syscall.SYS_SHUTDOWN: []seccomp.Rule{ + // Used by fs/host to shutdown host sockets. + {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_RD)}, + {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_WR)}, + // Used by unet to shutdown connections. {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_RDWR)}, }, syscall.SYS_SIGALTSTACK: {}, diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 7633ab162..0cb7b47b6 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -2096,6 +2096,7 @@ cc_binary( deps = [ ":socket_generic_test_cases", ":socket_test_util", + ":socket_unix_cmsg_test_cases", ":socket_unix_test_cases", ":unix_domain_socket_test_util", "//test/util:test_main", @@ -2369,6 +2370,7 @@ cc_binary( deps = [ ":socket_generic_test_cases", ":socket_test_util", + ":socket_unix_cmsg_test_cases", ":socket_unix_test_cases", ":unix_domain_socket_test_util", "//test/util:test_main", @@ -2490,6 +2492,26 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "socket_unix_cmsg_test_cases", + testonly = 1, + srcs = [ + "socket_unix_cmsg.cc", + ], + hdrs = [ + "socket_unix_cmsg.h", + ], + deps = [ + ":socket_test_util", + ":unix_domain_socket_test_util", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + ], + alwayslink = 1, +) + cc_library( name = "socket_stream_blocking_test_cases", testonly = 1, @@ -2733,6 +2755,7 @@ cc_binary( linkstatic = 1, deps = [ ":socket_test_util", + ":socket_unix_cmsg_test_cases", ":socket_unix_test_cases", ":unix_domain_socket_test_util", "//test/util:test_main", diff --git a/test/syscalls/linux/socket_abstract.cc b/test/syscalls/linux/socket_abstract.cc index 503ba986b..715d87b76 100644 --- a/test/syscalls/linux/socket_abstract.cc +++ b/test/syscalls/linux/socket_abstract.cc @@ -17,6 +17,7 @@ #include "test/syscalls/linux/socket_generic.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/socket_unix.h" +#include "test/syscalls/linux/socket_unix_cmsg.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" @@ -38,5 +39,9 @@ INSTANTIATE_TEST_SUITE_P( AbstractUnixSockets, UnixSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +INSTANTIATE_TEST_SUITE_P( + AbstractUnixSockets, UnixSocketPairCmsgTest, + ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_filesystem.cc b/test/syscalls/linux/socket_filesystem.cc index e38a320f6..74e262959 100644 --- a/test/syscalls/linux/socket_filesystem.cc +++ b/test/syscalls/linux/socket_filesystem.cc @@ -17,6 +17,7 @@ #include "test/syscalls/linux/socket_generic.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/socket_unix.h" +#include "test/syscalls/linux/socket_unix_cmsg.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" @@ -38,5 +39,9 @@ INSTANTIATE_TEST_SUITE_P( FilesystemUnixSockets, UnixSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +INSTANTIATE_TEST_SUITE_P( + FilesystemUnixSockets, UnixSocketPairCmsgTest, + ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_unix.cc b/test/syscalls/linux/socket_unix.cc index 95cf8d2a3..875f0391f 100644 --- a/test/syscalls/linux/socket_unix.cc +++ b/test/syscalls/linux/socket_unix.cc @@ -32,1437 +32,16 @@ #include "test/util/test_util.h" #include "test/util/thread_util.h" -// This file is a generic socket test file. It must be built with another file -// that provides the test types. - -namespace gvisor { -namespace testing { - -namespace { - -TEST_P(UnixSocketPairTest, BasicFDPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - char received_data[20]; - int fd = -1; - ASSERT_NO_FATAL_FAILURE(RecvSingleFD(sockets->second_fd(), &fd, received_data, - sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); -} - -TEST_P(UnixSocketPairTest, BasicTwoFDPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair1 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - auto pair2 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - int sent_fds[] = {pair1->second_fd(), pair2->second_fd()}; - - ASSERT_NO_FATAL_FAILURE( - SendFDs(sockets->first_fd(), sent_fds, 2, sent_data, sizeof(sent_data))); - - char received_data[20]; - int received_fds[] = {-1, -1}; - - ASSERT_NO_FATAL_FAILURE(RecvFDs(sockets->second_fd(), received_fds, 2, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[0], pair1->first_fd())); - ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[1], pair2->first_fd())); -} - -TEST_P(UnixSocketPairTest, BasicThreeFDPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair1 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - auto pair2 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - auto pair3 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - int sent_fds[] = {pair1->second_fd(), pair2->second_fd(), pair3->second_fd()}; - - ASSERT_NO_FATAL_FAILURE( - SendFDs(sockets->first_fd(), sent_fds, 3, sent_data, sizeof(sent_data))); - - char received_data[20]; - int received_fds[] = {-1, -1, -1}; - - ASSERT_NO_FATAL_FAILURE(RecvFDs(sockets->second_fd(), received_fds, 3, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[0], pair1->first_fd())); - ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[1], pair2->first_fd())); - ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[2], pair3->first_fd())); -} - -TEST_P(UnixSocketPairTest, BadFDPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - int sent_fd = -1; - - struct msghdr msg = {}; - char control[CMSG_SPACE(sizeof(sent_fd))]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - cmsg->cmsg_len = CMSG_LEN(sizeof(sent_fd)); - cmsg->cmsg_level = SOL_SOCKET; - cmsg->cmsg_type = SCM_RIGHTS; - memcpy(CMSG_DATA(cmsg), &sent_fd, sizeof(sent_fd)); - - struct iovec iov; - iov.iov_base = sent_data; - iov.iov_len = sizeof(sent_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(sendmsg)(sockets->first_fd(), &msg, 0), - SyscallFailsWithErrno(EBADF)); -} - -// BasicFDPassNoSpace starts off by sending a single FD just like BasicFDPass. -// The difference is that when calling recvmsg, no space for FDs is provided, -// only space for the cmsg header. -TEST_P(UnixSocketPairTest, BasicFDPassNoSpace) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - char received_data[20]; - - struct msghdr msg = {}; - std::vector control(CMSG_SPACE(0)); - msg.msg_control = &control[0]; - msg.msg_controllen = control.size(); - - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(msg.msg_controllen, 0); - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -// BasicFDPassNoSpaceMsgCtrunc sends an FD, but does not provide any space to -// receive it. It then verifies that the MSG_CTRUNC flag is set in the msghdr. -TEST_P(UnixSocketPairTest, BasicFDPassNoSpaceMsgCtrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - std::vector control(CMSG_SPACE(0)); - msg.msg_control = &control[0]; - msg.msg_controllen = control.size(); - - char received_data[sizeof(sent_data)]; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(msg.msg_controllen, 0); - EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); -} - -// BasicFDPassNullControlMsgCtrunc sends an FD and sets contradictory values for -// msg_controllen and msg_control. msg_controllen is set to the correct size to -// accomidate the FD, but msg_control is set to NULL. In this case, msg_control -// should override msg_controllen. -TEST_P(UnixSocketPairTest, BasicFDPassNullControlMsgCtrunc) { - // FIXME(gvisor.dev/issue/207): Fix handling of NULL msg_control. - SKIP_IF(IsRunningOnGvisor()); - - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - msg.msg_controllen = CMSG_SPACE(1); - - char received_data[sizeof(sent_data)]; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(msg.msg_controllen, 0); - EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); -} - -// BasicFDPassNotEnoughSpaceMsgCtrunc sends an FD, but does not provide enough -// space to receive it. It then verifies that the MSG_CTRUNC flag is set in the -// msghdr. -TEST_P(UnixSocketPairTest, BasicFDPassNotEnoughSpaceMsgCtrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - std::vector control(CMSG_SPACE(0) + 1); - msg.msg_control = &control[0]; - msg.msg_controllen = control.size(); - - char received_data[sizeof(sent_data)]; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(msg.msg_controllen, 0); - EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); -} - -// BasicThreeFDPassTruncationMsgCtrunc sends three FDs, but only provides enough -// space to receive two of them. It then verifies that the MSG_CTRUNC flag is -// set in the msghdr. -TEST_P(UnixSocketPairTest, BasicThreeFDPassTruncationMsgCtrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair1 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - auto pair2 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - auto pair3 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - int sent_fds[] = {pair1->second_fd(), pair2->second_fd(), pair3->second_fd()}; - - ASSERT_NO_FATAL_FAILURE( - SendFDs(sockets->first_fd(), sent_fds, 3, sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - std::vector control(CMSG_SPACE(2 * sizeof(int))); - msg.msg_control = &control[0]; - msg.msg_controllen = control.size(); - - char received_data[sizeof(sent_data)]; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(2 * sizeof(int))); - EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); - EXPECT_EQ(cmsg->cmsg_type, SCM_RIGHTS); -} - -// BasicFDPassUnalignedRecv starts off by sending a single FD just like -// BasicFDPass. The difference is that when calling recvmsg, the length of the -// receive data is only aligned on a 4 byte boundry instead of the normal 8. -TEST_P(UnixSocketPairTest, BasicFDPassUnalignedRecv) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - char received_data[20]; - int fd = -1; - ASSERT_NO_FATAL_FAILURE(RecvSingleFDUnaligned( - sockets->second_fd(), &fd, received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); -} - -// BasicFDPassUnalignedRecvNoMsgTrunc sends one FD and only provides enough -// space to receive just it. (Normally the minimum amount of space one would -// provide would be enough space for two FDs.) It then verifies that the -// MSG_CTRUNC flag is not set in the msghdr. -TEST_P(UnixSocketPairTest, BasicFDPassUnalignedRecvNoMsgTrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - char control[CMSG_SPACE(sizeof(int)) - sizeof(int)]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - char received_data[sizeof(sent_data)] = {}; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(msg.msg_flags, 0); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int))); - EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); - EXPECT_EQ(cmsg->cmsg_type, SCM_RIGHTS); -} - -// BasicTwoFDPassUnalignedRecvTruncationMsgTrunc sends two FDs, but only -// provides enough space to receive one of them. It then verifies that the -// MSG_CTRUNC flag is set in the msghdr. -TEST_P(UnixSocketPairTest, BasicTwoFDPassUnalignedRecvTruncationMsgTrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - int sent_fds[] = {pair->first_fd(), pair->second_fd()}; - - ASSERT_NO_FATAL_FAILURE( - SendFDs(sockets->first_fd(), sent_fds, 2, sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - // CMSG_SPACE rounds up to two FDs, we only want one. - char control[CMSG_SPACE(sizeof(int)) - sizeof(int)]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - char received_data[sizeof(sent_data)] = {}; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int))); - EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); - EXPECT_EQ(cmsg->cmsg_type, SCM_RIGHTS); -} - -TEST_P(UnixSocketPairTest, ConcurrentBasicFDPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - int sockfd1 = sockets->first_fd(); - auto recv_func = [sockfd1, sent_data]() { - char received_data[20]; - int fd = -1; - RecvSingleFD(sockfd1, &fd, received_data, sizeof(received_data)); - ASSERT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - char buf[20]; - ASSERT_THAT(ReadFd(fd, buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - ASSERT_THAT(WriteFd(fd, buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - }; - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->second_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - ScopedThread t(recv_func); - - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT(WriteFd(pair->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - - char received_data[20]; - ASSERT_THAT(ReadFd(pair->first_fd(), received_data, sizeof(received_data)), - SyscallSucceedsWithValue(sizeof(received_data))); - - t.Join(); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -// FDPassNoRecv checks that the control message can be safely ignored by using -// read(2) instead of recvmsg(2). -TEST_P(UnixSocketPairTest, FDPassNoRecv) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - // Read while ignoring the passed FD. - char received_data[20]; - ASSERT_THAT( - ReadFd(sockets->second_fd(), received_data, sizeof(received_data)), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - // Check that the socket still works for reads and writes. - ASSERT_NO_FATAL_FAILURE( - TransferTest(sockets->first_fd(), sockets->second_fd())); -} - -// FDPassInterspersed1 checks that sent control messages cannot be read before -// their associated data has been read. -TEST_P(UnixSocketPairTest, FDPassInterspersed1) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char written_data[20]; - RandomizeBuffer(written_data, sizeof(written_data)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), written_data, sizeof(written_data)), - SyscallSucceedsWithValue(sizeof(written_data))); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - // Check that we don't get a control message, but do get the data. - char received_data[20]; - RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data)); - EXPECT_EQ(0, memcmp(written_data, received_data, sizeof(written_data))); -} - -// FDPassInterspersed2 checks that sent control messages cannot be read after -// their assocated data has been read while ignoring the control message by -// using read(2) instead of recvmsg(2). -TEST_P(UnixSocketPairTest, FDPassInterspersed2) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - char written_data[20]; - RandomizeBuffer(written_data, sizeof(written_data)); - ASSERT_THAT(WriteFd(sockets->first_fd(), written_data, sizeof(written_data)), - SyscallSucceedsWithValue(sizeof(written_data))); - - char received_data[20]; - ASSERT_THAT( - ReadFd(sockets->second_fd(), received_data, sizeof(received_data)), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - ASSERT_NO_FATAL_FAILURE( - RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); - EXPECT_EQ(0, memcmp(written_data, received_data, sizeof(written_data))); -} - -TEST_P(UnixSocketPairTest, FDPassNotCoalesced) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data1[20]; - RandomizeBuffer(sent_data1, sizeof(sent_data1)); - - auto pair1 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair1->second_fd(), - sent_data1, sizeof(sent_data1))); - - char sent_data2[20]; - RandomizeBuffer(sent_data2, sizeof(sent_data2)); - - auto pair2 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair2->second_fd(), - sent_data2, sizeof(sent_data2))); - - char received_data1[sizeof(sent_data1) + sizeof(sent_data2)]; - int received_fd1 = -1; - - RecvSingleFD(sockets->second_fd(), &received_fd1, received_data1, - sizeof(received_data1), sizeof(sent_data1)); - - EXPECT_EQ(0, memcmp(sent_data1, received_data1, sizeof(sent_data1))); - TransferTest(pair1->first_fd(), pair1->second_fd()); - - char received_data2[sizeof(sent_data1) + sizeof(sent_data2)]; - int received_fd2 = -1; - - RecvSingleFD(sockets->second_fd(), &received_fd2, received_data2, - sizeof(received_data2), sizeof(sent_data2)); - - EXPECT_EQ(0, memcmp(sent_data2, received_data2, sizeof(sent_data2))); - TransferTest(pair2->first_fd(), pair2->second_fd()); -} - -TEST_P(UnixSocketPairTest, FDPassPeek) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - char peek_data[20]; - int peek_fd = -1; - PeekSingleFD(sockets->second_fd(), &peek_fd, peek_data, sizeof(peek_data)); - EXPECT_EQ(0, memcmp(sent_data, peek_data, sizeof(sent_data))); - TransferTest(peek_fd, pair->first_fd()); - EXPECT_THAT(close(peek_fd), SyscallSucceeds()); - - char received_data[20]; - int received_fd = -1; - RecvSingleFD(sockets->second_fd(), &received_fd, received_data, - sizeof(received_data)); - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - TransferTest(received_fd, pair->first_fd()); - EXPECT_THAT(close(received_fd), SyscallSucceeds()); -} - -TEST_P(UnixSocketPairTest, BasicCredPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - struct ucred sent_creds; - - ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); - - ASSERT_NO_FATAL_FAILURE( - SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - char received_data[20]; - struct ucred received_creds; - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - EXPECT_EQ(sent_creds.pid, received_creds.pid); - EXPECT_EQ(sent_creds.uid, received_creds.uid); - EXPECT_EQ(sent_creds.gid, received_creds.gid); -} - -TEST_P(UnixSocketPairTest, SendNullCredsBeforeSoPassCredRecvEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - ASSERT_NO_FATAL_FAILURE( - SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - char received_data[20]; - struct ucred received_creds; - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - struct ucred want_creds { - 0, 65534, 65534 - }; - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); -} - -TEST_P(UnixSocketPairTest, SendNullCredsAfterSoPassCredRecvEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - SetSoPassCred(sockets->second_fd()); - - ASSERT_NO_FATAL_FAILURE( - SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data))); - - char received_data[20]; - struct ucred received_creds; - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - struct ucred want_creds; - ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); -} - -TEST_P(UnixSocketPairTest, SendNullCredsBeforeSoPassCredSendEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - ASSERT_NO_FATAL_FAILURE( - SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->first_fd()); - - char received_data[20]; - ASSERT_NO_FATAL_FAILURE( - RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -TEST_P(UnixSocketPairTest, SendNullCredsAfterSoPassCredSendEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - SetSoPassCred(sockets->first_fd()); - - ASSERT_NO_FATAL_FAILURE( - SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data))); - - char received_data[20]; - ASSERT_NO_FATAL_FAILURE( - RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -TEST_P(UnixSocketPairTest, SendNullCredsBeforeSoPassCredRecvEndAfterSendEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - SetSoPassCred(sockets->first_fd()); - - ASSERT_NO_FATAL_FAILURE( - SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - char received_data[20]; - struct ucred received_creds; - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - struct ucred want_creds; - ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); -} - -TEST_P(UnixSocketPairTest, WriteBeforeSoPassCredRecvEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - char received_data[20]; - - struct ucred received_creds; - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - struct ucred want_creds { - 0, 65534, 65534 - }; - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); -} - -TEST_P(UnixSocketPairTest, WriteAfterSoPassCredRecvEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - SetSoPassCred(sockets->second_fd()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - - char received_data[20]; - - struct ucred received_creds; - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - struct ucred want_creds; - ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); -} - -TEST_P(UnixSocketPairTest, WriteBeforeSoPassCredSendEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - - SetSoPassCred(sockets->first_fd()); - - char received_data[20]; - ASSERT_NO_FATAL_FAILURE( - RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -TEST_P(UnixSocketPairTest, WriteAfterSoPassCredSendEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - SetSoPassCred(sockets->first_fd()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - - char received_data[20]; - ASSERT_NO_FATAL_FAILURE( - RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -TEST_P(UnixSocketPairTest, WriteBeforeSoPassCredRecvEndAfterSendEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - SetSoPassCred(sockets->first_fd()); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - char received_data[20]; - - struct ucred received_creds; - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - struct ucred want_creds; - ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); -} - -TEST_P(UnixSocketPairTest, CredPassTruncated) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - struct ucred sent_creds; - - ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); - - ASSERT_NO_FATAL_FAILURE( - SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - struct msghdr msg = {}; - char control[CMSG_SPACE(0) + sizeof(pid_t)]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - char received_data[sizeof(sent_data)] = {}; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - EXPECT_EQ(msg.msg_controllen, sizeof(control)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, sizeof(control)); - EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); - EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); - - pid_t pid = 0; - memcpy(&pid, CMSG_DATA(cmsg), sizeof(pid)); - EXPECT_EQ(pid, sent_creds.pid); -} - -// CredPassNoMsgCtrunc passes a full set of credentials. It then verifies that -// receiving the full set does not result in MSG_CTRUNC being set in the msghdr. -TEST_P(UnixSocketPairTest, CredPassNoMsgCtrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - struct ucred sent_creds; - - ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); - - ASSERT_NO_FATAL_FAILURE( - SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - struct msghdr msg = {}; - char control[CMSG_SPACE(sizeof(struct ucred))]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - char received_data[sizeof(sent_data)] = {}; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - // The control message should not be truncated. - EXPECT_EQ(msg.msg_flags, 0); - EXPECT_EQ(msg.msg_controllen, sizeof(control)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(struct ucred))); - EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); - EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); -} - -// CredPassNoSpaceMsgCtrunc passes a full set of credentials. It then receives -// the data without providing space for any credentials and verifies that -// MSG_CTRUNC is set in the msghdr. -TEST_P(UnixSocketPairTest, CredPassNoSpaceMsgCtrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - struct ucred sent_creds; - - ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); - - ASSERT_NO_FATAL_FAILURE( - SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - struct msghdr msg = {}; - char control[CMSG_SPACE(0)]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - char received_data[sizeof(sent_data)] = {}; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - // The control message should be truncated. - EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); - EXPECT_EQ(msg.msg_controllen, sizeof(control)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, sizeof(control)); - EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); - EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); -} - -// CredPassTruncatedMsgCtrunc passes a full set of credentials. It then receives -// the data while providing enough space for only the first field of the -// credentials and verifies that MSG_CTRUNC is set in the msghdr. -TEST_P(UnixSocketPairTest, CredPassTruncatedMsgCtrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - struct ucred sent_creds; - - ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); - - ASSERT_NO_FATAL_FAILURE( - SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - struct msghdr msg = {}; - char control[CMSG_SPACE(0) + sizeof(pid_t)]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - char received_data[sizeof(sent_data)] = {}; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - // The control message should be truncated. - EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); - EXPECT_EQ(msg.msg_controllen, sizeof(control)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, sizeof(control)); - EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); - EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); -} - -TEST_P(UnixSocketPairTest, SoPassCred) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int opt; - socklen_t optLen = sizeof(opt); - EXPECT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), - SyscallSucceeds()); - EXPECT_FALSE(opt); - - optLen = sizeof(opt); - EXPECT_THAT( - getsockopt(sockets->second_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), - SyscallSucceeds()); - EXPECT_FALSE(opt); - - SetSoPassCred(sockets->first_fd()); - - optLen = sizeof(opt); - EXPECT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), - SyscallSucceeds()); - EXPECT_TRUE(opt); - - optLen = sizeof(opt); - EXPECT_THAT( - getsockopt(sockets->second_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), - SyscallSucceeds()); - EXPECT_FALSE(opt); - - int zero = 0; - EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_PASSCRED, &zero, - sizeof(zero)), - SyscallSucceeds()); - - optLen = sizeof(opt); - EXPECT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), - SyscallSucceeds()); - EXPECT_FALSE(opt); - - optLen = sizeof(opt); - EXPECT_THAT( - getsockopt(sockets->second_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), - SyscallSucceeds()); - EXPECT_FALSE(opt); -} - -TEST_P(UnixSocketPairTest, NoDataCredPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - struct msghdr msg = {}; - - struct iovec iov; - iov.iov_base = sent_data; - iov.iov_len = sizeof(sent_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - char control[CMSG_SPACE(0)]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - cmsg->cmsg_level = SOL_SOCKET; - cmsg->cmsg_type = SCM_CREDENTIALS; - cmsg->cmsg_len = CMSG_LEN(0); - - ASSERT_THAT(RetryEINTR(sendmsg)(sockets->first_fd(), &msg, 0), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(UnixSocketPairTest, NoPassCred) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - struct ucred sent_creds; - - ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); - - ASSERT_NO_FATAL_FAILURE( - SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); - - char received_data[20]; - - ASSERT_NO_FATAL_FAILURE( - RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -TEST_P(UnixSocketPairTest, CredAndFDPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - struct ucred sent_creds; - - ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendCredsAndFD(sockets->first_fd(), sent_creds, - pair->second_fd(), sent_data, - sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - char received_data[20]; - struct ucred received_creds; - int fd = -1; - ASSERT_NO_FATAL_FAILURE(RecvCredsAndFD(sockets->second_fd(), &received_creds, - &fd, received_data, - sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - EXPECT_EQ(sent_creds.pid, received_creds.pid); - EXPECT_EQ(sent_creds.uid, received_creds.uid); - EXPECT_EQ(sent_creds.gid, received_creds.gid); - - ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); -} - -TEST_P(UnixSocketPairTest, FDPassBeforeSoPassCred) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - char received_data[20]; - struct ucred received_creds; - int fd = -1; - ASSERT_NO_FATAL_FAILURE(RecvCredsAndFD(sockets->second_fd(), &received_creds, - &fd, received_data, - sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - struct ucred want_creds { - 0, 65534, 65534 - }; - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); - - ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); -} - -TEST_P(UnixSocketPairTest, FDPassAfterSoPassCred) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - SetSoPassCred(sockets->second_fd()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - char received_data[20]; - struct ucred received_creds; - int fd = -1; - ASSERT_NO_FATAL_FAILURE(RecvCredsAndFD(sockets->second_fd(), &received_creds, - &fd, received_data, - sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - struct ucred want_creds; - ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); - - ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); -} - -TEST_P(UnixSocketPairTest, CloexecDroppedWhenFDPassed) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = ASSERT_NO_ERRNO_AND_VALUE( - UnixDomainSocketPair(SOCK_SEQPACKET | SOCK_CLOEXEC).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - char received_data[20]; - int fd = -1; - ASSERT_NO_FATAL_FAILURE(RecvSingleFD(sockets->second_fd(), &fd, received_data, - sizeof(received_data))); - - EXPECT_THAT(fcntl(fd, F_GETFD), SyscallSucceedsWithValue(0)); -} - -TEST_P(UnixSocketPairTest, CloexecRecvFDPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - char control[CMSG_SPACE(sizeof(int))]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - struct iovec iov; - char received_data[20]; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_CMSG_CLOEXEC), - SyscallSucceedsWithValue(sizeof(received_data))); - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int))); - ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); - ASSERT_EQ(cmsg->cmsg_type, SCM_RIGHTS); - - int fd = -1; - memcpy(&fd, CMSG_DATA(cmsg), sizeof(int)); - - EXPECT_THAT(fcntl(fd, F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC)); -} - -TEST_P(UnixSocketPairTest, FDPassAfterSoPassCredWithoutCredSpace) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - SetSoPassCred(sockets->second_fd()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - char control[CMSG_LEN(0)]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - char received_data[20]; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - EXPECT_EQ(msg.msg_controllen, sizeof(control)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, sizeof(control)); - EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); - EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); -} - -// This test will validate that MSG_CTRUNC as an input flag to recvmsg will -// not appear as an output flag on the control message when truncation doesn't -// happen. -TEST_P(UnixSocketPairTest, MsgCtruncInputIsNoop) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - char control[CMSG_SPACE(sizeof(int)) /* we're passing a single fd */]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - struct iovec iov; - char received_data[20]; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_CTRUNC), - SyscallSucceedsWithValue(sizeof(received_data))); - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int))); - ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); - ASSERT_EQ(cmsg->cmsg_type, SCM_RIGHTS); - - // Now we should verify that MSG_CTRUNC wasn't set as an output flag. - EXPECT_EQ(msg.msg_flags & MSG_CTRUNC, 0); -} - -TEST_P(UnixSocketPairTest, FDPassAfterSoPassCredWithoutCredHeaderSpace) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - SetSoPassCred(sockets->second_fd()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - char control[CMSG_LEN(0) / 2]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - char received_data[20]; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; +// This file contains tests specific to Unix domain sockets. It does not contain +// tests for UDS control messages. Those belong in socket_unix_cmsg.cc. +// +// This file is a generic socket test file. It must be built with another file +// that provides the test types. - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); +namespace gvisor { +namespace testing { - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - EXPECT_EQ(msg.msg_controllen, 0); -} +namespace { TEST_P(UnixSocketPairTest, InvalidGetSockOpt) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -1519,6 +98,14 @@ TEST_P(UnixSocketPairTest, RecvmmsgTimeoutAfterRecv) { TEST_P(UnixSocketPairTest, TIOCINQSucceeds) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + if (IsRunningOnGvisor()) { + // TODO(gvisor.dev/issue/273): Inherited host UDS don't support TIOCINQ. + // Skip the test. + int size = -1; + int ret = ioctl(sockets->first_fd(), TIOCINQ, &size); + SKIP_IF(ret == -1 && errno == ENOTTY); + } + int size = -1; EXPECT_THAT(ioctl(sockets->first_fd(), TIOCINQ, &size), SyscallSucceeds()); EXPECT_EQ(size, 0); @@ -1544,6 +131,14 @@ TEST_P(UnixSocketPairTest, TIOCINQSucceeds) { TEST_P(UnixSocketPairTest, TIOCOUTQSucceeds) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + if (IsRunningOnGvisor()) { + // TODO(gvisor.dev/issue/273): Inherited host UDS don't support TIOCOUTQ. + // Skip the test. + int size = -1; + int ret = ioctl(sockets->second_fd(), TIOCOUTQ, &size); + SKIP_IF(ret == -1 && errno == ENOTTY); + } + int size = -1; EXPECT_THAT(ioctl(sockets->second_fd(), TIOCOUTQ, &size), SyscallSucceeds()); EXPECT_EQ(size, 0); @@ -1580,19 +175,70 @@ TEST_P(UnixSocketPairTest, NetdeviceIoctlsSucceed) { } } -TEST_P(UnixSocketPairTest, SocketShutdown) { +TEST_P(UnixSocketPairTest, Shutdown) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - char buf[20]; + const std::string data = "abc"; - ASSERT_THAT(WriteFd(sockets->first_fd(), data.c_str(), 3), - SyscallSucceedsWithValue(3)); + ASSERT_THAT(WriteFd(sockets->first_fd(), data.c_str(), data.size()), + SyscallSucceedsWithValue(data.size())); + ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RDWR), SyscallSucceeds()); ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_RDWR), SyscallSucceeds()); // Shutting down a socket does not clear the buffer. - ASSERT_THAT(ReadFd(sockets->second_fd(), buf, 3), - SyscallSucceedsWithValue(3)); - EXPECT_EQ(data, absl::string_view(buf, 3)); + char buf[3]; + ASSERT_THAT(ReadFd(sockets->second_fd(), buf, data.size()), + SyscallSucceedsWithValue(data.size())); + EXPECT_EQ(data, absl::string_view(buf, data.size())); +} + +TEST_P(UnixSocketPairTest, ShutdownRead) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RD), SyscallSucceeds()); + + // When the socket is shutdown for read, read behavior varies between + // different socket types. This is covered by the various ReadOneSideClosed + // test cases. + + // ... and the peer cannot write. + const std::string data = "abc"; + EXPECT_THAT(WriteFd(sockets->second_fd(), data.c_str(), data.size()), + SyscallFailsWithErrno(EPIPE)); + + // ... but the socket can still write. + ASSERT_THAT(WriteFd(sockets->first_fd(), data.c_str(), data.size()), + SyscallSucceedsWithValue(data.size())); + + // ... and the peer can still read. + char buf[3]; + EXPECT_THAT(ReadFd(sockets->second_fd(), buf, data.size()), + SyscallSucceedsWithValue(data.size())); + EXPECT_EQ(data, absl::string_view(buf, data.size())); +} + +TEST_P(UnixSocketPairTest, ShutdownWrite) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_WR), SyscallSucceeds()); + + // When the socket is shutdown for write, it cannot write. + const std::string data = "abc"; + EXPECT_THAT(WriteFd(sockets->first_fd(), data.c_str(), data.size()), + SyscallFailsWithErrno(EPIPE)); + + // ... and the peer read behavior varies between different socket types. This + // is covered by the various ReadOneSideClosed test cases. + + // ... but the peer can still write. + char buf[3]; + ASSERT_THAT(WriteFd(sockets->second_fd(), data.c_str(), data.size()), + SyscallSucceedsWithValue(data.size())); + + // ... and the socket can still read. + EXPECT_THAT(ReadFd(sockets->first_fd(), buf, data.size()), + SyscallSucceedsWithValue(data.size())); + EXPECT_EQ(data, absl::string_view(buf, data.size())); } TEST_P(UnixSocketPairTest, SocketReopenFromProcfs) { diff --git a/test/syscalls/linux/socket_unix_cmsg.cc b/test/syscalls/linux/socket_unix_cmsg.cc new file mode 100644 index 000000000..b0ab26847 --- /dev/null +++ b/test/syscalls/linux/socket_unix_cmsg.cc @@ -0,0 +1,1473 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/socket_unix_cmsg.h" + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "gtest/gtest.h" +#include "gtest/gtest.h" +#include "absl/strings/string_view.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/syscalls/linux/unix_domain_socket_test_util.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +// This file contains tests for control message in Unix domain sockets. +// +// This file is a generic socket test file. It must be built with another file +// that provides the test types. + +namespace gvisor { +namespace testing { + +namespace { + +TEST_P(UnixSocketPairCmsgTest, BasicFDPass) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + char received_data[20]; + int fd = -1; + ASSERT_NO_FATAL_FAILURE(RecvSingleFD(sockets->second_fd(), &fd, received_data, + sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); +} + +TEST_P(UnixSocketPairCmsgTest, BasicTwoFDPass) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair1 = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + auto pair2 = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + int sent_fds[] = {pair1->second_fd(), pair2->second_fd()}; + + ASSERT_NO_FATAL_FAILURE( + SendFDs(sockets->first_fd(), sent_fds, 2, sent_data, sizeof(sent_data))); + + char received_data[20]; + int received_fds[] = {-1, -1}; + + ASSERT_NO_FATAL_FAILURE(RecvFDs(sockets->second_fd(), received_fds, 2, + received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[0], pair1->first_fd())); + ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[1], pair2->first_fd())); +} + +TEST_P(UnixSocketPairCmsgTest, BasicThreeFDPass) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair1 = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + auto pair2 = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + auto pair3 = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + int sent_fds[] = {pair1->second_fd(), pair2->second_fd(), pair3->second_fd()}; + + ASSERT_NO_FATAL_FAILURE( + SendFDs(sockets->first_fd(), sent_fds, 3, sent_data, sizeof(sent_data))); + + char received_data[20]; + int received_fds[] = {-1, -1, -1}; + + ASSERT_NO_FATAL_FAILURE(RecvFDs(sockets->second_fd(), received_fds, 3, + received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[0], pair1->first_fd())); + ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[1], pair2->first_fd())); + ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[2], pair3->first_fd())); +} + +TEST_P(UnixSocketPairCmsgTest, BadFDPass) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + int sent_fd = -1; + + struct msghdr msg = {}; + char control[CMSG_SPACE(sizeof(sent_fd))]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + cmsg->cmsg_len = CMSG_LEN(sizeof(sent_fd)); + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SCM_RIGHTS; + memcpy(CMSG_DATA(cmsg), &sent_fd, sizeof(sent_fd)); + + struct iovec iov; + iov.iov_base = sent_data; + iov.iov_len = sizeof(sent_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(sendmsg)(sockets->first_fd(), &msg, 0), + SyscallFailsWithErrno(EBADF)); +} + +// BasicFDPassNoSpace starts off by sending a single FD just like BasicFDPass. +// The difference is that when calling recvmsg, no space for FDs is provided, +// only space for the cmsg header. +TEST_P(UnixSocketPairCmsgTest, BasicFDPassNoSpace) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + char received_data[20]; + + struct msghdr msg = {}; + std::vector control(CMSG_SPACE(0)); + msg.msg_control = &control[0]; + msg.msg_controllen = control.size(); + + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(msg.msg_controllen, 0); + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); +} + +// BasicFDPassNoSpaceMsgCtrunc sends an FD, but does not provide any space to +// receive it. It then verifies that the MSG_CTRUNC flag is set in the msghdr. +TEST_P(UnixSocketPairCmsgTest, BasicFDPassNoSpaceMsgCtrunc) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + struct msghdr msg = {}; + std::vector control(CMSG_SPACE(0)); + msg.msg_control = &control[0]; + msg.msg_controllen = control.size(); + + char received_data[sizeof(sent_data)]; + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(msg.msg_controllen, 0); + EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); +} + +// BasicFDPassNullControlMsgCtrunc sends an FD and sets contradictory values for +// msg_controllen and msg_control. msg_controllen is set to the correct size to +// accomidate the FD, but msg_control is set to NULL. In this case, msg_control +// should override msg_controllen. +TEST_P(UnixSocketPairCmsgTest, BasicFDPassNullControlMsgCtrunc) { + // FIXME(gvisor.dev/issue/207): Fix handling of NULL msg_control. + SKIP_IF(IsRunningOnGvisor()); + + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + struct msghdr msg = {}; + msg.msg_controllen = CMSG_SPACE(1); + + char received_data[sizeof(sent_data)]; + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(msg.msg_controllen, 0); + EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); +} + +// BasicFDPassNotEnoughSpaceMsgCtrunc sends an FD, but does not provide enough +// space to receive it. It then verifies that the MSG_CTRUNC flag is set in the +// msghdr. +TEST_P(UnixSocketPairCmsgTest, BasicFDPassNotEnoughSpaceMsgCtrunc) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + struct msghdr msg = {}; + std::vector control(CMSG_SPACE(0) + 1); + msg.msg_control = &control[0]; + msg.msg_controllen = control.size(); + + char received_data[sizeof(sent_data)]; + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(msg.msg_controllen, 0); + EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); +} + +// BasicThreeFDPassTruncationMsgCtrunc sends three FDs, but only provides enough +// space to receive two of them. It then verifies that the MSG_CTRUNC flag is +// set in the msghdr. +TEST_P(UnixSocketPairCmsgTest, BasicThreeFDPassTruncationMsgCtrunc) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair1 = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + auto pair2 = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + auto pair3 = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + int sent_fds[] = {pair1->second_fd(), pair2->second_fd(), pair3->second_fd()}; + + ASSERT_NO_FATAL_FAILURE( + SendFDs(sockets->first_fd(), sent_fds, 3, sent_data, sizeof(sent_data))); + + struct msghdr msg = {}; + std::vector control(CMSG_SPACE(2 * sizeof(int))); + msg.msg_control = &control[0]; + msg.msg_controllen = control.size(); + + char received_data[sizeof(sent_data)]; + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(2 * sizeof(int))); + EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); + EXPECT_EQ(cmsg->cmsg_type, SCM_RIGHTS); +} + +// BasicFDPassUnalignedRecv starts off by sending a single FD just like +// BasicFDPass. The difference is that when calling recvmsg, the length of the +// receive data is only aligned on a 4 byte boundry instead of the normal 8. +TEST_P(UnixSocketPairCmsgTest, BasicFDPassUnalignedRecv) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + char received_data[20]; + int fd = -1; + ASSERT_NO_FATAL_FAILURE(RecvSingleFDUnaligned( + sockets->second_fd(), &fd, received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); +} + +// BasicFDPassUnalignedRecvNoMsgTrunc sends one FD and only provides enough +// space to receive just it. (Normally the minimum amount of space one would +// provide would be enough space for two FDs.) It then verifies that the +// MSG_CTRUNC flag is not set in the msghdr. +TEST_P(UnixSocketPairCmsgTest, BasicFDPassUnalignedRecvNoMsgTrunc) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + struct msghdr msg = {}; + char control[CMSG_SPACE(sizeof(int)) - sizeof(int)]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + char received_data[sizeof(sent_data)] = {}; + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(msg.msg_flags, 0); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int))); + EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); + EXPECT_EQ(cmsg->cmsg_type, SCM_RIGHTS); +} + +// BasicTwoFDPassUnalignedRecvTruncationMsgTrunc sends two FDs, but only +// provides enough space to receive one of them. It then verifies that the +// MSG_CTRUNC flag is set in the msghdr. +TEST_P(UnixSocketPairCmsgTest, BasicTwoFDPassUnalignedRecvTruncationMsgTrunc) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + int sent_fds[] = {pair->first_fd(), pair->second_fd()}; + + ASSERT_NO_FATAL_FAILURE( + SendFDs(sockets->first_fd(), sent_fds, 2, sent_data, sizeof(sent_data))); + + struct msghdr msg = {}; + // CMSG_SPACE rounds up to two FDs, we only want one. + char control[CMSG_SPACE(sizeof(int)) - sizeof(int)]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + char received_data[sizeof(sent_data)] = {}; + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int))); + EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); + EXPECT_EQ(cmsg->cmsg_type, SCM_RIGHTS); +} + +TEST_P(UnixSocketPairCmsgTest, ConcurrentBasicFDPass) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + int sockfd1 = sockets->first_fd(); + auto recv_func = [sockfd1, sent_data]() { + char received_data[20]; + int fd = -1; + RecvSingleFD(sockfd1, &fd, received_data, sizeof(received_data)); + ASSERT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + char buf[20]; + ASSERT_THAT(ReadFd(fd, buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(WriteFd(fd, buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + }; + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->second_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + ScopedThread t(recv_func); + + RandomizeBuffer(sent_data, sizeof(sent_data)); + ASSERT_THAT(WriteFd(pair->first_fd(), sent_data, sizeof(sent_data)), + SyscallSucceedsWithValue(sizeof(sent_data))); + + char received_data[20]; + ASSERT_THAT(ReadFd(pair->first_fd(), received_data, sizeof(received_data)), + SyscallSucceedsWithValue(sizeof(received_data))); + + t.Join(); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); +} + +// FDPassNoRecv checks that the control message can be safely ignored by using +// read(2) instead of recvmsg(2). +TEST_P(UnixSocketPairCmsgTest, FDPassNoRecv) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + // Read while ignoring the passed FD. + char received_data[20]; + ASSERT_THAT( + ReadFd(sockets->second_fd(), received_data, sizeof(received_data)), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + // Check that the socket still works for reads and writes. + ASSERT_NO_FATAL_FAILURE( + TransferTest(sockets->first_fd(), sockets->second_fd())); +} + +// FDPassInterspersed1 checks that sent control messages cannot be read before +// their associated data has been read. +TEST_P(UnixSocketPairCmsgTest, FDPassInterspersed1) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char written_data[20]; + RandomizeBuffer(written_data, sizeof(written_data)); + + ASSERT_THAT(WriteFd(sockets->first_fd(), written_data, sizeof(written_data)), + SyscallSucceedsWithValue(sizeof(written_data))); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + // Check that we don't get a control message, but do get the data. + char received_data[20]; + RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data)); + EXPECT_EQ(0, memcmp(written_data, received_data, sizeof(written_data))); +} + +// FDPassInterspersed2 checks that sent control messages cannot be read after +// their assocated data has been read while ignoring the control message by +// using read(2) instead of recvmsg(2). +TEST_P(UnixSocketPairCmsgTest, FDPassInterspersed2) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + char written_data[20]; + RandomizeBuffer(written_data, sizeof(written_data)); + ASSERT_THAT(WriteFd(sockets->first_fd(), written_data, sizeof(written_data)), + SyscallSucceedsWithValue(sizeof(written_data))); + + char received_data[20]; + ASSERT_THAT( + ReadFd(sockets->second_fd(), received_data, sizeof(received_data)), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + ASSERT_NO_FATAL_FAILURE( + RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); + EXPECT_EQ(0, memcmp(written_data, received_data, sizeof(written_data))); +} + +TEST_P(UnixSocketPairCmsgTest, FDPassNotCoalesced) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data1[20]; + RandomizeBuffer(sent_data1, sizeof(sent_data1)); + + auto pair1 = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair1->second_fd(), + sent_data1, sizeof(sent_data1))); + + char sent_data2[20]; + RandomizeBuffer(sent_data2, sizeof(sent_data2)); + + auto pair2 = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair2->second_fd(), + sent_data2, sizeof(sent_data2))); + + char received_data1[sizeof(sent_data1) + sizeof(sent_data2)]; + int received_fd1 = -1; + + RecvSingleFD(sockets->second_fd(), &received_fd1, received_data1, + sizeof(received_data1), sizeof(sent_data1)); + + EXPECT_EQ(0, memcmp(sent_data1, received_data1, sizeof(sent_data1))); + TransferTest(pair1->first_fd(), pair1->second_fd()); + + char received_data2[sizeof(sent_data1) + sizeof(sent_data2)]; + int received_fd2 = -1; + + RecvSingleFD(sockets->second_fd(), &received_fd2, received_data2, + sizeof(received_data2), sizeof(sent_data2)); + + EXPECT_EQ(0, memcmp(sent_data2, received_data2, sizeof(sent_data2))); + TransferTest(pair2->first_fd(), pair2->second_fd()); +} + +TEST_P(UnixSocketPairCmsgTest, FDPassPeek) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + char peek_data[20]; + int peek_fd = -1; + PeekSingleFD(sockets->second_fd(), &peek_fd, peek_data, sizeof(peek_data)); + EXPECT_EQ(0, memcmp(sent_data, peek_data, sizeof(sent_data))); + TransferTest(peek_fd, pair->first_fd()); + EXPECT_THAT(close(peek_fd), SyscallSucceeds()); + + char received_data[20]; + int received_fd = -1; + RecvSingleFD(sockets->second_fd(), &received_fd, received_data, + sizeof(received_data)); + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + TransferTest(received_fd, pair->first_fd()); + EXPECT_THAT(close(received_fd), SyscallSucceeds()); +} + +TEST_P(UnixSocketPairCmsgTest, BasicCredPass) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + struct ucred sent_creds; + + ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); + + ASSERT_NO_FATAL_FAILURE( + SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); + + SetSoPassCred(sockets->second_fd()); + + char received_data[20]; + struct ucred received_creds; + ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, + received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + EXPECT_EQ(sent_creds.pid, received_creds.pid); + EXPECT_EQ(sent_creds.uid, received_creds.uid); + EXPECT_EQ(sent_creds.gid, received_creds.gid); +} + +TEST_P(UnixSocketPairCmsgTest, SendNullCredsBeforeSoPassCredRecvEnd) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + ASSERT_NO_FATAL_FAILURE( + SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data))); + + SetSoPassCred(sockets->second_fd()); + + char received_data[20]; + struct ucred received_creds; + ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, + received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + struct ucred want_creds { + 0, 65534, 65534 + }; + + EXPECT_EQ(want_creds.pid, received_creds.pid); + EXPECT_EQ(want_creds.uid, received_creds.uid); + EXPECT_EQ(want_creds.gid, received_creds.gid); +} + +TEST_P(UnixSocketPairCmsgTest, SendNullCredsAfterSoPassCredRecvEnd) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + SetSoPassCred(sockets->second_fd()); + + ASSERT_NO_FATAL_FAILURE( + SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data))); + + char received_data[20]; + struct ucred received_creds; + ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, + received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + struct ucred want_creds; + ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); + ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); + + EXPECT_EQ(want_creds.pid, received_creds.pid); + EXPECT_EQ(want_creds.uid, received_creds.uid); + EXPECT_EQ(want_creds.gid, received_creds.gid); +} + +TEST_P(UnixSocketPairCmsgTest, SendNullCredsBeforeSoPassCredSendEnd) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + ASSERT_NO_FATAL_FAILURE( + SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data))); + + SetSoPassCred(sockets->first_fd()); + + char received_data[20]; + ASSERT_NO_FATAL_FAILURE( + RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); +} + +TEST_P(UnixSocketPairCmsgTest, SendNullCredsAfterSoPassCredSendEnd) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + SetSoPassCred(sockets->first_fd()); + + ASSERT_NO_FATAL_FAILURE( + SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data))); + + char received_data[20]; + ASSERT_NO_FATAL_FAILURE( + RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); +} + +TEST_P(UnixSocketPairCmsgTest, + SendNullCredsBeforeSoPassCredRecvEndAfterSendEnd) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + SetSoPassCred(sockets->first_fd()); + + ASSERT_NO_FATAL_FAILURE( + SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data))); + + SetSoPassCred(sockets->second_fd()); + + char received_data[20]; + struct ucred received_creds; + ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, + received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + struct ucred want_creds; + ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); + ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); + + EXPECT_EQ(want_creds.pid, received_creds.pid); + EXPECT_EQ(want_creds.uid, received_creds.uid); + EXPECT_EQ(want_creds.gid, received_creds.gid); +} + +TEST_P(UnixSocketPairCmsgTest, WriteBeforeSoPassCredRecvEnd) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), + SyscallSucceedsWithValue(sizeof(sent_data))); + + SetSoPassCred(sockets->second_fd()); + + char received_data[20]; + + struct ucred received_creds; + ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, + received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + struct ucred want_creds { + 0, 65534, 65534 + }; + + EXPECT_EQ(want_creds.pid, received_creds.pid); + EXPECT_EQ(want_creds.uid, received_creds.uid); + EXPECT_EQ(want_creds.gid, received_creds.gid); +} + +TEST_P(UnixSocketPairCmsgTest, WriteAfterSoPassCredRecvEnd) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + SetSoPassCred(sockets->second_fd()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), + SyscallSucceedsWithValue(sizeof(sent_data))); + + char received_data[20]; + + struct ucred received_creds; + ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, + received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + struct ucred want_creds; + ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); + ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); + + EXPECT_EQ(want_creds.pid, received_creds.pid); + EXPECT_EQ(want_creds.uid, received_creds.uid); + EXPECT_EQ(want_creds.gid, received_creds.gid); +} + +TEST_P(UnixSocketPairCmsgTest, WriteBeforeSoPassCredSendEnd) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), + SyscallSucceedsWithValue(sizeof(sent_data))); + + SetSoPassCred(sockets->first_fd()); + + char received_data[20]; + ASSERT_NO_FATAL_FAILURE( + RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); +} + +TEST_P(UnixSocketPairCmsgTest, WriteAfterSoPassCredSendEnd) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + SetSoPassCred(sockets->first_fd()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), + SyscallSucceedsWithValue(sizeof(sent_data))); + + char received_data[20]; + ASSERT_NO_FATAL_FAILURE( + RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); +} + +TEST_P(UnixSocketPairCmsgTest, WriteBeforeSoPassCredRecvEndAfterSendEnd) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + SetSoPassCred(sockets->first_fd()); + + ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), + SyscallSucceedsWithValue(sizeof(sent_data))); + + SetSoPassCred(sockets->second_fd()); + + char received_data[20]; + + struct ucred received_creds; + ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, + received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + struct ucred want_creds; + ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); + ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); + + EXPECT_EQ(want_creds.pid, received_creds.pid); + EXPECT_EQ(want_creds.uid, received_creds.uid); + EXPECT_EQ(want_creds.gid, received_creds.gid); +} + +TEST_P(UnixSocketPairCmsgTest, CredPassTruncated) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + struct ucred sent_creds; + + ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); + + ASSERT_NO_FATAL_FAILURE( + SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); + + SetSoPassCred(sockets->second_fd()); + + struct msghdr msg = {}; + char control[CMSG_SPACE(0) + sizeof(pid_t)]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + char received_data[sizeof(sent_data)] = {}; + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + EXPECT_EQ(msg.msg_controllen, sizeof(control)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, sizeof(control)); + EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); + EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); + + pid_t pid = 0; + memcpy(&pid, CMSG_DATA(cmsg), sizeof(pid)); + EXPECT_EQ(pid, sent_creds.pid); +} + +// CredPassNoMsgCtrunc passes a full set of credentials. It then verifies that +// receiving the full set does not result in MSG_CTRUNC being set in the msghdr. +TEST_P(UnixSocketPairCmsgTest, CredPassNoMsgCtrunc) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + struct ucred sent_creds; + + ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); + + ASSERT_NO_FATAL_FAILURE( + SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); + + SetSoPassCred(sockets->second_fd()); + + struct msghdr msg = {}; + char control[CMSG_SPACE(sizeof(struct ucred))]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + char received_data[sizeof(sent_data)] = {}; + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + // The control message should not be truncated. + EXPECT_EQ(msg.msg_flags, 0); + EXPECT_EQ(msg.msg_controllen, sizeof(control)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(struct ucred))); + EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); + EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); +} + +// CredPassNoSpaceMsgCtrunc passes a full set of credentials. It then receives +// the data without providing space for any credentials and verifies that +// MSG_CTRUNC is set in the msghdr. +TEST_P(UnixSocketPairCmsgTest, CredPassNoSpaceMsgCtrunc) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + struct ucred sent_creds; + + ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); + + ASSERT_NO_FATAL_FAILURE( + SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); + + SetSoPassCred(sockets->second_fd()); + + struct msghdr msg = {}; + char control[CMSG_SPACE(0)]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + char received_data[sizeof(sent_data)] = {}; + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + // The control message should be truncated. + EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); + EXPECT_EQ(msg.msg_controllen, sizeof(control)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, sizeof(control)); + EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); + EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); +} + +// CredPassTruncatedMsgCtrunc passes a full set of credentials. It then receives +// the data while providing enough space for only the first field of the +// credentials and verifies that MSG_CTRUNC is set in the msghdr. +TEST_P(UnixSocketPairCmsgTest, CredPassTruncatedMsgCtrunc) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + struct ucred sent_creds; + + ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); + + ASSERT_NO_FATAL_FAILURE( + SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); + + SetSoPassCred(sockets->second_fd()); + + struct msghdr msg = {}; + char control[CMSG_SPACE(0) + sizeof(pid_t)]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + char received_data[sizeof(sent_data)] = {}; + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + // The control message should be truncated. + EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); + EXPECT_EQ(msg.msg_controllen, sizeof(control)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, sizeof(control)); + EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); + EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); +} + +TEST_P(UnixSocketPairCmsgTest, SoPassCred) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int opt; + socklen_t optLen = sizeof(opt); + EXPECT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), + SyscallSucceeds()); + EXPECT_FALSE(opt); + + optLen = sizeof(opt); + EXPECT_THAT( + getsockopt(sockets->second_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), + SyscallSucceeds()); + EXPECT_FALSE(opt); + + SetSoPassCred(sockets->first_fd()); + + optLen = sizeof(opt); + EXPECT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), + SyscallSucceeds()); + EXPECT_TRUE(opt); + + optLen = sizeof(opt); + EXPECT_THAT( + getsockopt(sockets->second_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), + SyscallSucceeds()); + EXPECT_FALSE(opt); + + int zero = 0; + EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_PASSCRED, &zero, + sizeof(zero)), + SyscallSucceeds()); + + optLen = sizeof(opt); + EXPECT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), + SyscallSucceeds()); + EXPECT_FALSE(opt); + + optLen = sizeof(opt); + EXPECT_THAT( + getsockopt(sockets->second_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), + SyscallSucceeds()); + EXPECT_FALSE(opt); +} + +TEST_P(UnixSocketPairCmsgTest, NoDataCredPass) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + struct msghdr msg = {}; + + struct iovec iov; + iov.iov_base = sent_data; + iov.iov_len = sizeof(sent_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + char control[CMSG_SPACE(0)]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SCM_CREDENTIALS; + cmsg->cmsg_len = CMSG_LEN(0); + + ASSERT_THAT(RetryEINTR(sendmsg)(sockets->first_fd(), &msg, 0), + SyscallFailsWithErrno(EINVAL)); +} + +TEST_P(UnixSocketPairCmsgTest, NoPassCred) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + struct ucred sent_creds; + + ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); + + ASSERT_NO_FATAL_FAILURE( + SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); + + char received_data[20]; + + ASSERT_NO_FATAL_FAILURE( + RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); +} + +TEST_P(UnixSocketPairCmsgTest, CredAndFDPass) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + struct ucred sent_creds; + + ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendCredsAndFD(sockets->first_fd(), sent_creds, + pair->second_fd(), sent_data, + sizeof(sent_data))); + + SetSoPassCred(sockets->second_fd()); + + char received_data[20]; + struct ucred received_creds; + int fd = -1; + ASSERT_NO_FATAL_FAILURE(RecvCredsAndFD(sockets->second_fd(), &received_creds, + &fd, received_data, + sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + EXPECT_EQ(sent_creds.pid, received_creds.pid); + EXPECT_EQ(sent_creds.uid, received_creds.uid); + EXPECT_EQ(sent_creds.gid, received_creds.gid); + + ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); +} + +TEST_P(UnixSocketPairCmsgTest, FDPassBeforeSoPassCred) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + SetSoPassCred(sockets->second_fd()); + + char received_data[20]; + struct ucred received_creds; + int fd = -1; + ASSERT_NO_FATAL_FAILURE(RecvCredsAndFD(sockets->second_fd(), &received_creds, + &fd, received_data, + sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + struct ucred want_creds { + 0, 65534, 65534 + }; + + EXPECT_EQ(want_creds.pid, received_creds.pid); + EXPECT_EQ(want_creds.uid, received_creds.uid); + EXPECT_EQ(want_creds.gid, received_creds.gid); + + ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); +} + +TEST_P(UnixSocketPairCmsgTest, FDPassAfterSoPassCred) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + SetSoPassCred(sockets->second_fd()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + char received_data[20]; + struct ucred received_creds; + int fd = -1; + ASSERT_NO_FATAL_FAILURE(RecvCredsAndFD(sockets->second_fd(), &received_creds, + &fd, received_data, + sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + struct ucred want_creds; + ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); + ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); + + EXPECT_EQ(want_creds.pid, received_creds.pid); + EXPECT_EQ(want_creds.uid, received_creds.uid); + EXPECT_EQ(want_creds.gid, received_creds.gid); + + ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); +} + +TEST_P(UnixSocketPairCmsgTest, CloexecDroppedWhenFDPassed) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = ASSERT_NO_ERRNO_AND_VALUE( + UnixDomainSocketPair(SOCK_SEQPACKET | SOCK_CLOEXEC).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + char received_data[20]; + int fd = -1; + ASSERT_NO_FATAL_FAILURE(RecvSingleFD(sockets->second_fd(), &fd, received_data, + sizeof(received_data))); + + EXPECT_THAT(fcntl(fd, F_GETFD), SyscallSucceedsWithValue(0)); +} + +TEST_P(UnixSocketPairCmsgTest, CloexecRecvFDPass) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + struct msghdr msg = {}; + char control[CMSG_SPACE(sizeof(int))]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + struct iovec iov; + char received_data[20]; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_CMSG_CLOEXEC), + SyscallSucceedsWithValue(sizeof(received_data))); + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int))); + ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); + ASSERT_EQ(cmsg->cmsg_type, SCM_RIGHTS); + + int fd = -1; + memcpy(&fd, CMSG_DATA(cmsg), sizeof(int)); + + EXPECT_THAT(fcntl(fd, F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC)); +} + +TEST_P(UnixSocketPairCmsgTest, FDPassAfterSoPassCredWithoutCredSpace) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + SetSoPassCred(sockets->second_fd()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + struct msghdr msg = {}; + char control[CMSG_LEN(0)]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + char received_data[20]; + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + EXPECT_EQ(msg.msg_controllen, sizeof(control)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, sizeof(control)); + EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); + EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); +} + +// This test will validate that MSG_CTRUNC as an input flag to recvmsg will +// not appear as an output flag on the control message when truncation doesn't +// happen. +TEST_P(UnixSocketPairCmsgTest, MsgCtruncInputIsNoop) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + struct msghdr msg = {}; + char control[CMSG_SPACE(sizeof(int)) /* we're passing a single fd */]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + struct iovec iov; + char received_data[20]; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_CTRUNC), + SyscallSucceedsWithValue(sizeof(received_data))); + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int))); + ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); + ASSERT_EQ(cmsg->cmsg_type, SCM_RIGHTS); + + // Now we should verify that MSG_CTRUNC wasn't set as an output flag. + EXPECT_EQ(msg.msg_flags & MSG_CTRUNC, 0); +} + +TEST_P(UnixSocketPairCmsgTest, FDPassAfterSoPassCredWithoutCredHeaderSpace) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + SetSoPassCred(sockets->second_fd()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + struct msghdr msg = {}; + char control[CMSG_LEN(0) / 2]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + char received_data[20]; + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + EXPECT_EQ(msg.msg_controllen, 0); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_cmsg.h b/test/syscalls/linux/socket_unix_cmsg.h new file mode 100644 index 000000000..431606903 --- /dev/null +++ b/test/syscalls/linux/socket_unix_cmsg.h @@ -0,0 +1,30 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_CMSG_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_CMSG_H_ + +#include "test/syscalls/linux/socket_test_util.h" + +namespace gvisor { +namespace testing { + +// Test fixture for tests that apply to pairs of connected unix sockets about +// control messages. +using UnixSocketPairCmsgTest = SocketPairTest; + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_CMSG_H_ diff --git a/test/syscalls/linux/socket_unix_pair.cc b/test/syscalls/linux/socket_unix_pair.cc index bacfc11e4..411fb4518 100644 --- a/test/syscalls/linux/socket_unix_pair.cc +++ b/test/syscalls/linux/socket_unix_pair.cc @@ -16,6 +16,7 @@ #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/socket_unix.h" +#include "test/syscalls/linux/socket_unix_cmsg.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" @@ -33,5 +34,9 @@ INSTANTIATE_TEST_SUITE_P( AllUnixDomainSockets, UnixSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +INSTANTIATE_TEST_SUITE_P( + AllUnixDomainSockets, UnixSocketPairCmsgTest, + ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); + } // namespace testing } // namespace gvisor -- cgit v1.2.3 From 2d2831e3541c8ae3c84f17cfd1bf0a26f2027044 Mon Sep 17 00:00:00 2001 From: Rahat Mahmood Date: Thu, 6 Jun 2019 15:03:44 -0700 Subject: Track and export socket state. This is necessary for implementing network diagnostic interfaces like /proc/net/{tcp,udp,unix} and sock_diag(7). For pass-through endpoints such as hostinet, we obtain the socket state from the backend. For netstack, we add explicit tracking of TCP states. PiperOrigin-RevId: 251934850 --- pkg/abi/linux/socket.go | 16 ++ pkg/sentry/fs/proc/net.go | 20 +-- pkg/sentry/socket/epsocket/epsocket.go | 44 +++++ pkg/sentry/socket/hostinet/socket.go | 24 +++ pkg/sentry/socket/netlink/socket.go | 5 + pkg/sentry/socket/rpcinet/socket.go | 6 + pkg/sentry/socket/socket.go | 4 + pkg/sentry/socket/unix/transport/BUILD | 1 + pkg/sentry/socket/unix/transport/connectioned.go | 9 ++ pkg/sentry/socket/unix/transport/connectionless.go | 16 ++ pkg/sentry/socket/unix/transport/unix.go | 4 + pkg/sentry/socket/unix/unix.go | 5 + pkg/tcpip/stack/transport_test.go | 4 + pkg/tcpip/tcpip.go | 4 + pkg/tcpip/transport/icmp/BUILD | 1 + pkg/tcpip/transport/icmp/endpoint.go | 6 + pkg/tcpip/transport/raw/endpoint.go | 5 + pkg/tcpip/transport/tcp/accept.go | 12 +- pkg/tcpip/transport/tcp/connect.go | 26 ++- pkg/tcpip/transport/tcp/endpoint.go | 174 ++++++++++++++------ pkg/tcpip/transport/tcp/endpoint_state.go | 42 ++--- pkg/tcpip/transport/tcp/rcv.go | 37 +++++ pkg/tcpip/transport/tcp/snd.go | 4 + pkg/tcpip/transport/tcp/tcp_test.go | 131 ++++++++++++--- pkg/tcpip/transport/tcp/testing/context/context.go | 39 ++++- pkg/tcpip/transport/udp/endpoint.go | 6 + test/syscalls/linux/proc_net_unix.cc | 178 +++++++++++++++++++++ 27 files changed, 696 insertions(+), 127 deletions(-) (limited to 'pkg/sentry/fs') diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 417840731..44bd69df6 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -200,6 +200,22 @@ const ( SS_DISCONNECTING = 4 // In process of disconnecting. ) +// TCP protocol states, from include/net/tcp_states.h. +const ( + TCP_ESTABLISHED uint32 = iota + 1 + TCP_SYN_SENT + TCP_SYN_RECV + TCP_FIN_WAIT1 + TCP_FIN_WAIT2 + TCP_TIME_WAIT + TCP_CLOSE + TCP_CLOSE_WAIT + TCP_LAST_ACK + TCP_LISTEN + TCP_CLOSING + TCP_NEW_SYN_RECV +) + // SockAddrMax is the maximum size of a struct sockaddr, from // uapi/linux/socket.h. const SockAddrMax = 128 diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 4a107c739..3daaa962c 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -240,24 +240,6 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s } } - var sockState int - switch sops.Endpoint().Type() { - case linux.SOCK_DGRAM: - sockState = linux.SS_CONNECTING - // Unlike Linux, we don't have unbound connection-less sockets, - // so no SS_DISCONNECTING. - - case linux.SOCK_SEQPACKET: - fallthrough - case linux.SOCK_STREAM: - // Connectioned. - if sops.Endpoint().(transport.ConnectingEndpoint).Connected() { - sockState = linux.SS_CONNECTED - } else { - sockState = linux.SS_UNCONNECTED - } - } - // In the socket entry below, the value for the 'Num' field requires // some consideration. Linux prints the address to the struct // unix_sock representing a socket in the kernel, but may redact the @@ -282,7 +264,7 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s 0, // Protocol, always 0 for UDS. sockFlags, // Flags. sops.Endpoint().Type(), // Type. - sockState, // State. + sops.State(), // State. sfile.InodeID(), // Inode. ) diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index de4b963da..f91c5127a 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -52,6 +52,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp" "gvisor.googlesource.com/gvisor/pkg/waiter" ) @@ -2281,3 +2282,46 @@ func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 { } return rv } + +// State implements socket.Socket.State. State translates the internal state +// returned by netstack to values defined by Linux. +func (s *SocketOperations) State() uint32 { + if s.family != linux.AF_INET && s.family != linux.AF_INET6 { + // States not implemented for this socket's family. + return 0 + } + + if !s.isPacketBased() { + // TCP socket. + switch tcp.EndpointState(s.Endpoint.State()) { + case tcp.StateEstablished: + return linux.TCP_ESTABLISHED + case tcp.StateSynSent: + return linux.TCP_SYN_SENT + case tcp.StateSynRecv: + return linux.TCP_SYN_RECV + case tcp.StateFinWait1: + return linux.TCP_FIN_WAIT1 + case tcp.StateFinWait2: + return linux.TCP_FIN_WAIT2 + case tcp.StateTimeWait: + return linux.TCP_TIME_WAIT + case tcp.StateClose, tcp.StateInitial, tcp.StateBound, tcp.StateConnecting, tcp.StateError: + return linux.TCP_CLOSE + case tcp.StateCloseWait: + return linux.TCP_CLOSE_WAIT + case tcp.StateLastAck: + return linux.TCP_LAST_ACK + case tcp.StateListen: + return linux.TCP_LISTEN + case tcp.StateClosing: + return linux.TCP_CLOSING + default: + // Internal or unknown state. + return 0 + } + } + + // TODO(b/112063468): Export states for UDP, ICMP, and raw sockets. + return 0 +} diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 41f9693bb..0d75580a3 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -19,7 +19,9 @@ import ( "syscall" "gvisor.googlesource.com/gvisor/pkg/abi/linux" + "gvisor.googlesource.com/gvisor/pkg/binary" "gvisor.googlesource.com/gvisor/pkg/fdnotifier" + "gvisor.googlesource.com/gvisor/pkg/log" "gvisor.googlesource.com/gvisor/pkg/sentry/context" "gvisor.googlesource.com/gvisor/pkg/sentry/fs" "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil" @@ -519,6 +521,28 @@ func translateIOSyscallError(err error) error { return err } +// State implements socket.Socket.State. +func (s *socketOperations) State() uint32 { + info := linux.TCPInfo{} + buf, err := getsockopt(s.fd, syscall.SOL_TCP, syscall.TCP_INFO, linux.SizeOfTCPInfo) + if err != nil { + if err != syscall.ENOPROTOOPT { + log.Warningf("Failed to get TCP socket info from %+v: %v", s, err) + } + // For non-TCP sockets, silently ignore the failure. + return 0 + } + if len(buf) != linux.SizeOfTCPInfo { + // Unmarshal below will panic if getsockopt returns a buffer of + // unexpected size. + log.Warningf("Failed to get TCP socket info from %+v: getsockopt(2) returned %d bytes, expecting %d bytes.", s, len(buf), linux.SizeOfTCPInfo) + return 0 + } + + binary.Unmarshal(buf, usermem.ByteOrder, &info) + return uint32(info.State) +} + type socketProvider struct { family int } diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index afd06ca33..16c79aa33 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -616,3 +616,8 @@ func (s *Socket) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, n, err := s.sendMsg(ctx, src, nil, 0, socket.ControlMessages{}) return int64(n), err.ToError() } + +// State implements socket.Socket.State. +func (s *Socket) State() uint32 { + return s.ep.State() +} diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index 55e0b6665..bf42bdf69 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -830,6 +830,12 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] } } +// State implements socket.Socket.State. +func (s *socketOperations) State() uint32 { + // TODO(b/127845868): Define a new rpc to query the socket state. + return 0 +} + type socketProvider struct { family int } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 9393acd28..a99423365 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -116,6 +116,10 @@ type Socket interface { // SendTimeout gets the current timeout (in ns) for send operations. Zero // means no timeout, and negative means DONTWAIT. SendTimeout() int64 + + // State returns the current state of the socket, as represented by Linux in + // procfs. The returned state value is protocol-specific. + State() uint32 } // Provider is the interface implemented by providers of sockets for specific diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 5a2de0c4c..52f324eed 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -28,6 +28,7 @@ go_library( importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport", visibility = ["//:sandbox"], deps = [ + "//pkg/abi/linux", "//pkg/ilist", "//pkg/refs", "//pkg/syserr", diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 18e492862..9c8ec0365 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -17,6 +17,7 @@ package transport import ( "sync" + "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/waiter" @@ -458,3 +459,11 @@ func (e *connectionedEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask return ready } + +// State implements socket.Socket.State. +func (e *connectionedEndpoint) State() uint32 { + if e.Connected() { + return linux.SS_CONNECTED + } + return linux.SS_UNCONNECTED +} diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 43ff875e4..c034cf984 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -15,6 +15,7 @@ package transport import ( + "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/waiter" @@ -194,3 +195,18 @@ func (e *connectionlessEndpoint) Readiness(mask waiter.EventMask) waiter.EventMa return ready } + +// State implements socket.Socket.State. +func (e *connectionlessEndpoint) State() uint32 { + e.Lock() + defer e.Unlock() + + switch { + case e.isBound(): + return linux.SS_UNCONNECTED + case e.Connected(): + return linux.SS_CONNECTING + default: + return linux.SS_DISCONNECTING + } +} diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 37d82bb6b..5fc09af55 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -191,6 +191,10 @@ type Endpoint interface { // GetSockOpt gets a socket option. opt should be a pointer to one of the // tcpip.*Option types. GetSockOpt(opt interface{}) *tcpip.Error + + // State returns the current state of the socket, as represented by Linux in + // procfs. + State() uint32 } // A Credentialer is a socket or endpoint that supports the SO_PASSCRED socket diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 388cc0d8b..375542350 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -596,6 +596,11 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } } +// State implements socket.Socket.State. +func (s *SocketOperations) State() uint32 { + return s.ep.State() +} + // provider is a unix domain socket provider. type provider struct{} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 8d74f1543..e8a9392b5 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -188,6 +188,10 @@ func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, s f.proto.controlCount++ } +func (f *fakeTransportEndpoint) State() uint32 { + return 0 +} + type fakeTransportGoodOption bool type fakeTransportBadOption bool diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index f9886c6e4..85ef014d0 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -377,6 +377,10 @@ type Endpoint interface { // GetSockOpt gets a socket option. opt should be a pointer to one of the // *Option types. GetSockOpt(opt interface{}) *Error + + // State returns a socket's lifecycle state. The returned value is + // protocol-specific and is primarily used for diagnostics. + State() uint32 } // WriteOptions contains options for Endpoint.Write. diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index 9aa6f3978..84a2b53b7 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -33,6 +33,7 @@ go_library( "//pkg/tcpip/header", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", + "//pkg/tcpip/transport/tcp", "//pkg/waiter", ], ) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index e2b90ef10..b8005093a 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -708,3 +708,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { } + +// State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't +// expose internal socket state. +func (e *endpoint) State() uint32 { + return 0 +} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 1daf5823f..e4ff50c91 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -519,3 +519,8 @@ func (ep *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv b ep.waiterQueue.Notify(waiter.EventIn) } } + +// State implements socket.Socket.State. +func (ep *endpoint) State() uint32 { + return 0 +} diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 31e365ae5..a32e20b06 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -226,7 +226,6 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i } n.isRegistered = true - n.state = stateConnecting // Create sender and receiver. // @@ -258,8 +257,9 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head ep.Close() return nil, err } - - ep.state = stateConnected + ep.mu.Lock() + ep.state = StateEstablished + ep.mu.Unlock() // Update the receive window scaling. We can't do it before the // handshake because it's possible that the peer doesn't support window @@ -276,7 +276,7 @@ func (e *endpoint) deliverAccepted(n *endpoint) { e.mu.RLock() state := e.state e.mu.RUnlock() - if state == stateListen { + if state == StateListen { e.acceptedChan <- n e.waiterQueue.Notify(waiter.EventIn) } else { @@ -406,7 +406,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n.tsOffset = 0 // Switch state to connected. - n.state = stateConnected + n.state = StateEstablished // Do the delivery in a separate goroutine so // that we don't block the listen loop in case @@ -429,7 +429,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { // handleSynSegment() from attempting to queue new connections // to the endpoint. e.mu.Lock() - e.state = stateClosed + e.state = StateClose // Do cleanup if needed. e.completeWorkerLocked() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 371d2ed29..0ad7bfb38 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -151,6 +151,9 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea h.mss = opts.MSS h.sndWndScale = opts.WS h.listenEP = listenEP + h.ep.mu.Lock() + h.ep.state = StateSynRecv + h.ep.mu.Unlock() } // checkAck checks if the ACK number, if present, of a segment received during @@ -219,6 +222,9 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // but resend our own SYN and wait for it to be acknowledged in the // SYN-RCVD state. h.state = handshakeSynRcvd + h.ep.mu.Lock() + h.ep.state = StateSynRecv + h.ep.mu.Unlock() synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, TS: rcvSynOpts.TS, @@ -668,7 +674,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // sendRaw sends a TCP segment to the endpoint's peer. func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { var sackBlocks []header.SACKBlock - if e.state == stateConnected && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { + if e.state == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) @@ -719,8 +725,7 @@ func (e *endpoint) handleClose() *tcpip.Error { // protocol goroutine. func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0) - - e.state = stateError + e.state = StateError e.hardError = err } @@ -876,14 +881,19 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // handshake, and then inform potential waiters about its // completion. h := newHandshake(e, seqnum.Size(e.receiveBufferAvailable())) + e.mu.Lock() + h.ep.state = StateSynSent + e.mu.Unlock() + if err := h.execute(); err != nil { e.lastErrorMu.Lock() e.lastError = err e.lastErrorMu.Unlock() e.mu.Lock() - e.state = stateError + e.state = StateError e.hardError = err + // Lock released below. epilogue() @@ -905,7 +915,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Tell waiters that the endpoint is connected and writable. e.mu.Lock() - e.state = stateConnected + e.state = StateEstablished drained := e.drainDone != nil e.mu.Unlock() if drained { @@ -1005,7 +1015,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return err } } - if e.state != stateError { + if e.state != StateError { close(e.drainDone) <-e.undrain } @@ -1061,8 +1071,8 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Mark endpoint as closed. e.mu.Lock() - if e.state != stateError { - e.state = stateClosed + if e.state != StateError { + e.state = StateClose } // Lock released below. epilogue() diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index fd697402e..23422ca5e 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -32,18 +32,81 @@ import ( "gvisor.googlesource.com/gvisor/pkg/waiter" ) -type endpointState int +// EndpointState represents the state of a TCP endpoint. +type EndpointState uint32 +// Endpoint states. Note that are represented in a netstack-specific manner and +// may not be meaningful externally. Specifically, they need to be translated to +// Linux's representation for these states if presented to userspace. const ( - stateInitial endpointState = iota - stateBound - stateListen - stateConnecting - stateConnected - stateClosed - stateError + // Endpoint states internal to netstack. These map to the TCP state CLOSED. + StateInitial EndpointState = iota + StateBound + StateConnecting // Connect() called, but the initial SYN hasn't been sent. + StateError + + // TCP protocol states. + StateEstablished + StateSynSent + StateSynRecv + StateFinWait1 + StateFinWait2 + StateTimeWait + StateClose + StateCloseWait + StateLastAck + StateListen + StateClosing ) +// connected is the set of states where an endpoint is connected to a peer. +func (s EndpointState) connected() bool { + switch s { + case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: + return true + default: + return false + } +} + +// String implements fmt.Stringer.String. +func (s EndpointState) String() string { + switch s { + case StateInitial: + return "INITIAL" + case StateBound: + return "BOUND" + case StateConnecting: + return "CONNECTING" + case StateError: + return "ERROR" + case StateEstablished: + return "ESTABLISHED" + case StateSynSent: + return "SYN-SENT" + case StateSynRecv: + return "SYN-RCVD" + case StateFinWait1: + return "FIN-WAIT1" + case StateFinWait2: + return "FIN-WAIT2" + case StateTimeWait: + return "TIME-WAIT" + case StateClose: + return "CLOSED" + case StateCloseWait: + return "CLOSE-WAIT" + case StateLastAck: + return "LAST-ACK" + case StateListen: + return "LISTEN" + case StateClosing: + return "CLOSING" + default: + panic("unreachable") + } +} + // Reasons for notifying the protocol goroutine. const ( notifyNonZeroReceiveWindow = 1 << iota @@ -108,10 +171,14 @@ type endpoint struct { rcvBufUsed int // The following fields are protected by the mutex. - mu sync.RWMutex `state:"nosave"` - id stack.TransportEndpointID - state endpointState `state:".(endpointState)"` - isPortReserved bool `state:"manual"` + mu sync.RWMutex `state:"nosave"` + id stack.TransportEndpointID + + // state endpointState `state:".(endpointState)"` + // pState ProtocolState + state EndpointState `state:".(EndpointState)"` + + isPortReserved bool `state:"manual"` isRegistered bool boundNICID tcpip.NICID `state:"manual"` route stack.Route `state:"manual"` @@ -304,6 +371,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite stack: stack, netProto: netProto, waiterQueue: waiterQueue, + state: StateInitial, rcvBufSize: DefaultBufferSize, sndBufSize: DefaultBufferSize, sndMTU: int(math.MaxInt32), @@ -351,14 +419,14 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { defer e.mu.RUnlock() switch e.state { - case stateInitial, stateBound, stateConnecting: + case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv: // Ready for nothing. - case stateClosed, stateError: + case StateClose, StateError: // Ready for anything. result = mask - case stateListen: + case StateListen: // Check if there's anything in the accepted channel. if (mask & waiter.EventIn) != 0 { if len(e.acceptedChan) > 0 { @@ -366,7 +434,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { } } - case stateConnected: + case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: // Determine if the endpoint is writable if requested. if (mask & waiter.EventOut) != 0 { e.sndBufMu.Lock() @@ -427,7 +495,7 @@ func (e *endpoint) Close() { // are immediately available for reuse after Close() is called. If also // registered, we unregister as well otherwise the next user would fail // in Listen() when trying to register. - if e.state == stateListen && e.isPortReserved { + if e.state == StateListen && e.isPortReserved { if e.isRegistered { e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) e.isRegistered = false @@ -487,15 +555,15 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, e.mu.RLock() // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. Also note that a RST being received - // would cause the state to become stateError so we should allow the + // would cause the state to become StateError so we should allow the // reads to proceed before returning a ECONNRESET. e.rcvListMu.Lock() bufUsed := e.rcvBufUsed - if s := e.state; s != stateConnected && s != stateClosed && bufUsed == 0 { + if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() he := e.hardError e.mu.RUnlock() - if s == stateError { + if s == StateError { return buffer.View{}, tcpip.ControlMessages{}, he } return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState @@ -511,7 +579,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { if e.rcvBufUsed == 0 { - if e.rcvClosed || e.state != stateConnected { + if e.rcvClosed || !e.state.connected() { return buffer.View{}, tcpip.ErrClosedForReceive } return buffer.View{}, tcpip.ErrWouldBlock @@ -547,9 +615,9 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c defer e.mu.RUnlock() // The endpoint cannot be written to if it's not connected. - if e.state != stateConnected { + if !e.state.connected() { switch e.state { - case stateError: + case StateError: return 0, nil, e.hardError default: return 0, nil, tcpip.ErrClosedForSend @@ -612,8 +680,8 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. - if s := e.state; s != stateConnected && s != stateClosed { - if s == stateError { + if s := e.state; !s.connected() && s != StateClose { + if s == StateError { return 0, tcpip.ControlMessages{}, e.hardError } return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState @@ -623,7 +691,7 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er defer e.rcvListMu.Unlock() if e.rcvBufUsed == 0 { - if e.rcvClosed || e.state != stateConnected { + if e.rcvClosed || !e.state.connected() { return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive } return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock @@ -789,7 +857,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { defer e.mu.Unlock() // We only allow this to be set when we're in the initial state. - if e.state != stateInitial { + if e.state != StateInitial { return tcpip.ErrInvalidEndpointState } @@ -841,7 +909,7 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { defer e.mu.RUnlock() // The endpoint cannot be in listen state. - if e.state == stateListen { + if e.state == StateListen { return 0, tcpip.ErrInvalidEndpointState } @@ -1057,7 +1125,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er nicid := addr.NIC switch e.state { - case stateBound: + case StateBound: // If we're already bound to a NIC but the caller is requesting // that we use a different one now, we cannot proceed. if e.boundNICID == 0 { @@ -1070,16 +1138,16 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er nicid = e.boundNICID - case stateInitial: - // Nothing to do. We'll eventually fill-in the gaps in the ID - // (if any) when we find a route. + case StateInitial: + // Nothing to do. We'll eventually fill-in the gaps in the ID (if any) + // when we find a route. - case stateConnecting: - // A connection request has already been issued but hasn't - // completed yet. + case StateConnecting, StateSynSent, StateSynRecv: + // A connection request has already been issued but hasn't completed + // yet. return tcpip.ErrAlreadyConnecting - case stateConnected: + case StateEstablished: // The endpoint is already connected. If caller hasn't been notified yet, return success. if !e.isConnectNotified { e.isConnectNotified = true @@ -1088,7 +1156,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er // Otherwise return that it's already connected. return tcpip.ErrAlreadyConnected - case stateError: + case StateError: return e.hardError default: @@ -1154,7 +1222,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er } e.isRegistered = true - e.state = stateConnecting + e.state = StateConnecting e.route = r.Clone() e.boundNICID = nicid e.effectiveNetProtos = netProtos @@ -1175,7 +1243,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er } e.segmentQueue.mu.Unlock() e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0) - e.state = stateConnected + e.state = StateEstablished } if run { @@ -1199,8 +1267,8 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { defer e.mu.Unlock() e.shutdownFlags |= flags - switch e.state { - case stateConnected: + switch { + case e.state.connected(): // Close for read. if (e.shutdownFlags & tcpip.ShutdownRead) != 0 { // Mark read side as closed. @@ -1241,7 +1309,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { e.sndCloseWaker.Assert() } - case stateListen: + case e.state == StateListen: // Tell protocolListenLoop to stop. if flags&tcpip.ShutdownRead != 0 { e.notifyProtocolGoroutine(notifyClose) @@ -1269,7 +1337,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { // When the endpoint shuts down, it sets workerCleanup to true, and from // that point onward, acceptedChan is the responsibility of the cleanup() // method (and should not be touched anywhere else, including here). - if e.state == stateListen && !e.workerCleanup { + if e.state == StateListen && !e.workerCleanup { // Adjust the size of the channel iff we can fix existing // pending connections into the new one. if len(e.acceptedChan) > backlog { @@ -1288,7 +1356,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { } // Endpoint must be bound before it can transition to listen mode. - if e.state != stateBound { + if e.state != StateBound { return tcpip.ErrInvalidEndpointState } @@ -1298,7 +1366,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { } e.isRegistered = true - e.state = stateListen + e.state = StateListen if e.acceptedChan == nil { e.acceptedChan = make(chan *endpoint, backlog) } @@ -1325,7 +1393,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { defer e.mu.RUnlock() // Endpoint must be in listen state before it can accept connections. - if e.state != stateListen { + if e.state != StateListen { return nil, nil, tcpip.ErrInvalidEndpointState } @@ -1353,7 +1421,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { // Don't allow binding once endpoint is not in the initial state // anymore. This is because once the endpoint goes into a connected or // listen state, it is already bound. - if e.state != stateInitial { + if e.state != StateInitial { return tcpip.ErrAlreadyBound } @@ -1408,7 +1476,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { } // Mark endpoint as bound. - e.state = stateBound + e.state = StateBound return nil } @@ -1430,7 +1498,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - if e.state != stateConnected { + if !e.state.connected() { return tcpip.FullAddress{}, tcpip.ErrNotConnected } @@ -1739,3 +1807,11 @@ func (e *endpoint) initGSO() { gso.MaxSize = e.route.GSOMaxSize() e.gso = gso } + +// State implements tcpip.Endpoint.State. It exports the endpoint's protocol +// state for diagnostics. +func (e *endpoint) State() uint32 { + e.mu.Lock() + defer e.mu.Unlock() + return uint32(e.state) +} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index e8aed2875..5f30c2374 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -49,8 +49,8 @@ func (e *endpoint) beforeSave() { defer e.mu.Unlock() switch e.state { - case stateInitial, stateBound: - case stateConnected: + case StateInitial, StateBound: + case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 { if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 { panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)}) @@ -66,17 +66,17 @@ func (e *endpoint) beforeSave() { break } fallthrough - case stateListen, stateConnecting: + case StateListen, StateConnecting: e.drainSegmentLocked() - if e.state != stateClosed && e.state != stateError { + if e.state != StateClose && e.state != StateError { if !e.workerRunning { panic("endpoint has no worker running in listen, connecting, or connected state") } break } fallthrough - case stateError, stateClosed: - for e.state == stateError && e.workerRunning { + case StateError, StateClose: + for e.state == StateError && e.workerRunning { e.mu.Unlock() time.Sleep(100 * time.Millisecond) e.mu.Lock() @@ -92,7 +92,7 @@ func (e *endpoint) beforeSave() { panic("endpoint still has waiters upon save") } - if e.state != stateClosed && !((e.state == stateBound || e.state == stateListen) == e.isPortReserved) { + if e.state != StateClose && !((e.state == StateBound || e.state == StateListen) == e.isPortReserved) { panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state") } } @@ -132,7 +132,7 @@ func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) { } // saveState is invoked by stateify. -func (e *endpoint) saveState() endpointState { +func (e *endpoint) saveState() EndpointState { return e.state } @@ -146,15 +146,15 @@ var connectingLoading sync.WaitGroup // Bound endpoint loading happens last. // loadState is invoked by stateify. -func (e *endpoint) loadState(state endpointState) { +func (e *endpoint) loadState(state EndpointState) { // This is to ensure that the loading wait groups include all applicable // endpoints before any asynchronous calls to the Wait() methods. switch state { - case stateConnected: + case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: connectedLoading.Add(1) - case stateListen: + case StateListen: listenLoading.Add(1) - case stateConnecting: + case StateConnecting, StateSynSent, StateSynRecv: connectingLoading.Add(1) } e.state = state @@ -168,7 +168,7 @@ func (e *endpoint) afterLoad() { state := e.state switch state { - case stateInitial, stateBound, stateListen, stateConnecting, stateConnected: + case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss SendBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { @@ -181,7 +181,7 @@ func (e *endpoint) afterLoad() { } bind := func() { - e.state = stateInitial + e.state = StateInitial if len(e.bindAddress) == 0 { e.bindAddress = e.id.LocalAddress } @@ -191,7 +191,7 @@ func (e *endpoint) afterLoad() { } switch state { - case stateConnected: + case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: bind() if len(e.connectingAddress) == 0 { // This endpoint is accepted by netstack but not yet by @@ -211,7 +211,7 @@ func (e *endpoint) afterLoad() { panic("endpoint connecting failed: " + err.String()) } connectedLoading.Done() - case stateListen: + case StateListen: tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -223,7 +223,7 @@ func (e *endpoint) afterLoad() { listenLoading.Done() tcpip.AsyncLoading.Done() }() - case stateConnecting: + case StateConnecting, StateSynSent, StateSynRecv: tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -235,7 +235,7 @@ func (e *endpoint) afterLoad() { connectingLoading.Done() tcpip.AsyncLoading.Done() }() - case stateBound: + case StateBound: tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -244,7 +244,7 @@ func (e *endpoint) afterLoad() { bind() tcpip.AsyncLoading.Done() }() - case stateClosed: + case StateClose: if e.isPortReserved { tcpip.AsyncLoading.Add(1) go func() { @@ -252,12 +252,12 @@ func (e *endpoint) afterLoad() { listenLoading.Wait() connectingLoading.Wait() bind() - e.state = stateClosed + e.state = StateClose tcpip.AsyncLoading.Done() }() } fallthrough - case stateError: + case StateError: tcpip.DeleteDanglingEndpoint(e) } } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index b08a0e356..f02fa6105 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -134,6 +134,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // sequence numbers that have been consumed. TrimSACKBlockList(&r.ep.sack, r.rcvNxt) + // Handle FIN or FIN-ACK. if s.flagIsSet(header.TCPFlagFin) { r.rcvNxt++ @@ -144,6 +145,25 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum r.closed = true r.ep.readyToRead(nil) + // We just received a FIN, our next state depends on whether we sent a + // FIN already or not. + r.ep.mu.Lock() + switch r.ep.state { + case StateEstablished: + r.ep.state = StateCloseWait + case StateFinWait1: + if s.flagIsSet(header.TCPFlagAck) { + // FIN-ACK, transition to TIME-WAIT. + r.ep.state = StateTimeWait + } else { + // Simultaneous close, expecting a final ACK. + r.ep.state = StateClosing + } + case StateFinWait2: + r.ep.state = StateTimeWait + } + r.ep.mu.Unlock() + // Flush out any pending segments, except the very first one if // it happens to be the one we're handling now because the // caller is using it. @@ -156,6 +176,23 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum r.pendingRcvdSegments[i].decRef() } r.pendingRcvdSegments = r.pendingRcvdSegments[:first] + + return true + } + + // Handle ACK (not FIN-ACK, which we handled above) during one of the + // shutdown states. + if s.flagIsSet(header.TCPFlagAck) { + r.ep.mu.Lock() + switch r.ep.state { + case StateFinWait1: + r.ep.state = StateFinWait2 + case StateClosing: + r.ep.state = StateTimeWait + case StateLastAck: + r.ep.state = StateClose + } + r.ep.mu.Unlock() } return true diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 3464e4be7..b236d7af2 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -632,6 +632,10 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se } seg.flags = header.TCPFlagAck | header.TCPFlagFin segEnd = seg.sequenceNumber.Add(1) + // Transition to FIN-WAIT1 state since we're initiating an active close. + s.ep.mu.Lock() + s.ep.state = StateFinWait1 + s.ep.mu.Unlock() } else { // We're sending a non-FIN segment. if seg.flags&header.TCPFlagFin != 0 { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index b8f0ccaf1..56b490aaa 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -168,8 +168,8 @@ func TestTCPResetsSentIncrement(t *testing.T) { // Receive the SYN-ACK reply. b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) ackHeaders := &context.Headers{ SrcPort: context.TestPort, @@ -269,8 +269,8 @@ func TestConnectResetAfterClose(t *testing.T) { time.Sleep(3 * time.Second) for { b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - if tcp.Flags() == header.TCPFlagAck|header.TCPFlagFin { + tcpHdr := header.TCP(header.IPv4(b).Payload()) + if tcpHdr.Flags() == header.TCPFlagAck|header.TCPFlagFin { // This is a retransmit of the FIN, ignore it. continue } @@ -553,9 +553,13 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { // We shouldn't consume a sequence number on RST. checker.SeqNum(uint32(c.IRS)+1), )) + // The RST puts the endpoint into an error state. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } - // This final should be ignored because an ACK on a reset doesn't - // mean anything. + // This final ACK should be ignored because an ACK on a reset doesn't mean + // anything. c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, @@ -618,6 +622,10 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.SeqNum(uint32(c.IRS)+1), )) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + // Cause a RST to be generated by closing the read end now since we have // unread data. c.EP.Shutdown(tcpip.ShutdownRead) @@ -630,6 +638,10 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { // We shouldn't consume a sequence number on RST. checker.SeqNum(uint32(c.IRS)+1), )) + // The RST puts the endpoint into an error state. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } // The ACK to the FIN should now be rejected since the connection has been // closed by a RST. @@ -1510,8 +1522,8 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { for bytesReceived != dataLen { b := c.GetPacket() numPackets++ - tcp := header.TCP(header.IPv4(b).Payload()) - payloadLen := len(tcp.Payload()) + tcpHdr := header.TCP(header.IPv4(b).Payload()) + payloadLen := len(tcpHdr.Payload()) checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), @@ -1522,7 +1534,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { ) pdata := data[bytesReceived : bytesReceived+payloadLen] - if p := tcp.Payload(); !bytes.Equal(pdata, p) { + if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) { t.Fatalf("got data = %v, want = %v", p, pdata) } bytesReceived += payloadLen @@ -1530,7 +1542,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { if c.TimeStampEnabled { // If timestamp option is enabled, echo back the timestamp and increment // the TSEcr value included in the packet and send that back as the TSVal. - parsedOpts := tcp.ParsedOptions() + parsedOpts := tcpHdr.ParsedOptions() tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:]) options = tsOpt[:] @@ -1757,8 +1769,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { ), ) - tcp := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) // Wait for retransmit. time.Sleep(1 * time.Second) @@ -1766,8 +1778,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagSyn), - checker.SrcPort(tcp.SourcePort()), - checker.SeqNum(tcp.SequenceNumber()), + checker.SrcPort(tcpHdr.SourcePort()), + checker.SeqNum(tcpHdr.SequenceNumber()), checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), ), ) @@ -1775,8 +1787,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { // Send SYN-ACK. iss := seqnum.Value(789) c.SendPacket(nil, &context.Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: tcp.SourcePort(), + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), Flags: header.TCPFlagSyn | header.TCPFlagAck, SeqNum: iss, AckNum: c.IRS.Add(1), @@ -2523,8 +2535,8 @@ func TestReceivedSegmentQueuing(t *testing.T) { checker.TCPFlags(header.TCPFlagAck), ), ) - tcp := header.TCP(header.IPv4(b).Payload()) - ack := seqnum.Value(tcp.AckNumber()) + tcpHdr := header.TCP(header.IPv4(b).Payload()) + ack := seqnum.Value(tcpHdr.AckNumber()) if ack == last { break } @@ -2568,6 +2580,10 @@ func TestReadAfterClosedState(t *testing.T) { ), ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + // Send some data and acknowledge the FIN. data := []byte{1, 2, 3} c.SendPacket(data, &context.Headers{ @@ -2589,9 +2605,15 @@ func TestReadAfterClosedState(t *testing.T) { ), ) - // Give the stack the chance to transition to closed state. + // Give the stack the chance to transition to closed state. Note that since + // both the sender and receiver are now closed, we effectively skip the + // TIME-WAIT state. time.Sleep(1 * time.Second) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + // Wait for receive to be notified. select { case <-ch: @@ -3680,9 +3702,15 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) { if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %v", err) } + if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } if err := c.EP.Listen(1); err != nil { t.Fatalf("Listen failed: %v", err) } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } stats := c.Stack().Stats() want := stats.TCP.PassiveConnectionOpenings.Value() + 1 @@ -3826,3 +3854,68 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { } } } + +func TestEndpointBindListenAcceptState(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS}) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + aep, _, err := ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + aep, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + // Listening endpoint remains in listen state. + if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + ep.Close() + // Give worker goroutines time to receive the close notification. + time.Sleep(1 * time.Second) + if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + // Accepted endpoint remains open when the listen endpoint is closed. + if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + +} diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 6e12413c6..69a43b6f4 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -532,6 +532,9 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. if err != nil { c.t.Fatalf("NewEndpoint failed: %v", err) } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want { + c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } if epRcvBuf != nil { if err := c.EP.SetSockOpt(*epRcvBuf); err != nil { @@ -557,13 +560,16 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. checker.TCPFlags(header.TCPFlagSyn), ), ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } - tcp := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) c.SendPacket(nil, &Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: tcp.SourcePort(), + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), Flags: header.TCPFlagSyn | header.TCPFlagAck, SeqNum: iss, AckNum: c.IRS.Add(1), @@ -591,8 +597,11 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. case <-time.After(1 * time.Second): c.t.Fatalf("Timed out waiting for connection") } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } - c.Port = tcp.SourcePort() + c.Port = tcpHdr.SourcePort() } // RawEndpoint is just a small wrapper around a TCP endpoint's state to make @@ -690,6 +699,9 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * if err != nil { c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err) } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } // Start connection attempt. waitEntry, notifyCh := waiter.NewChannelEntry(nil) @@ -719,6 +731,10 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * }), ), ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } + tcpSeg := header.TCP(header.IPv4(b).Payload()) synOptions := header.ParseSynOptions(tcpSeg.Options(), false) @@ -782,6 +798,9 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * case <-time.After(1 * time.Second): c.t.Fatalf("Timed out waiting for connection") } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } // Store the source port in use by the endpoint. c.Port = tcpSeg.SourcePort() @@ -821,10 +840,16 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption if err := ep.Bind(tcpip.FullAddress{Port: StackPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } + if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { + c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } if err := ep.Listen(10); err != nil { c.t.Fatalf("Listen failed: %v", err) } + if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { + c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } rep := c.PassiveConnectWithOptions(100, wndScale, synOptions) @@ -847,6 +872,10 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption c.t.Fatalf("Timed out waiting for accept") } } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { + c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + return rep } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 3d52a4f31..fa7278286 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1000,3 +1000,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { } + +// State implements socket.Socket.State. +func (e *endpoint) State() uint32 { + // TODO(b/112063468): Translate internal state to values returned by Linux. + return 0 +} diff --git a/test/syscalls/linux/proc_net_unix.cc b/test/syscalls/linux/proc_net_unix.cc index 6d745f728..82d325c17 100644 --- a/test/syscalls/linux/proc_net_unix.cc +++ b/test/syscalls/linux/proc_net_unix.cc @@ -34,6 +34,16 @@ using absl::StrFormat; constexpr char kProcNetUnixHeader[] = "Num RefCount Protocol Flags Type St Inode Path"; +// Possible values of the "st" field in a /proc/net/unix entry. Source: Linux +// kernel, include/uapi/linux/net.h. +enum { + SS_FREE = 0, // Not allocated + SS_UNCONNECTED, // Unconnected to any socket + SS_CONNECTING, // In process of connecting + SS_CONNECTED, // Connected to socket + SS_DISCONNECTING // In process of disconnecting +}; + // UnixEntry represents a single entry from /proc/net/unix. struct UnixEntry { uintptr_t addr; @@ -71,7 +81,12 @@ PosixErrorOr> ProcNetUnixEntries() { bool skipped_header = false; std::vector entries; std::vector lines = absl::StrSplit(content, absl::ByAnyChar("\n")); + std::cerr << "" << std::endl; for (std::string line : lines) { + // Emit the proc entry to the test output to provide context for the test + // results. + std::cerr << line << std::endl; + if (!skipped_header) { EXPECT_EQ(line, kProcNetUnixHeader); skipped_header = true; @@ -139,6 +154,7 @@ PosixErrorOr> ProcNetUnixEntries() { entries.push_back(entry); } + std::cerr << "" << std::endl; return entries; } @@ -241,6 +257,168 @@ TEST(ProcNetUnix, SocketPair) { EXPECT_EQ(entries.size(), 2); } +TEST(ProcNetUnix, StreamSocketStateUnconnectedOnBind) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE( + AbstractUnboundUnixDomainSocketPair(SOCK_STREAM).Create()); + + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + std::vector entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + + const std::string address = ExtractPath(sockets->first_addr()); + UnixEntry bind_entry; + ASSERT_TRUE(FindByPath(entries, &bind_entry, address)); + EXPECT_EQ(bind_entry.state, SS_UNCONNECTED); +} + +TEST(ProcNetUnix, StreamSocketStateStateUnconnectedOnListen) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE( + AbstractUnboundUnixDomainSocketPair(SOCK_STREAM).Create()); + + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + std::vector entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + + const std::string address = ExtractPath(sockets->first_addr()); + UnixEntry bind_entry; + ASSERT_TRUE(FindByPath(entries, &bind_entry, address)); + EXPECT_EQ(bind_entry.state, SS_UNCONNECTED); + + ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); + + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + UnixEntry listen_entry; + ASSERT_TRUE( + FindByPath(entries, &listen_entry, ExtractPath(sockets->first_addr()))); + EXPECT_EQ(listen_entry.state, SS_UNCONNECTED); + // The bind and listen entries should refer to the same socket. + EXPECT_EQ(listen_entry.inode, bind_entry.inode); +} + +TEST(ProcNetUnix, StreamSocketStateStateConnectedOnAccept) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE( + AbstractUnboundUnixDomainSocketPair(SOCK_STREAM).Create()); + const std::string address = ExtractPath(sockets->first_addr()); + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); + std::vector entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + UnixEntry listen_entry; + ASSERT_TRUE( + FindByPath(entries, &listen_entry, ExtractPath(sockets->first_addr()))); + + ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + int clientfd; + ASSERT_THAT(clientfd = accept(sockets->first_fd(), nullptr, nullptr), + SyscallSucceeds()); + + // Find the entry for the accepted socket. UDS proc entries don't have a + // remote address, so we distinguish the accepted socket from the listen + // socket by checking for a different inode. + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + UnixEntry accept_entry; + ASSERT_TRUE(FindBy( + entries, &accept_entry, [address, listen_entry](const UnixEntry& e) { + return e.path == address && e.inode != listen_entry.inode; + })); + EXPECT_EQ(accept_entry.state, SS_CONNECTED); + // Listen entry should still be in SS_UNCONNECTED state. + ASSERT_TRUE(FindBy(entries, &listen_entry, + [&sockets, listen_entry](const UnixEntry& e) { + return e.path == ExtractPath(sockets->first_addr()) && + e.inode == listen_entry.inode; + })); + EXPECT_EQ(listen_entry.state, SS_UNCONNECTED); +} + +TEST(ProcNetUnix, DgramSocketStateDisconnectingOnBind) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE( + AbstractUnboundUnixDomainSocketPair(SOCK_DGRAM).Create()); + + std::vector entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + + // On gVisor, the only two UDS on the system are the ones we just created and + // we rely on this to locate the test socket entries in the remainder of the + // test. On a generic Linux system, we have no easy way to locate the + // corresponding entries, as they don't have an address yet. + if (IsRunningOnGvisor()) { + ASSERT_EQ(entries.size(), 2); + for (auto e : entries) { + ASSERT_EQ(e.state, SS_DISCONNECTING); + } + } + + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + const std::string address = ExtractPath(sockets->first_addr()); + UnixEntry bind_entry; + ASSERT_TRUE(FindByPath(entries, &bind_entry, address)); + EXPECT_EQ(bind_entry.state, SS_UNCONNECTED); +} + +TEST(ProcNetUnix, DgramSocketStateConnectingOnConnect) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE( + AbstractUnboundUnixDomainSocketPair(SOCK_DGRAM).Create()); + + std::vector entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + + // On gVisor, the only two UDS on the system are the ones we just created and + // we rely on this to locate the test socket entries in the remainder of the + // test. On a generic Linux system, we have no easy way to locate the + // corresponding entries, as they don't have an address yet. + if (IsRunningOnGvisor()) { + ASSERT_EQ(entries.size(), 2); + for (auto e : entries) { + ASSERT_EQ(e.state, SS_DISCONNECTING); + } + } + + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + const std::string address = ExtractPath(sockets->first_addr()); + UnixEntry bind_entry; + ASSERT_TRUE(FindByPath(entries, &bind_entry, address)); + + ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + + // Once again, we have no easy way to identify the connecting socket as it has + // no listed address. We can only identify the entry as the "non-bind socket + // entry" on gVisor, where we're guaranteed to have only the two entries we + // create during this test. + if (IsRunningOnGvisor()) { + ASSERT_EQ(entries.size(), 2); + UnixEntry connect_entry; + ASSERT_TRUE( + FindBy(entries, &connect_entry, [bind_entry](const UnixEntry& e) { + return e.inode != bind_entry.inode; + })); + EXPECT_EQ(connect_entry.state, SS_CONNECTING); + } +} + } // namespace } // namespace testing } // namespace gvisor -- cgit v1.2.3 From 02ab1f187cd24c67b754b004229421d189cee264 Mon Sep 17 00:00:00 2001 From: Fabricio Voznika Date: Thu, 6 Jun 2019 16:44:40 -0700 Subject: Copy up parent when binding UDS on overlayfs Overlayfs was expecting the parent to exist when bind(2) was called, which may not be the case. The fix is to copy the parent directory to the upper layer before binding the UDS. There is not good place to add tests for it. Syscall tests would be ideal, but it's hard to guarantee that the directory where the socket is created hasn't been touched before (and thus copied the parent to the upper layer). Added it to runsc integration tests for now. If it turns out we have lots of these kind of tests, we can consider moving them somewhere more appropriate. PiperOrigin-RevId: 251954156 --- pkg/sentry/fs/dirent.go | 2 +- pkg/sentry/fs/inode.go | 4 +-- pkg/sentry/fs/inode_overlay.go | 12 ++++----- runsc/test/integration/BUILD | 1 + runsc/test/integration/regression_test.go | 45 +++++++++++++++++++++++++++++++ 5 files changed, 55 insertions(+), 9 deletions(-) create mode 100644 runsc/test/integration/regression_test.go (limited to 'pkg/sentry/fs') diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go index c0bc261a2..a0a35c242 100644 --- a/pkg/sentry/fs/dirent.go +++ b/pkg/sentry/fs/dirent.go @@ -805,7 +805,7 @@ func (d *Dirent) Bind(ctx context.Context, root *Dirent, name string, data trans var childDir *Dirent err := d.genericCreate(ctx, root, name, func() error { var e error - childDir, e = d.Inode.Bind(ctx, name, data, perms) + childDir, e = d.Inode.Bind(ctx, d, name, data, perms) if e != nil { return e } diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go index aef1a1cb9..0b54c2e77 100644 --- a/pkg/sentry/fs/inode.go +++ b/pkg/sentry/fs/inode.go @@ -220,9 +220,9 @@ func (i *Inode) Rename(ctx context.Context, oldParent *Dirent, renamed *Dirent, } // Bind calls i.InodeOperations.Bind with i as the directory. -func (i *Inode) Bind(ctx context.Context, name string, data transport.BoundEndpoint, perm FilePermissions) (*Dirent, error) { +func (i *Inode) Bind(ctx context.Context, parent *Dirent, name string, data transport.BoundEndpoint, perm FilePermissions) (*Dirent, error) { if i.overlay != nil { - return overlayBind(ctx, i.overlay, name, data, perm) + return overlayBind(ctx, i.overlay, parent, name, data, perm) } return i.InodeOperations.Bind(ctx, i, name, data, perm) } diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go index cdffe173b..06506fb20 100644 --- a/pkg/sentry/fs/inode_overlay.go +++ b/pkg/sentry/fs/inode_overlay.go @@ -398,14 +398,14 @@ func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, rena return nil } -func overlayBind(ctx context.Context, o *overlayEntry, name string, data transport.BoundEndpoint, perm FilePermissions) (*Dirent, error) { +func overlayBind(ctx context.Context, o *overlayEntry, parent *Dirent, name string, data transport.BoundEndpoint, perm FilePermissions) (*Dirent, error) { + if err := copyUp(ctx, parent); err != nil { + return nil, err + } + o.copyMu.RLock() defer o.copyMu.RUnlock() - // We do not support doing anything exciting with sockets unless there - // is already a directory in the upper filesystem. - if o.upper == nil { - return nil, syserror.EOPNOTSUPP - } + d, err := o.upper.InodeOperations.Bind(ctx, o.upper, name, data, perm) if err != nil { return nil, err diff --git a/runsc/test/integration/BUILD b/runsc/test/integration/BUILD index 0c4e4fa80..04ed885c6 100644 --- a/runsc/test/integration/BUILD +++ b/runsc/test/integration/BUILD @@ -8,6 +8,7 @@ go_test( srcs = [ "exec_test.go", "integration_test.go", + "regression_test.go", ], embed = [":integration"], tags = [ diff --git a/runsc/test/integration/regression_test.go b/runsc/test/integration/regression_test.go new file mode 100644 index 000000000..80bae9970 --- /dev/null +++ b/runsc/test/integration/regression_test.go @@ -0,0 +1,45 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration + +import ( + "strings" + "testing" + + "gvisor.googlesource.com/gvisor/runsc/test/testutil" +) + +// Test that UDS can be created using overlay when parent directory is in lower +// layer only (b/134090485). +// +// Prerequisite: the directory where the socket file is created must not have +// been open for write before bind(2) is called. +func TestBindOverlay(t *testing.T) { + if err := testutil.Pull("ubuntu:trusty"); err != nil { + t.Fatal("docker pull failed:", err) + } + d := testutil.MakeDocker("bind-overlay-test") + + cmd := "nc -l -U /var/run/sock& sleep 1 && echo foobar-asdf | nc -U /var/run/sock" + got, err := d.RunFg("ubuntu:trusty", "bash", "-c", cmd) + if err != nil { + t.Fatal("docker run failed:", err) + } + + if want := "foobar-asdf"; !strings.Contains(got, want) { + t.Fatalf("docker run output is missing %q: %s", want, got) + } + defer d.CleanUp() +} -- cgit v1.2.3 From 315cf9a523d409dc6ddd5ce25f8f0315068ccc67 Mon Sep 17 00:00:00 2001 From: Rahat Mahmood Date: Thu, 6 Jun 2019 16:59:21 -0700 Subject: Use common definition of SockType. SockType isn't specific to unix domain sockets, and the current definition basically mirrors the linux ABI's definition. PiperOrigin-RevId: 251956740 --- pkg/abi/linux/socket.go | 18 +++++++++++------- pkg/sentry/fs/gofer/socket.go | 11 ++++++----- pkg/sentry/fs/host/socket.go | 11 ++++++----- pkg/sentry/socket/epsocket/BUILD | 1 - pkg/sentry/socket/epsocket/epsocket.go | 11 +++++------ pkg/sentry/socket/epsocket/provider.go | 7 +++---- pkg/sentry/socket/hostinet/BUILD | 1 - pkg/sentry/socket/hostinet/socket.go | 5 ++--- pkg/sentry/socket/netlink/provider.go | 7 +++---- pkg/sentry/socket/rpcinet/BUILD | 1 - pkg/sentry/socket/rpcinet/socket.go | 9 ++++----- pkg/sentry/socket/socket.go | 8 ++++---- pkg/sentry/socket/unix/transport/connectioned.go | 20 ++++++++++---------- pkg/sentry/socket/unix/transport/connectionless.go | 4 ++-- pkg/sentry/socket/unix/transport/unix.go | 22 ++++------------------ pkg/sentry/socket/unix/unix.go | 4 ++-- pkg/sentry/strace/socket.go | 14 +++++++------- pkg/sentry/syscalls/linux/sys_socket.go | 4 ++-- 18 files changed, 71 insertions(+), 87 deletions(-) (limited to 'pkg/sentry/fs') diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 44bd69df6..a714ac86d 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -102,15 +102,19 @@ const ( SOL_NETLINK = 270 ) +// A SockType is a type (as opposed to family) of sockets. These are enumerated +// below as SOCK_* constants. +type SockType int + // Socket types, from linux/net.h. const ( - SOCK_STREAM = 1 - SOCK_DGRAM = 2 - SOCK_RAW = 3 - SOCK_RDM = 4 - SOCK_SEQPACKET = 5 - SOCK_DCCP = 6 - SOCK_PACKET = 10 + SOCK_STREAM SockType = 1 + SOCK_DGRAM = 2 + SOCK_RAW = 3 + SOCK_RDM = 4 + SOCK_SEQPACKET = 5 + SOCK_DCCP = 6 + SOCK_PACKET = 10 ) // SOCK_TYPE_MASK covers all of the above socket types. The remaining bits are diff --git a/pkg/sentry/fs/gofer/socket.go b/pkg/sentry/fs/gofer/socket.go index 7376fd76f..7ac0a421f 100644 --- a/pkg/sentry/fs/gofer/socket.go +++ b/pkg/sentry/fs/gofer/socket.go @@ -15,6 +15,7 @@ package gofer import ( + "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/log" "gvisor.googlesource.com/gvisor/pkg/p9" "gvisor.googlesource.com/gvisor/pkg/sentry/fs" @@ -61,13 +62,13 @@ type endpoint struct { path string } -func unixSockToP9(t transport.SockType) (p9.ConnectFlags, bool) { +func sockTypeToP9(t linux.SockType) (p9.ConnectFlags, bool) { switch t { - case transport.SockStream: + case linux.SOCK_STREAM: return p9.StreamSocket, true - case transport.SockSeqpacket: + case linux.SOCK_SEQPACKET: return p9.SeqpacketSocket, true - case transport.SockDgram: + case linux.SOCK_DGRAM: return p9.DgramSocket, true } return 0, false @@ -75,7 +76,7 @@ func unixSockToP9(t transport.SockType) (p9.ConnectFlags, bool) { // BidirectionalConnect implements ConnectableEndpoint.BidirectionalConnect. func (e *endpoint) BidirectionalConnect(ce transport.ConnectingEndpoint, returnConnect func(transport.Receiver, transport.ConnectedEndpoint)) *syserr.Error { - cf, ok := unixSockToP9(ce.Type()) + cf, ok := sockTypeToP9(ce.Type()) if !ok { return syserr.ErrConnectionRefused } diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index e4ec0f62c..6423ad938 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -19,6 +19,7 @@ import ( "sync" "syscall" + "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/fd" "gvisor.googlesource.com/gvisor/pkg/fdnotifier" "gvisor.googlesource.com/gvisor/pkg/log" @@ -56,7 +57,7 @@ type ConnectedEndpoint struct { srfd int `state:"wait"` // stype is the type of Unix socket. - stype transport.SockType + stype linux.SockType // sndbuf is the size of the send buffer. // @@ -105,7 +106,7 @@ func (c *ConnectedEndpoint) init() *syserr.Error { return syserr.ErrInvalidEndpointState } - c.stype = transport.SockType(stype) + c.stype = linux.SockType(stype) c.sndbuf = sndbuf return nil @@ -163,7 +164,7 @@ func NewSocketWithDirent(ctx context.Context, d *fs.Dirent, f *fd.FD, flags fs.F ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) - return unixsocket.NewWithDirent(ctx, d, ep, e.stype != transport.SockStream, flags), nil + return unixsocket.NewWithDirent(ctx, d, ep, e.stype != linux.SOCK_STREAM, flags), nil } // newSocket allocates a new unix socket with host endpoint. @@ -195,7 +196,7 @@ func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error) ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) - return unixsocket.New(ctx, ep, e.stype != transport.SockStream), nil + return unixsocket.New(ctx, ep, e.stype != linux.SOCK_STREAM), nil } // Send implements transport.ConnectedEndpoint.Send. @@ -209,7 +210,7 @@ func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.Contro // Since stream sockets don't preserve message boundaries, we can write // only as much of the message as fits in the send buffer. - truncate := c.stype == transport.SockStream + truncate := c.stype == linux.SOCK_STREAM n, totalLen, err := fdWriteVec(c.file.FD(), data, c.sndbuf, truncate) if n < totalLen && err == nil { diff --git a/pkg/sentry/socket/epsocket/BUILD b/pkg/sentry/socket/epsocket/BUILD index 44bb97b5b..7e2679ea0 100644 --- a/pkg/sentry/socket/epsocket/BUILD +++ b/pkg/sentry/socket/epsocket/BUILD @@ -32,7 +32,6 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/safemem", "//pkg/sentry/socket", - "//pkg/sentry/socket/unix/transport", "//pkg/sentry/unimpl", "//pkg/sentry/usermem", "//pkg/syserr", diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index f91c5127a..e1e29de35 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -44,7 +44,6 @@ import ( ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time" "gvisor.googlesource.com/gvisor/pkg/sentry/safemem" "gvisor.googlesource.com/gvisor/pkg/sentry/socket" - "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/unimpl" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserr" @@ -228,7 +227,7 @@ type SocketOperations struct { family int Endpoint tcpip.Endpoint - skType transport.SockType + skType linux.SockType // readMu protects access to the below fields. readMu sync.Mutex `state:"nosave"` @@ -253,8 +252,8 @@ type SocketOperations struct { } // New creates a new endpoint socket. -func New(t *kernel.Task, family int, skType transport.SockType, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { - if skType == transport.SockStream { +func New(t *kernel.Task, family int, skType linux.SockType, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { + if skType == linux.SOCK_STREAM { if err := endpoint.SetSockOpt(tcpip.DelayOption(1)); err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -638,7 +637,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) ( // GetSockOpt can be used to implement the linux syscall getsockopt(2) for // sockets backed by a commonEndpoint. -func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType transport.SockType, level, name, outLen int) (interface{}, *syserr.Error) { +func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) { switch level { case linux.SOL_SOCKET: return getSockOptSocket(t, s, ep, family, skType, name, outLen) @@ -664,7 +663,7 @@ func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, } // getSockOptSocket implements GetSockOpt when level is SOL_SOCKET. -func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType transport.SockType, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) { // TODO(b/124056281): Stop rejecting short optLen values in getsockopt. switch name { case linux.SO_TYPE: diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go index ec930d8d5..e48a106ea 100644 --- a/pkg/sentry/socket/epsocket/provider.go +++ b/pkg/sentry/socket/epsocket/provider.go @@ -23,7 +23,6 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth" "gvisor.googlesource.com/gvisor/pkg/sentry/socket" - "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/header" @@ -42,7 +41,7 @@ type provider struct { // getTransportProtocol figures out transport protocol. Currently only TCP, // UDP, and ICMP are supported. -func getTransportProtocol(ctx context.Context, stype transport.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) { +func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) { switch stype { case linux.SOCK_STREAM: if protocol != 0 && protocol != syscall.IPPROTO_TCP { @@ -80,7 +79,7 @@ func getTransportProtocol(ctx context.Context, stype transport.SockType, protoco } // Socket creates a new socket object for the AF_INET or AF_INET6 family. -func (p *provider) Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) { +func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { // Fail right away if we don't have a stack. stack := t.NetworkContext() if stack == nil { @@ -116,7 +115,7 @@ func (p *provider) Socket(t *kernel.Task, stype transport.SockType, protocol int } // Pair just returns nil sockets (not supported). -func (*provider) Pair(*kernel.Task, transport.SockType, int) (*fs.File, *fs.File, *syserr.Error) { +func (*provider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) { return nil, nil, nil } diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index a469af7ac..975f47bc3 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -30,7 +30,6 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/safemem", "//pkg/sentry/socket", - "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/syserror", diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 0d75580a3..4517951a0 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -30,7 +30,6 @@ import ( ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time" "gvisor.googlesource.com/gvisor/pkg/sentry/safemem" "gvisor.googlesource.com/gvisor/pkg/sentry/socket" - "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/syserror" @@ -548,7 +547,7 @@ type socketProvider struct { } // Socket implements socket.Provider.Socket. -func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, protocol int) (*fs.File, *syserr.Error) { +func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*fs.File, *syserr.Error) { // Check that we are using the host network stack. stack := t.NetworkContext() if stack == nil { @@ -590,7 +589,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, p } // Pair implements socket.Provider.Pair. -func (p *socketProvider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { +func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { // Not supported by AF_INET/AF_INET6. return nil, nil, nil } diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go index 76cf12fd4..863edc241 100644 --- a/pkg/sentry/socket/netlink/provider.go +++ b/pkg/sentry/socket/netlink/provider.go @@ -22,7 +22,6 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/fs" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" "gvisor.googlesource.com/gvisor/pkg/sentry/socket" - "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/syserr" ) @@ -66,10 +65,10 @@ type socketProvider struct { } // Socket implements socket.Provider.Socket. -func (*socketProvider) Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) { +func (*socketProvider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { // Netlink sockets must be specified as datagram or raw, but they // behave the same regardless of type. - if stype != transport.SockDgram && stype != transport.SockRaw { + if stype != linux.SOCK_DGRAM && stype != linux.SOCK_RAW { return nil, syserr.ErrSocketNotSupported } @@ -94,7 +93,7 @@ func (*socketProvider) Socket(t *kernel.Task, stype transport.SockType, protocol } // Pair implements socket.Provider.Pair by returning an error. -func (*socketProvider) Pair(*kernel.Task, transport.SockType, int) (*fs.File, *fs.File, *syserr.Error) { +func (*socketProvider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) { // Netlink sockets never supports creating socket pairs. return nil, nil, syserr.ErrNotSupported } diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD index 4da14a1e0..33ba20de7 100644 --- a/pkg/sentry/socket/rpcinet/BUILD +++ b/pkg/sentry/socket/rpcinet/BUILD @@ -31,7 +31,6 @@ go_library( "//pkg/sentry/socket/hostinet", "//pkg/sentry/socket/rpcinet/conn", "//pkg/sentry/socket/rpcinet/notifier", - "//pkg/sentry/socket/unix/transport", "//pkg/sentry/unimpl", "//pkg/sentry/usermem", "//pkg/syserr", diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index bf42bdf69..2d5b5b58f 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -32,7 +32,6 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/conn" "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/notifier" pb "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto" - "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/unimpl" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserr" @@ -70,7 +69,7 @@ type socketOperations struct { var _ = socket.Socket(&socketOperations{}) // New creates a new RPC socket. -func newSocketFile(ctx context.Context, stack *Stack, family int, skType int, protocol int) (*fs.File, *syserr.Error) { +func newSocketFile(ctx context.Context, stack *Stack, family int, skType linux.SockType, protocol int) (*fs.File, *syserr.Error) { id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Socket{&pb.SocketRequest{Family: int64(family), Type: int64(skType | syscall.SOCK_NONBLOCK), Protocol: int64(protocol)}}}, false /* ignoreResult */) <-c @@ -841,7 +840,7 @@ type socketProvider struct { } // Socket implements socket.Provider.Socket. -func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, protocol int) (*fs.File, *syserr.Error) { +func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*fs.File, *syserr.Error) { // Check that we are using the RPC network stack. stack := t.NetworkContext() if stack == nil { @@ -857,7 +856,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, p // // Try to restrict the flags we will accept to minimize backwards // incompatibility with netstack. - stype := int(stypeflags) & linux.SOCK_TYPE_MASK + stype := stypeflags & linux.SOCK_TYPE_MASK switch stype { case syscall.SOCK_STREAM: switch protocol { @@ -881,7 +880,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, p } // Pair implements socket.Provider.Pair. -func (p *socketProvider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { +func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { // Not supported by AF_INET/AF_INET6. return nil, nil, nil } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index a99423365..f1021ec67 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -130,12 +130,12 @@ type Provider interface { // If a nil Socket _and_ a nil error is returned, it means that the // protocol is not supported. A non-nil error should only be returned // if the protocol is supported, but an error occurs during creation. - Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) + Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) // Pair creates a pair of connected sockets. // // See Socket for error information. - Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) + Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) } // families holds a map of all known address families and their providers. @@ -149,7 +149,7 @@ func RegisterProvider(family int, provider Provider) { } // New creates a new socket with the given family, type and protocol. -func New(t *kernel.Task, family int, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) { +func New(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { for _, p := range families[family] { s, err := p.Socket(t, stype, protocol) if err != nil { @@ -166,7 +166,7 @@ func New(t *kernel.Task, family int, stype transport.SockType, protocol int) (*f // Pair creates a new connected socket pair with the given family, type and // protocol. -func Pair(t *kernel.Task, family int, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { +func Pair(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { providers, ok := families[family] if !ok { return nil, nil, syserr.ErrAddressFamilyNotSupported diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 9c8ec0365..db79ac904 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -45,7 +45,7 @@ type ConnectingEndpoint interface { // Type returns the socket type, typically either SockStream or // SockSeqpacket. The connection attempt must be aborted if this // value doesn't match the ConnectableEndpoint's type. - Type() SockType + Type() linux.SockType // GetLocalAddress returns the bound path. GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) @@ -101,7 +101,7 @@ type connectionedEndpoint struct { // stype is used by connecting sockets to ensure that they are the // same type. The value is typically either tcpip.SockSeqpacket or // tcpip.SockStream. - stype SockType + stype linux.SockType // acceptedChan is per the TCP endpoint implementation. Note that the // sockets in this channel are _already in the connected state_, and @@ -112,7 +112,7 @@ type connectionedEndpoint struct { } // NewConnectioned creates a new unbound connectionedEndpoint. -func NewConnectioned(stype SockType, uid UniqueIDProvider) Endpoint { +func NewConnectioned(stype linux.SockType, uid UniqueIDProvider) Endpoint { return &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, id: uid.UniqueID(), @@ -122,7 +122,7 @@ func NewConnectioned(stype SockType, uid UniqueIDProvider) Endpoint { } // NewPair allocates a new pair of connected unix-domain connectionedEndpoints. -func NewPair(stype SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { +func NewPair(stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { a := &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, id: uid.UniqueID(), @@ -139,7 +139,7 @@ func NewPair(stype SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit} q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: initialLimit} - if stype == SockStream { + if stype == linux.SOCK_STREAM { a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}} b.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q2}} } else { @@ -163,7 +163,7 @@ func NewPair(stype SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { // NewExternal creates a new externally backed Endpoint. It behaves like a // socketpair. -func NewExternal(stype SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint { +func NewExternal(stype linux.SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint { return &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected}, id: uid.UniqueID(), @@ -178,7 +178,7 @@ func (e *connectionedEndpoint) ID() uint64 { } // Type implements ConnectingEndpoint.Type and Endpoint.Type. -func (e *connectionedEndpoint) Type() SockType { +func (e *connectionedEndpoint) Type() linux.SockType { return e.stype } @@ -294,7 +294,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, retur } writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: initialLimit} - if e.stype == SockStream { + if e.stype == linux.SOCK_STREAM { ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}} } else { ne.receiver = &queueReceiver{readQueue: writeQueue} @@ -309,7 +309,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, retur writeQueue: writeQueue, } readQueue.IncRef() - if e.stype == SockStream { + if e.stype == linux.SOCK_STREAM { returnConnect(&streamQueueReceiver{queueReceiver: queueReceiver{readQueue: readQueue}}, connected) } else { returnConnect(&queueReceiver{readQueue: readQueue}, connected) @@ -429,7 +429,7 @@ func (e *connectionedEndpoint) Bind(addr tcpip.FullAddress, commit func() *syser func (e *connectionedEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) { // Stream sockets do not support specifying the endpoint. Seqpacket // sockets ignore the passed endpoint. - if e.stype == SockStream && to != nil { + if e.stype == linux.SOCK_STREAM && to != nil { return 0, syserr.ErrNotSupported } return e.baseEndpoint.SendMsg(data, c, to) diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index c034cf984..81ebfba10 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -119,8 +119,8 @@ func (e *connectionlessEndpoint) SendMsg(data [][]byte, c ControlMessages, to Bo } // Type implements Endpoint.Type. -func (e *connectionlessEndpoint) Type() SockType { - return SockDgram +func (e *connectionlessEndpoint) Type() linux.SockType { + return linux.SOCK_DGRAM } // Connect attempts to connect directly to server. diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 5fc09af55..5c55c529e 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -19,6 +19,7 @@ import ( "sync" "sync/atomic" + "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" @@ -28,21 +29,6 @@ import ( // initialLimit is the starting limit for the socket buffers. const initialLimit = 16 * 1024 -// A SockType is a type (as opposed to family) of sockets. These are enumerated -// in the syscall package as syscall.SOCK_* constants. -type SockType int - -const ( - // SockStream corresponds to syscall.SOCK_STREAM. - SockStream SockType = 1 - // SockDgram corresponds to syscall.SOCK_DGRAM. - SockDgram SockType = 2 - // SockRaw corresponds to syscall.SOCK_RAW. - SockRaw SockType = 3 - // SockSeqpacket corresponds to syscall.SOCK_SEQPACKET. - SockSeqpacket SockType = 5 -) - // A RightsControlMessage is a control message containing FDs. type RightsControlMessage interface { // Clone returns a copy of the RightsControlMessage. @@ -175,7 +161,7 @@ type Endpoint interface { // Type return the socket type, typically either SockStream, SockDgram // or SockSeqpacket. - Type() SockType + Type() linux.SockType // GetLocalAddress returns the address to which the endpoint is bound. GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) @@ -629,7 +615,7 @@ type connectedEndpoint struct { GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) // Type implements Endpoint.Type. - Type() SockType + Type() linux.SockType } writeQueue *queue @@ -653,7 +639,7 @@ func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages, } truncate := false - if e.endpoint.Type() == SockStream { + if e.endpoint.Type() == linux.SOCK_STREAM { // Since stream sockets don't preserve message boundaries, we // can write only as much of the message as fits in the queue. truncate = true diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 375542350..56ed63e21 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -605,7 +605,7 @@ func (s *SocketOperations) State() uint32 { type provider struct{} // Socket returns a new unix domain socket. -func (*provider) Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) { +func (*provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { // Check arguments. if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ { return nil, syserr.ErrProtocolNotSupported @@ -631,7 +631,7 @@ func (*provider) Socket(t *kernel.Task, stype transport.SockType, protocol int) } // Pair creates a new pair of AF_UNIX connected sockets. -func (*provider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { +func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { // Check arguments. if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ { return nil, nil, syserr.ErrProtocolNotSupported diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index dbe53b9a2..0b5ef84c4 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -76,13 +76,13 @@ var SocketFamily = abi.ValueSet{ // SocketType are the possible socket(2) types. var SocketType = abi.ValueSet{ - linux.SOCK_STREAM: "SOCK_STREAM", - linux.SOCK_DGRAM: "SOCK_DGRAM", - linux.SOCK_RAW: "SOCK_RAW", - linux.SOCK_RDM: "SOCK_RDM", - linux.SOCK_SEQPACKET: "SOCK_SEQPACKET", - linux.SOCK_DCCP: "SOCK_DCCP", - linux.SOCK_PACKET: "SOCK_PACKET", + uint64(linux.SOCK_STREAM): "SOCK_STREAM", + uint64(linux.SOCK_DGRAM): "SOCK_DGRAM", + uint64(linux.SOCK_RAW): "SOCK_RAW", + uint64(linux.SOCK_RDM): "SOCK_RDM", + uint64(linux.SOCK_SEQPACKET): "SOCK_SEQPACKET", + uint64(linux.SOCK_DCCP): "SOCK_DCCP", + uint64(linux.SOCK_PACKET): "SOCK_PACKET", } // SocketFlagSet are the possible socket(2) flags. diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 8f4dbf3bc..31295a6a9 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -188,7 +188,7 @@ func Socket(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } // Create the new socket. - s, e := socket.New(t, domain, transport.SockType(stype&0xf), protocol) + s, e := socket.New(t, domain, linux.SockType(stype&0xf), protocol) if e != nil { return 0, nil, e.ToError() } @@ -227,7 +227,7 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } // Create the socket pair. - s1, s2, e := socket.Pair(t, domain, transport.SockType(stype&0xf), protocol) + s1, s2, e := socket.Pair(t, domain, linux.SockType(stype&0xf), protocol) if e != nil { return 0, nil, e.ToError() } -- cgit v1.2.3 From a00157cc0e216a9829f2659ce35c856a22aa5ba2 Mon Sep 17 00:00:00 2001 From: Rahat Mahmood Date: Mon, 10 Jun 2019 15:16:42 -0700 Subject: Store more information in the kernel socket table. Store enough information in the kernel socket table to distinguish between different types of sockets. Previously we were only storing the socket family, but this isn't enough to classify sockets. For example, TCPv4 and UDPv4 sockets are both AF_INET, and ICMP sockets are SOCK_DGRAM sockets with a particular protocol. Instead of creating more sub-tables, flatten the socket table and provide a filtering mechanism based on the socket entry. Also generate and store a socket entry index ("sl" in linux) which allows us to output entries in a stable order from procfs. PiperOrigin-RevId: 252495895 --- pkg/sentry/fs/host/socket.go | 4 +-- pkg/sentry/fs/proc/BUILD | 1 + pkg/sentry/fs/proc/net.go | 14 ++++---- pkg/sentry/kernel/BUILD | 13 +++++++ pkg/sentry/kernel/kernel.go | 55 +++++++++++++--------------- pkg/sentry/socket/epsocket/epsocket.go | 13 +++++-- pkg/sentry/socket/epsocket/provider.go | 2 +- pkg/sentry/socket/hostinet/socket.go | 32 +++++++++++------ pkg/sentry/socket/netlink/provider.go | 2 +- pkg/sentry/socket/netlink/socket.go | 12 ++++++- pkg/sentry/socket/rpcinet/socket.go | 16 +++++++-- pkg/sentry/socket/socket.go | 9 +++-- pkg/sentry/socket/unix/unix.go | 65 ++++++++++++++++++++-------------- 13 files changed, 152 insertions(+), 86 deletions(-) (limited to 'pkg/sentry/fs') diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 6423ad938..305eea718 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -164,7 +164,7 @@ func NewSocketWithDirent(ctx context.Context, d *fs.Dirent, f *fd.FD, flags fs.F ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) - return unixsocket.NewWithDirent(ctx, d, ep, e.stype != linux.SOCK_STREAM, flags), nil + return unixsocket.NewWithDirent(ctx, d, ep, e.stype, flags), nil } // newSocket allocates a new unix socket with host endpoint. @@ -196,7 +196,7 @@ func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error) ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) - return unixsocket.New(ctx, ep, e.stype != linux.SOCK_STREAM), nil + return unixsocket.New(ctx, ep, e.stype), nil } // Send implements transport.ConnectedEndpoint.Send. diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index d19c360e0..1728fe0b5 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -45,6 +45,7 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/limits", "//pkg/sentry/mm", + "//pkg/sentry/socket", "//pkg/sentry/socket/rpcinet", "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 3daaa962c..034950158 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -27,6 +27,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/fs/ramfs" "gvisor.googlesource.com/gvisor/pkg/sentry/inet" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket" "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix" "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" ) @@ -213,17 +214,18 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s fmt.Fprintf(&buf, "Num RefCount Protocol Flags Type St Inode Path\n") // Entries - for _, sref := range n.k.ListSockets(linux.AF_UNIX) { - s := sref.Get() + for _, se := range n.k.ListSockets() { + s := se.Sock.Get() if s == nil { - log.Debugf("Couldn't resolve weakref %v in socket table, racing with destruction?", sref) + log.Debugf("Couldn't resolve weakref %v in socket table, racing with destruction?", se.Sock) continue } sfile := s.(*fs.File) - sops, ok := sfile.FileOperations.(*unix.SocketOperations) - if !ok { - panic(fmt.Sprintf("Found non-unix socket file in unix socket table: %+v", sfile)) + if family, _, _ := sfile.FileOperations.(socket.Socket).Type(); family != linux.AF_UNIX { + // Not a unix socket. + continue } + sops := sfile.FileOperations.(*unix.SocketOperations) addr, err := sops.Endpoint().GetLocalAddress() if err != nil { diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index 99a2fd964..04e375910 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -64,6 +64,18 @@ go_template_instance( }, ) +go_template_instance( + name = "socket_list", + out = "socket_list.go", + package = "kernel", + prefix = "socket", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*SocketEntry", + "Linker": "*SocketEntry", + }, +) + proto_library( name = "uncaught_signal_proto", srcs = ["uncaught_signal.proto"], @@ -104,6 +116,7 @@ go_library( "sessions.go", "signal.go", "signal_handlers.go", + "socket_list.go", "syscalls.go", "syscalls_state.go", "syslog.go", diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 85d73ace2..f253a81d9 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -182,9 +182,13 @@ type Kernel struct { // danglingEndpoints is used to save / restore tcpip.DanglingEndpoints. danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"` - // socketTable is used to track all sockets on the system. Protected by + // sockets is the list of all network sockets the system. Protected by // extMu. - socketTable map[int]map[*refs.WeakRef]struct{} + sockets socketList + + // nextSocketEntry is the next entry number to use in sockets. Protected + // by extMu. + nextSocketEntry uint64 // deviceRegistry is used to save/restore device.SimpleDevices. deviceRegistry struct{} `state:".(*device.Registry)"` @@ -283,7 +287,6 @@ func (k *Kernel) Init(args InitKernelArgs) error { k.monotonicClock = &timekeeperClock{tk: args.Timekeeper, c: sentrytime.Monotonic} k.futexes = futex.NewManager() k.netlinkPorts = port.New() - k.socketTable = make(map[int]map[*refs.WeakRef]struct{}) return nil } @@ -1137,51 +1140,43 @@ func (k *Kernel) EmitUnimplementedEvent(ctx context.Context) { }) } -// socketEntry represents a socket recorded in Kernel.socketTable. It implements +// SocketEntry represents a socket recorded in Kernel.sockets. It implements // refs.WeakRefUser for sockets stored in the socket table. // // +stateify savable -type socketEntry struct { - k *Kernel - sock *refs.WeakRef - family int +type SocketEntry struct { + socketEntry + k *Kernel + Sock *refs.WeakRef + ID uint64 // Socket table entry number. } // WeakRefGone implements refs.WeakRefUser.WeakRefGone. -func (s *socketEntry) WeakRefGone() { +func (s *SocketEntry) WeakRefGone() { s.k.extMu.Lock() - // k.socketTable is guaranteed to point to a valid socket table for s.family - // at this point, since we made sure of the fact when we created this - // socketEntry, and we never delete socket tables. - delete(s.k.socketTable[s.family], s.sock) + s.k.sockets.Remove(s) s.k.extMu.Unlock() } // RecordSocket adds a socket to the system-wide socket table for tracking. // // Precondition: Caller must hold a reference to sock. -func (k *Kernel) RecordSocket(sock *fs.File, family int) { +func (k *Kernel) RecordSocket(sock *fs.File) { k.extMu.Lock() - table, ok := k.socketTable[family] - if !ok { - table = make(map[*refs.WeakRef]struct{}) - k.socketTable[family] = table - } - se := socketEntry{k: k, family: family} - se.sock = refs.NewWeakRef(sock, &se) - table[se.sock] = struct{}{} + id := k.nextSocketEntry + k.nextSocketEntry++ + s := &SocketEntry{k: k, ID: id} + s.Sock = refs.NewWeakRef(sock, s) + k.sockets.PushBack(s) k.extMu.Unlock() } -// ListSockets returns a snapshot of all sockets of a given family. -func (k *Kernel) ListSockets(family int) []*refs.WeakRef { +// ListSockets returns a snapshot of all sockets. +func (k *Kernel) ListSockets() []*SocketEntry { k.extMu.Lock() - socks := []*refs.WeakRef{} - if table, ok := k.socketTable[family]; ok { - socks = make([]*refs.WeakRef, 0, len(table)) - for s := range table { - socks = append(socks, s) - } + var socks []*SocketEntry + for s := k.sockets.Front(); s != nil; s = s.Next() { + socks = append(socks, s) } k.extMu.Unlock() return socks diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index e1e29de35..f67451179 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -228,6 +228,7 @@ type SocketOperations struct { family int Endpoint tcpip.Endpoint skType linux.SockType + protocol int // readMu protects access to the below fields. readMu sync.Mutex `state:"nosave"` @@ -252,7 +253,7 @@ type SocketOperations struct { } // New creates a new endpoint socket. -func New(t *kernel.Task, family int, skType linux.SockType, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { +func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { if skType == linux.SOCK_STREAM { if err := endpoint.SetSockOpt(tcpip.DelayOption(1)); err != nil { return nil, syserr.TranslateNetstackError(err) @@ -266,6 +267,7 @@ func New(t *kernel.Task, family int, skType linux.SockType, queue *waiter.Queue, family: family, Endpoint: endpoint, skType: skType, + protocol: protocol, }), nil } @@ -550,7 +552,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } } - ns, err := New(t, s.family, s.skType, wq, ep) + ns, err := New(t, s.family, s.skType, s.protocol, wq, ep) if err != nil { return 0, nil, 0, err } @@ -578,7 +580,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } fd, e := t.FDMap().NewFDFrom(0, ns, fdFlags, t.ThreadGroup().Limits()) - t.Kernel().RecordSocket(ns, s.family) + t.Kernel().RecordSocket(ns) return fd, addr, addrLen, syserr.FromError(e) } @@ -2324,3 +2326,8 @@ func (s *SocketOperations) State() uint32 { // TODO(b/112063468): Export states for UDP, ICMP, and raw sockets. return 0 } + +// Type implements socket.Socket.Type. +func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) { + return s.family, s.skType, s.protocol +} diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go index e48a106ea..516582828 100644 --- a/pkg/sentry/socket/epsocket/provider.go +++ b/pkg/sentry/socket/epsocket/provider.go @@ -111,7 +111,7 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (* return nil, syserr.TranslateNetstackError(e) } - return New(t, p.family, stype, wq, ep) + return New(t, p.family, stype, protocol, wq, ep) } // Pair just returns nil sockets (not supported). diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 4517951a0..c62c8d8f1 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -56,15 +56,22 @@ type socketOperations struct { fsutil.FileUseInodeUnstableAttr `state:"nosave"` socket.SendReceiveTimeout - family int // Read-only. - fd int // must be O_NONBLOCK - queue waiter.Queue + family int // Read-only. + stype linux.SockType // Read-only. + protocol int // Read-only. + fd int // must be O_NONBLOCK + queue waiter.Queue } var _ = socket.Socket(&socketOperations{}) -func newSocketFile(ctx context.Context, family int, fd int, nonblock bool) (*fs.File, *syserr.Error) { - s := &socketOperations{family: family, fd: fd} +func newSocketFile(ctx context.Context, family int, stype linux.SockType, protocol int, fd int, nonblock bool) (*fs.File, *syserr.Error) { + s := &socketOperations{ + family: family, + stype: stype, + protocol: protocol, + fd: fd, + } if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil { return nil, syserr.FromError(err) } @@ -222,7 +229,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr) } - f, err := newSocketFile(t, s.family, fd, flags&syscall.SOCK_NONBLOCK != 0) + f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&syscall.SOCK_NONBLOCK != 0) if err != nil { syscall.Close(fd) return 0, nil, 0, err @@ -233,7 +240,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0, } kfd, kerr := t.FDMap().NewFDFrom(0, f, fdFlags, t.ThreadGroup().Limits()) - t.Kernel().RecordSocket(f, s.family) + t.Kernel().RecordSocket(f) return kfd, peerAddr, peerAddrlen, syserr.FromError(kerr) } @@ -542,6 +549,11 @@ func (s *socketOperations) State() uint32 { return uint32(info.State) } +// Type implements socket.Socket.Type. +func (s *socketOperations) Type() (family int, skType linux.SockType, protocol int) { + return s.family, s.stype, s.protocol +} + type socketProvider struct { family int } @@ -558,7 +570,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, proto } // Only accept TCP and UDP. - stype := int(stypeflags) & linux.SOCK_TYPE_MASK + stype := stypeflags & linux.SOCK_TYPE_MASK switch stype { case syscall.SOCK_STREAM: switch protocol { @@ -581,11 +593,11 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, proto // Conservatively ignore all flags specified by the application and add // SOCK_NONBLOCK since socketOperations requires it. Pass a protocol of 0 // to simplify the syscall filters, since 0 and IPPROTO_* are equivalent. - fd, err := syscall.Socket(p.family, stype|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) + fd, err := syscall.Socket(p.family, int(stype)|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) if err != nil { return nil, syserr.FromError(err) } - return newSocketFile(t, p.family, fd, stypeflags&syscall.SOCK_NONBLOCK != 0) + return newSocketFile(t, p.family, stype, protocol, fd, stypeflags&syscall.SOCK_NONBLOCK != 0) } // Pair implements socket.Provider.Pair. diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go index 863edc241..5dc103877 100644 --- a/pkg/sentry/socket/netlink/provider.go +++ b/pkg/sentry/socket/netlink/provider.go @@ -82,7 +82,7 @@ func (*socketProvider) Socket(t *kernel.Task, stype linux.SockType, protocol int return nil, err } - s, err := NewSocket(t, p) + s, err := NewSocket(t, stype, p) if err != nil { return nil, err } diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 16c79aa33..62659784a 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -80,6 +80,10 @@ type Socket struct { // protocol is the netlink protocol implementation. protocol Protocol + // skType is the socket type. This is either SOCK_DGRAM or SOCK_RAW for + // netlink sockets. + skType linux.SockType + // ep is a datagram unix endpoint used to buffer messages sent from the // kernel to userspace. RecvMsg reads messages from this endpoint. ep transport.Endpoint @@ -105,7 +109,7 @@ type Socket struct { var _ socket.Socket = (*Socket)(nil) // NewSocket creates a new Socket. -func NewSocket(t *kernel.Task, protocol Protocol) (*Socket, *syserr.Error) { +func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *syserr.Error) { // Datagram endpoint used to buffer kernel -> user messages. ep := transport.NewConnectionless() @@ -126,6 +130,7 @@ func NewSocket(t *kernel.Task, protocol Protocol) (*Socket, *syserr.Error) { return &Socket{ ports: t.Kernel().NetlinkPorts(), protocol: protocol, + skType: skType, ep: ep, connection: connection, sendBufferSize: defaultSendBufferSize, @@ -621,3 +626,8 @@ func (s *Socket) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, func (s *Socket) State() uint32 { return s.ep.State() } + +// Type implements socket.Socket.Type. +func (s *Socket) Type() (family int, skType linux.SockType, protocol int) { + return linux.AF_NETLINK, s.skType, s.protocol.Protocol() +} diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index 2d5b5b58f..c22ff1ff0 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -53,7 +53,10 @@ type socketOperations struct { fsutil.FileUseInodeUnstableAttr `state:"nosave"` socket.SendReceiveTimeout - family int // Read-only. + family int // Read-only. + stype linux.SockType // Read-only. + protocol int // Read-only. + fd uint32 // must be O_NONBLOCK wq *waiter.Queue rpcConn *conn.RPCConnection @@ -86,6 +89,8 @@ func newSocketFile(ctx context.Context, stack *Stack, family int, skType linux.S defer dirent.DecRef() return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &socketOperations{ family: family, + stype: skType, + protocol: protocol, wq: &wq, fd: fd, rpcConn: stack.rpcConn, @@ -332,7 +337,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, if err != nil { return 0, nil, 0, syserr.FromError(err) } - t.Kernel().RecordSocket(file, s.family) + t.Kernel().RecordSocket(file) if peerRequested { return fd, payload.Address.Address, payload.Address.Length, nil @@ -835,6 +840,11 @@ func (s *socketOperations) State() uint32 { return 0 } +// Type implements socket.Socket.Type. +func (s *socketOperations) Type() (family int, skType linux.SockType, protocol int) { + return s.family, s.stype, s.protocol +} + type socketProvider struct { family int } @@ -876,7 +886,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, proto return nil, nil } - return newSocketFile(t, s, p.family, stype, 0) + return newSocketFile(t, s, p.family, stype, protocol) } // Pair implements socket.Provider.Pair. diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index f1021ec67..d60944b6b 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -120,6 +120,9 @@ type Socket interface { // State returns the current state of the socket, as represented by Linux in // procfs. The returned state value is protocol-specific. State() uint32 + + // Type returns the family, socket type and protocol of the socket. + Type() (family int, skType linux.SockType, protocol int) } // Provider is the interface implemented by providers of sockets for specific @@ -156,7 +159,7 @@ func New(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.Fi return nil, err } if s != nil { - t.Kernel().RecordSocket(s, family) + t.Kernel().RecordSocket(s) return s, nil } } @@ -179,8 +182,8 @@ func Pair(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.F } if s1 != nil && s2 != nil { k := t.Kernel() - k.RecordSocket(s1, family) - k.RecordSocket(s2, family) + k.RecordSocket(s1) + k.RecordSocket(s2) return s1, s2, nil } } diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 56ed63e21..b07e8d67b 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -17,6 +17,7 @@ package unix import ( + "fmt" "strings" "syscall" @@ -55,22 +56,22 @@ type SocketOperations struct { refs.AtomicRefCount socket.SendReceiveTimeout - ep transport.Endpoint - isPacket bool + ep transport.Endpoint + stype linux.SockType } // New creates a new unix socket. -func New(ctx context.Context, endpoint transport.Endpoint, isPacket bool) *fs.File { +func New(ctx context.Context, endpoint transport.Endpoint, stype linux.SockType) *fs.File { dirent := socket.NewDirent(ctx, unixSocketDevice) defer dirent.DecRef() - return NewWithDirent(ctx, dirent, endpoint, isPacket, fs.FileFlags{Read: true, Write: true}) + return NewWithDirent(ctx, dirent, endpoint, stype, fs.FileFlags{Read: true, Write: true}) } // NewWithDirent creates a new unix socket using an existing dirent. -func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, isPacket bool, flags fs.FileFlags) *fs.File { +func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, stype linux.SockType, flags fs.FileFlags) *fs.File { return fs.NewFile(ctx, d, flags, &SocketOperations{ - ep: ep, - isPacket: isPacket, + ep: ep, + stype: stype, }) } @@ -88,6 +89,18 @@ func (s *SocketOperations) Release() { s.DecRef() } +func (s *SocketOperations) isPacket() bool { + switch s.stype { + case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET: + return true + case linux.SOCK_STREAM: + return false + default: + // We shouldn't have allowed any other socket types during creation. + panic(fmt.Sprintf("Invalid socket type %d", s.stype)) + } +} + // Endpoint extracts the transport.Endpoint. func (s *SocketOperations) Endpoint() transport.Endpoint { return s.ep @@ -193,7 +206,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } } - ns := New(t, ep, s.isPacket) + ns := New(t, ep, s.stype) defer ns.DecRef() if flags&linux.SOCK_NONBLOCK != 0 { @@ -221,7 +234,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, return 0, nil, 0, syserr.FromError(e) } - t.Kernel().RecordSocket(ns, linux.AF_UNIX) + t.Kernel().RecordSocket(ns) return fd, addr, addrLen, nil } @@ -487,6 +500,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 waitAll := flags&linux.MSG_WAITALL != 0 + isPacket := s.isPacket() // Calculate the number of FDs for which we have space and if we are // requesting credentials. @@ -528,8 +542,8 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags msgFlags |= linux.MSG_CTRUNC } - if err != nil || dontWait || !waitAll || s.isPacket || n >= dst.NumBytes() { - if s.isPacket && n < int64(r.MsgSize) { + if err != nil || dontWait || !waitAll || isPacket || n >= dst.NumBytes() { + if isPacket && n < int64(r.MsgSize) { msgFlags |= linux.MSG_TRUNC } @@ -570,11 +584,11 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags total += n } - if err != nil || !waitAll || s.isPacket || n >= dst.NumBytes() { + if err != nil || !waitAll || isPacket || n >= dst.NumBytes() { if total > 0 { err = nil } - if s.isPacket && n < int64(r.MsgSize) { + if isPacket && n < int64(r.MsgSize) { msgFlags |= linux.MSG_TRUNC } return int(total), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err) @@ -601,6 +615,12 @@ func (s *SocketOperations) State() uint32 { return s.ep.State() } +// Type implements socket.Socket.Type. +func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) { + // Unix domain sockets always have a protocol of 0. + return linux.AF_UNIX, s.stype, 0 +} + // provider is a unix domain socket provider. type provider struct{} @@ -613,21 +633,16 @@ func (*provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs // Create the endpoint and socket. var ep transport.Endpoint - var isPacket bool switch stype { case linux.SOCK_DGRAM: - isPacket = true ep = transport.NewConnectionless() - case linux.SOCK_SEQPACKET: - isPacket = true - fallthrough - case linux.SOCK_STREAM: + case linux.SOCK_SEQPACKET, linux.SOCK_STREAM: ep = transport.NewConnectioned(stype, t.Kernel()) default: return nil, syserr.ErrInvalidArgument } - return New(t, ep, isPacket), nil + return New(t, ep, stype), nil } // Pair creates a new pair of AF_UNIX connected sockets. @@ -637,19 +652,17 @@ func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.F return nil, nil, syserr.ErrProtocolNotSupported } - var isPacket bool switch stype { - case linux.SOCK_STREAM: - case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET: - isPacket = true + case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET: + // Ok default: return nil, nil, syserr.ErrInvalidArgument } // Create the endpoints and sockets. ep1, ep2 := transport.NewPair(stype, t.Kernel()) - s1 := New(t, ep1, isPacket) - s2 := New(t, ep2, isPacket) + s1 := New(t, ep1, stype) + s2 := New(t, ep2, stype) return s1, s2, nil } -- cgit v1.2.3 From fc746efa9ad57a5001a6328c52622adafa1d3ffe Mon Sep 17 00:00:00 2001 From: Fabricio Voznika Date: Tue, 11 Jun 2019 14:52:06 -0700 Subject: Add support to mount pod shared tmpfs mounts Parse annotations containing 'gvisor.dev/spec/mount' that gives hints about how mounts are shared between containers inside a pod. This information can be used to better inform how to mount these volumes inside gVisor. For example, a volume that is shared between containers inside a pod can be bind mounted inside the sandbox, instead of being two independent mounts. For now, this information is used to allow the same tmpfs mounts to be shared between containers which wasn't possible before. PiperOrigin-RevId: 252704037 --- pkg/sentry/fs/tmpfs/fs.go | 25 ++- runsc/boot/BUILD | 1 + runsc/boot/controller.go | 2 +- runsc/boot/fs.go | 281 ++++++++++++++++++++++++++++-- runsc/boot/fs_test.go | 193 +++++++++++++++++++++ runsc/boot/loader.go | 14 +- runsc/boot/loader_test.go | 4 +- runsc/container/multi_container_test.go | 299 ++++++++++++++++++++++++++++++++ runsc/specutils/fs.go | 40 +++-- 9 files changed, 828 insertions(+), 31 deletions(-) create mode 100644 runsc/boot/fs_test.go (limited to 'pkg/sentry/fs') diff --git a/pkg/sentry/fs/tmpfs/fs.go b/pkg/sentry/fs/tmpfs/fs.go index b7c29a4d1..83e1bf247 100644 --- a/pkg/sentry/fs/tmpfs/fs.go +++ b/pkg/sentry/fs/tmpfs/fs.go @@ -34,6 +34,16 @@ const ( // GID for the root directory. rootGIDKey = "gid" + // cacheKey sets the caching policy for the mount. + cacheKey = "cache" + + // cacheAll uses the virtual file system cache for everything (default). + cacheAll = "cache" + + // cacheRevalidate allows dirents to be cached, but revalidates them on each + // lookup. + cacheRevalidate = "revalidate" + // TODO(edahlgren/mpratt): support a tmpfs size limit. // size = "size" @@ -122,15 +132,24 @@ func (f *Filesystem) Mount(ctx context.Context, device string, flags fs.MountSou delete(options, rootGIDKey) } + // Construct a mount which will follow the cache options provided. + var msrc *fs.MountSource + switch options[cacheKey] { + case "", cacheAll: + msrc = fs.NewCachingMountSource(f, flags) + case cacheRevalidate: + msrc = fs.NewRevalidatingMountSource(f, flags) + default: + return nil, fmt.Errorf("invalid cache policy option %q", options[cacheKey]) + } + delete(options, cacheKey) + // Fail if the caller passed us more options than we can parse. They may be // expecting us to set something we can't set. if len(options) > 0 { return nil, fmt.Errorf("unsupported mount options: %v", options) } - // Construct a mount which will cache dirents. - msrc := fs.NewCachingMountSource(f, flags) - // Construct the tmpfs root. return NewDir(ctx, nil, owner, perms, msrc), nil } diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index ac28c4339..6ba196917 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -94,6 +94,7 @@ go_test( size = "small", srcs = [ "compat_test.go", + "fs_test.go", "loader_test.go", ], embed = [":boot"], diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go index a277145b1..416e5355d 100644 --- a/runsc/boot/controller.go +++ b/runsc/boot/controller.go @@ -340,7 +340,7 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { cm.l.k = k // Set up the restore environment. - mntr := newContainerMounter(cm.l.spec, "", cm.l.goferFDs, cm.l.k) + mntr := newContainerMounter(cm.l.spec, "", cm.l.goferFDs, cm.l.k, cm.l.mountHints) renv, err := mntr.createRestoreEnvironment(cm.l.conf) if err != nil { return fmt.Errorf("creating RestoreEnvironment: %v", err) diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go index 939f2419c..2fa0725d1 100644 --- a/runsc/boot/fs.go +++ b/runsc/boot/fs.go @@ -18,6 +18,7 @@ import ( "fmt" "path" "path/filepath" + "sort" "strconv" "strings" "syscall" @@ -50,6 +51,9 @@ const ( // Device name for root mount. rootDevice = "9pfs-/" + // MountPrefix is the annotation prefix for mount hints. + MountPrefix = "gvisor.dev/spec/mount" + // ChildContainersDir is the directory where child container root // filesystems are mounted. ChildContainersDir = "/__runsc_containers__" @@ -292,6 +296,174 @@ func (f *fdDispenser) empty() bool { return len(f.fds) == 0 } +type shareType int + +const ( + invalid shareType = iota + + // container shareType indicates that the mount is used by a single container. + container + + // pod shareType indicates that the mount is used by more than one container + // inside the pod. + pod + + // shared shareType indicates that the mount can also be shared with a process + // outside the pod, e.g. NFS. + shared +) + +func parseShare(val string) (shareType, error) { + switch val { + case "container": + return container, nil + case "pod": + return pod, nil + case "shared": + return shared, nil + default: + return 0, fmt.Errorf("invalid share value %q", val) + } +} + +func (s shareType) String() string { + switch s { + case invalid: + return "invalid" + case container: + return "container" + case pod: + return "pod" + case shared: + return "shared" + default: + return fmt.Sprintf("invalid share value %d", s) + } +} + +// mountHint represents extra information about mounts that are provided via +// annotations. They can override mount type, and provide sharing information +// so that mounts can be correctly shared inside the pod. +type mountHint struct { + name string + share shareType + mount specs.Mount + + // root is the inode where the volume is mounted. For mounts with 'pod' share + // the volume is mounted once and then bind mounted inside the containers. + root *fs.Inode +} + +func (m *mountHint) setField(key, val string) error { + switch key { + case "source": + if len(val) == 0 { + return fmt.Errorf("source cannot be empty") + } + m.mount.Source = val + case "type": + return m.setType(val) + case "share": + share, err := parseShare(val) + if err != nil { + return err + } + m.share = share + case "options": + return m.setOptions(val) + default: + return fmt.Errorf("invalid mount annotation: %s=%s", key, val) + } + return nil +} + +func (m *mountHint) setType(val string) error { + switch val { + case "tmpfs", "bind": + m.mount.Type = val + default: + return fmt.Errorf("invalid type %q", val) + } + return nil +} + +func (m *mountHint) setOptions(val string) error { + opts := strings.Split(val, ",") + if err := specutils.ValidateMountOptions(opts); err != nil { + return err + } + // Sort options so it can be compared with container mount options later on. + sort.Strings(opts) + m.mount.Options = opts + return nil +} + +func (m *mountHint) isSupported() bool { + return m.mount.Type == tmpfs && m.share == pod +} + +// podMountHints contains a collection of mountHints for the pod. +type podMountHints struct { + mounts map[string]*mountHint +} + +func newPodMountHints(spec *specs.Spec) (*podMountHints, error) { + mnts := make(map[string]*mountHint) + for k, v := range spec.Annotations { + // Look for 'gvisor.dev/spec/mount' annotations and parse them. + if strings.HasPrefix(k, MountPrefix) { + parts := strings.Split(k, "/") + if len(parts) != 5 { + return nil, fmt.Errorf("invalid mount annotation: %s=%s", k, v) + } + name := parts[3] + if len(name) == 0 || path.Clean(name) != name { + return nil, fmt.Errorf("invalid mount name: %s", name) + } + mnt := mnts[name] + if mnt == nil { + mnt = &mountHint{name: name} + mnts[name] = mnt + } + if err := mnt.setField(parts[4], v); err != nil { + return nil, err + } + } + } + + // Validate all hints after done parsing. + for name, m := range mnts { + log.Infof("Mount annotation found, name: %s, source: %q, type: %s, share: %v", name, m.mount.Source, m.mount.Type, m.share) + if m.share == invalid { + return nil, fmt.Errorf("share field for %q has not been set", m.name) + } + if len(m.mount.Source) == 0 { + return nil, fmt.Errorf("source field for %q has not been set", m.name) + } + if len(m.mount.Type) == 0 { + return nil, fmt.Errorf("type field for %q has not been set", m.name) + } + + // Check for duplicate mount sources. + for name2, m2 := range mnts { + if name != name2 && m.mount.Source == m2.mount.Source { + return nil, fmt.Errorf("mounts %q and %q have the same mount source %q", m.name, m2.name, m.mount.Source) + } + } + } + + return &podMountHints{mounts: mnts}, nil +} + +func (p *podMountHints) findMount(mount specs.Mount) *mountHint { + for _, m := range p.mounts { + if m.mount.Source == mount.Source { + return m + } + } + return nil +} + type containerMounter struct { // cid is the container ID. May be set to empty for the root container. cid string @@ -306,15 +478,18 @@ type containerMounter struct { fds fdDispenser k *kernel.Kernel + + hints *podMountHints } -func newContainerMounter(spec *specs.Spec, cid string, goferFDs []int, k *kernel.Kernel) *containerMounter { +func newContainerMounter(spec *specs.Spec, cid string, goferFDs []int, k *kernel.Kernel, hints *podMountHints) *containerMounter { return &containerMounter{ cid: cid, root: spec.Root, mounts: compileMounts(spec), fds: fdDispenser{fds: goferFDs}, k: k, + hints: hints, } } @@ -476,6 +651,15 @@ func destroyContainerFS(ctx context.Context, cid string, k *kernel.Kernel) error // 'setMountNS' is called after namespace is created. It must set the mount NS // to 'rootCtx'. func (c *containerMounter) setupRootContainer(userCtx context.Context, rootCtx context.Context, conf *Config, setMountNS func(*fs.MountNamespace)) error { + for _, hint := range c.hints.mounts { + log.Infof("Mounting master of shared mount %q from %q type %q", hint.name, hint.mount.Source, hint.mount.Type) + inode, err := c.mountSharedMaster(rootCtx, conf, hint) + if err != nil { + return fmt.Errorf("mounting shared master %q: %v", hint.name, err) + } + hint.root = inode + } + // Create a tmpfs mount where we create and mount a root filesystem for // each child container. c.mounts = append(c.mounts, specs.Mount{ @@ -498,21 +682,57 @@ func (c *containerMounter) setupRootContainer(userCtx context.Context, rootCtx c return c.mountSubmounts(rootCtx, conf, mns, root) } +// mountSharedMaster mounts the master of a volume that is shared among +// containers in a pod. It returns the root mount's inode. +func (c *containerMounter) mountSharedMaster(ctx context.Context, conf *Config, hint *mountHint) (*fs.Inode, error) { + // Map mount type to filesystem name, and parse out the options that we are + // capable of dealing with. + fsName, opts, useOverlay, err := c.getMountNameAndOptions(conf, hint.mount) + if err != nil { + return nil, err + } + if len(fsName) == 0 { + return nil, fmt.Errorf("mount type not supported %q", hint.mount.Type) + } + + // Mount with revalidate because it's shared among containers. + opts = append(opts, "cache=revalidate") + + // All filesystem names should have been mapped to something we know. + filesystem := mustFindFilesystem(fsName) + + mf := mountFlags(hint.mount.Options) + if useOverlay { + // All writes go to upper, be paranoid and make lower readonly. + mf.ReadOnly = true + } + + inode, err := filesystem.Mount(ctx, mountDevice(hint.mount), mf, strings.Join(opts, ","), nil) + if err != nil { + return nil, fmt.Errorf("creating mount %q: %v", hint.name, err) + } + + if useOverlay { + log.Debugf("Adding overlay on top of shared mount %q", hint.name) + inode, err = addOverlay(ctx, conf, inode, hint.mount.Type, mf) + if err != nil { + return nil, err + } + } + + return inode, nil +} + // createRootMount creates the root filesystem. func (c *containerMounter) createRootMount(ctx context.Context, conf *Config) (*fs.Inode, error) { // First construct the filesystem from the spec.Root. mf := fs.MountSourceFlags{ReadOnly: c.root.Readonly || conf.Overlay} - var ( - rootInode *fs.Inode - err error - ) - fd := c.fds.remove() log.Infof("Mounting root over 9P, ioFD: %d", fd) p9FS := mustFindFilesystem("9p") opts := p9MountOptions(fd, conf.FileAccess) - rootInode, err = p9FS.Mount(ctx, rootDevice, mf, strings.Join(opts, ","), nil) + rootInode, err := p9FS.Mount(ctx, rootDevice, mf, strings.Join(opts, ","), nil) if err != nil { return nil, fmt.Errorf("creating root mount point: %v", err) } @@ -579,8 +799,14 @@ func (c *containerMounter) getMountNameAndOptions(conf *Config, m specs.Mount) ( func (c *containerMounter) mountSubmounts(ctx context.Context, conf *Config, mns *fs.MountNamespace, root *fs.Dirent) error { for _, m := range c.mounts { - if err := c.mountSubmount(ctx, conf, mns, root, m); err != nil { - return fmt.Errorf("mount submount %q: %v", m.Destination, err) + if hint := c.hints.findMount(m); hint != nil && hint.isSupported() { + if err := c.mountSharedSubmount(ctx, mns, root, m, hint); err != nil { + return fmt.Errorf("mount shared mount %q to %q: %v", hint.name, m.Destination, err) + } + } else { + if err := c.mountSubmount(ctx, conf, mns, root, m); err != nil { + return fmt.Errorf("mount submount %q: %v", m.Destination, err) + } } } @@ -653,6 +879,37 @@ func (c *containerMounter) mountSubmount(ctx context.Context, conf *Config, mns return nil } +// mountSharedSubmount binds mount to a previously mounted volume that is shared +// among containers in the same pod. +func (c *containerMounter) mountSharedSubmount(ctx context.Context, mns *fs.MountNamespace, root *fs.Dirent, mount specs.Mount, source *mountHint) error { + // For now enforce that all options are the same. Once bind mount is properly + // supported, then we should ensure the master is less restrictive than the + // container, e.g. master can be 'rw' while container mounts as 'ro'. + if len(mount.Options) != len(source.mount.Options) { + return fmt.Errorf("mount options in annotations differ from container mount, annotation: %s, mount: %s", source.mount.Options, mount.Options) + } + sort.Strings(mount.Options) + for i, opt := range mount.Options { + if opt != source.mount.Options[i] { + return fmt.Errorf("mount options in annotations differ from container mount, annotation: %s, mount: %s", source.mount.Options, mount.Options) + } + } + + maxTraversals := uint(0) + target, err := mns.FindInode(ctx, root, root, mount.Destination, &maxTraversals) + if err != nil { + return fmt.Errorf("can't find mount destination %q: %v", mount.Destination, err) + } + defer target.DecRef() + + if err := mns.Mount(ctx, target, source.root); err != nil { + return fmt.Errorf("bind mount %q error: %v", mount.Destination, err) + } + + log.Infof("Mounted %q type shared bind to %q", mount.Destination, source.name) + return nil +} + // addRestoreMount adds a mount to the MountSources map used for restoring a // checkpointed container. func (c *containerMounter) addRestoreMount(conf *Config, renv *fs.RestoreEnvironment, m specs.Mount) error { @@ -678,8 +935,8 @@ func (c *containerMounter) addRestoreMount(conf *Config, renv *fs.RestoreEnviron return nil } -// createRestoreEnvironment builds a fs.RestoreEnvironment called renv by adding the mounts -// to the environment. +// createRestoreEnvironment builds a fs.RestoreEnvironment called renv by adding +// the mounts to the environment. func (c *containerMounter) createRestoreEnvironment(conf *Config) (*fs.RestoreEnvironment, error) { renv := &fs.RestoreEnvironment{ MountSources: make(map[string][]fs.MountArgs), @@ -730,7 +987,7 @@ func (c *containerMounter) createRestoreEnvironment(conf *Config) (*fs.RestoreEn // Technically we don't have to mount tmpfs at /tmp, as we could just rely on // the host /tmp, but this is a nice optimization, and fixes some apps that call // mknod in /tmp. It's unsafe to mount tmpfs if: -// 1. /tmp is mounted explictly: we should not override user's wish +// 1. /tmp is mounted explicitly: we should not override user's wish // 2. /tmp is not empty: mounting tmpfs would hide existing files in /tmp // // Note that when there are submounts inside of '/tmp', directories for the diff --git a/runsc/boot/fs_test.go b/runsc/boot/fs_test.go new file mode 100644 index 000000000..49ab34b33 --- /dev/null +++ b/runsc/boot/fs_test.go @@ -0,0 +1,193 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package boot + +import ( + "path" + "reflect" + "strings" + "testing" + + specs "github.com/opencontainers/runtime-spec/specs-go" +) + +func TestPodMountHintsHappy(t *testing.T) { + spec := &specs.Spec{ + Annotations: map[string]string{ + path.Join(MountPrefix, "mount1", "source"): "foo", + path.Join(MountPrefix, "mount1", "type"): "tmpfs", + path.Join(MountPrefix, "mount1", "share"): "pod", + + path.Join(MountPrefix, "mount2", "source"): "bar", + path.Join(MountPrefix, "mount2", "type"): "bind", + path.Join(MountPrefix, "mount2", "share"): "container", + path.Join(MountPrefix, "mount2", "options"): "rw,private", + }, + } + podHints, err := newPodMountHints(spec) + if err != nil { + t.Errorf("newPodMountHints failed: %v", err) + } + + // Check that fields were set correctly. + mount1 := podHints.mounts["mount1"] + if want := "mount1"; want != mount1.name { + t.Errorf("mount1 name, want: %q, got: %q", want, mount1.name) + } + if want := "foo"; want != mount1.mount.Source { + t.Errorf("mount1 source, want: %q, got: %q", want, mount1.mount.Source) + } + if want := "tmpfs"; want != mount1.mount.Type { + t.Errorf("mount1 type, want: %q, got: %q", want, mount1.mount.Type) + } + if want := pod; want != mount1.share { + t.Errorf("mount1 type, want: %q, got: %q", want, mount1.share) + } + if want := []string(nil); !reflect.DeepEqual(want, mount1.mount.Options) { + t.Errorf("mount1 type, want: %q, got: %q", want, mount1.mount.Options) + } + + mount2 := podHints.mounts["mount2"] + if want := "mount2"; want != mount2.name { + t.Errorf("mount2 name, want: %q, got: %q", want, mount2.name) + } + if want := "bar"; want != mount2.mount.Source { + t.Errorf("mount2 source, want: %q, got: %q", want, mount2.mount.Source) + } + if want := "bind"; want != mount2.mount.Type { + t.Errorf("mount2 type, want: %q, got: %q", want, mount2.mount.Type) + } + if want := container; want != mount2.share { + t.Errorf("mount2 type, want: %q, got: %q", want, mount2.share) + } + if want := []string{"private", "rw"}; !reflect.DeepEqual(want, mount2.mount.Options) { + t.Errorf("mount2 type, want: %q, got: %q", want, mount2.mount.Options) + } +} + +func TestPodMountHintsErrors(t *testing.T) { + for _, tst := range []struct { + name string + annotations map[string]string + error string + }{ + { + name: "too short", + annotations: map[string]string{ + path.Join(MountPrefix, "mount1"): "foo", + }, + error: "invalid mount annotation", + }, + { + name: "no name", + annotations: map[string]string{ + MountPrefix + "//source": "foo", + }, + error: "invalid mount name", + }, + { + name: "missing source", + annotations: map[string]string{ + path.Join(MountPrefix, "mount1", "type"): "tmpfs", + path.Join(MountPrefix, "mount1", "share"): "pod", + }, + error: "source field", + }, + { + name: "missing type", + annotations: map[string]string{ + path.Join(MountPrefix, "mount1", "source"): "foo", + path.Join(MountPrefix, "mount1", "share"): "pod", + }, + error: "type field", + }, + { + name: "missing share", + annotations: map[string]string{ + path.Join(MountPrefix, "mount1", "source"): "foo", + path.Join(MountPrefix, "mount1", "type"): "tmpfs", + }, + error: "share field", + }, + { + name: "invalid field name", + annotations: map[string]string{ + path.Join(MountPrefix, "mount1", "invalid"): "foo", + }, + error: "invalid mount annotation", + }, + { + name: "invalid source", + annotations: map[string]string{ + path.Join(MountPrefix, "mount1", "source"): "", + path.Join(MountPrefix, "mount1", "type"): "tmpfs", + path.Join(MountPrefix, "mount1", "share"): "pod", + }, + error: "source cannot be empty", + }, + { + name: "invalid type", + annotations: map[string]string{ + path.Join(MountPrefix, "mount1", "source"): "foo", + path.Join(MountPrefix, "mount1", "type"): "invalid-type", + path.Join(MountPrefix, "mount1", "share"): "pod", + }, + error: "invalid type", + }, + { + name: "invalid share", + annotations: map[string]string{ + path.Join(MountPrefix, "mount1", "source"): "foo", + path.Join(MountPrefix, "mount1", "type"): "tmpfs", + path.Join(MountPrefix, "mount1", "share"): "invalid-share", + }, + error: "invalid share", + }, + { + name: "invalid options", + annotations: map[string]string{ + path.Join(MountPrefix, "mount1", "source"): "foo", + path.Join(MountPrefix, "mount1", "type"): "tmpfs", + path.Join(MountPrefix, "mount1", "share"): "pod", + path.Join(MountPrefix, "mount1", "options"): "invalid-option", + }, + error: "unknown mount option", + }, + { + name: "duplicate source", + annotations: map[string]string{ + path.Join(MountPrefix, "mount1", "source"): "foo", + path.Join(MountPrefix, "mount1", "type"): "tmpfs", + path.Join(MountPrefix, "mount1", "share"): "pod", + + path.Join(MountPrefix, "mount2", "source"): "foo", + path.Join(MountPrefix, "mount2", "type"): "bind", + path.Join(MountPrefix, "mount2", "share"): "container", + }, + error: "have the same mount source", + }, + } { + t.Run(tst.name, func(t *testing.T) { + spec := &specs.Spec{Annotations: tst.annotations} + podHints, err := newPodMountHints(spec) + if err == nil || !strings.Contains(err.Error(), tst.error) { + t.Errorf("newPodMountHints invalid error, want: .*%s.*, got: %v", tst.error, err) + } + if podHints != nil { + t.Errorf("newPodMountHints must return nil on failure: %+v", podHints) + } + }) + } +} diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index 42bddb2e8..3e6095fdc 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -117,6 +117,10 @@ type Loader struct { // // processes is guardded by mu. processes map[execID]*execProcess + + // mountHints provides extra information about mounts for containers that + // apply to the entire pod. + mountHints *podMountHints } // execID uniquely identifies a sentry process that is executed in a container. @@ -299,6 +303,11 @@ func New(args Args) (*Loader, error) { return nil, fmt.Errorf("initializing compat logs: %v", err) } + mountHints, err := newPodMountHints(args.Spec) + if err != nil { + return nil, fmt.Errorf("creating pod mount hints: %v", err) + } + eid := execID{cid: args.ID} l := &Loader{ k: k, @@ -311,6 +320,7 @@ func New(args Args) (*Loader, error) { rootProcArgs: procArgs, sandboxID: args.ID, processes: map[execID]*execProcess{eid: {}}, + mountHints: mountHints, } // We don't care about child signals; some platforms can generate a @@ -502,7 +512,7 @@ func (l *Loader) run() error { // cid for root container can be empty. Only subcontainers need it to set // the mount location. - mntr := newContainerMounter(l.spec, "", l.goferFDs, l.k) + mntr := newContainerMounter(l.spec, "", l.goferFDs, l.k, l.mountHints) if err := mntr.setupFS(ctx, l.conf, &l.rootProcArgs, l.rootProcArgs.Credentials); err != nil { return err } @@ -623,7 +633,7 @@ func (l *Loader) startContainer(spec *specs.Spec, conf *Config, cid string, file goferFDs = append(goferFDs, fd) } - mntr := newContainerMounter(spec, cid, goferFDs, l.k) + mntr := newContainerMounter(spec, cid, goferFDs, l.k, l.mountHints) if err := mntr.setupFS(ctx, conf, &procArgs, creds); err != nil { return fmt.Errorf("configuring container FS: %v", err) } diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go index 6393cb3fb..2f2499811 100644 --- a/runsc/boot/loader_test.go +++ b/runsc/boot/loader_test.go @@ -404,7 +404,7 @@ func TestCreateMountNamespace(t *testing.T) { mns = m ctx.(*contexttest.TestContext).RegisterValue(fs.CtxRoot, mns.Root()) } - mntr := newContainerMounter(&tc.spec, "", []int{sandEnd}, nil) + mntr := newContainerMounter(&tc.spec, "", []int{sandEnd}, nil, &podMountHints{}) if err := mntr.setupRootContainer(ctx, ctx, conf, setMountNS); err != nil { t.Fatalf("createMountNamespace test case %q failed: %v", tc.name, err) } @@ -610,7 +610,7 @@ func TestRestoreEnvironment(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { conf := testConfig() - mntr := newContainerMounter(tc.spec, "", tc.ioFDs, nil) + mntr := newContainerMounter(tc.spec, "", tc.ioFDs, nil, &podMountHints{}) actualRenv, err := mntr.createRestoreEnvironment(conf) if !tc.errorExpected && err != nil { t.Fatalf("could not create restore environment for test:%s", tc.name) diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go index 4ea3c74ac..d57a73d46 100644 --- a/runsc/container/multi_container_test.go +++ b/runsc/container/multi_container_test.go @@ -99,6 +99,36 @@ func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*C return containers, cleanup, nil } +type execDesc struct { + c *Container + cmd []string + want int + desc string +} + +func execMany(execs []execDesc) error { + for _, exec := range execs { + args := &control.ExecArgs{Argv: exec.cmd} + if ws, err := exec.c.executeSync(args); err != nil { + return fmt.Errorf("error executing %+v: %v", args, err) + } else if ws.ExitStatus() != exec.want { + return fmt.Errorf("%q: exec %q got exit status: %d, want: %d", exec.desc, exec.cmd, ws.ExitStatus(), exec.want) + } + } + return nil +} + +func createSharedMount(mount specs.Mount, name string, pod ...*specs.Spec) { + for _, spec := range pod { + spec.Annotations[path.Join(boot.MountPrefix, name, "source")] = mount.Source + spec.Annotations[path.Join(boot.MountPrefix, name, "type")] = mount.Type + spec.Annotations[path.Join(boot.MountPrefix, name, "share")] = "pod" + if len(mount.Options) > 0 { + spec.Annotations[path.Join(boot.MountPrefix, name, "options")] = strings.Join(mount.Options, ",") + } + } +} + // TestMultiContainerSanity checks that it is possible to run 2 dead-simple // containers in the same sandbox. func TestMultiContainerSanity(t *testing.T) { @@ -828,3 +858,272 @@ func TestMultiContainerGoferStop(t *testing.T) { } } } + +// Test that pod shared mounts are properly mounted in 2 containers and that +// changes from one container is reflected in the other. +func TestMultiContainerSharedMount(t *testing.T) { + for _, conf := range configs(all...) { + t.Logf("Running test with conf: %+v", conf) + + // Setup the containers. + sleep := []string{"sleep", "100"} + podSpec, ids := createSpecs(sleep, sleep) + mnt0 := specs.Mount{ + Destination: "/mydir/test", + Source: "/some/dir", + Type: "tmpfs", + Options: nil, + } + podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0) + + mnt1 := mnt0 + mnt1.Destination = "/mydir2/test2" + podSpec[1].Mounts = append(podSpec[1].Mounts, mnt1) + + createSharedMount(mnt0, "test-mount", podSpec...) + + containers, cleanup, err := startContainers(conf, podSpec, ids) + if err != nil { + t.Fatalf("error starting containers: %v", err) + } + defer cleanup() + + file0 := path.Join(mnt0.Destination, "abc") + file1 := path.Join(mnt1.Destination, "abc") + execs := []execDesc{ + { + c: containers[0], + cmd: []string{"/usr/bin/test", "-d", mnt0.Destination}, + desc: "directory is mounted in container0", + }, + { + c: containers[1], + cmd: []string{"/usr/bin/test", "-d", mnt1.Destination}, + desc: "directory is mounted in container1", + }, + { + c: containers[0], + cmd: []string{"/usr/bin/touch", file0}, + desc: "create file in container0", + }, + { + c: containers[0], + cmd: []string{"/usr/bin/test", "-f", file0}, + desc: "file appears in container0", + }, + { + c: containers[1], + cmd: []string{"/usr/bin/test", "-f", file1}, + desc: "file appears in container1", + }, + { + c: containers[1], + cmd: []string{"/bin/rm", file1}, + desc: "file removed from container1", + }, + { + c: containers[0], + cmd: []string{"/usr/bin/test", "!", "-f", file0}, + desc: "file removed from container0", + }, + { + c: containers[1], + cmd: []string{"/usr/bin/test", "!", "-f", file1}, + desc: "file removed from container1", + }, + { + c: containers[1], + cmd: []string{"/bin/mkdir", file1}, + desc: "create directory in container1", + }, + { + c: containers[0], + cmd: []string{"/usr/bin/test", "-d", file0}, + desc: "dir appears in container0", + }, + { + c: containers[1], + cmd: []string{"/usr/bin/test", "-d", file1}, + desc: "dir appears in container1", + }, + { + c: containers[0], + cmd: []string{"/bin/rmdir", file0}, + desc: "create directory in container0", + }, + { + c: containers[0], + cmd: []string{"/usr/bin/test", "!", "-d", file0}, + desc: "dir removed from container0", + }, + { + c: containers[1], + cmd: []string{"/usr/bin/test", "!", "-d", file1}, + desc: "dir removed from container1", + }, + } + if err := execMany(execs); err != nil { + t.Fatal(err.Error()) + } + } +} + +// Test that pod mounts are mounted as readonly when requested. +func TestMultiContainerSharedMountReadonly(t *testing.T) { + for _, conf := range configs(all...) { + t.Logf("Running test with conf: %+v", conf) + + // Setup the containers. + sleep := []string{"sleep", "100"} + podSpec, ids := createSpecs(sleep, sleep) + mnt0 := specs.Mount{ + Destination: "/mydir/test", + Source: "/some/dir", + Type: "tmpfs", + Options: []string{"ro"}, + } + podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0) + + mnt1 := mnt0 + mnt1.Destination = "/mydir2/test2" + podSpec[1].Mounts = append(podSpec[1].Mounts, mnt1) + + createSharedMount(mnt0, "test-mount", podSpec...) + + containers, cleanup, err := startContainers(conf, podSpec, ids) + if err != nil { + t.Fatalf("error starting containers: %v", err) + } + defer cleanup() + + file0 := path.Join(mnt0.Destination, "abc") + file1 := path.Join(mnt1.Destination, "abc") + execs := []execDesc{ + { + c: containers[0], + cmd: []string{"/usr/bin/test", "-d", mnt0.Destination}, + desc: "directory is mounted in container0", + }, + { + c: containers[1], + cmd: []string{"/usr/bin/test", "-d", mnt1.Destination}, + desc: "directory is mounted in container1", + }, + { + c: containers[0], + cmd: []string{"/usr/bin/touch", file0}, + want: 1, + desc: "fails to write to container0", + }, + { + c: containers[1], + cmd: []string{"/usr/bin/touch", file1}, + want: 1, + desc: "fails to write to container1", + }, + } + if err := execMany(execs); err != nil { + t.Fatal(err.Error()) + } + } +} + +// Test that shared pod mounts continue to work after container is restarted. +func TestMultiContainerSharedMountRestart(t *testing.T) { + for _, conf := range configs(all...) { + t.Logf("Running test with conf: %+v", conf) + + // Setup the containers. + sleep := []string{"sleep", "100"} + podSpec, ids := createSpecs(sleep, sleep) + mnt0 := specs.Mount{ + Destination: "/mydir/test", + Source: "/some/dir", + Type: "tmpfs", + Options: nil, + } + podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0) + + mnt1 := mnt0 + mnt1.Destination = "/mydir2/test2" + podSpec[1].Mounts = append(podSpec[1].Mounts, mnt1) + + createSharedMount(mnt0, "test-mount", podSpec...) + + containers, cleanup, err := startContainers(conf, podSpec, ids) + if err != nil { + t.Fatalf("error starting containers: %v", err) + } + defer cleanup() + + file0 := path.Join(mnt0.Destination, "abc") + file1 := path.Join(mnt1.Destination, "abc") + execs := []execDesc{ + { + c: containers[0], + cmd: []string{"/usr/bin/touch", file0}, + desc: "create file in container0", + }, + { + c: containers[0], + cmd: []string{"/usr/bin/test", "-f", file0}, + desc: "file appears in container0", + }, + { + c: containers[1], + cmd: []string{"/usr/bin/test", "-f", file1}, + desc: "file appears in container1", + }, + } + if err := execMany(execs); err != nil { + t.Fatal(err.Error()) + } + + containers[1].Destroy() + + bundleDir, err := testutil.SetupBundleDir(podSpec[1]) + if err != nil { + t.Fatalf("error restarting container: %v", err) + } + defer os.RemoveAll(bundleDir) + + containers[1], err = Create(ids[1], podSpec[1], conf, bundleDir, "", "", "") + if err != nil { + t.Fatalf("error creating container: %v", err) + } + if err := containers[1].Start(conf); err != nil { + t.Fatalf("error starting container: %v", err) + } + + execs = []execDesc{ + { + c: containers[0], + cmd: []string{"/usr/bin/test", "-f", file0}, + desc: "file is still in container0", + }, + { + c: containers[1], + cmd: []string{"/usr/bin/test", "-f", file1}, + desc: "file is still in container1", + }, + { + c: containers[1], + cmd: []string{"/bin/rm", file1}, + desc: "file removed from container1", + }, + { + c: containers[0], + cmd: []string{"/usr/bin/test", "!", "-f", file0}, + desc: "file removed from container0", + }, + { + c: containers[1], + cmd: []string{"/usr/bin/test", "!", "-f", file1}, + desc: "file removed from container1", + }, + } + if err := execMany(execs); err != nil { + t.Fatal(err.Error()) + } + } +} diff --git a/runsc/specutils/fs.go b/runsc/specutils/fs.go index 1f3afb4e4..6e6902e9f 100644 --- a/runsc/specutils/fs.go +++ b/runsc/specutils/fs.go @@ -16,6 +16,7 @@ package specutils import ( "fmt" + "math/bits" "path" "syscall" @@ -105,22 +106,30 @@ func optionsToFlags(opts []string, source map[string]mapping) uint32 { return rv } -// ValidateMount validates that spec mounts are correct. +// validateMount validates that spec mounts are correct. func validateMount(mnt *specs.Mount) error { if !path.IsAbs(mnt.Destination) { return fmt.Errorf("Mount.Destination must be an absolute path: %v", mnt) } - if mnt.Type == "bind" { - for _, o := range mnt.Options { - if ContainsStr(invalidOptions, o) { - return fmt.Errorf("mount option %q is not supported: %v", o, mnt) - } - _, ok1 := optionsMap[o] - _, ok2 := propOptionsMap[o] - if !ok1 && !ok2 { - return fmt.Errorf("unknown mount option %q", o) - } + return ValidateMountOptions(mnt.Options) + } + return nil +} + +// ValidateMountOptions validates that mount options are correct. +func ValidateMountOptions(opts []string) error { + for _, o := range opts { + if ContainsStr(invalidOptions, o) { + return fmt.Errorf("mount option %q is not supported", o) + } + _, ok1 := optionsMap[o] + _, ok2 := propOptionsMap[o] + if !ok1 && !ok2 { + return fmt.Errorf("unknown mount option %q", o) + } + if err := validatePropagation(o); err != nil { + return err } } return nil @@ -133,5 +142,14 @@ func validateRootfsPropagation(opt string) error { if flags&(syscall.MS_SLAVE|syscall.MS_PRIVATE) == 0 { return fmt.Errorf("root mount propagation option must specify private or slave: %q", opt) } + return validatePropagation(opt) +} + +func validatePropagation(opt string) error { + flags := PropOptionsToFlags([]string{opt}) + exclusive := flags & (syscall.MS_SLAVE | syscall.MS_PRIVATE | syscall.MS_SHARED | syscall.MS_UNBINDABLE) + if bits.OnesCount32(exclusive) > 1 { + return fmt.Errorf("mount propagation options are mutually exclusive: %q", opt) + } return nil } -- cgit v1.2.3