summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJamie Liu <jamieliu@google.com>2020-05-12 12:24:23 -0700
committergVisor bot <gvisor-bot@google.com>2020-05-12 12:26:05 -0700
commit8dd1d5b75a95100e747b1a88e9e557d5d2c30b64 (patch)
tree9846e2c721ed9bd4d02390b23d7b26d9212b7965
parent06ded1c4372d4871f0581c7090957935d93cd50e (diff)
Don't call kernel.Task.Block() from netstack.SocketOperations.Write().
kernel.Task.Block() requires that the caller is running on the task goroutine. netstack.SocketOperations.Write() uses kernel.TaskFromContext() to call kernel.Task.Block() even if it's not running on the task goroutine. Stop doing that. PiperOrigin-RevId: 311178335
-rw-r--r--pkg/amutex/BUILD1
-rw-r--r--pkg/amutex/amutex.go17
-rw-r--r--pkg/sentry/socket/netstack/BUILD1
-rw-r--r--pkg/sentry/socket/netstack/netstack.go13
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go17
5 files changed, 27 insertions, 22 deletions
diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD
index 9612f072e..ffc918846 100644
--- a/pkg/amutex/BUILD
+++ b/pkg/amutex/BUILD
@@ -6,6 +6,7 @@ go_library(
name = "amutex",
srcs = ["amutex.go"],
visibility = ["//:sandbox"],
+ deps = ["//pkg/syserror"],
)
go_test(
diff --git a/pkg/amutex/amutex.go b/pkg/amutex/amutex.go
index 1c4fd1784..a078a31db 100644
--- a/pkg/amutex/amutex.go
+++ b/pkg/amutex/amutex.go
@@ -18,6 +18,8 @@ package amutex
import (
"sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/syserror"
)
// Sleeper must be implemented by users of the abortable mutex to allow for
@@ -53,6 +55,21 @@ func (NoopSleeper) SleepFinish(success bool) {}
// Interrupted implements Sleeper.Interrupted.
func (NoopSleeper) Interrupted() bool { return false }
+// Block blocks until either receiving from ch succeeds (in which case it
+// returns nil) or sleeper is interrupted (in which case it returns
+// syserror.ErrInterrupted).
+func Block(sleeper Sleeper, ch <-chan struct{}) error {
+ cancel := sleeper.SleepStart()
+ select {
+ case <-ch:
+ sleeper.SleepFinish(true)
+ return nil
+ case <-cancel:
+ sleeper.SleepFinish(false)
+ return syserror.ErrInterrupted
+ }
+}
+
// AbortableMutex is an abortable mutex. It allows Lock() to be aborted while it
// waits to acquire the mutex.
type AbortableMutex struct {
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index 6129fb83d..333e0042e 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -18,6 +18,7 @@ go_library(
],
deps = [
"//pkg/abi/linux",
+ "//pkg/amutex",
"//pkg/binary",
"//pkg/context",
"//pkg/log",
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 81053d8ef..9d032f052 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -34,6 +34,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
@@ -553,11 +554,9 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
}
if resCh != nil {
- t := kernel.TaskFromContext(ctx)
- if err := t.Block(resCh); err != nil {
- return 0, syserr.FromError(err).ToError()
+ if err := amutex.Block(ctx, resCh); err != nil {
+ return 0, err
}
-
n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{})
}
@@ -626,11 +625,9 @@ func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader
}
if resCh != nil {
- t := kernel.TaskFromContext(ctx)
- if err := t.Block(resCh); err != nil {
- return 0, syserr.FromError(err).ToError()
+ if err := amutex.Block(ctx, resCh); err != nil {
+ return 0, err
}
-
n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{
Atomic: true, // See above.
})
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index 191970d41..fcd8013c0 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -16,6 +16,7 @@ package netstack
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
@@ -89,11 +90,6 @@ func (s *SocketVFS2) EventUnregister(e *waiter.Entry) {
s.socketOpsCommon.EventUnregister(e)
}
-// PRead implements vfs.FileDescriptionImpl.
-func (s *SocketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
- return 0, syserror.ESPIPE
-}
-
// Read implements vfs.FileDescriptionImpl.
func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
// All flags other than RWF_NOWAIT should be ignored.
@@ -115,11 +111,6 @@ func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.
return int64(n), nil
}
-// PWrite implements vfs.FileDescriptionImpl.
-func (s *SocketVFS2) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
- return 0, syserror.ESPIPE
-}
-
// Write implements vfs.FileDescriptionImpl.
func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
// All flags other than RWF_NOWAIT should be ignored.
@@ -135,11 +126,9 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs
}
if resCh != nil {
- t := kernel.TaskFromContext(ctx)
- if err := t.Block(resCh); err != nil {
- return 0, syserr.FromError(err).ToError()
+ if err := amutex.Block(ctx, resCh); err != nil {
+ return 0, err
}
-
n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{})
}