summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/rpcinet/socket.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket/rpcinet/socket.go')
-rw-r--r--pkg/sentry/socket/rpcinet/socket.go41
1 files changed, 35 insertions, 6 deletions
diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go
index 90844f10f..257bc2d71 100644
--- a/pkg/sentry/socket/rpcinet/socket.go
+++ b/pkg/sentry/socket/rpcinet/socket.go
@@ -17,6 +17,7 @@ package rpcinet
import (
"sync/atomic"
"syscall"
+ "time"
"gvisor.googlesource.com/gvisor/pkg/abi/linux"
"gvisor.googlesource.com/gvisor/pkg/binary"
@@ -44,12 +45,12 @@ import (
// socketOperations implements fs.FileOperations and socket.Socket for a socket
// implemented using a host socket.
type socketOperations struct {
- socket.ReceiveTimeout
fsutil.PipeSeek `state:"nosave"`
fsutil.NotDirReaddir `state:"nosave"`
fsutil.NoFsync `state:"nosave"`
fsutil.NoopFlush `state:"nosave"`
fsutil.NoMMap `state:"nosave"`
+ socket.SendReceiveTimeout
fd uint32 // must be O_NONBLOCK
wq *waiter.Queue
@@ -379,7 +380,8 @@ func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
// GetSockOpt implements socket.Socket.GetSockOpt.
func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outLen int) (interface{}, *syserr.Error) {
- // SO_RCVTIMEO is special because blocking is performed within the sentry.
+ // SO_RCVTIMEO and SO_SNDTIMEO are special because blocking is performed
+ // within the sentry.
if level == linux.SOL_SOCKET && name == linux.SO_RCVTIMEO {
if outLen < linux.SizeOfTimeval {
return nil, syserr.ErrInvalidArgument
@@ -387,6 +389,13 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outLe
return linux.NsecToTimeval(s.RecvTimeout()), nil
}
+ if level == linux.SOL_SOCKET && name == linux.SO_SNDTIMEO {
+ if outLen < linux.SizeOfTimeval {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ return linux.NsecToTimeval(s.SendTimeout()), nil
+ }
stack := t.NetworkContext().(*Stack)
id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetSockOpt{&pb.GetSockOptRequest{Fd: s.fd, Level: int64(level), Name: int64(name), Length: uint32(outLen)}}}, false /* ignoreResult */)
@@ -403,8 +412,9 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outLe
// SetSockOpt implements socket.Socket.SetSockOpt.
func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
// Because blocking actually happens within the sentry we need to inspect
- // this socket option to determine if it's a SO_RCVTIMEO, and if so, we will
- // save it and use it as the deadline for recv(2) related syscalls.
+ // this socket option to determine if it's a SO_RCVTIMEO or SO_SNDTIMEO,
+ // and if so, we will save it and use it as the deadline for recv(2)
+ // or send(2) related syscalls.
if level == linux.SOL_SOCKET && name == linux.SO_RCVTIMEO {
if len(opt) < linux.SizeOfTimeval {
return syserr.ErrInvalidArgument
@@ -412,9 +422,25 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [
var v linux.Timeval
binary.Unmarshal(opt[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
+ if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
+ return syserr.ErrDomain
+ }
s.SetRecvTimeout(v.ToNsecCapped())
return nil
}
+ if level == linux.SOL_SOCKET && name == linux.SO_SNDTIMEO {
+ if len(opt) < linux.SizeOfTimeval {
+ return syserr.ErrInvalidArgument
+ }
+
+ var v linux.Timeval
+ binary.Unmarshal(opt[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
+ if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
+ return syserr.ErrDomain
+ }
+ s.SetSendTimeout(v.ToNsecCapped())
+ return nil
+ }
stack := t.NetworkContext().(*Stack)
id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_SetSockOpt{&pb.SetSockOptRequest{Fd: s.fd, Level: int64(level), Name: int64(name), Opt: opt}}}, false /* ignoreResult */)
@@ -720,7 +746,7 @@ func rpcSendMsg(t *kernel.Task, req *pb.SyscallRequest_Sendmsg) (uint32, *syserr
}
// SendMsg implements socket.Socket.SendMsg.
-func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
// Whitelist flags.
if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 {
return 0, syserr.ErrInvalidArgument
@@ -787,7 +813,10 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
return int(totalWritten), nil
}
- if err := t.Block(ch); err != nil {
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ return int(totalWritten), syserr.ErrTryAgain
+ }
return int(totalWritten), syserr.FromError(err)
}
}