From d0f8b3174e01cc14939d8f631e4415cf30925208 Mon Sep 17 00:00:00 2001 From: Howard Zhang Date: Tue, 3 Nov 2020 13:54:40 +0800 Subject: ARM64: follow nogo rules add function description Signed-off-by: Howard Zhang --- pkg/sentry/platform/ring0/kernel_arm64.go | 6 ++++++ pkg/sentry/platform/ring0/lib_arm64.go | 4 ++-- 2 files changed, 8 insertions(+), 2 deletions(-) (limited to 'pkg/sentry') diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go index b294ccc7c..68291b504 100644 --- a/pkg/sentry/platform/ring0/kernel_arm64.go +++ b/pkg/sentry/platform/ring0/kernel_arm64.go @@ -51,6 +51,12 @@ func IsCanonical(addr uint64) bool { return addr <= 0x0000ffffffffffff || addr > 0xffff000000000000 } +// SwitchToUser performs an eret. +// +// The return value is the exception vector. +// +// +checkescape:all +// //go:nosplit func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { storeAppASID(uintptr(switchOpts.UserASID)) diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go index d91a09de1..456107cd8 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.go +++ b/pkg/sentry/platform/ring0/lib_arm64.go @@ -28,13 +28,13 @@ func FlushTlbAll() // CPACREL1 returns the value of the CPACR_EL1 register. func CPACREL1() (value uintptr) -// FPCR returns the value of FPCR register. +// GetFPCR returns the value of FPCR register. func GetFPCR() (value uintptr) // SetFPCR writes the FPCR value. func SetFPCR(value uintptr) -// FPSR returns the value of FPSR register. +// GetFPSR returns the value of FPSR register. func GetFPSR() (value uintptr) // SetFPSR writes the FPSR value. -- cgit v1.2.3 From 4f79706ccdc8b6515ad384c5f1896f5405e9d445 Mon Sep 17 00:00:00 2001 From: Robin Luk Date: Thu, 19 Nov 2020 17:58:24 +0800 Subject: arm64 tlb: add support for tlbi-vale1ls/tlbi-aside1ls This patch adds support for tlbi-vale1ls/tlbi-aside1ls. And make the code consistent with the flush strategy of the x86 platform. Signed-off-by: Robin Luk --- pkg/sentry/platform/ring0/kernel_arm64.go | 2 +- pkg/sentry/platform/ring0/lib_arm64.go | 8 +++++++- pkg/sentry/platform/ring0/lib_arm64.s | 17 +++++++++++++++++ 3 files changed, 25 insertions(+), 2 deletions(-) (limited to 'pkg/sentry') diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go index 6cbbf001f..c1c808b96 100644 --- a/pkg/sentry/platform/ring0/kernel_arm64.go +++ b/pkg/sentry/platform/ring0/kernel_arm64.go @@ -53,7 +53,7 @@ func IsCanonical(addr uint64) bool { func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { storeAppASID(uintptr(switchOpts.UserASID)) if switchOpts.Flush { - FlushTlbAll() + FlushTlbByASID(uintptr(switchOpts.UserASID)) } regs := switchOpts.Registers diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go index d91a09de1..bf1c655f4 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.go +++ b/pkg/sentry/platform/ring0/lib_arm64.go @@ -22,7 +22,13 @@ func storeAppASID(asid uintptr) // LocalFlushTlbAll same as FlushTlbAll, but only applies to the calling CPU. func LocalFlushTlbAll() -// FlushTlbAll flush all tlb. +// FlushTlbByVA invalidates tlb by VA/Last-level/Inner-Shareable. +func FlushTlbByVA(addr uintptr) + +// FlushTlbByASID invalidates tlb by ASID/Inner-Shareable. +func FlushTlbByASID(asid uintptr) + +// FlushTlbAll invalidates all tlb. func FlushTlbAll() // CPACREL1 returns the value of the CPACR_EL1 register. diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s index 19c1fca8b..675a8bdb7 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.s +++ b/pkg/sentry/platform/ring0/lib_arm64.s @@ -15,6 +15,23 @@ #include "funcdata.h" #include "textflag.h" +#define TLBI_ASID_SHIFT 48 + +TEXT ·FlushTlbByVA(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R1 + DSB $10 // dsb(ishst) + WORD $0xd50883a1 // tlbi vale1is, x1 + DSB $11 // dsb(ish) + RET + +TEXT ·FlushTlbByASID(SB),NOSPLIT,$0-8 + MOVD asid+0(FP), R1 + LSL $TLBI_ASID_SHIFT, R1, R1 + DSB $10 // dsb(ishst) + WORD $0xd5088341 // tlbi aside1is, x1 + DSB $11 // dsb(ish) + RET + TEXT ·LocalFlushTlbAll(SB),NOSPLIT,$0 DSB $6 // dsb(nshst) WORD $0xd508871f // __tlbi(vmalle1) -- cgit v1.2.3 From e60514493892a17a8ce7b3d98747164926600614 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Tue, 15 Dec 2020 12:38:31 -0800 Subject: Internal change. PiperOrigin-RevId: 347671070 --- pkg/sentry/fsimpl/verity/verity.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'pkg/sentry') diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index 9571ce9f1..9563ceab4 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -748,7 +748,7 @@ func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32) // file /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The // hash of the generated Merkle tree and the data size is returned. If fd // points to a regular file, the data is the content of the file. If fd points -// to a directory, the data is all hahes of its children, written to the Merkle +// to a directory, the data is all hashes of its children, written to the Merkle // tree file. // // Preconditions: fd.d.fs.verityMu must be locked. -- cgit v1.2.3 From cc28d36845cd3b2267ececbdf81b2c265267cdec Mon Sep 17 00:00:00 2001 From: Ayush Ranjan Date: Tue, 15 Dec 2020 13:46:38 -0800 Subject: [netstack] Make recvmsg(2) call to host in hostinet even if dst is empty. We want to make the recvmsg syscall to the host regardless of if the dst is empty or not so that: - Host can populate the control messages if necessary. - Host can return sender address. - Host can return appropriate errors. Earlier because we were using the IOSequence.CopyOutFrom() API, the usermem package does not even call the Reader function if the destination is empty (as an optimization). PiperOrigin-RevId: 347684566 --- pkg/sentry/socket/hostinet/socket.go | 110 ++++++++++++++++++++--------------- 1 file changed, 62 insertions(+), 48 deletions(-) (limited to 'pkg/sentry') diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index be418df2e..1f220c343 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -416,6 +416,37 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] return nil } +func (s *socketOpsCommon) recvMsgFromHost(iovs []syscall.Iovec, flags int, senderRequested bool, controlLen uint64) (uint64, int, []byte, []byte, error) { + // We always do a non-blocking recv*(). + sysflags := flags | syscall.MSG_DONTWAIT + + msg := syscall.Msghdr{} + if len(iovs) > 0 { + msg.Iov = &iovs[0] + msg.Iovlen = uint64(len(iovs)) + } + var senderAddrBuf []byte + if senderRequested { + senderAddrBuf = make([]byte, sizeofSockaddr) + msg.Name = &senderAddrBuf[0] + msg.Namelen = uint32(sizeofSockaddr) + } + var controlBuf []byte + if controlLen > 0 { + if controlLen > maxControlLen { + controlLen = maxControlLen + } + controlBuf = make([]byte, controlLen) + msg.Control = &controlBuf[0] + msg.Controllen = controlLen + } + n, err := recvmsg(s.fd, &msg, sysflags) + if err != nil { + return 0 /* n */, 0 /* mFlags */, nil /* senderAddrBuf */, nil /* controlBuf */, err + } + return n, int(msg.Flags), senderAddrBuf[:msg.Namelen], controlBuf[:msg.Controllen], err +} + // RecvMsg implements socket.Socket.RecvMsg. func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { // Only allow known and safe flags. @@ -427,56 +458,36 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument } - var senderAddr linux.SockAddr var senderAddrBuf []byte - if senderRequested { - senderAddrBuf = make([]byte, sizeofSockaddr) - } - var controlBuf []byte var msgFlags int - - recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) { - // Refuse to do anything if any part of dst.Addrs was unusable. - if uint64(dst.NumBytes()) != dsts.NumBytes() { - return 0, nil - } - if dsts.IsEmpty() { - return 0, nil - } - - // We always do a non-blocking recv*(). - sysflags := flags | syscall.MSG_DONTWAIT - - iovs := safemem.IovecsFromBlockSeq(dsts) - msg := syscall.Msghdr{ - Iov: &iovs[0], - Iovlen: uint64(len(iovs)), - } - if len(senderAddrBuf) != 0 { - msg.Name = &senderAddrBuf[0] - msg.Namelen = uint32(len(senderAddrBuf)) - } - if controlLen > 0 { - if controlLen > maxControlLen { - controlLen = maxControlLen + copyToDst := func() (int64, error) { + var n uint64 + var err error + if dst.NumBytes() == 0 { + // We want to make the recvmsg(2) call to the host even if dst is empty + // to fetch control messages, sender address or errors if any occur. + n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(nil, flags, senderRequested, controlLen) + return int64(n), err + } + + recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) { + // Refuse to do anything if any part of dst.Addrs was unusable. + if uint64(dst.NumBytes()) != dsts.NumBytes() { + return 0, nil + } + if dsts.IsEmpty() { + return 0, nil } - controlBuf = make([]byte, controlLen) - msg.Control = &controlBuf[0] - msg.Controllen = controlLen - } - n, err := recvmsg(s.fd, &msg, sysflags) - if err != nil { - return 0, err - } - senderAddrBuf = senderAddrBuf[:msg.Namelen] - msgFlags = int(msg.Flags) - controlLen = uint64(msg.Controllen) - return n, nil - }) + + n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(safemem.IovecsFromBlockSeq(dsts), flags, senderRequested, controlLen) + return n, err + }) + return dst.CopyOutFrom(t, recvmsgToBlocks) + } var ch chan struct{} - n, err := dst.CopyOutFrom(t, recvmsgToBlocks) + n, err := copyToDst() if flags&syscall.MSG_DONTWAIT == 0 { for err == syserror.ErrWouldBlock { // We only expect blocking to come from the actual syscall, in which @@ -494,22 +505,26 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags s.EventRegister(&e, waiter.EventIn) defer s.EventUnregister(&e) } - n, err = dst.CopyOutFrom(t, recvmsgToBlocks) + n, err = copyToDst() } } if err != nil { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) } + var senderAddr linux.SockAddr if senderRequested { senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf) } - unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf[:controlLen]) + unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf) if err != nil { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) } + return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), parseUnixControlMessages(unixControlMessages), nil +} +func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) socket.ControlMessages { controlMessages := socket.ControlMessages{} for _, unixCmsg := range unixControlMessages { switch unixCmsg.Header.Level { @@ -558,8 +573,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } } } - - return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), controlMessages, nil + return controlMessages } // SendMsg implements socket.Socket.SendMsg. -- cgit v1.2.3 From f6407de6bafbf8fe3e4579c876640672380fa96c Mon Sep 17 00:00:00 2001 From: Ayush Ranjan Date: Tue, 15 Dec 2020 15:25:35 -0800 Subject: [syzkaller] Avoid AIOContext from resurrecting after being marked dead. syzkaller reported the closing of a nil channel. This is only possible when the AIOContext was destroyed twice. Some scenarios that could lead to this: - It died and then some called aioCtx.Prepare() on it and then killed it again which could cause the double destroy. The context could have been destroyed in between the call to LookupAIOContext() and Prepare(). - aioManager was destroyed but it did not update the contexts map. So Lookup could still return a dead AIOContext and then someone could call Prepare on it and kill it again. So added a check in aioCtx.Prepare() for the context being dead. This will prevent a dead context from resurrecting. Also refactored code to destroy the aioContext consistently. Earlier we were not munmapping the aioContexts that were destroyed upon aioManager destruction. Reported-by: syzbot+ef6a588d0ce6059991d2@syzkaller.appspotmail.com PiperOrigin-RevId: 347704347 --- pkg/sentry/mm/aio_context.go | 79 ++++++++++++++++++++--------------- pkg/sentry/mm/lifecycle.go | 2 +- pkg/sentry/mm/mm_test.go | 43 +++++++++++++++++++ pkg/sentry/syscalls/linux/sys_aio.go | 5 +-- pkg/sentry/syscalls/linux/vfs2/aio.go | 5 +-- 5 files changed, 94 insertions(+), 40 deletions(-) (limited to 'pkg/sentry') diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go index 4c8cd38ed..5ab2ef79f 100644 --- a/pkg/sentry/mm/aio_context.go +++ b/pkg/sentry/mm/aio_context.go @@ -36,12 +36,12 @@ type aioManager struct { contexts map[uint64]*AIOContext } -func (a *aioManager) destroy() { - a.mu.Lock() - defer a.mu.Unlock() +func (mm *MemoryManager) destroyAIOManager(ctx context.Context) { + mm.aioManager.mu.Lock() + defer mm.aioManager.mu.Unlock() - for _, ctx := range a.contexts { - ctx.destroy() + for id := range mm.aioManager.contexts { + mm.destroyAIOContextLocked(ctx, id) } } @@ -68,16 +68,26 @@ func (a *aioManager) newAIOContext(events uint32, id uint64) bool { // be drained. // // Nil is returned if the context does not exist. -func (a *aioManager) destroyAIOContext(id uint64) *AIOContext { - a.mu.Lock() - defer a.mu.Unlock() - ctx, ok := a.contexts[id] +// +// Precondition: mm.aioManager.mu is locked. +func (mm *MemoryManager) destroyAIOContextLocked(ctx context.Context, id uint64) *AIOContext { + aioCtx, ok := mm.aioManager.contexts[id] if !ok { return nil } - delete(a.contexts, id) - ctx.destroy() - return ctx + + // Only unmaps after it assured that the address is a valid aio context to + // prevent random memory from been unmapped. + // + // Note: It's possible to unmap this address and map something else into + // the same address. Then it would be unmapping memory that it doesn't own. + // This is, however, the way Linux implements AIO. Keeps the same [weird] + // semantics in case anyone relies on it. + mm.MUnmap(ctx, usermem.Addr(id), aioRingBufferSize) + + delete(mm.aioManager.contexts, id) + aioCtx.destroy() + return aioCtx } // lookupAIOContext looks up the given context. @@ -140,16 +150,21 @@ func (ctx *AIOContext) checkForDone() { } } -// Prepare reserves space for a new request, returning true if available. -// Returns false if the context is busy. -func (ctx *AIOContext) Prepare() bool { +// Prepare reserves space for a new request, returning nil if available. +// Returns EAGAIN if the context is busy and EINVAL if the context is dead. +func (ctx *AIOContext) Prepare() error { ctx.mu.Lock() defer ctx.mu.Unlock() + if ctx.dead { + // Context died after the caller looked it up. + return syserror.EINVAL + } if ctx.outstanding >= ctx.maxOutstanding { - return false + // Context is busy. + return syserror.EAGAIN } ctx.outstanding++ - return true + return nil } // PopRequest pops a completed request if available, this function does not do @@ -391,20 +406,13 @@ func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint // DestroyAIOContext destroys an asynchronous I/O context. It returns the // destroyed context. nil if the context does not exist. func (mm *MemoryManager) DestroyAIOContext(ctx context.Context, id uint64) *AIOContext { - if _, ok := mm.LookupAIOContext(ctx, id); !ok { + if !mm.isValidAddr(ctx, id) { return nil } - // Only unmaps after it assured that the address is a valid aio context to - // prevent random memory from been unmapped. - // - // Note: It's possible to unmap this address and map something else into - // the same address. Then it would be unmapping memory that it doesn't own. - // This is, however, the way Linux implements AIO. Keeps the same [weird] - // semantics in case anyone relies on it. - mm.MUnmap(ctx, usermem.Addr(id), aioRingBufferSize) - - return mm.aioManager.destroyAIOContext(id) + mm.aioManager.mu.Lock() + defer mm.aioManager.mu.Unlock() + return mm.destroyAIOContextLocked(ctx, id) } // LookupAIOContext looks up the given context. It returns false if the context @@ -415,13 +423,18 @@ func (mm *MemoryManager) LookupAIOContext(ctx context.Context, id uint64) (*AIOC return nil, false } - // Protect against 'ids' that are inaccessible (Linux also reads 4 bytes - // from id). - var buf [4]byte - _, err := mm.CopyIn(ctx, usermem.Addr(id), buf[:], usermem.IOOpts{}) - if err != nil { + // Protect against 'id' that is inaccessible. + if !mm.isValidAddr(ctx, id) { return nil, false } return aioCtx, true } + +// isValidAddr determines if the address `id` is valid. (Linux also reads 4 +// bytes from id). +func (mm *MemoryManager) isValidAddr(ctx context.Context, id uint64) bool { + var buf [4]byte + _, err := mm.CopyIn(ctx, usermem.Addr(id), buf[:], usermem.IOOpts{}) + return err == nil +} diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go index 09dbc06a4..120707429 100644 --- a/pkg/sentry/mm/lifecycle.go +++ b/pkg/sentry/mm/lifecycle.go @@ -253,7 +253,7 @@ func (mm *MemoryManager) DecUsers(ctx context.Context) { panic(fmt.Sprintf("Invalid MemoryManager.users: %d", users)) } - mm.aioManager.destroy() + mm.destroyAIOManager(ctx) mm.metadataMu.Lock() exe := mm.executable diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go index acac3d357..bc53bd41e 100644 --- a/pkg/sentry/mm/mm_test.go +++ b/pkg/sentry/mm/mm_test.go @@ -229,3 +229,46 @@ func TestIOAfterMProtect(t *testing.T) { t.Errorf("CopyOut got %d want 1", n) } } + +// TestAIOPrepareAfterDestroy tests that AIOContext should not be able to be +// prepared after destruction. +func TestAIOPrepareAfterDestroy(t *testing.T) { + ctx := contexttest.Context(t) + mm := testMemoryManager(ctx) + defer mm.DecUsers(ctx) + + id, err := mm.NewAIOContext(ctx, 1) + if err != nil { + t.Fatalf("mm.NewAIOContext got err %v want nil", err) + } + aioCtx, ok := mm.LookupAIOContext(ctx, id) + if !ok { + t.Fatalf("AIOContext not found") + } + mm.DestroyAIOContext(ctx, id) + + // Prepare should fail because aioCtx should be destroyed. + if err := aioCtx.Prepare(); err != syserror.EINVAL { + t.Errorf("aioCtx.Prepare got err %v want nil", err) + } else if err == nil { + aioCtx.CancelPendingRequest() + } +} + +// TestAIOLookupAfterDestroy tests that AIOContext should not be able to be +// looked up after memory manager is destroyed. +func TestAIOLookupAfterDestroy(t *testing.T) { + ctx := contexttest.Context(t) + mm := testMemoryManager(ctx) + + id, err := mm.NewAIOContext(ctx, 1) + if err != nil { + mm.DecUsers(ctx) + t.Fatalf("mm.NewAIOContext got err %v want nil", err) + } + mm.DecUsers(ctx) // This destroys the AIOContext manager. + + if _, ok := mm.LookupAIOContext(ctx, id); ok { + t.Errorf("AIOContext found even after AIOContext manager is destroyed") + } +} diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go index 0bf313a13..c2285f796 100644 --- a/pkg/sentry/syscalls/linux/sys_aio.go +++ b/pkg/sentry/syscalls/linux/sys_aio.go @@ -307,9 +307,8 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user if !ok { return syserror.EINVAL } - if ready := ctx.Prepare(); !ready { - // Context is busy. - return syserror.EAGAIN + if err := ctx.Prepare(); err != nil { + return err } if eventFile != nil { diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go index 6d0a38330..1365a5a62 100644 --- a/pkg/sentry/syscalls/linux/vfs2/aio.go +++ b/pkg/sentry/syscalls/linux/vfs2/aio.go @@ -130,9 +130,8 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user if !ok { return syserror.EINVAL } - if ready := aioCtx.Prepare(); !ready { - // Context is busy. - return syserror.EAGAIN + if err := aioCtx.Prepare(); err != nil { + return err } if eventFD != nil { -- cgit v1.2.3 From 7aa674eb68e9b760ea72508dfb79a19dbf5b85ed Mon Sep 17 00:00:00 2001 From: Chong Cai Date: Tue, 15 Dec 2020 15:38:19 -0800 Subject: Change violation mode to an enum PiperOrigin-RevId: 347706953 --- pkg/sentry/fsimpl/verity/filesystem.go | 2 +- pkg/sentry/fsimpl/verity/verity.go | 32 ++++++++++++++++++++------------ pkg/sentry/fsimpl/verity/verity_test.go | 10 +++++----- 3 files changed, 26 insertions(+), 18 deletions(-) (limited to 'pkg/sentry') diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 04e7110a3..a4ad625bb 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -163,7 +163,7 @@ afterSymlink: // verifyChildLocked verifies the hash of child against the already verified // hash of the parent to ensure the child is expected. verifyChild triggers a // sentry panic if unexpected modifications to the file system are detected. In -// noCrashOnVerificationFailure mode it returns a syserror instead. +// ErrorOnViolation mode it returns a syserror instead. // // Preconditions: // * fs.renameMu must be locked. diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index 9563ceab4..66029c64d 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -92,10 +92,8 @@ const ( ) var ( - // noCrashOnVerificationFailure indicates whether the sandbox should panic - // whenever verification fails. If true, an error is returned instead of - // panicking. This should only be set for tests. - noCrashOnVerificationFailure bool + // action specifies the action towards detected violation. + action ViolationAction // verityMu synchronizes concurrent operations that enable verity and perform // verification checks. @@ -106,6 +104,18 @@ var ( // content. type HashAlgorithm int +// ViolationAction is a type specifying the action when an integrity violation +// is detected. +type ViolationAction int + +const ( + // PanicOnViolation terminates the sentry on detected violation. + PanicOnViolation ViolationAction = 0 + // ErrorOnViolation returns an error from the violating system call on + // detected violation. + ErrorOnViolation = 1 +) + // Currently supported hashing algorithms include SHA256 and SHA512. const ( SHA256 HashAlgorithm = iota @@ -200,10 +210,8 @@ type InternalFilesystemOptions struct { // system wrapped by verity file system. LowerGetFSOptions vfs.GetFilesystemOptions - // NoCrashOnVerificationFailure indicates whether the sandbox should - // panic whenever verification fails. If true, an error is returned - // instead of panicking. This should only be set for tests. - NoCrashOnVerificationFailure bool + // Action specifies the action on an integrity violation. + Action ViolationAction } // Name implements vfs.FilesystemType.Name. @@ -215,10 +223,10 @@ func (FilesystemType) Name() string { func (FilesystemType) Release(ctx context.Context) {} // alertIntegrityViolation alerts a violation of integrity, which usually means -// unexpected modification to the file system is detected. In -// noCrashOnVerificationFailure mode, it returns EIO, otherwise it panic. +// unexpected modification to the file system is detected. In ErrorOnViolation +// mode, it returns EIO, otherwise it panic. func alertIntegrityViolation(msg string) error { - if noCrashOnVerificationFailure { + if action == ErrorOnViolation { return syserror.EIO } panic(msg) @@ -231,7 +239,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt ctx.Warningf("verity.FilesystemType.GetFilesystem: missing verity configs") return nil, nil, syserror.EINVAL } - noCrashOnVerificationFailure = iopts.NoCrashOnVerificationFailure + action = iopts.Action // Mount the lower file system. The lower file system is wrapped inside // verity, and should not be exposed or connected. diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go index bd948715f..30d8b4355 100644 --- a/pkg/sentry/fsimpl/verity/verity_test.go +++ b/pkg/sentry/fsimpl/verity/verity_test.go @@ -92,11 +92,11 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, mntns, err := vfsObj.NewMountNamespace(ctx, auth.CredentialsFromContext(ctx), "", "verity", &vfs.MountOptions{ GetFilesystemOptions: vfs.GetFilesystemOptions{ InternalData: InternalFilesystemOptions{ - RootMerkleFileName: rootMerkleFilename, - LowerName: "tmpfs", - Alg: hashAlg, - AllowRuntimeEnable: true, - NoCrashOnVerificationFailure: true, + RootMerkleFileName: rootMerkleFilename, + LowerName: "tmpfs", + Alg: hashAlg, + AllowRuntimeEnable: true, + Action: ErrorOnViolation, }, }, }) -- cgit v1.2.3 From 1e56a2f9a29ff72eada493bf024a4e3fd5a963b6 Mon Sep 17 00:00:00 2001 From: Jing Chen Date: Tue, 15 Dec 2020 16:03:41 -0800 Subject: Implement command SEM_INFO and SEM_STAT for semctl. PiperOrigin-RevId: 347711998 --- pkg/abi/linux/sem.go | 2 +- pkg/sentry/kernel/semaphore/semaphore.go | 28 ++++- pkg/sentry/syscalls/linux/linux64.go | 4 +- pkg/sentry/syscalls/linux/sys_sem.go | 35 +++++- test/syscalls/linux/semaphore.cc | 182 ++++++++++++++++++++++++++----- 5 files changed, 214 insertions(+), 37 deletions(-) (limited to 'pkg/sentry') diff --git a/pkg/abi/linux/sem.go b/pkg/abi/linux/sem.go index 0adff8dff..2424884c1 100644 --- a/pkg/abi/linux/sem.go +++ b/pkg/abi/linux/sem.go @@ -43,10 +43,10 @@ const ( SEMVMX = 32767 SEMAEM = SEMVMX - // followings are unused in kernel SEMUME = SEMOPM SEMMNU = SEMMNS SEMMAP = SEMMNS + SEMUSZ = 20 ) const SEM_UNDO = 0x1000 diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index 3dd3953b3..db01e4a97 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -193,12 +193,26 @@ func (r *Registry) IPCInfo() *linux.SemInfo { SemMsl: linux.SEMMSL, SemOpm: linux.SEMOPM, SemUme: linux.SEMUME, - SemUsz: 0, // SemUsz not supported. + SemUsz: linux.SEMUSZ, SemVmx: linux.SEMVMX, SemAem: linux.SEMAEM, } } +// SemInfo returns a seminfo structure containing the same information as +// for IPC_INFO, except that SemUsz field returns the number of existing +// semaphore sets, and SemAem field returns the number of existing semaphores. +func (r *Registry) SemInfo() *linux.SemInfo { + r.mu.Lock() + defer r.mu.Unlock() + + info := r.IPCInfo() + info.SemUsz = uint32(len(r.semaphores)) + info.SemAem = uint32(r.totalSems()) + + return info +} + // HighestIndex returns the index of the highest used entry in // the kernel's array. func (r *Registry) HighestIndex() int32 { @@ -289,6 +303,18 @@ func (r *Registry) FindByID(id int32) *Set { return r.semaphores[id] } +// FindByIndex looks up a set given an index. +func (r *Registry) FindByIndex(index int32) *Set { + r.mu.Lock() + defer r.mu.Unlock() + + id, present := r.indexes[index] + if !present { + return nil + } + return r.semaphores[id] +} + func (r *Registry) findByKey(key int32) *Set { for _, v := range r.semaphores { if v.key == key { diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index cff442846..b815e498f 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -118,7 +118,7 @@ var AMD64 = &kernel.SyscallTable{ 63: syscalls.Supported("uname", Uname), 64: syscalls.Supported("semget", Semget), 65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), - 66: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil), + 66: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_STAT_ANY not supported.", nil), 67: syscalls.Supported("shmdt", Shmdt), 68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) @@ -619,7 +619,7 @@ var ARM64 = &kernel.SyscallTable{ 188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 190: syscalls.Supported("semget", Semget), - 191: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil), + 191: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_STAT_ANY not supported.", nil), 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), 193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), 194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil), diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go index a62a6b3b5..1166cd7bb 100644 --- a/pkg/sentry/syscalls/linux/sys_sem.go +++ b/pkg/sentry/syscalls/linux/sys_sem.go @@ -155,10 +155,28 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } return uintptr(r.HighestIndex()), nil, nil - case linux.SEM_INFO, - linux.SEM_STAT, - linux.SEM_STAT_ANY: + case linux.SEM_INFO: + buf := args[3].Pointer() + r := t.IPCNamespace().SemaphoreRegistry() + info := r.SemInfo() + if _, err := info.CopyOut(t, buf); err != nil { + return 0, nil, err + } + return uintptr(r.HighestIndex()), nil, nil + case linux.SEM_STAT: + arg := args[3].Pointer() + // id is an index in SEM_STAT. + semid, ds, err := semStat(t, id) + if err != nil { + return 0, nil, err + } + if _, err := ds.CopyOut(t, arg); err != nil { + return 0, nil, err + } + return uintptr(semid), nil, err + + case linux.SEM_STAT_ANY: t.Kernel().EmitUnimplementedEvent(t) fallthrough @@ -203,6 +221,17 @@ func ipcStat(t *kernel.Task, id int32) (*linux.SemidDS, error) { return set.GetStat(creds) } +func semStat(t *kernel.Task, index int32) (int32, *linux.SemidDS, error) { + r := t.IPCNamespace().SemaphoreRegistry() + set := r.FindByIndex(index) + if set == nil { + return 0, nil, syserror.EINVAL + } + creds := auth.CredentialsFromContext(t) + ds, err := set.GetStat(creds) + return set.ID, ds, err +} + func setVal(t *kernel.Task, id int32, num int32, val int16) error { r := t.IPCNamespace().SemaphoreRegistry() set := r.FindByID(id) diff --git a/test/syscalls/linux/semaphore.cc b/test/syscalls/linux/semaphore.cc index fb4695e72..c2f080917 100644 --- a/test/syscalls/linux/semaphore.cc +++ b/test/syscalls/linux/semaphore.cc @@ -20,7 +20,7 @@ #include #include #include -#include +#include #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -32,10 +32,23 @@ #include "test/util/test_util.h" #include "test/util/thread_util.h" +using ::testing::Contains; + namespace gvisor { namespace testing { namespace { +constexpr int kSemMap = 1024000000; +constexpr int kSemMni = 32000; +constexpr int kSemMns = 1024000000; +constexpr int kSemMnu = 1024000000; +constexpr int kSemMsl = 32000; +constexpr int kSemOpm = 500; +constexpr int kSemUme = 500; +constexpr int kSemUsz = 20; +constexpr int kSemVmx = 32767; +constexpr int kSemAem = 32767; + class AutoSem { public: explicit AutoSem(int id) : id_(id) {} @@ -775,42 +788,151 @@ TEST(SemaphoreTest, SemopGetncntOnSignal_NoRandomSave) { } TEST(SemaphoreTest, IpcInfo) { - std::stack sem_ids; - std::stack max_used_indexes; + constexpr int kLoops = 5; + std::set sem_ids; struct seminfo info; - for (int i = 0; i < 3; i++) { - int sem_id = 0; - ASSERT_THAT(sem_id = semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT), - SyscallSucceeds()); - sem_ids.push(sem_id); - int max_used_index = 0; - EXPECT_THAT(max_used_index = semctl(0, 0, IPC_INFO, &info), - SyscallSucceeds()); - if (!max_used_indexes.empty()) { - EXPECT_GT(max_used_index, max_used_indexes.top()); + // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions. + ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false)); + ASSERT_THAT(semctl(0, 0, IPC_INFO, &info), SyscallSucceedsWithValue(0)); + for (int i = 0; i < kLoops; i++) { + AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT)); + ASSERT_THAT(sem.get(), SyscallSucceeds()); + sem_ids.insert(sem.release()); + } + ASSERT_EQ(sem_ids.size(), kLoops); + + int max_used_index = 0; + EXPECT_THAT(max_used_index = semctl(0, 0, IPC_INFO, &info), + SyscallSucceeds()); + + int index_count = 0; + for (int i = 0; i <= max_used_index; i++) { + struct semid_ds ds = {}; + int sem_id = semctl(i, 0, SEM_STAT, &ds); + // Only if index i is used within the registry. + if (sem_id != -1) { + ASSERT_THAT(sem_ids, Contains(sem_id)); + struct semid_ds ipc_stat_ds; + ASSERT_THAT(semctl(sem_id, 0, IPC_STAT, &ipc_stat_ds), SyscallSucceeds()); + EXPECT_EQ(ds.sem_perm.__key, ipc_stat_ds.sem_perm.__key); + EXPECT_EQ(ds.sem_perm.uid, ipc_stat_ds.sem_perm.uid); + EXPECT_EQ(ds.sem_perm.gid, ipc_stat_ds.sem_perm.gid); + EXPECT_EQ(ds.sem_perm.cuid, ipc_stat_ds.sem_perm.cuid); + EXPECT_EQ(ds.sem_perm.cgid, ipc_stat_ds.sem_perm.cgid); + EXPECT_EQ(ds.sem_perm.mode, ipc_stat_ds.sem_perm.mode); + EXPECT_EQ(ds.sem_otime, ipc_stat_ds.sem_otime); + EXPECT_EQ(ds.sem_ctime, ipc_stat_ds.sem_ctime); + EXPECT_EQ(ds.sem_nsems, ipc_stat_ds.sem_nsems); + + // Remove the semaphore set's read permission. + struct semid_ds ipc_set_ds; + ipc_set_ds.sem_perm.uid = getuid(); + ipc_set_ds.sem_perm.gid = getgid(); + // Keep the semaphore set's write permission so that it could be removed. + ipc_set_ds.sem_perm.mode = 0200; + ASSERT_THAT(semctl(sem_id, 0, IPC_SET, &ipc_set_ds), SyscallSucceeds()); + ASSERT_THAT(semctl(i, 0, SEM_STAT, &ds), SyscallFailsWithErrno(EACCES)); + + index_count += 1; } - max_used_indexes.push(max_used_index); } - while (!sem_ids.empty()) { - int sem_id = sem_ids.top(); - sem_ids.pop(); + EXPECT_EQ(index_count, kLoops); + ASSERT_THAT(semctl(0, 0, IPC_INFO, &info), + SyscallSucceedsWithValue(max_used_index)); + for (const int sem_id : sem_ids) { ASSERT_THAT(semctl(sem_id, 0, IPC_RMID), SyscallSucceeds()); - int max_index = max_used_indexes.top(); - EXPECT_THAT(max_index = semctl(0, 0, IPC_INFO, &info), SyscallSucceeds()); - EXPECT_GE(max_used_indexes.top(), max_index); - max_used_indexes.pop(); } + + ASSERT_THAT(semctl(0, 0, IPC_INFO, &info), SyscallSucceedsWithValue(0)); + EXPECT_EQ(info.semmap, kSemMap); + EXPECT_EQ(info.semmni, kSemMni); + EXPECT_EQ(info.semmns, kSemMns); + EXPECT_EQ(info.semmnu, kSemMnu); + EXPECT_EQ(info.semmsl, kSemMsl); + EXPECT_EQ(info.semopm, kSemOpm); + EXPECT_EQ(info.semume, kSemUme); + EXPECT_EQ(info.semusz, kSemUsz); + EXPECT_EQ(info.semvmx, kSemVmx); + EXPECT_EQ(info.semaem, kSemAem); +} + +TEST(SemaphoreTest, SemInfo) { + constexpr int kLoops = 5; + constexpr int kSemSetSize = 3; + std::set sem_ids; + struct seminfo info; + // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions. + ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false)); ASSERT_THAT(semctl(0, 0, IPC_INFO, &info), SyscallSucceedsWithValue(0)); + for (int i = 0; i < kLoops; i++) { + AutoSem sem(semget(IPC_PRIVATE, kSemSetSize, 0600 | IPC_CREAT)); + ASSERT_THAT(sem.get(), SyscallSucceeds()); + sem_ids.insert(sem.release()); + } + ASSERT_EQ(sem_ids.size(), kLoops); + int max_used_index = 0; + EXPECT_THAT(max_used_index = semctl(0, 0, SEM_INFO, &info), + SyscallSucceeds()); + EXPECT_EQ(info.semmap, kSemMap); + EXPECT_EQ(info.semmni, kSemMni); + EXPECT_EQ(info.semmns, kSemMns); + EXPECT_EQ(info.semmnu, kSemMnu); + EXPECT_EQ(info.semmsl, kSemMsl); + EXPECT_EQ(info.semopm, kSemOpm); + EXPECT_EQ(info.semume, kSemUme); + EXPECT_EQ(info.semusz, sem_ids.size()); + EXPECT_EQ(info.semvmx, kSemVmx); + EXPECT_EQ(info.semaem, sem_ids.size() * kSemSetSize); + + int index_count = 0; + for (int i = 0; i <= max_used_index; i++) { + struct semid_ds ds = {}; + int sem_id = semctl(i, 0, SEM_STAT, &ds); + // Only if index i is used within the registry. + if (sem_id != -1) { + ASSERT_THAT(sem_ids, Contains(sem_id)); + struct semid_ds ipc_stat_ds; + ASSERT_THAT(semctl(sem_id, 0, IPC_STAT, &ipc_stat_ds), SyscallSucceeds()); + EXPECT_EQ(ds.sem_perm.__key, ipc_stat_ds.sem_perm.__key); + EXPECT_EQ(ds.sem_perm.uid, ipc_stat_ds.sem_perm.uid); + EXPECT_EQ(ds.sem_perm.gid, ipc_stat_ds.sem_perm.gid); + EXPECT_EQ(ds.sem_perm.cuid, ipc_stat_ds.sem_perm.cuid); + EXPECT_EQ(ds.sem_perm.cgid, ipc_stat_ds.sem_perm.cgid); + EXPECT_EQ(ds.sem_perm.mode, ipc_stat_ds.sem_perm.mode); + EXPECT_EQ(ds.sem_otime, ipc_stat_ds.sem_otime); + EXPECT_EQ(ds.sem_ctime, ipc_stat_ds.sem_ctime); + EXPECT_EQ(ds.sem_nsems, ipc_stat_ds.sem_nsems); + + // Remove the semaphore set's read permission. + struct semid_ds ipc_set_ds; + ipc_set_ds.sem_perm.uid = getuid(); + ipc_set_ds.sem_perm.gid = getgid(); + // Keep the semaphore set's write permission so that it could be removed. + ipc_set_ds.sem_perm.mode = 0200; + ASSERT_THAT(semctl(sem_id, 0, IPC_SET, &ipc_set_ds), SyscallSucceeds()); + ASSERT_THAT(semctl(i, 0, SEM_STAT, &ds), SyscallFailsWithErrno(EACCES)); + + index_count += 1; + } + } + EXPECT_EQ(index_count, kLoops); + ASSERT_THAT(semctl(0, 0, SEM_INFO, &info), + SyscallSucceedsWithValue(max_used_index)); + for (const int sem_id : sem_ids) { + ASSERT_THAT(semctl(sem_id, 0, IPC_RMID), SyscallSucceeds()); + } - EXPECT_EQ(info.semmap, 1024000000); - EXPECT_EQ(info.semmni, 32000); - EXPECT_EQ(info.semmns, 1024000000); - EXPECT_EQ(info.semmnu, 1024000000); - EXPECT_EQ(info.semmsl, 32000); - EXPECT_EQ(info.semopm, 500); - EXPECT_EQ(info.semume, 500); - EXPECT_EQ(info.semvmx, 32767); - EXPECT_EQ(info.semaem, 32767); + ASSERT_THAT(semctl(0, 0, SEM_INFO, &info), SyscallSucceedsWithValue(0)); + EXPECT_EQ(info.semmap, kSemMap); + EXPECT_EQ(info.semmni, kSemMni); + EXPECT_EQ(info.semmns, kSemMns); + EXPECT_EQ(info.semmnu, kSemMnu); + EXPECT_EQ(info.semmsl, kSemMsl); + EXPECT_EQ(info.semopm, kSemOpm); + EXPECT_EQ(info.semume, kSemUme); + EXPECT_EQ(info.semusz, 0); + EXPECT_EQ(info.semvmx, kSemVmx); + EXPECT_EQ(info.semaem, 0); } } // namespace -- cgit v1.2.3 From 97406b20a1551ddc8d1884d11d81b958b829afae Mon Sep 17 00:00:00 2001 From: Andrei Vagin Date: Tue, 15 Dec 2020 16:49:39 -0800 Subject: Internal change. PiperOrigin-RevId: 347720083 --- pkg/cpuid/cpuid.go | 11 +++++++++++ pkg/cpuid/cpuid_x86.go | 11 ----------- pkg/sentry/arch/arch.go | 15 +++++++++++++++ pkg/sentry/arch/arch_state_x86.go | 17 ----------------- 4 files changed, 26 insertions(+), 28 deletions(-) (limited to 'pkg/sentry') diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go index f7f9dbf86..69eeb7528 100644 --- a/pkg/cpuid/cpuid.go +++ b/pkg/cpuid/cpuid.go @@ -36,3 +36,14 @@ package cpuid // On arm64, features are numbered according to the ELF HWCAP definition. // arch/arm64/include/uapi/asm/hwcap.h type Feature int + +// ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a +// subset of the host feature set. +type ErrIncompatible struct { + message string +} + +// Error implements error. +func (e ErrIncompatible) Error() string { + return e.message +} diff --git a/pkg/cpuid/cpuid_x86.go b/pkg/cpuid/cpuid_x86.go index 17a89c00d..392711e8f 100644 --- a/pkg/cpuid/cpuid_x86.go +++ b/pkg/cpuid/cpuid_x86.go @@ -681,17 +681,6 @@ func (fs *FeatureSet) Intel() bool { return fs.VendorID == intelVendorID } -// ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a -// subset of the host feature set. -type ErrIncompatible struct { - message string -} - -// Error implements error. -func (e ErrIncompatible) Error() string { - return e.message -} - // CheckHostCompatible returns nil if fs is a subset of the host feature set. func (fs *FeatureSet) CheckHostCompatible() error { hfs := HostFeatureSet() diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go index d75d665ae..dd2effdf9 100644 --- a/pkg/sentry/arch/arch.go +++ b/pkg/sentry/arch/arch.go @@ -365,3 +365,18 @@ func (a SyscallArgument) SizeT() uint { func (a SyscallArgument) ModeT() uint { return uint(uint16(a.Value)) } + +// ErrFloatingPoint indicates a failed restore due to unusable floating point +// state. +type ErrFloatingPoint struct { + // supported is the supported floating point state. + supported uint64 + + // saved is the saved floating point state. + saved uint64 +} + +// Error returns a sensible description of the restore error. +func (e ErrFloatingPoint) Error() string { + return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supported, e.saved) +} diff --git a/pkg/sentry/arch/arch_state_x86.go b/pkg/sentry/arch/arch_state_x86.go index 19ce99d25..840e53d33 100644 --- a/pkg/sentry/arch/arch_state_x86.go +++ b/pkg/sentry/arch/arch_state_x86.go @@ -17,27 +17,10 @@ package arch import ( - "fmt" - "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/usermem" ) -// ErrFloatingPoint indicates a failed restore due to unusable floating point -// state. -type ErrFloatingPoint struct { - // supported is the supported floating point state. - supported uint64 - - // saved is the saved floating point state. - saved uint64 -} - -// Error returns a sensible description of the restore error. -func (e ErrFloatingPoint) Error() string { - return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supported, e.saved) -} - // XSTATE_BV does not exist if FXSAVE is used, but FXSAVE implicitly saves x87 // and SSE state, so this is the equivalent XSTATE_BV value. const fxsaveBV uint64 = cpuid.XSAVEFeatureX87 | cpuid.XSAVEFeatureSSE -- cgit v1.2.3 From 74788b1b6194ef62f8355f7e4721c00f615d16ad Mon Sep 17 00:00:00 2001 From: Ayush Ranjan Date: Thu, 17 Dec 2020 08:45:38 -0800 Subject: [netstack] Implement MSG_ERRQUEUE flag for recvmsg(2). Introduces the per-socket error queue and the necessary cmsg mechanisms. PiperOrigin-RevId: 348028508 --- pkg/abi/linux/BUILD | 1 + pkg/abi/linux/errqueue.go | 93 ++++++++++++++++++++++++++++++++ pkg/sentry/socket/control/control.go | 39 ++++++++++++++ pkg/sentry/socket/hostinet/socket.go | 19 ++++--- pkg/sentry/socket/netstack/netstack.go | 43 +++++++++++++++ pkg/sentry/socket/socket.go | 55 +++++++++++++++++++ pkg/sentry/syscalls/linux/sys_socket.go | 5 -- pkg/sentry/syscalls/linux/vfs2/socket.go | 5 -- pkg/tcpip/BUILD | 14 +++++ pkg/tcpip/socketops.go | 63 +++++++++++++++++++++- pkg/tcpip/tcpip.go | 3 ++ 11 files changed, 323 insertions(+), 17 deletions(-) create mode 100644 pkg/abi/linux/errqueue.go (limited to 'pkg/sentry') diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index a0654df2f..8fa61d6f7 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -21,6 +21,7 @@ go_library( "epoll_amd64.go", "epoll_arm64.go", "errors.go", + "errqueue.go", "eventfd.go", "exec.go", "fadvise.go", diff --git a/pkg/abi/linux/errqueue.go b/pkg/abi/linux/errqueue.go new file mode 100644 index 000000000..3905d4222 --- /dev/null +++ b/pkg/abi/linux/errqueue.go @@ -0,0 +1,93 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +import ( + "gvisor.dev/gvisor/pkg/marshal" +) + +// Socket error origin codes as defined in include/uapi/linux/errqueue.h. +const ( + SO_EE_ORIGIN_NONE = 0 + SO_EE_ORIGIN_LOCAL = 1 + SO_EE_ORIGIN_ICMP = 2 + SO_EE_ORIGIN_ICMP6 = 3 +) + +// SockExtendedErr represents struct sock_extended_err in Linux defined in +// include/uapi/linux/errqueue.h. +// +// +marshal +type SockExtendedErr struct { + Errno uint32 + Origin uint8 + Type uint8 + Code uint8 + Pad uint8 + Info uint32 + Data uint32 +} + +// SockErrCMsg represents the IP*_RECVERR control message. +type SockErrCMsg interface { + marshal.Marshallable + + CMsgLevel() uint32 + CMsgType() uint32 +} + +// SockErrCMsgIPv4 is the IP_RECVERR control message used in +// recvmsg(MSG_ERRQUEUE) by ipv4 sockets. This is equilavent to `struct errhdr` +// defined in net/ipv4/ip_sockglue.c:ip_recv_error(). +// +// +marshal +type SockErrCMsgIPv4 struct { + SockExtendedErr + Offender SockAddrInet +} + +var _ SockErrCMsg = (*SockErrCMsgIPv4)(nil) + +// CMsgLevel implements SockErrCMsg.CMsgLevel. +func (*SockErrCMsgIPv4) CMsgLevel() uint32 { + return SOL_IP +} + +// CMsgType implements SockErrCMsg.CMsgType. +func (*SockErrCMsgIPv4) CMsgType() uint32 { + return IP_RECVERR +} + +// SockErrCMsgIPv6 is the IPV6_RECVERR control message used in +// recvmsg(MSG_ERRQUEUE) by ipv6 sockets. This is equilavent to `struct errhdr` +// defined in net/ipv6/datagram.c:ipv6_recv_error(). +// +// +marshal +type SockErrCMsgIPv6 struct { + SockExtendedErr + Offender SockAddrInet6 +} + +var _ SockErrCMsg = (*SockErrCMsgIPv6)(nil) + +// CMsgLevel implements SockErrCMsg.CMsgLevel. +func (*SockErrCMsgIPv6) CMsgLevel() uint32 { + return SOL_IPV6 +} + +// CMsgType implements SockErrCMsg.CMsgType. +func (*SockErrCMsgIPv6) CMsgType() uint32 { + return IPV6_RECVERR +} diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index b88cdca48..ff6b71802 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -371,6 +371,17 @@ func PackOriginalDstAddress(t *kernel.Task, originalDstAddress linux.SockAddr, b buf, level, optType, t.Arch().Width(), originalDstAddress) } +// PackSockExtendedErr packs an IP*_RECVERR socket control message. +func PackSockExtendedErr(t *kernel.Task, sockErr linux.SockErrCMsg, buf []byte) []byte { + return putCmsgStruct( + buf, + sockErr.CMsgLevel(), + sockErr.CMsgType(), + t.Arch().Width(), + sockErr, + ) +} + // PackControlMessages packs control messages into the given buffer. // // We skip control messages specific to Unix domain sockets. @@ -403,6 +414,10 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt buf = PackOriginalDstAddress(t, cmsgs.IP.OriginalDstAddress, buf) } + if cmsgs.IP.SockErr != nil { + buf = PackSockExtendedErr(t, cmsgs.IP.SockErr, buf) + } + return buf } @@ -440,6 +455,10 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int { space += cmsgSpace(t, cmsgs.IP.OriginalDstAddress.SizeBytes()) } + if cmsgs.IP.SockErr != nil { + space += cmsgSpace(t, cmsgs.IP.SockErr.SizeBytes()) + } + return space } @@ -546,6 +565,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con cmsgs.IP.OriginalDstAddress = &addr i += binary.AlignUp(length, width) + case linux.IP_RECVERR: + var errCmsg linux.SockErrCMsgIPv4 + if length < errCmsg.SizeBytes() { + return socket.ControlMessages{}, syserror.EINVAL + } + + errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) + cmsgs.IP.SockErr = &errCmsg + i += binary.AlignUp(length, width) + default: return socket.ControlMessages{}, syserror.EINVAL } @@ -568,6 +597,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con cmsgs.IP.OriginalDstAddress = &addr i += binary.AlignUp(length, width) + case linux.IPV6_RECVERR: + var errCmsg linux.SockErrCMsgIPv6 + if length < errCmsg.SizeBytes() { + return socket.ControlMessages{}, syserror.EINVAL + } + + errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) + cmsgs.IP.SockErr = &errCmsg + i += binary.AlignUp(length, width) + default: return socket.ControlMessages{}, syserror.EINVAL } diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 1f220c343..2b34ef190 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -450,11 +450,7 @@ func (s *socketOpsCommon) recvMsgFromHost(iovs []syscall.Iovec, flags int, sende // RecvMsg implements socket.Socket.RecvMsg. func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { // Only allow known and safe flags. - // - // FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary - // messages that gvisor/pkg/tcpip/transport/unix doesn't understand. Kill the - // Socket interface's dependence on netstack. - if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC) != 0 { + if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC|syscall.MSG_ERRQUEUE) != 0 { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument } @@ -488,7 +484,8 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags var ch chan struct{} n, err := copyToDst() - if flags&syscall.MSG_DONTWAIT == 0 { + // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT. + if flags&(syscall.MSG_DONTWAIT|syscall.MSG_ERRQUEUE) == 0 { for err == syserror.ErrWouldBlock { // We only expect blocking to come from the actual syscall, in which // case it can't have returned any data. @@ -551,6 +548,11 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s var addr linux.SockAddrInet binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr) controlMessages.IP.OriginalDstAddress = &addr + + case syscall.IP_RECVERR: + var errCmsg linux.SockErrCMsgIPv4 + errCmsg.UnmarshalBytes(unixCmsg.Data) + controlMessages.IP.SockErr = &errCmsg } case linux.SOL_IPV6: @@ -563,6 +565,11 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s var addr linux.SockAddrInet6 binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr) controlMessages.IP.OriginalDstAddress = &addr + + case syscall.IPV6_RECVERR: + var errCmsg linux.SockErrCMsgIPv6 + errCmsg.UnmarshalBytes(unixCmsg.Data) + controlMessages.IP.SockErr = &errCmsg } case linux.SOL_TCP: diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 23d5cab9c..a8ab6b385 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -2772,6 +2772,8 @@ func (s *socketOpsCommon) controlMessages() socket.ControlMessages { IP: socket.IPControlMessages{ HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, Timestamp: s.readCM.Timestamp, + HasInq: s.readCM.HasInq, + Inq: s.readCM.Inq, HasTOS: s.readCM.HasTOS, TOS: s.readCM.TOS, HasTClass: s.readCM.HasTClass, @@ -2779,6 +2781,7 @@ func (s *socketOpsCommon) controlMessages() socket.ControlMessages { HasIPPacketInfo: s.readCM.HasIPPacketInfo, PacketInfo: s.readCM.PacketInfo, OriginalDstAddress: s.readCM.OriginalDstAddress, + SockErr: s.readCM.SockErr, }, } } @@ -2795,9 +2798,49 @@ func (s *socketOpsCommon) updateTimestamp() { } } +// addrFamilyFromNetProto returns the address family identifier for the given +// network protocol. +func addrFamilyFromNetProto(net tcpip.NetworkProtocolNumber) int { + switch net { + case header.IPv4ProtocolNumber: + return linux.AF_INET + case header.IPv6ProtocolNumber: + return linux.AF_INET6 + default: + panic(fmt.Sprintf("invalid net proto for addr family inference: %d", net)) + } +} + +// recvErr handles MSG_ERRQUEUE for recvmsg(2). +// This is analogous to net/ipv4/ip_sockglue.c:ip_recv_error(). +func (s *socketOpsCommon) recvErr(t *kernel.Task, dst usermem.IOSequence) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { + sockErr := s.Endpoint.SocketOptions().DequeueErr() + if sockErr == nil { + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain + } + + // The payload of the original packet that caused the error is passed as + // normal data via msg_iovec. -- recvmsg(2) + msgFlags := linux.MSG_ERRQUEUE + if int(dst.NumBytes()) < len(sockErr.Payload) { + msgFlags |= linux.MSG_TRUNC + } + n, err := dst.CopyOut(t, sockErr.Payload) + + // The original destination address of the datagram that caused the error is + // supplied via msg_name. -- recvmsg(2) + dstAddr, dstAddrLen := socket.ConvertAddress(addrFamilyFromNetProto(sockErr.NetProto), sockErr.Dst) + cmgs := socket.ControlMessages{IP: socket.NewIPControlMessages(s.family, tcpip.ControlMessages{SockErr: sockErr})} + return n, msgFlags, dstAddr, dstAddrLen, cmgs, syserr.FromError(err) +} + // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // tcpip.Endpoint. func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { + if flags&linux.MSG_ERRQUEUE != 0 { + return s.recvErr(t, dst) + } + trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index bcc426e33..97729dacc 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -56,6 +56,57 @@ func packetInfoToLinux(packetInfo tcpip.IPPacketInfo) linux.ControlMessageIPPack return p } +// errOriginToLinux maps tcpip socket origin to Linux socket origin constants. +func errOriginToLinux(origin tcpip.SockErrOrigin) uint8 { + switch origin { + case tcpip.SockExtErrorOriginNone: + return linux.SO_EE_ORIGIN_NONE + case tcpip.SockExtErrorOriginLocal: + return linux.SO_EE_ORIGIN_LOCAL + case tcpip.SockExtErrorOriginICMP: + return linux.SO_EE_ORIGIN_ICMP + case tcpip.SockExtErrorOriginICMP6: + return linux.SO_EE_ORIGIN_ICMP6 + default: + panic(fmt.Sprintf("unknown socket origin: %d", origin)) + } +} + +// sockErrCmsgToLinux converts SockError control message from tcpip format to +// Linux format. +func sockErrCmsgToLinux(sockErr *tcpip.SockError) linux.SockErrCMsg { + if sockErr == nil { + return nil + } + + ee := linux.SockExtendedErr{ + Errno: uint32(syserr.TranslateNetstackError(sockErr.Err).ToLinux().Number()), + Origin: errOriginToLinux(sockErr.ErrOrigin), + Type: sockErr.ErrType, + Code: sockErr.ErrCode, + Info: sockErr.ErrInfo, + } + + switch sockErr.NetProto { + case header.IPv4ProtocolNumber: + errMsg := &linux.SockErrCMsgIPv4{SockExtendedErr: ee} + if len(sockErr.Offender.Addr) > 0 { + addr, _ := ConvertAddress(linux.AF_INET, sockErr.Offender) + errMsg.Offender = *addr.(*linux.SockAddrInet) + } + return errMsg + case header.IPv6ProtocolNumber: + errMsg := &linux.SockErrCMsgIPv6{SockExtendedErr: ee} + if len(sockErr.Offender.Addr) > 0 { + addr, _ := ConvertAddress(linux.AF_INET6, sockErr.Offender) + errMsg.Offender = *addr.(*linux.SockAddrInet6) + } + return errMsg + default: + panic(fmt.Sprintf("invalid net proto for creating SockErrCMsg: %d", sockErr.NetProto)) + } +} + // NewIPControlMessages converts the tcpip ControlMessgaes (which does not // have Linux specific format) to Linux format. func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessages { @@ -75,6 +126,7 @@ func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessa HasIPPacketInfo: cmgs.HasIPPacketInfo, PacketInfo: packetInfoToLinux(cmgs.PacketInfo), OriginalDstAddress: orgDstAddr, + SockErr: sockErrCmsgToLinux(cmgs.SockErr), } } @@ -117,6 +169,9 @@ type IPControlMessages struct { // OriginalDestinationAddress holds the original destination address // and port of the incoming packet. OriginalDstAddress linux.SockAddr + + // SockErr is the dequeued socket error on recvmsg(MSG_ERRQUEUE). + SockErr linux.SockErrCMsg } // Release releases Unix domain socket credentials and rights. diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 9cd052c3d..4adfa6637 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -749,11 +749,6 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i return 0, err } - // FIXME(b/63594852): Pretend we have an empty error queue. - if flags&linux.MSG_ERRQUEUE != 0 { - return 0, syserror.EAGAIN - } - // Fast path when no control message nor name buffers are provided. if msg.ControlLen == 0 && msg.NameLen == 0 { n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0) diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index 7b33b3f59..987012acc 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -752,11 +752,6 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla return 0, err } - // FIXME(b/63594852): Pretend we have an empty error queue. - if flags&linux.MSG_ERRQUEUE != 0 { - return 0, syserror.EAGAIN - } - // Fast path when no control message nor name buffers are provided. if msg.ControlLen == 0 && msg.NameLen == 0 { n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0) diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 27f96a3ac..89b765f1b 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -1,10 +1,24 @@ load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) +go_template_instance( + name = "sock_err_list", + out = "sock_err_list.go", + package = "tcpip", + prefix = "sockError", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*SockError", + "Linker": "*SockError", + }, +) + go_library( name = "tcpip", srcs = [ + "sock_err_list.go", "socketops.go", "tcpip.go", "time_unsafe.go", diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index b60a5fd76..eb63d735f 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -104,7 +104,7 @@ type SocketOptions struct { keepAliveEnabled uint32 // multicastLoopEnabled determines whether multicast packets sent over a - // non-loopback interface will be looped back. Analogous to inet->mc_loop. + // non-loopback interface will be looped back. multicastLoopEnabled uint32 // receiveTOSEnabled is used to specify if the TOS ancillary message is @@ -145,6 +145,10 @@ type SocketOptions struct { // the incoming packet should be returned as an ancillary message. receiveOriginalDstAddress uint32 + // errQueue is the per-socket error queue. It is protected by errQueueMu. + errQueueMu sync.Mutex `state:"nosave"` + errQueue sockErrorList + // mu protects the access to the below fields. mu sync.Mutex `state:"nosave"` @@ -362,3 +366,60 @@ func (so *SocketOptions) SetLinger(linger LingerOption) { so.linger = linger so.mu.Unlock() } + +// SockErrOrigin represents the constants for error origin. +type SockErrOrigin uint8 + +const ( + // SockExtErrorOriginNone represents an unknown error origin. + SockExtErrorOriginNone SockErrOrigin = iota + + // SockExtErrorOriginLocal indicates a local error. + SockExtErrorOriginLocal + + // SockExtErrorOriginICMP indicates an IPv4 ICMP error. + SockExtErrorOriginICMP + + // SockExtErrorOriginICMP6 indicates an IPv6 ICMP error. + SockExtErrorOriginICMP6 +) + +// SockError represents a queue entry in the per-socket error queue. +// +// +stateify savable +type SockError struct { + sockErrorEntry + + // Err is the error caused by the errant packet. + Err *Error + // ErrOrigin indicates the error origin. + ErrOrigin SockErrOrigin + // ErrType is the type in the ICMP header. + ErrType uint8 + // ErrCode is the code in the ICMP header. + ErrCode uint8 + // ErrInfo is additional info about the error. + ErrInfo uint32 + + // Payload is the errant packet's payload. + Payload []byte + // Dst is the original destination address of the errant packet. + Dst FullAddress + // Offender is the original sender address of the errant packet. + Offender FullAddress + // NetProto is the network protocol being used to transmit the packet. + NetProto NetworkProtocolNumber +} + +// DequeueErr dequeues a socket extended error from the error queue and returns +// it. Returns nil if queue is empty. +func (so *SocketOptions) DequeueErr() *SockError { + so.errQueueMu.Lock() + defer so.errQueueMu.Unlock() + + err := so.errQueue.Front() + if err != nil { + so.errQueue.Remove(err) + } + return err +} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 45fa62720..a488cc108 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -500,6 +500,9 @@ type ControlMessages struct { // OriginalDestinationAddress holds the original destination address // and port of the incoming packet. OriginalDstAddress FullAddress + + // SockErr is the dequeued socket error on recvmsg(MSG_ERRQUEUE). + SockErr *SockError } // PacketOwner is used to get UID and GID of the packet. -- cgit v1.2.3 From 30860902f6953348577e6a1d742521c6fbc4c75d Mon Sep 17 00:00:00 2001 From: Fabricio Voznika Date: Thu, 17 Dec 2020 10:53:50 -0800 Subject: Set process group and session on host TTY Closes #5128 PiperOrigin-RevId: 348052446 --- pkg/sentry/fsimpl/host/host.go | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'pkg/sentry') diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 435a21d77..36a3f6810 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -31,6 +31,7 @@ import ( fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/hostfd" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix" @@ -499,6 +500,10 @@ func (i *inode) open(ctx context.Context, d *kernfs.Dentry, mnt *vfs.Mount, flag fileDescription: fileDescription{inode: i}, termios: linux.DefaultReplicaTermios, } + if task := kernel.TaskFromContext(ctx); task != nil { + fd.fgProcessGroup = task.ThreadGroup().ProcessGroup() + fd.session = fd.fgProcessGroup.Session() + } fd.LockFD.Init(&i.locks) vfsfd := &fd.vfsfd if err := vfsfd.Init(fd, flags, mnt, d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { -- cgit v1.2.3 From 028271b5308708463d2aa593122840e70c93f02c Mon Sep 17 00:00:00 2001 From: Ayush Ranjan Date: Thu, 17 Dec 2020 11:07:56 -0800 Subject: [netstack] Implement IP(V6)_RECVERR socket option. PiperOrigin-RevId: 348055514 --- pkg/sentry/socket/hostinet/socket.go | 8 +-- pkg/sentry/socket/netstack/netstack.go | 58 ++++++++++++++++++++-- pkg/tcpip/header/icmpv4.go | 14 ++++++ pkg/tcpip/socketops.go | 69 ++++++++++++++++++++++++++ pkg/tcpip/transport/packet/endpoint.go | 7 +++ pkg/tcpip/transport/tcp/endpoint.go | 54 +++++++++++++++++--- pkg/tcpip/transport/udp/endpoint.go | 79 +++++++++++++++++++++++++++-- runsc/boot/filter/config.go | 38 +++++++++++--- test/syscalls/linux/udp_socket.cc | 90 ++++++++++++++++++++++++++++++++++ 9 files changed, 389 insertions(+), 28 deletions(-) (limited to 'pkg/sentry') diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 2b34ef190..5b868216d 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -331,12 +331,12 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr switch level { case linux.SOL_IP: switch name { - case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO, linux.IP_RECVORIGDSTADDR: + case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO, linux.IP_RECVORIGDSTADDR, linux.IP_RECVERR: optlen = sizeofInt32 } case linux.SOL_IPV6: switch name { - case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_RECVERR, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR: optlen = sizeofInt32 } case linux.SOL_SOCKET: @@ -377,14 +377,14 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] switch level { case linux.SOL_IP: switch name { - case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_RECVORIGDSTADDR: + case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_RECVORIGDSTADDR, linux.IP_RECVERR: optlen = sizeofInt32 case linux.IP_PKTINFO: optlen = linux.SizeOfControlMessageIPPacketInfo } case linux.SOL_IPV6: switch name { - case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_RECVERR, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR: optlen = sizeofInt32 } case linux.SOL_SOCKET: diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index a8ab6b385..460c95b9f 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1405,6 +1405,13 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTClass())) return &v, nil + case linux.IPV6_RECVERR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetRecvError())) + return &v, nil case linux.IPV6_RECVORIGDSTADDR: if outLen < sizeOfInt32 { @@ -1579,6 +1586,14 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTOS())) return &v, nil + case linux.IP_RECVERR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetRecvError())) + return &v, nil + case linux.IP_PKTINFO: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument @@ -2129,6 +2144,16 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name ep.SocketOptions().SetReceiveTClass(v != 0) return nil + case linux.IPV6_RECVERR: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + ep.SocketOptions().SetRecvError(v != 0) + return nil case linux.IP6T_SO_SET_REPLACE: if len(optVal) < linux.SizeOfIP6TReplace { @@ -2317,6 +2342,17 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in ep.SocketOptions().SetReceiveTOS(v != 0) return nil + case linux.IP_RECVERR: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + ep.SocketOptions().SetRecvError(v != 0) + return nil + case linux.IP_PKTINFO: if len(optVal) == 0 { return nil @@ -2386,7 +2422,6 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in linux.IP_NODEFRAG, linux.IP_OPTIONS, linux.IP_PASSSEC, - linux.IP_RECVERR, linux.IP_RECVFRAGSIZE, linux.IP_RECVOPTS, linux.IP_RECVTTL, @@ -2462,7 +2497,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) { linux.IPV6_MULTICAST_IF, linux.IPV6_MULTICAST_LOOP, linux.IPV6_RECVDSTOPTS, - linux.IPV6_RECVERR, linux.IPV6_RECVFRAGSIZE, linux.IPV6_RECVHOPLIMIT, linux.IPV6_RECVHOPOPTS, @@ -2496,7 +2530,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) { linux.IP_PKTINFO, linux.IP_PKTOPTIONS, linux.IP_MTU_DISCOVER, - linux.IP_RECVERR, linux.IP_RECVTTL, linux.IP_RECVTOS, linux.IP_MTU, @@ -2798,6 +2831,23 @@ func (s *socketOpsCommon) updateTimestamp() { } } +// dequeueErr is analogous to net/core/skbuff.c:sock_dequeue_err_skb(). +func (s *socketOpsCommon) dequeueErr() *tcpip.SockError { + so := s.Endpoint.SocketOptions() + err := so.DequeueErr() + if err == nil { + return nil + } + + // Update socket error to reflect ICMP errors in queue. + if nextErr := so.PeekErr(); nextErr != nil && nextErr.ErrOrigin.IsICMPErr() { + so.SetLastError(nextErr.Err) + } else if err.ErrOrigin.IsICMPErr() { + so.SetLastError(nil) + } + return err +} + // addrFamilyFromNetProto returns the address family identifier for the given // network protocol. func addrFamilyFromNetProto(net tcpip.NetworkProtocolNumber) int { @@ -2814,7 +2864,7 @@ func addrFamilyFromNetProto(net tcpip.NetworkProtocolNumber) int { // recvErr handles MSG_ERRQUEUE for recvmsg(2). // This is analogous to net/ipv4/ip_sockglue.c:ip_recv_error(). func (s *socketOpsCommon) recvErr(t *kernel.Task, dst usermem.IOSequence) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { - sockErr := s.Endpoint.SocketOptions().DequeueErr() + sockErr := s.dequeueErr() if sockErr == nil { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain } diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index 2f13dea6a..1be90d7d5 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -16,6 +16,7 @@ package header import ( "encoding/binary" + "fmt" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -213,3 +214,16 @@ func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 { return xsum } + +// ICMPOriginFromNetProto returns the appropriate SockErrOrigin to use when +// a packet having a `net` header causing an ICMP error. +func ICMPOriginFromNetProto(net tcpip.NetworkProtocolNumber) tcpip.SockErrOrigin { + switch net { + case IPv4ProtocolNumber: + return tcpip.SockExtErrorOriginICMP + case IPv6ProtocolNumber: + return tcpip.SockExtErrorOriginICMP6 + default: + panic(fmt.Sprintf("unsupported net proto to extract ICMP error origin: %d", net)) + } +} diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index eb63d735f..095d1734a 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -42,6 +42,9 @@ type SocketOptionsHandler interface { // LastError is invoked when SO_ERROR is read for an endpoint. LastError() *Error + + // UpdateLastError updates the endpoint specific last error field. + UpdateLastError(err *Error) } // DefaultSocketOptionsHandler is an embeddable type that implements no-op @@ -70,6 +73,9 @@ func (*DefaultSocketOptionsHandler) LastError() *Error { return nil } +// UpdateLastError implements SocketOptionsHandler.UpdateLastError. +func (*DefaultSocketOptionsHandler) UpdateLastError(*Error) {} + // SocketOptions contains all the variables which store values for SOL_SOCKET, // SOL_IP, SOL_IPV6 and SOL_TCP level options. // @@ -145,6 +151,10 @@ type SocketOptions struct { // the incoming packet should be returned as an ancillary message. receiveOriginalDstAddress uint32 + // recvErrEnabled determines whether extended reliable error message passing + // is enabled. + recvErrEnabled uint32 + // errQueue is the per-socket error queue. It is protected by errQueueMu. errQueueMu sync.Mutex `state:"nosave"` errQueue sockErrorList @@ -171,6 +181,11 @@ func storeAtomicBool(addr *uint32, v bool) { atomic.StoreUint32(addr, val) } +// SetLastError sets the last error for a socket. +func (so *SocketOptions) SetLastError(err *Error) { + so.handler.UpdateLastError(err) +} + // GetBroadcast gets value for SO_BROADCAST option. func (so *SocketOptions) GetBroadcast() bool { return atomic.LoadUint32(&so.broadcastEnabled) != 0 @@ -338,6 +353,19 @@ func (so *SocketOptions) SetReceiveOriginalDstAddress(v bool) { storeAtomicBool(&so.receiveOriginalDstAddress, v) } +// GetRecvError gets value for IP*_RECVERR option. +func (so *SocketOptions) GetRecvError() bool { + return atomic.LoadUint32(&so.recvErrEnabled) != 0 +} + +// SetRecvError sets value for IP*_RECVERR option. +func (so *SocketOptions) SetRecvError(v bool) { + storeAtomicBool(&so.recvErrEnabled, v) + if !v { + so.pruneErrQueue() + } +} + // GetLastError gets value for SO_ERROR option. func (so *SocketOptions) GetLastError() *Error { return so.handler.LastError() @@ -384,6 +412,11 @@ const ( SockExtErrorOriginICMP6 ) +// IsICMPErr indicates if the error originated from an ICMP error. +func (origin SockErrOrigin) IsICMPErr() bool { + return origin == SockExtErrorOriginICMP || origin == SockExtErrorOriginICMP6 +} + // SockError represents a queue entry in the per-socket error queue. // // +stateify savable @@ -411,6 +444,13 @@ type SockError struct { NetProto NetworkProtocolNumber } +// pruneErrQueue resets the queue. +func (so *SocketOptions) pruneErrQueue() { + so.errQueueMu.Lock() + so.errQueue.Reset() + so.errQueueMu.Unlock() +} + // DequeueErr dequeues a socket extended error from the error queue and returns // it. Returns nil if queue is empty. func (so *SocketOptions) DequeueErr() *SockError { @@ -423,3 +463,32 @@ func (so *SocketOptions) DequeueErr() *SockError { } return err } + +// PeekErr returns the error in the front of the error queue. Returns nil if +// the error queue is empty. +func (so *SocketOptions) PeekErr() *SockError { + so.errQueueMu.Lock() + defer so.errQueueMu.Unlock() + return so.errQueue.Front() +} + +// QueueErr inserts the error at the back of the error queue. +// +// Preconditions: so.GetRecvError() == true. +func (so *SocketOptions) QueueErr(err *SockError) { + so.errQueueMu.Lock() + defer so.errQueueMu.Unlock() + so.errQueue.PushBack(err) +} + +// QueueLocalErr queues a local error onto the local queue. +func (so *SocketOptions) QueueLocalErr(err *Error, net NetworkProtocolNumber, info uint32, dst FullAddress, payload []byte) { + so.QueueErr(&SockError{ + Err: err, + ErrOrigin: SockExtErrorOriginLocal, + ErrInfo: info, + Payload: payload, + Dst: dst, + NetProto: net, + }) +} diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 9faab4b9e..e5e247342 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -366,6 +366,13 @@ func (ep *endpoint) LastError() *tcpip.Error { return err } +// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. +func (ep *endpoint) UpdateLastError(err *tcpip.Error) { + ep.lastErrorMu.Lock() + ep.lastError = err + ep.lastErrorMu.Unlock() +} + // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { return tcpip.ErrNotSupported diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index bb0795f78..2128206d7 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1303,6 +1303,15 @@ func (e *endpoint) LastError() *tcpip.Error { return e.lastErrorLocked() } +// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. +func (e *endpoint) UpdateLastError(err *tcpip.Error) { + e.LockUser() + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() + e.UnlockUser() +} + // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.LockUser() @@ -2708,6 +2717,41 @@ func (e *endpoint) enqueueSegment(s *segment) bool { return true } +func (e *endpoint) onICMPError(err *tcpip.Error, id stack.TransportEndpointID, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { + // Update last error first. + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() + + // Update the error queue if IP_RECVERR is enabled. + if e.SocketOptions().GetRecvError() { + e.SocketOptions().QueueErr(&tcpip.SockError{ + Err: err, + ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber), + ErrType: errType, + ErrCode: errCode, + ErrInfo: extra, + // Linux passes the payload with the TCP header. We don't know if the TCP + // header even exists, it may not for fragmented packets. + Payload: pkt.Data.ToView(), + Dst: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.RemoteAddress, + Port: id.RemotePort, + }, + Offender: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.LocalAddress, + Port: id.LocalPort, + }, + NetProto: pkt.NetworkProtocolNumber, + }) + } + + // Notify of the error. + e.notifyProtocolGoroutine(notifyError) +} + // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { switch typ { @@ -2722,16 +2766,10 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C e.notifyProtocolGoroutine(notifyMTUChanged) case stack.ControlNoRoute: - e.lastErrorMu.Lock() - e.lastError = tcpip.ErrNoRoute - e.lastErrorMu.Unlock() - e.notifyProtocolGoroutine(notifyError) + e.onICMPError(tcpip.ErrNoRoute, id, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt) case stack.ControlNetworkUnreachable: - e.lastErrorMu.Lock() - e.lastError = tcpip.ErrNetworkUnreachable - e.lastErrorMu.Unlock() - e.notifyProtocolGoroutine(notifyError) + e.onICMPError(tcpip.ErrNetworkUnreachable, id, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt) } } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 8e16c8435..d919fa011 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -226,6 +226,13 @@ func (e *endpoint) LastError() *tcpip.Error { return err } +// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. +func (e *endpoint) UpdateLastError(err *tcpip.Error) { + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() +} + // Abort implements stack.TransportEndpoint.Abort. func (e *endpoint) Abort() { e.Close() @@ -511,6 +518,20 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } if len(v) > header.UDPMaximumPacketSize { // Payload can't possibly fit in a packet. + so := e.SocketOptions() + if so.GetRecvError() { + so.QueueLocalErr( + tcpip.ErrMessageTooLong, + route.NetProto, + header.UDPMaximumPacketSize, + tcpip.FullAddress{ + NIC: route.NICID(), + Addr: route.RemoteAddress, + Port: dstPort, + }, + v, + ) + } return 0, nil, tcpip.ErrMessageTooLong } @@ -1338,15 +1359,63 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } } +func (e *endpoint) onICMPError(err *tcpip.Error, id stack.TransportEndpointID, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { + // Update last error first. + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() + + // Update the error queue if IP_RECVERR is enabled. + if e.SocketOptions().GetRecvError() { + // Linux passes the payload without the UDP header. + var payload []byte + udp := header.UDP(pkt.Data.ToView()) + if len(udp) >= header.UDPMinimumSize { + payload = udp.Payload() + } + + e.SocketOptions().QueueErr(&tcpip.SockError{ + Err: err, + ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber), + ErrType: errType, + ErrCode: errCode, + ErrInfo: extra, + Payload: payload, + Dst: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.RemoteAddress, + Port: id.RemotePort, + }, + Offender: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.LocalAddress, + Port: id.LocalPort, + }, + NetProto: pkt.NetworkProtocolNumber, + }) + } + + // Notify of the error. + e.waiterQueue.Notify(waiter.EventErr) +} + // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { if typ == stack.ControlPortUnreachable { if e.EndpointState() == StateConnected { - e.lastErrorMu.Lock() - e.lastError = tcpip.ErrConnectionRefused - e.lastErrorMu.Unlock() - - e.waiterQueue.Notify(waiter.EventErr) + var errType byte + var errCode byte + switch pkt.NetworkProtocolNumber { + case header.IPv4ProtocolNumber: + errType = byte(header.ICMPv4DstUnreachable) + errCode = byte(header.ICMPv4PortUnreachable) + case header.IPv6ProtocolNumber: + errType = byte(header.ICMPv6DstUnreachable) + errCode = byte(header.ICMPv6PortUnreachable) + default: + panic(fmt.Sprintf("unsupported net proto for infering ICMP type and code: %d", pkt.NetworkProtocolNumber)) + } + e.onICMPError(tcpip.ErrConnectionRefused, id, errType, errCode, extra, pkt) return } } diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go index 4e3bb9ac7..eacd73531 100644 --- a/runsc/boot/filter/config.go +++ b/runsc/boot/filter/config.go @@ -351,6 +351,11 @@ func hostInetFilters() seccomp.SyscallRules { seccomp.EqualTo(syscall.SOL_IP), seccomp.EqualTo(syscall.IP_RECVORIGDSTADDR), }, + { + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IP), + seccomp.EqualTo(syscall.IP_RECVERR), + }, { seccomp.MatchAny{}, seccomp.EqualTo(syscall.SOL_IPV6), @@ -361,6 +366,11 @@ func hostInetFilters() seccomp.SyscallRules { seccomp.EqualTo(syscall.SOL_IPV6), seccomp.EqualTo(syscall.IPV6_RECVTCLASS), }, + { + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IPV6), + seccomp.EqualTo(syscall.IPV6_RECVERR), + }, { seccomp.MatchAny{}, seccomp.EqualTo(syscall.SOL_IPV6), @@ -444,13 +454,6 @@ func hostInetFilters() seccomp.SyscallRules { syscall.SYS_SENDMSG: {}, syscall.SYS_SENDTO: {}, syscall.SYS_SETSOCKOPT: []seccomp.Rule{ - { - seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_IPV6), - seccomp.EqualTo(syscall.IPV6_V6ONLY), - seccomp.MatchAny{}, - seccomp.EqualTo(4), - }, { seccomp.MatchAny{}, seccomp.EqualTo(syscall.SOL_SOCKET), @@ -521,6 +524,13 @@ func hostInetFilters() seccomp.SyscallRules { seccomp.MatchAny{}, seccomp.EqualTo(4), }, + { + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IP), + seccomp.EqualTo(syscall.IP_RECVERR), + seccomp.MatchAny{}, + seccomp.EqualTo(4), + }, { seccomp.MatchAny{}, seccomp.EqualTo(syscall.SOL_IPV6), @@ -542,6 +552,20 @@ func hostInetFilters() seccomp.SyscallRules { seccomp.MatchAny{}, seccomp.EqualTo(4), }, + { + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IPV6), + seccomp.EqualTo(syscall.IPV6_RECVERR), + seccomp.MatchAny{}, + seccomp.EqualTo(4), + }, + { + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IPV6), + seccomp.EqualTo(syscall.IPV6_V6ONLY), + seccomp.MatchAny{}, + seccomp.EqualTo(4), + }, }, syscall.SYS_SHUTDOWN: []seccomp.Rule{ { diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc index 90ef8bf21..21727a2e7 100644 --- a/test/syscalls/linux/udp_socket.cc +++ b/test/syscalls/linux/udp_socket.cc @@ -14,6 +14,8 @@ #include #include +#include +#include #include @@ -779,6 +781,94 @@ TEST_P(UdpSocketTest, ConnectAndSendNoReceiver) { SyscallFailsWithErrno(ECONNREFUSED)); } +#ifdef __linux__ +TEST_P(UdpSocketTest, RecvErrorConnRefused) { + // We will simulate an ICMP error and verify that we do receive that error via + // recvmsg(MSG_ERRQUEUE). + ASSERT_NO_ERRNO(BindLoopback()); + // Close the socket to release the port so that we get an ICMP error. + ASSERT_THAT(close(bind_.release()), SyscallSucceeds()); + + // Set IP_RECVERR socket option to enable error queueing. + int v = kSockOptOn; + socklen_t optlen = sizeof(v); + int opt_level = SOL_IP; + int opt_type = IP_RECVERR; + if (GetParam() != AddressFamily::kIpv4) { + opt_level = SOL_IPV6; + opt_type = IPV6_RECVERR; + } + ASSERT_THAT(setsockopt(sock_.get(), opt_level, opt_type, &v, optlen), + SyscallSucceeds()); + + // Connect to loopback:bind_addr_ which should *hopefully* not be bound by an + // UDP socket. There is no easy way to ensure that the UDP port is not bound + // by another conncurrently running test. *This is potentially flaky*. + const int kBufLen = 300; + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + char buf[kBufLen]; + RandomizeBuffer(buf, sizeof(buf)); + // Send from sock_ to an unbound port. This should cause ECONNREFUSED. + EXPECT_THAT(send(sock_.get(), buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(buf))); + + // Dequeue error using recvmsg(MSG_ERRQUEUE). + char got[kBufLen]; + struct iovec iov; + iov.iov_base = reinterpret_cast(got); + iov.iov_len = kBufLen; + + size_t control_buf_len = CMSG_SPACE(sizeof(sock_extended_err) + addrlen_); + char* control_buf = static_cast(calloc(1, control_buf_len)); + struct sockaddr_storage remote; + memset(&remote, 0, sizeof(remote)); + struct msghdr msg = {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_flags = 0; + msg.msg_control = control_buf; + msg.msg_controllen = control_buf_len; + msg.msg_name = reinterpret_cast(&remote); + msg.msg_namelen = addrlen_; + ASSERT_THAT(recvmsg(sock_.get(), &msg, MSG_ERRQUEUE), + SyscallSucceedsWithValue(kBufLen)); + + // Check the contents of msg. + EXPECT_EQ(memcmp(got, buf, sizeof(buf)), 0); // iovec check + EXPECT_NE(msg.msg_flags & MSG_ERRQUEUE, 0); + EXPECT_EQ(memcmp(&remote, bind_addr_, addrlen_), 0); + + // Check the contents of the control message. + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(CMSG_NXTHDR(&msg, cmsg), nullptr); + EXPECT_EQ(cmsg->cmsg_level, opt_level); + EXPECT_EQ(cmsg->cmsg_type, opt_type); + + // Check the contents of socket error. + struct sock_extended_err* sock_err = + (struct sock_extended_err*)CMSG_DATA(cmsg); + EXPECT_EQ(sock_err->ee_errno, ECONNREFUSED); + if (GetParam() == AddressFamily::kIpv4) { + EXPECT_EQ(sock_err->ee_origin, SO_EE_ORIGIN_ICMP); + EXPECT_EQ(sock_err->ee_type, ICMP_DEST_UNREACH); + EXPECT_EQ(sock_err->ee_code, ICMP_PORT_UNREACH); + } else { + EXPECT_EQ(sock_err->ee_origin, SO_EE_ORIGIN_ICMP6); + EXPECT_EQ(sock_err->ee_type, ICMP6_DST_UNREACH); + EXPECT_EQ(sock_err->ee_code, ICMP6_DST_UNREACH_NOPORT); + } + + // Now verify that the socket error was cleared by recvmsg(MSG_ERRQUEUE). + int err; + optlen = sizeof(err); + ASSERT_THAT(getsockopt(sock_.get(), SOL_SOCKET, SO_ERROR, &err, &optlen), + SyscallSucceeds()); + ASSERT_EQ(err, 0); + ASSERT_EQ(optlen, sizeof(err)); +} +#endif // __linux__ + TEST_P(UdpSocketTest, ZerolengthWriteAllowed) { // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes. SKIP_IF(IsRunningWithHostinet()); -- cgit v1.2.3 From 1ea241e4cc9529d45817e448c66f85213778f948 Mon Sep 17 00:00:00 2001 From: Nicolas Lacasse Date: Thu, 17 Dec 2020 11:11:23 -0800 Subject: Fix seek on /proc/pid/cmdline when task is zombie. PiperOrigin-RevId: 348056159 --- pkg/sentry/fsimpl/proc/task_files.go | 17 +++++++++++------ test/syscalls/linux/proc.cc | 27 +++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 6 deletions(-) (limited to 'pkg/sentry') diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index a3780b222..75be6129f 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -57,9 +57,6 @@ func getMM(task *kernel.Task) *mm.MemoryManager { // MemoryManager's users count is incremented, and must be decremented by the // caller when it is no longer in use. func getMMIncRef(task *kernel.Task) (*mm.MemoryManager, error) { - if task.ExitState() == kernel.TaskExitDead { - return nil, syserror.ESRCH - } var m *mm.MemoryManager task.WithMuLocked(func(t *kernel.Task) { m = t.MemoryManager() @@ -111,9 +108,13 @@ var _ dynamicInode = (*auxvData)(nil) // Generate implements vfs.DynamicBytesSource.Generate. func (d *auxvData) Generate(ctx context.Context, buf *bytes.Buffer) error { + if d.task.ExitState() == kernel.TaskExitDead { + return syserror.ESRCH + } m, err := getMMIncRef(d.task) if err != nil { - return err + // Return empty file. + return nil } defer m.DecUsers(ctx) @@ -157,9 +158,13 @@ var _ dynamicInode = (*cmdlineData)(nil) // Generate implements vfs.DynamicBytesSource.Generate. func (d *cmdlineData) Generate(ctx context.Context, buf *bytes.Buffer) error { + if d.task.ExitState() == kernel.TaskExitDead { + return syserror.ESRCH + } m, err := getMMIncRef(d.task) if err != nil { - return err + // Return empty file. + return nil } defer m.DecUsers(ctx) @@ -472,7 +477,7 @@ func (fd *memFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64 } m, err := getMMIncRef(fd.inode.task) if err != nil { - return 0, nil + return 0, err } defer m.DecUsers(ctx) // Buffer the read data because of MM locks diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index 575be014c..e508ce27f 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -1802,6 +1802,33 @@ TEST(ProcPidCmdline, SubprocessForkSameCmdline) { } } +TEST(ProcPidCmdline, SubprocessSeekCmdline) { + FileDescriptor fd; + ASSERT_NO_ERRNO(WithSubprocess( + [&](int pid) -> PosixError { + // Running. Open /proc/pid/cmdline. + ASSIGN_OR_RETURN_ERRNO( + fd, Open(absl::StrCat("/proc/", pid, "/cmdline"), O_RDONLY)); + return NoError(); + }, + [&](int pid) -> PosixError { + // Zombie, but seek should still succeed. + int ret = lseek(fd.get(), 0x801, 0); + if (ret < 0) { + return PosixError(errno); + } + return NoError(); + }, + [&](int pid) -> PosixError { + // Exited. + int ret = lseek(fd.get(), 0x801, 0); + if (ret < 0) { + return PosixError(errno); + } + return NoError(); + })); +} + // Test whether /proc/PID/ symlinks can be read for a running process. TEST(ProcPidSymlink, SubprocessRunning) { char buf[1]; -- cgit v1.2.3 From 433fd0e64650e31ab28e9d918d6dfcd9a67b4246 Mon Sep 17 00:00:00 2001 From: Chong Cai Date: Thu, 17 Dec 2020 14:20:56 -0800 Subject: Set verityMu to be state nosave PiperOrigin-RevId: 348092999 --- pkg/sentry/fsimpl/verity/verity.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'pkg/sentry') diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index 66029c64d..a5171b5ad 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -180,7 +180,7 @@ type filesystem struct { // its children. So they shouldn't be enabled the same time. This lock // is for the whole file system to ensure that no more than one file is // enabled the same time. - verityMu sync.RWMutex + verityMu sync.RWMutex `state:"nosave"` } // InternalFilesystemOptions may be passed as -- cgit v1.2.3 From 7c8ba72b026db3b79f12e679ab69078a25c143e8 Mon Sep 17 00:00:00 2001 From: Nayana Bidari Date: Tue, 22 Dec 2020 14:41:11 -0800 Subject: Move SO_BINDTODEVICE to socketops. PiperOrigin-RevId: 348696094 --- pkg/sentry/socket/netstack/netstack.go | 11 +++------ pkg/tcpip/socketops.go | 26 +++++++++++++++++++++ pkg/tcpip/stack/transport_demuxer_test.go | 5 ++-- pkg/tcpip/tcpip.go | 8 ------- pkg/tcpip/transport/tcp/endpoint.go | 39 +++++++++++-------------------- pkg/tcpip/transport/tcp/tcp_test.go | 17 ++++++-------- pkg/tcpip/transport/udp/endpoint.go | 30 ++++++++---------------- pkg/tcpip/transport/udp/forwarder.go | 2 +- pkg/tcpip/transport/udp/udp_test.go | 12 ++++------ 9 files changed, 68 insertions(+), 82 deletions(-) (limited to 'pkg/sentry') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 460c95b9f..3f587638f 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1042,10 +1042,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return &v, nil case linux.SO_BINDTODEVICE: - var v tcpip.BindToDeviceOption - if err := ep.GetSockOpt(&v); err != nil { - return nil, syserr.TranslateNetstackError(err) - } + v := ep.SocketOptions().GetBindToDevice() if v == 0 { var b primitive.ByteSlice return &b, nil @@ -1804,8 +1801,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } name := string(optVal[:n]) if name == "" { - v := tcpip.BindToDeviceOption(0) - return syserr.TranslateNetstackError(ep.SetSockOpt(&v)) + return syserr.TranslateNetstackError(ep.SocketOptions().SetBindToDevice(0)) } s := t.NetworkContext() if s == nil { @@ -1813,8 +1809,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } for nicID, nic := range s.Interfaces() { if nic.Name == name { - v := tcpip.BindToDeviceOption(nicID) - return syserr.TranslateNetstackError(ep.SetSockOpt(&v)) + return syserr.TranslateNetstackError(ep.SocketOptions().SetBindToDevice(nicID)) } } return syserr.ErrUnknownDevice diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index 095d1734a..f3ad40fdf 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -45,6 +45,9 @@ type SocketOptionsHandler interface { // UpdateLastError updates the endpoint specific last error field. UpdateLastError(err *Error) + + // HasNIC is invoked to check if the NIC is valid for SO_BINDTODEVICE. + HasNIC(v int32) bool } // DefaultSocketOptionsHandler is an embeddable type that implements no-op @@ -76,6 +79,11 @@ func (*DefaultSocketOptionsHandler) LastError() *Error { // UpdateLastError implements SocketOptionsHandler.UpdateLastError. func (*DefaultSocketOptionsHandler) UpdateLastError(*Error) {} +// HasNIC implements SocketOptionsHandler.HasNIC. +func (*DefaultSocketOptionsHandler) HasNIC(int32) bool { + return false +} + // SocketOptions contains all the variables which store values for SOL_SOCKET, // SOL_IP, SOL_IPV6 and SOL_TCP level options. // @@ -159,6 +167,9 @@ type SocketOptions struct { errQueueMu sync.Mutex `state:"nosave"` errQueue sockErrorList + // bindToDevice determines the device to which the socket is bound. + bindToDevice int32 + // mu protects the access to the below fields. mu sync.Mutex `state:"nosave"` @@ -492,3 +503,18 @@ func (so *SocketOptions) QueueLocalErr(err *Error, net NetworkProtocolNumber, in NetProto: net, }) } + +// GetBindToDevice gets value for SO_BINDTODEVICE option. +func (so *SocketOptions) GetBindToDevice() int32 { + return atomic.LoadInt32(&so.bindToDevice) +} + +// SetBindToDevice sets value for SO_BINDTODEVICE option. +func (so *SocketOptions) SetBindToDevice(bindToDevice int32) *Error { + if !so.handler.HasNIC(bindToDevice) { + return ErrUnknownDevice + } + + atomic.StoreInt32(&so.bindToDevice, bindToDevice) + return nil +} diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index a692af20b..737d8d912 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -308,9 +308,8 @@ func TestBindToDeviceDistribution(t *testing.T) { defer ep.Close() ep.SocketOptions().SetReusePort(endpoint.reuse) - bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) - if err := ep.SetSockOpt(&bindToDeviceOption); err != nil { - t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", bindToDeviceOption, bindToDeviceOption, i, err) + if err := ep.SocketOptions().SetBindToDevice(int32(endpoint.bindToDevice)); err != nil { + t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", endpoint.bindToDevice, endpoint.bindToDevice, i, err) } var dstAddr tcpip.Address diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index d195304be..ef0f51f1a 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -955,14 +955,6 @@ type SettableSocketOption interface { isSettableSocketOption() } -// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets -// should bind only on a specific NIC. -type BindToDeviceOption NICID - -func (*BindToDeviceOption) isGettableSocketOption() {} - -func (*BindToDeviceOption) isSettableSocketOption() {} - // TCPInfoOption is used by GetSockOpt to expose TCP statistics. // // TODO(b/64800844): Add and populate stat fields. diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index c88e74bec..6e3c8860e 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -502,9 +502,6 @@ type endpoint struct { // sack holds TCP SACK related information for this endpoint. sack SACKInfo - // bindToDevice is set to the NIC on which to bind or disabled if 0. - bindToDevice tcpip.NICID - // delay enables Nagle's algorithm. // // delay is a boolean (0 is false) and must be accessed atomically. @@ -1821,18 +1818,13 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } +func (e *endpoint) HasNIC(id int32) bool { + return id == 0 || e.stack.HasNIC(tcpip.NICID(id)) +} + // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { switch v := opt.(type) { - case *tcpip.BindToDeviceOption: - id := tcpip.NICID(*v) - if id != 0 && !e.stack.HasNIC(id) { - return tcpip.ErrUnknownDevice - } - e.LockUser() - e.bindToDevice = id - e.UnlockUser() - case *tcpip.KeepaliveIdleOption: e.keepalive.Lock() e.keepalive.idle = time.Duration(*v) @@ -2013,11 +2005,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { switch o := opt.(type) { - case *tcpip.BindToDeviceOption: - e.LockUser() - *o = tcpip.BindToDeviceOption(e.bindToDevice) - e.UnlockUser() - case *tcpip.TCPInfoOption: *o = tcpip.TCPInfoOption{} e.LockUser() @@ -2220,11 +2207,12 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } } + bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) { if sameAddr && p == e.ID.RemotePort { return false, nil } - if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil { + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil { if err != tcpip.ErrPortInUse || !reuse { return false, nil } @@ -2262,15 +2250,15 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc tcpEP.notifyProtocolGoroutine(notifyAbort) tcpEP.UnlockUser() // Now try and Reserve again if it fails then we skip. - if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil { + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil { return false, nil } } id := e.ID id.LocalPort = p - if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, e.bindToDevice); err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr) + if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil { + e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr) if err == tcpip.ErrPortInUse { return false, nil } @@ -2281,7 +2269,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc // the selected port. e.ID = id e.isPortReserved = true - e.boundBindToDevice = e.bindToDevice + e.boundBindToDevice = bindToDevice e.boundPortFlags = e.portFlags e.boundDest = addr return true, nil @@ -2634,7 +2622,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { e.ID.LocalAddress = addr.Addr } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, func(p uint16) bool { + bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, bindToDevice, tcpip.FullAddress{}, func(p uint16) bool { id := e.ID id.LocalPort = p // CheckRegisterTransportEndpoint should only return an error if there is a @@ -2645,7 +2634,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // demuxer. Further connected endpoints always have a remote // address/port. Hence this will only return an error if there is a matching // listening endpoint. - if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, e.bindToDevice); err != nil { + if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, bindToDevice); err != nil { return false } return true @@ -2654,7 +2643,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { return err } - e.boundBindToDevice = e.bindToDevice + e.boundBindToDevice = bindToDevice e.boundPortFlags = e.portFlags // TODO(gvisor.dev/issue/3691): Add test to verify boundNICID is correct. e.boundNICID = nic diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 351a5e4f5..cf60d5b53 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1380,9 +1380,8 @@ func TestConnectBindToDevice(t *testing.T) { defer c.Cleanup() c.Create(-1) - bindToDevice := tcpip.BindToDeviceOption(test.device) - if err := c.EP.SetSockOpt(&bindToDevice); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", bindToDevice, bindToDevice, err) + if err := c.EP.SocketOptions().SetBindToDevice(int32(test.device)); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", test.device, test.device, err) } // Start connection attempt. waitEntry, _ := waiter.NewChannelEntry(nil) @@ -4507,7 +4506,7 @@ func TestBindToDeviceOption(t *testing.T) { name string setBindToDevice *tcpip.NICID setBindToDeviceError *tcpip.Error - getBindToDevice tcpip.BindToDeviceOption + getBindToDevice int32 }{ {"GetDefaultValue", nil, nil, 0}, {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, @@ -4517,15 +4516,13 @@ func TestBindToDeviceOption(t *testing.T) { for _, testAction := range testActions { t.Run(testAction.name, func(t *testing.T) { if testAction.setBindToDevice != nil { - bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if gotErr, wantErr := ep.SetSockOpt(&bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + bindToDevice := int32(*testAction.setBindToDevice) + if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr) } } - bindToDevice := tcpip.BindToDeviceOption(88888) - if err := ep.GetSockOpt(&bindToDevice); err != nil { - t.Errorf("GetSockOpt(&%T): %s", bindToDevice, err) - } else if bindToDevice != testAction.getBindToDevice { + bindToDevice := ep.SocketOptions().GetBindToDevice() + if bindToDevice != testAction.getBindToDevice { t.Errorf("got bindToDevice = %d, want %d", bindToDevice, testAction.getBindToDevice) } }) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 24d0c2cb9..9b9e4deb0 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -109,7 +109,6 @@ type endpoint struct { multicastAddr tcpip.Address multicastNICID tcpip.NICID portFlags ports.Flags - bindToDevice tcpip.NICID lastErrorMu sync.Mutex `state:"nosave"` lastError *tcpip.Error `state:".(string)"` @@ -659,6 +658,10 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } +func (e *endpoint) HasNIC(id int32) bool { + return id == 0 || e.stack.HasNIC(tcpip.NICID(id)) +} + // SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { switch v := opt.(type) { @@ -775,15 +778,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { delete(e.multicastMemberships, memToRemove) - case *tcpip.BindToDeviceOption: - id := tcpip.NICID(*v) - if id != 0 && !e.stack.HasNIC(id) { - return tcpip.ErrUnknownDevice - } - e.mu.Lock() - e.bindToDevice = id - e.mu.Unlock() - case *tcpip.SocketDetachFilterOption: return nil } @@ -859,11 +853,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { } e.mu.Unlock() - case *tcpip.BindToDeviceOption: - e.mu.RLock() - *o = tcpip.BindToDeviceOption(e.bindToDevice) - e.mu.RUnlock() - default: return tcpip.ErrUnknownProtocolOption } @@ -1113,21 +1102,22 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcp } func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { + bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if e.ID.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, nil /* testPort */) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, bindToDevice, tcpip.FullAddress{}, nil /* testPort */) if err != nil { - return id, e.bindToDevice, err + return id, bindToDevice, err } id.LocalPort = port } e.boundPortFlags = e.portFlags - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice) if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice, tcpip.FullAddress{}) + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, bindToDevice, tcpip.FullAddress{}) e.boundPortFlags = ports.Flags{} } - return id, e.bindToDevice, err + return id, bindToDevice, err } func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 14e4648cd..d7fc21f11 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -78,7 +78,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, route.ResolveWith(r.pkt.SourceLinkAddress()) ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue) - if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil { + if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil { ep.Close() route.Release() return nil, err diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 6f89b6271..8429f34b4 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -554,7 +554,7 @@ func TestBindToDeviceOption(t *testing.T) { name string setBindToDevice *tcpip.NICID setBindToDeviceError *tcpip.Error - getBindToDevice tcpip.BindToDeviceOption + getBindToDevice int32 }{ {"GetDefaultValue", nil, nil, 0}, {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, @@ -564,15 +564,13 @@ func TestBindToDeviceOption(t *testing.T) { for _, testAction := range testActions { t.Run(testAction.name, func(t *testing.T) { if testAction.setBindToDevice != nil { - bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if gotErr, wantErr := ep.SetSockOpt(&bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + bindToDevice := int32(*testAction.setBindToDevice) + if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr) } } - bindToDevice := tcpip.BindToDeviceOption(88888) - if err := ep.GetSockOpt(&bindToDevice); err != nil { - t.Errorf("GetSockOpt(&%T): %s", bindToDevice, err) - } else if bindToDevice != testAction.getBindToDevice { + bindToDevice := ep.SocketOptions().GetBindToDevice() + if bindToDevice != testAction.getBindToDevice { t.Errorf("got bindToDevice = %d, want = %d", bindToDevice, testAction.getBindToDevice) } }) -- cgit v1.2.3 From d07915987631f4c3c6345275019a5b5b0cf28dbb Mon Sep 17 00:00:00 2001 From: Andrei Vagin Date: Wed, 23 Dec 2020 11:08:42 -0800 Subject: vfs1: don't allow to open socket files open() has to return ENXIO in this case. O_PATH isn't supported by vfs1. PiperOrigin-RevId: 348820478 --- pkg/sentry/fs/gofer/inode.go | 3 +++ pkg/sentry/fs/host/inode.go | 4 ++++ pkg/sentry/fs/ramfs/socket.go | 3 ++- pkg/sentry/fs/tmpfs/inode_file.go | 4 ++++ pkg/sentry/syscalls/linux/sys_file.go | 12 ++++++------ test/syscalls/linux/BUILD | 1 + test/syscalls/linux/socket_unix_unbound_filesystem.cc | 16 ++++++++++++++++ 7 files changed, 36 insertions(+), 7 deletions(-) (limited to 'pkg/sentry') diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go index 9d6fdd08f..e840b6f5e 100644 --- a/pkg/sentry/fs/gofer/inode.go +++ b/pkg/sentry/fs/gofer/inode.go @@ -475,6 +475,9 @@ func (i *inodeOperations) Check(ctx context.Context, inode *fs.Inode, p fs.PermM func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { switch d.Inode.StableAttr.Type { case fs.Socket: + if i.session().overrides != nil { + return nil, syserror.ENXIO + } return i.getFileSocket(ctx, d, flags) case fs.Pipe: return i.getFilePipe(ctx, d, flags) diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go index fbfba1b58..2c14aa6d9 100644 --- a/pkg/sentry/fs/host/inode.go +++ b/pkg/sentry/fs/host/inode.go @@ -276,6 +276,10 @@ func (i *inodeOperations) BoundEndpoint(inode *fs.Inode, path string) transport. // GetFile implements fs.InodeOperations.GetFile. func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + if fs.IsSocket(d.Inode.StableAttr) { + return nil, syserror.ENXIO + } + return newFile(ctx, d, flags, i), nil } diff --git a/pkg/sentry/fs/ramfs/socket.go b/pkg/sentry/fs/ramfs/socket.go index 29ff004f2..d0c565879 100644 --- a/pkg/sentry/fs/ramfs/socket.go +++ b/pkg/sentry/fs/ramfs/socket.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) @@ -63,7 +64,7 @@ func (s *Socket) BoundEndpoint(*fs.Inode, string) transport.BoundEndpoint { // GetFile implements fs.FileOperations.GetFile. func (s *Socket) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { - return fs.NewFile(ctx, dirent, flags, &socketFileOperations{}), nil + return nil, syserror.ENXIO } // +stateify savable diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go index e04cd608d..ad4aea282 100644 --- a/pkg/sentry/fs/tmpfs/inode_file.go +++ b/pkg/sentry/fs/tmpfs/inode_file.go @@ -148,6 +148,10 @@ func (*fileInodeOperations) Rename(ctx context.Context, inode *fs.Inode, oldPare // GetFile implements fs.InodeOperations.GetFile. func (f *fileInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + if fs.IsSocket(d.Inode.StableAttr) { + return nil, syserror.ENXIO + } + if flags.Write { fsmetric.TmpfsOpensW.Increment() } else if flags.Read { diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 8db587401..c33571f43 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -175,6 +175,12 @@ func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uint } } + file, err := d.Inode.GetFile(t, d, fileFlags) + if err != nil { + return syserror.ConvertIntr(err, syserror.ERESTARTSYS) + } + defer file.DecRef(t) + // Truncate is called when O_TRUNC is specified for any kind of // existing Dirent. Behavior is delegated to the entry's Truncate // implementation. @@ -184,12 +190,6 @@ func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uint } } - file, err := d.Inode.GetFile(t, d, fileFlags) - if err != nil { - return syserror.ConvertIntr(err, syserror.ERESTARTSYS) - } - defer file.DecRef(t) - // Success. newFD, err := t.NewFDFrom(0, file, kernel.FDFlags{ CloseOnExec: flags&linux.O_CLOEXEC != 0, diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index a9d91c589..89d532c70 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -3324,6 +3324,7 @@ cc_binary( ":socket_test_util", ":unix_domain_socket_test_util", gtest, + "//test/util:file_descriptor", "//test/util:test_main", "//test/util:test_util", ], diff --git a/test/syscalls/linux/socket_unix_unbound_filesystem.cc b/test/syscalls/linux/socket_unix_unbound_filesystem.cc index cab912152..a035fb095 100644 --- a/test/syscalls/linux/socket_unix_unbound_filesystem.cc +++ b/test/syscalls/linux/socket_unix_unbound_filesystem.cc @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" +#include "test/util/file_descriptor.h" #include "test/util/test_util.h" namespace gvisor { @@ -70,6 +72,20 @@ TEST_P(UnboundFilesystemUnixSocketPairTest, GetSockNameLength) { strlen(want_addr.sun_path) + 1 + sizeof(want_addr.sun_family)); } +TEST_P(UnboundFilesystemUnixSocketPairTest, OpenSocketWithTruncate) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + const struct sockaddr_un *addr = + reinterpret_cast(sockets->first_addr()); + EXPECT_THAT(chmod(addr->sun_path, 0777), SyscallSucceeds()); + EXPECT_THAT(open(addr->sun_path, O_RDONLY | O_TRUNC), + SyscallFailsWithErrno(ENXIO)); +} + INSTANTIATE_TEST_SUITE_P( AllUnixDomainSockets, UnboundFilesystemUnixSocketPairTest, ::testing::ValuesIn(ApplyVec( -- cgit v1.2.3