diff options
Diffstat (limited to 'pkg/sentry/socket/rpcinet/socket.go')
-rw-r--r-- | pkg/sentry/socket/rpcinet/socket.go | 41 |
1 files changed, 35 insertions, 6 deletions
diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index 90844f10f..257bc2d71 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -17,6 +17,7 @@ package rpcinet import ( "sync/atomic" "syscall" + "time" "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/binary" @@ -44,12 +45,12 @@ import ( // socketOperations implements fs.FileOperations and socket.Socket for a socket // implemented using a host socket. type socketOperations struct { - socket.ReceiveTimeout fsutil.PipeSeek `state:"nosave"` fsutil.NotDirReaddir `state:"nosave"` fsutil.NoFsync `state:"nosave"` fsutil.NoopFlush `state:"nosave"` fsutil.NoMMap `state:"nosave"` + socket.SendReceiveTimeout fd uint32 // must be O_NONBLOCK wq *waiter.Queue @@ -379,7 +380,8 @@ func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { // GetSockOpt implements socket.Socket.GetSockOpt. func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outLen int) (interface{}, *syserr.Error) { - // SO_RCVTIMEO is special because blocking is performed within the sentry. + // SO_RCVTIMEO and SO_SNDTIMEO are special because blocking is performed + // within the sentry. if level == linux.SOL_SOCKET && name == linux.SO_RCVTIMEO { if outLen < linux.SizeOfTimeval { return nil, syserr.ErrInvalidArgument @@ -387,6 +389,13 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outLe return linux.NsecToTimeval(s.RecvTimeout()), nil } + if level == linux.SOL_SOCKET && name == linux.SO_SNDTIMEO { + if outLen < linux.SizeOfTimeval { + return nil, syserr.ErrInvalidArgument + } + + return linux.NsecToTimeval(s.SendTimeout()), nil + } stack := t.NetworkContext().(*Stack) id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetSockOpt{&pb.GetSockOptRequest{Fd: s.fd, Level: int64(level), Name: int64(name), Length: uint32(outLen)}}}, false /* ignoreResult */) @@ -403,8 +412,9 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outLe // SetSockOpt implements socket.Socket.SetSockOpt. func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error { // Because blocking actually happens within the sentry we need to inspect - // this socket option to determine if it's a SO_RCVTIMEO, and if so, we will - // save it and use it as the deadline for recv(2) related syscalls. + // this socket option to determine if it's a SO_RCVTIMEO or SO_SNDTIMEO, + // and if so, we will save it and use it as the deadline for recv(2) + // or send(2) related syscalls. if level == linux.SOL_SOCKET && name == linux.SO_RCVTIMEO { if len(opt) < linux.SizeOfTimeval { return syserr.ErrInvalidArgument @@ -412,9 +422,25 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [ var v linux.Timeval binary.Unmarshal(opt[:linux.SizeOfTimeval], usermem.ByteOrder, &v) + if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) { + return syserr.ErrDomain + } s.SetRecvTimeout(v.ToNsecCapped()) return nil } + if level == linux.SOL_SOCKET && name == linux.SO_SNDTIMEO { + if len(opt) < linux.SizeOfTimeval { + return syserr.ErrInvalidArgument + } + + var v linux.Timeval + binary.Unmarshal(opt[:linux.SizeOfTimeval], usermem.ByteOrder, &v) + if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) { + return syserr.ErrDomain + } + s.SetSendTimeout(v.ToNsecCapped()) + return nil + } stack := t.NetworkContext().(*Stack) id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_SetSockOpt{&pb.SetSockOptRequest{Fd: s.fd, Level: int64(level), Name: int64(name), Opt: opt}}}, false /* ignoreResult */) @@ -720,7 +746,7 @@ func rpcSendMsg(t *kernel.Task, req *pb.SyscallRequest_Sendmsg) (uint32, *syserr } // SendMsg implements socket.Socket.SendMsg. -func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) { +func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { // Whitelist flags. if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 { return 0, syserr.ErrInvalidArgument @@ -787,7 +813,10 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] return int(totalWritten), nil } - if err := t.Block(ch); err != nil { + if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { + if err == syserror.ETIMEDOUT { + return int(totalWritten), syserr.ErrTryAgain + } return int(totalWritten), syserr.FromError(err) } } |