summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket
diff options
context:
space:
mode:
authorIan Gudger <igudger@google.com>2018-12-14 16:12:51 -0800
committerShentubot <shentubot@google.com>2018-12-14 16:15:06 -0800
commite1dcf92ec5cf7d9bf58fb322f46f6ae2d98699d2 (patch)
tree61ed22a594dd96bb994d748e83358c5a51212ee5 /pkg/sentry/socket
parented930354ef46df9b6feece36e59ee644a7cdfa7f (diff)
Implement SO_SNDTIMEO
PiperOrigin-RevId: 225620490 Change-Id: Ia726107b3f58093a5f881634f90b071b32d2c269
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go35
-rw-r--r--pkg/sentry/socket/hostinet/socket.go9
-rw-r--r--pkg/sentry/socket/netlink/socket.go4
-rw-r--r--pkg/sentry/socket/rpcinet/socket.go41
-rw-r--r--pkg/sentry/socket/socket.go49
-rw-r--r--pkg/sentry/socket/unix/unix.go16
6 files changed, 121 insertions, 33 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index 19af7bc45..ab5d82183 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -30,6 +30,7 @@ import (
"strings"
"sync"
"syscall"
+ "time"
"gvisor.googlesource.com/gvisor/pkg/abi/linux"
"gvisor.googlesource.com/gvisor/pkg/binary"
@@ -137,12 +138,12 @@ type commonEndpoint interface {
//
// +stateify savable
type SocketOperations struct {
- socket.ReceiveTimeout
fsutil.PipeSeek `state:"nosave"`
fsutil.NotDirReaddir `state:"nosave"`
fsutil.NoFsync `state:"nosave"`
fsutil.NoopFlush `state:"nosave"`
fsutil.NoMMap `state:"nosave"`
+ socket.SendReceiveTimeout
*waiter.Queue
family int
@@ -643,7 +644,16 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
}
return syscall.Linger{}, nil
+ case linux.SO_SNDTIMEO:
+ // TODO: Linux allows shorter lengths for partial results.
+ if outLen < linux.SizeOfTimeval {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ return linux.NsecToTimeval(s.SendTimeout()), nil
+
case linux.SO_RCVTIMEO:
+ // TODO: Linux allows shorter lengths for partial results.
if outLen < linux.SizeOfTimeval {
return nil, syserr.ErrInvalidArgument
}
@@ -833,6 +843,19 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
v := usermem.ByteOrder.Uint32(optVal)
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.PasscredOption(v)))
+ case linux.SO_SNDTIMEO:
+ if len(optVal) < linux.SizeOfTimeval {
+ return syserr.ErrInvalidArgument
+ }
+
+ var v linux.Timeval
+ binary.Unmarshal(optVal[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
+ if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
+ return syserr.ErrDomain
+ }
+ s.SetSendTimeout(v.ToNsecCapped())
+ return nil
+
case linux.SO_RCVTIMEO:
if len(optVal) < linux.SizeOfTimeval {
return syserr.ErrInvalidArgument
@@ -840,6 +863,9 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
var v linux.Timeval
binary.Unmarshal(optVal[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
+ if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
+ return syserr.ErrDomain
+ }
s.SetRecvTimeout(v.ToNsecCapped())
return nil
@@ -1365,7 +1391,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
// SendMsg implements the linux syscall sendmsg(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
// Reject Unix control messages.
if !controlMessages.Unix.Empty() {
return 0, syserr.ErrInvalidArgument
@@ -1431,7 +1457,10 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
return int(total), nil
}
- if err := t.Block(ch); err != nil {
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ return int(total), syserr.ErrTryAgain
+ }
// handleIOError will consume errors from t.Block if needed.
return int(total), syserr.FromError(err)
}
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index e4e950fbb..34281cac0 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -46,12 +46,12 @@ const (
// socketOperations implements fs.FileOperations and socket.Socket for a socket
// implemented using a host socket.
type socketOperations struct {
- socket.ReceiveTimeout
fsutil.PipeSeek `state:"nosave"`
fsutil.NotDirReaddir `state:"nosave"`
fsutil.NoFsync `state:"nosave"`
fsutil.NoopFlush `state:"nosave"`
fsutil.NoMMap `state:"nosave"`
+ socket.SendReceiveTimeout
fd int // must be O_NONBLOCK
queue waiter.Queue
@@ -418,7 +418,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
// SendMsg implements socket.Socket.SendMsg.
-func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
// Whitelist flags.
if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 {
return 0, syserr.ErrInvalidArgument
@@ -468,7 +468,10 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
panic(fmt.Sprintf("CopyInTo: got (%d, %v), wanted (0, %v)", n, err, err))
}
if ch != nil {
- if err = t.Block(ch); err != nil {
+ if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = syserror.ErrWouldBlock
+ }
break
}
} else {
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index c4798839e..0a7d4772c 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -65,12 +65,12 @@ var netlinkSocketDevice = device.NewAnonDevice()
//
// +stateify savable
type Socket struct {
- socket.ReceiveTimeout
fsutil.PipeSeek `state:"nosave"`
fsutil.NotDirReaddir `state:"nosave"`
fsutil.NoFsync `state:"nosave"`
fsutil.NoopFlush `state:"nosave"`
fsutil.NoMMap `state:"nosave"`
+ socket.SendReceiveTimeout
// ports provides netlink port allocation.
ports *port.Manager
@@ -593,7 +593,7 @@ func (s *Socket) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte,
}
// SendMsg implements socket.Socket.SendMsg.
-func (s *Socket) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+func (s *Socket) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
return s.sendMsg(t, src, to, flags, controlMessages)
}
diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go
index 90844f10f..257bc2d71 100644
--- a/pkg/sentry/socket/rpcinet/socket.go
+++ b/pkg/sentry/socket/rpcinet/socket.go
@@ -17,6 +17,7 @@ package rpcinet
import (
"sync/atomic"
"syscall"
+ "time"
"gvisor.googlesource.com/gvisor/pkg/abi/linux"
"gvisor.googlesource.com/gvisor/pkg/binary"
@@ -44,12 +45,12 @@ import (
// socketOperations implements fs.FileOperations and socket.Socket for a socket
// implemented using a host socket.
type socketOperations struct {
- socket.ReceiveTimeout
fsutil.PipeSeek `state:"nosave"`
fsutil.NotDirReaddir `state:"nosave"`
fsutil.NoFsync `state:"nosave"`
fsutil.NoopFlush `state:"nosave"`
fsutil.NoMMap `state:"nosave"`
+ socket.SendReceiveTimeout
fd uint32 // must be O_NONBLOCK
wq *waiter.Queue
@@ -379,7 +380,8 @@ func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
// GetSockOpt implements socket.Socket.GetSockOpt.
func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outLen int) (interface{}, *syserr.Error) {
- // SO_RCVTIMEO is special because blocking is performed within the sentry.
+ // SO_RCVTIMEO and SO_SNDTIMEO are special because blocking is performed
+ // within the sentry.
if level == linux.SOL_SOCKET && name == linux.SO_RCVTIMEO {
if outLen < linux.SizeOfTimeval {
return nil, syserr.ErrInvalidArgument
@@ -387,6 +389,13 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outLe
return linux.NsecToTimeval(s.RecvTimeout()), nil
}
+ if level == linux.SOL_SOCKET && name == linux.SO_SNDTIMEO {
+ if outLen < linux.SizeOfTimeval {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ return linux.NsecToTimeval(s.SendTimeout()), nil
+ }
stack := t.NetworkContext().(*Stack)
id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetSockOpt{&pb.GetSockOptRequest{Fd: s.fd, Level: int64(level), Name: int64(name), Length: uint32(outLen)}}}, false /* ignoreResult */)
@@ -403,8 +412,9 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outLe
// SetSockOpt implements socket.Socket.SetSockOpt.
func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
// Because blocking actually happens within the sentry we need to inspect
- // this socket option to determine if it's a SO_RCVTIMEO, and if so, we will
- // save it and use it as the deadline for recv(2) related syscalls.
+ // this socket option to determine if it's a SO_RCVTIMEO or SO_SNDTIMEO,
+ // and if so, we will save it and use it as the deadline for recv(2)
+ // or send(2) related syscalls.
if level == linux.SOL_SOCKET && name == linux.SO_RCVTIMEO {
if len(opt) < linux.SizeOfTimeval {
return syserr.ErrInvalidArgument
@@ -412,9 +422,25 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [
var v linux.Timeval
binary.Unmarshal(opt[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
+ if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
+ return syserr.ErrDomain
+ }
s.SetRecvTimeout(v.ToNsecCapped())
return nil
}
+ if level == linux.SOL_SOCKET && name == linux.SO_SNDTIMEO {
+ if len(opt) < linux.SizeOfTimeval {
+ return syserr.ErrInvalidArgument
+ }
+
+ var v linux.Timeval
+ binary.Unmarshal(opt[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
+ if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
+ return syserr.ErrDomain
+ }
+ s.SetSendTimeout(v.ToNsecCapped())
+ return nil
+ }
stack := t.NetworkContext().(*Stack)
id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_SetSockOpt{&pb.SetSockOptRequest{Fd: s.fd, Level: int64(level), Name: int64(name), Opt: opt}}}, false /* ignoreResult */)
@@ -720,7 +746,7 @@ func rpcSendMsg(t *kernel.Task, req *pb.SyscallRequest_Sendmsg) (uint32, *syserr
}
// SendMsg implements socket.Socket.SendMsg.
-func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
// Whitelist flags.
if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 {
return 0, syserr.ErrInvalidArgument
@@ -787,7 +813,10 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
return int(totalWritten), nil
}
- if err := t.Block(ch); err != nil {
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ return int(totalWritten), syserr.ErrTryAgain
+ }
return int(totalWritten), syserr.FromError(err)
}
}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index f73127ea6..9d4aaeb9d 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -94,15 +94,23 @@ type Socket interface {
// ownership of the ControlMessage on error.
//
// If n > 0, err will either be nil or an error from t.Block.
- SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages ControlMessages) (n int, err *syserr.Error)
+ SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages ControlMessages) (n int, err *syserr.Error)
// SetRecvTimeout sets the timeout (in ns) for recv operations. Zero means
- // no timeout.
+ // no timeout, and negative means DONTWAIT.
SetRecvTimeout(nanoseconds int64)
// RecvTimeout gets the current timeout (in ns) for recv operations. Zero
- // means no timeout.
+ // means no timeout, and negative means DONTWAIT.
RecvTimeout() int64
+
+ // SetSendTimeout sets the timeout (in ns) for send operations. Zero means
+ // no timeout, and negative means DONTWAIT.
+ SetSendTimeout(nanoseconds int64)
+
+ // SendTimeout gets the current timeout (in ns) for send operations. Zero
+ // means no timeout, and negative means DONTWAIT.
+ SendTimeout() int64
}
// Provider is the interface implemented by providers of sockets for specific
@@ -192,30 +200,45 @@ func NewDirent(ctx context.Context, d *device.Device) *fs.Dirent {
return fs.NewDirent(inode, fmt.Sprintf("socket:[%d]", ino))
}
-// ReceiveTimeout stores a timeout for receive calls.
+// SendReceiveTimeout stores timeouts for send and receive calls.
//
// It is meant to be embedded into Socket implementations to help satisfy the
// interface.
//
-// Care must be taken when copying ReceiveTimeout as it contains atomic
+// Care must be taken when copying SendReceiveTimeout as it contains atomic
// variables.
//
// +stateify savable
-type ReceiveTimeout struct {
- // ns is length of the timeout in nanoseconds.
+type SendReceiveTimeout struct {
+ // send is length of the send timeout in nanoseconds.
+ //
+ // send must be accessed atomically.
+ send int64
+
+ // recv is length of the receive timeout in nanoseconds.
//
- // ns must be accessed atomically.
- ns int64
+ // recv must be accessed atomically.
+ recv int64
}
// SetRecvTimeout implements Socket.SetRecvTimeout.
-func (rt *ReceiveTimeout) SetRecvTimeout(nanoseconds int64) {
- atomic.StoreInt64(&rt.ns, nanoseconds)
+func (to *SendReceiveTimeout) SetRecvTimeout(nanoseconds int64) {
+ atomic.StoreInt64(&to.recv, nanoseconds)
}
// RecvTimeout implements Socket.RecvTimeout.
-func (rt *ReceiveTimeout) RecvTimeout() int64 {
- return atomic.LoadInt64(&rt.ns)
+func (to *SendReceiveTimeout) RecvTimeout() int64 {
+ return atomic.LoadInt64(&to.recv)
+}
+
+// SetSendTimeout implements Socket.SetSendTimeout.
+func (to *SendReceiveTimeout) SetSendTimeout(nanoseconds int64) {
+ atomic.StoreInt64(&to.send, nanoseconds)
+}
+
+// SendTimeout implements Socket.SendTimeout.
+func (to *SendReceiveTimeout) SendTimeout() int64 {
+ return atomic.LoadInt64(&to.send)
}
// GetSockOptEmitUnimplementedEvent emits unimplemented event if name is valid.
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 4c9dcbd61..da225eabb 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -45,15 +45,16 @@ import (
//
// +stateify savable
type SocketOperations struct {
- refs.AtomicRefCount
- socket.ReceiveTimeout
fsutil.PipeSeek `state:"nosave"`
fsutil.NotDirReaddir `state:"nosave"`
fsutil.NoFsync `state:"nosave"`
fsutil.NoopFlush `state:"nosave"`
fsutil.NoMMap `state:"nosave"`
- ep transport.Endpoint
- isPacket bool
+ refs.AtomicRefCount
+ socket.SendReceiveTimeout
+
+ ep transport.Endpoint
+ isPacket bool
}
// New creates a new unix socket.
@@ -367,7 +368,7 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
// SendMsg implements the linux syscall sendmsg(2) for unix sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
w := EndpointWriter{
Endpoint: s.ep,
Control: controlMessages.Unix,
@@ -404,7 +405,10 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
break
}
- if err := t.Block(ch); err != nil {
+ if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = syserror.ErrWouldBlock
+ }
break
}
}