diff options
author | Ian Gudger <igudger@google.com> | 2018-12-14 16:12:51 -0800 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2018-12-14 16:15:06 -0800 |
commit | e1dcf92ec5cf7d9bf58fb322f46f6ae2d98699d2 (patch) | |
tree | 61ed22a594dd96bb994d748e83358c5a51212ee5 | |
parent | ed930354ef46df9b6feece36e59ee644a7cdfa7f (diff) |
Implement SO_SNDTIMEO
PiperOrigin-RevId: 225620490
Change-Id: Ia726107b3f58093a5f881634f90b071b32d2c269
-rw-r--r-- | pkg/sentry/fs/host/BUILD | 1 | ||||
-rw-r--r-- | pkg/sentry/fs/host/socket_test.go | 3 | ||||
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 35 | ||||
-rw-r--r-- | pkg/sentry/socket/hostinet/socket.go | 9 | ||||
-rw-r--r-- | pkg/sentry/socket/netlink/socket.go | 4 | ||||
-rw-r--r-- | pkg/sentry/socket/rpcinet/socket.go | 41 | ||||
-rw-r--r-- | pkg/sentry/socket/socket.go | 49 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/unix.go | 16 | ||||
-rw-r--r-- | pkg/sentry/syscalls/linux/sys_socket.go | 36 | ||||
-rw-r--r-- | test/syscalls/linux/socket_generic.cc | 110 | ||||
-rw-r--r-- | test/syscalls/linux/socket_stream_blocking.cc | 22 | ||||
-rw-r--r-- | test/syscalls/linux/socket_unix_non_stream.cc | 22 |
12 files changed, 307 insertions, 41 deletions
diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index 89d7b2fe7..73d9cc71a 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -71,6 +71,7 @@ go_test( "//pkg/sentry/context", "//pkg/sentry/context/contexttest", "//pkg/sentry/fs", + "//pkg/sentry/kernel/time", "//pkg/sentry/socket", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", diff --git a/pkg/sentry/fs/host/socket_test.go b/pkg/sentry/fs/host/socket_test.go index 17bf397ef..6ddf63a6a 100644 --- a/pkg/sentry/fs/host/socket_test.go +++ b/pkg/sentry/fs/host/socket_test.go @@ -21,6 +21,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/fd" "gvisor.googlesource.com/gvisor/pkg/sentry/context/contexttest" + ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time" "gvisor.googlesource.com/gvisor/pkg/sentry/socket" "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" @@ -142,7 +143,7 @@ func TestSocketSendMsgLen0(t *testing.T) { defer sfile.DecRef() s := sfile.FileOperations.(socket.Socket) - n, terr := s.SendMsg(nil, usermem.BytesIOSequence(nil), []byte{}, 0, socket.ControlMessages{}) + n, terr := s.SendMsg(nil, usermem.BytesIOSequence(nil), []byte{}, 0, false, ktime.Time{}, socket.ControlMessages{}) if n != 0 { t.Fatalf("socket sendmsg() failed: %v wrote: %d", terr, n) } diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 19af7bc45..ab5d82183 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -30,6 +30,7 @@ import ( "strings" "sync" "syscall" + "time" "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/binary" @@ -137,12 +138,12 @@ type commonEndpoint interface { // // +stateify savable type SocketOperations struct { - socket.ReceiveTimeout fsutil.PipeSeek `state:"nosave"` fsutil.NotDirReaddir `state:"nosave"` fsutil.NoFsync `state:"nosave"` fsutil.NoopFlush `state:"nosave"` fsutil.NoMMap `state:"nosave"` + socket.SendReceiveTimeout *waiter.Queue family int @@ -643,7 +644,16 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family } return syscall.Linger{}, nil + case linux.SO_SNDTIMEO: + // TODO: Linux allows shorter lengths for partial results. + if outLen < linux.SizeOfTimeval { + return nil, syserr.ErrInvalidArgument + } + + return linux.NsecToTimeval(s.SendTimeout()), nil + case linux.SO_RCVTIMEO: + // TODO: Linux allows shorter lengths for partial results. if outLen < linux.SizeOfTimeval { return nil, syserr.ErrInvalidArgument } @@ -833,6 +843,19 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i v := usermem.ByteOrder.Uint32(optVal) return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.PasscredOption(v))) + case linux.SO_SNDTIMEO: + if len(optVal) < linux.SizeOfTimeval { + return syserr.ErrInvalidArgument + } + + var v linux.Timeval + binary.Unmarshal(optVal[:linux.SizeOfTimeval], usermem.ByteOrder, &v) + if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) { + return syserr.ErrDomain + } + s.SetSendTimeout(v.ToNsecCapped()) + return nil + case linux.SO_RCVTIMEO: if len(optVal) < linux.SizeOfTimeval { return syserr.ErrInvalidArgument @@ -840,6 +863,9 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i var v linux.Timeval binary.Unmarshal(optVal[:linux.SizeOfTimeval], usermem.ByteOrder, &v) + if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) { + return syserr.ErrDomain + } s.SetRecvTimeout(v.ToNsecCapped()) return nil @@ -1365,7 +1391,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags // SendMsg implements the linux syscall sendmsg(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) { +func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { // Reject Unix control messages. if !controlMessages.Unix.Empty() { return 0, syserr.ErrInvalidArgument @@ -1431,7 +1457,10 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] return int(total), nil } - if err := t.Block(ch); err != nil { + if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { + if err == syserror.ETIMEDOUT { + return int(total), syserr.ErrTryAgain + } // handleIOError will consume errors from t.Block if needed. return int(total), syserr.FromError(err) } diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index e4e950fbb..34281cac0 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -46,12 +46,12 @@ const ( // socketOperations implements fs.FileOperations and socket.Socket for a socket // implemented using a host socket. type socketOperations struct { - socket.ReceiveTimeout fsutil.PipeSeek `state:"nosave"` fsutil.NotDirReaddir `state:"nosave"` fsutil.NoFsync `state:"nosave"` fsutil.NoopFlush `state:"nosave"` fsutil.NoMMap `state:"nosave"` + socket.SendReceiveTimeout fd int // must be O_NONBLOCK queue waiter.Queue @@ -418,7 +418,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } // SendMsg implements socket.Socket.SendMsg. -func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) { +func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { // Whitelist flags. if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 { return 0, syserr.ErrInvalidArgument @@ -468,7 +468,10 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] panic(fmt.Sprintf("CopyInTo: got (%d, %v), wanted (0, %v)", n, err, err)) } if ch != nil { - if err = t.Block(ch); err != nil { + if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { + if err == syserror.ETIMEDOUT { + err = syserror.ErrWouldBlock + } break } } else { diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index c4798839e..0a7d4772c 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -65,12 +65,12 @@ var netlinkSocketDevice = device.NewAnonDevice() // // +stateify savable type Socket struct { - socket.ReceiveTimeout fsutil.PipeSeek `state:"nosave"` fsutil.NotDirReaddir `state:"nosave"` fsutil.NoFsync `state:"nosave"` fsutil.NoopFlush `state:"nosave"` fsutil.NoMMap `state:"nosave"` + socket.SendReceiveTimeout // ports provides netlink port allocation. ports *port.Manager @@ -593,7 +593,7 @@ func (s *Socket) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte, } // SendMsg implements socket.Socket.SendMsg. -func (s *Socket) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) { +func (s *Socket) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { return s.sendMsg(t, src, to, flags, controlMessages) } diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index 90844f10f..257bc2d71 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -17,6 +17,7 @@ package rpcinet import ( "sync/atomic" "syscall" + "time" "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/binary" @@ -44,12 +45,12 @@ import ( // socketOperations implements fs.FileOperations and socket.Socket for a socket // implemented using a host socket. type socketOperations struct { - socket.ReceiveTimeout fsutil.PipeSeek `state:"nosave"` fsutil.NotDirReaddir `state:"nosave"` fsutil.NoFsync `state:"nosave"` fsutil.NoopFlush `state:"nosave"` fsutil.NoMMap `state:"nosave"` + socket.SendReceiveTimeout fd uint32 // must be O_NONBLOCK wq *waiter.Queue @@ -379,7 +380,8 @@ func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { // GetSockOpt implements socket.Socket.GetSockOpt. func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outLen int) (interface{}, *syserr.Error) { - // SO_RCVTIMEO is special because blocking is performed within the sentry. + // SO_RCVTIMEO and SO_SNDTIMEO are special because blocking is performed + // within the sentry. if level == linux.SOL_SOCKET && name == linux.SO_RCVTIMEO { if outLen < linux.SizeOfTimeval { return nil, syserr.ErrInvalidArgument @@ -387,6 +389,13 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outLe return linux.NsecToTimeval(s.RecvTimeout()), nil } + if level == linux.SOL_SOCKET && name == linux.SO_SNDTIMEO { + if outLen < linux.SizeOfTimeval { + return nil, syserr.ErrInvalidArgument + } + + return linux.NsecToTimeval(s.SendTimeout()), nil + } stack := t.NetworkContext().(*Stack) id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetSockOpt{&pb.GetSockOptRequest{Fd: s.fd, Level: int64(level), Name: int64(name), Length: uint32(outLen)}}}, false /* ignoreResult */) @@ -403,8 +412,9 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outLe // SetSockOpt implements socket.Socket.SetSockOpt. func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error { // Because blocking actually happens within the sentry we need to inspect - // this socket option to determine if it's a SO_RCVTIMEO, and if so, we will - // save it and use it as the deadline for recv(2) related syscalls. + // this socket option to determine if it's a SO_RCVTIMEO or SO_SNDTIMEO, + // and if so, we will save it and use it as the deadline for recv(2) + // or send(2) related syscalls. if level == linux.SOL_SOCKET && name == linux.SO_RCVTIMEO { if len(opt) < linux.SizeOfTimeval { return syserr.ErrInvalidArgument @@ -412,9 +422,25 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [ var v linux.Timeval binary.Unmarshal(opt[:linux.SizeOfTimeval], usermem.ByteOrder, &v) + if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) { + return syserr.ErrDomain + } s.SetRecvTimeout(v.ToNsecCapped()) return nil } + if level == linux.SOL_SOCKET && name == linux.SO_SNDTIMEO { + if len(opt) < linux.SizeOfTimeval { + return syserr.ErrInvalidArgument + } + + var v linux.Timeval + binary.Unmarshal(opt[:linux.SizeOfTimeval], usermem.ByteOrder, &v) + if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) { + return syserr.ErrDomain + } + s.SetSendTimeout(v.ToNsecCapped()) + return nil + } stack := t.NetworkContext().(*Stack) id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_SetSockOpt{&pb.SetSockOptRequest{Fd: s.fd, Level: int64(level), Name: int64(name), Opt: opt}}}, false /* ignoreResult */) @@ -720,7 +746,7 @@ func rpcSendMsg(t *kernel.Task, req *pb.SyscallRequest_Sendmsg) (uint32, *syserr } // SendMsg implements socket.Socket.SendMsg. -func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) { +func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { // Whitelist flags. if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 { return 0, syserr.ErrInvalidArgument @@ -787,7 +813,10 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] return int(totalWritten), nil } - if err := t.Block(ch); err != nil { + if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { + if err == syserror.ETIMEDOUT { + return int(totalWritten), syserr.ErrTryAgain + } return int(totalWritten), syserr.FromError(err) } } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index f73127ea6..9d4aaeb9d 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -94,15 +94,23 @@ type Socket interface { // ownership of the ControlMessage on error. // // If n > 0, err will either be nil or an error from t.Block. - SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages ControlMessages) (n int, err *syserr.Error) + SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages ControlMessages) (n int, err *syserr.Error) // SetRecvTimeout sets the timeout (in ns) for recv operations. Zero means - // no timeout. + // no timeout, and negative means DONTWAIT. SetRecvTimeout(nanoseconds int64) // RecvTimeout gets the current timeout (in ns) for recv operations. Zero - // means no timeout. + // means no timeout, and negative means DONTWAIT. RecvTimeout() int64 + + // SetSendTimeout sets the timeout (in ns) for send operations. Zero means + // no timeout, and negative means DONTWAIT. + SetSendTimeout(nanoseconds int64) + + // SendTimeout gets the current timeout (in ns) for send operations. Zero + // means no timeout, and negative means DONTWAIT. + SendTimeout() int64 } // Provider is the interface implemented by providers of sockets for specific @@ -192,30 +200,45 @@ func NewDirent(ctx context.Context, d *device.Device) *fs.Dirent { return fs.NewDirent(inode, fmt.Sprintf("socket:[%d]", ino)) } -// ReceiveTimeout stores a timeout for receive calls. +// SendReceiveTimeout stores timeouts for send and receive calls. // // It is meant to be embedded into Socket implementations to help satisfy the // interface. // -// Care must be taken when copying ReceiveTimeout as it contains atomic +// Care must be taken when copying SendReceiveTimeout as it contains atomic // variables. // // +stateify savable -type ReceiveTimeout struct { - // ns is length of the timeout in nanoseconds. +type SendReceiveTimeout struct { + // send is length of the send timeout in nanoseconds. + // + // send must be accessed atomically. + send int64 + + // recv is length of the receive timeout in nanoseconds. // - // ns must be accessed atomically. - ns int64 + // recv must be accessed atomically. + recv int64 } // SetRecvTimeout implements Socket.SetRecvTimeout. -func (rt *ReceiveTimeout) SetRecvTimeout(nanoseconds int64) { - atomic.StoreInt64(&rt.ns, nanoseconds) +func (to *SendReceiveTimeout) SetRecvTimeout(nanoseconds int64) { + atomic.StoreInt64(&to.recv, nanoseconds) } // RecvTimeout implements Socket.RecvTimeout. -func (rt *ReceiveTimeout) RecvTimeout() int64 { - return atomic.LoadInt64(&rt.ns) +func (to *SendReceiveTimeout) RecvTimeout() int64 { + return atomic.LoadInt64(&to.recv) +} + +// SetSendTimeout implements Socket.SetSendTimeout. +func (to *SendReceiveTimeout) SetSendTimeout(nanoseconds int64) { + atomic.StoreInt64(&to.send, nanoseconds) +} + +// SendTimeout implements Socket.SendTimeout. +func (to *SendReceiveTimeout) SendTimeout() int64 { + return atomic.LoadInt64(&to.send) } // GetSockOptEmitUnimplementedEvent emits unimplemented event if name is valid. diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 4c9dcbd61..da225eabb 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -45,15 +45,16 @@ import ( // // +stateify savable type SocketOperations struct { - refs.AtomicRefCount - socket.ReceiveTimeout fsutil.PipeSeek `state:"nosave"` fsutil.NotDirReaddir `state:"nosave"` fsutil.NoFsync `state:"nosave"` fsutil.NoopFlush `state:"nosave"` fsutil.NoMMap `state:"nosave"` - ep transport.Endpoint - isPacket bool + refs.AtomicRefCount + socket.SendReceiveTimeout + + ep transport.Endpoint + isPacket bool } // New creates a new unix socket. @@ -367,7 +368,7 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO // SendMsg implements the linux syscall sendmsg(2) for unix sockets backed by // a transport.Endpoint. -func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) { +func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { w := EndpointWriter{ Endpoint: s.ep, Control: controlMessages.Unix, @@ -404,7 +405,10 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] break } - if err := t.Block(ch); err != nil { + if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { + if err == syserror.ETIMEDOUT { + err = syserror.ErrWouldBlock + } break } } diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 1165d4566..3049fe6e5 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -612,9 +612,11 @@ func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca var haveDeadline bool var deadline ktime.Time - if dl := s.RecvTimeout(); dl != 0 { + if dl := s.RecvTimeout(); dl > 0 { deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond) haveDeadline = true + } else if dl < 0 { + flags |= linux.MSG_DONTWAIT } n, err := recvSingleMsg(t, s, msgPtr, flags, haveDeadline, deadline) @@ -671,10 +673,11 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } if !haveDeadline { - dl := s.RecvTimeout() - if dl != 0 { + if dl := s.RecvTimeout(); dl > 0 { deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond) haveDeadline = true + } else if dl < 0 { + flags |= linux.MSG_DONTWAIT } } @@ -821,10 +824,11 @@ func recvFrom(t *kernel.Task, fd kdefs.FD, bufPtr usermem.Addr, bufLen uint64, f var haveDeadline bool var deadline ktime.Time - - if dl := s.RecvTimeout(); dl != 0 { + if dl := s.RecvTimeout(); dl > 0 { deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond) haveDeadline = true + } else if dl < 0 { + flags |= linux.MSG_DONTWAIT } n, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0) @@ -1001,8 +1005,17 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme return 0, err } + var haveDeadline bool + var deadline ktime.Time + if dl := s.SendTimeout(); dl > 0 { + deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond) + haveDeadline = true + } else if dl < 0 { + flags |= linux.MSG_DONTWAIT + } + // Call the syscall implementation. - n, e := s.SendMsg(t, src, to, int(flags), socket.ControlMessages{Unix: controlMessages}) + n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: controlMessages}) err = handleIOError(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendmsg", file) if err != nil { controlMessages.Release() @@ -1052,8 +1065,17 @@ func sendTo(t *kernel.Task, fd kdefs.FD, bufPtr usermem.Addr, bufLen uint64, fla return 0, err } + var haveDeadline bool + var deadline ktime.Time + if dl := s.SendTimeout(); dl > 0 { + deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond) + haveDeadline = true + } else if dl < 0 { + flags |= linux.MSG_DONTWAIT + } + // Call the syscall implementation. - n, e := s.SendMsg(t, src, to, int(flags), socket.ControlMessages{Unix: control.New(t, s, nil)}) + n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: control.New(t, s, nil)}) return uintptr(n), handleIOError(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendto", file) } diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc index a9edbb950..c65b29112 100644 --- a/test/syscalls/linux/socket_generic.cc +++ b/test/syscalls/linux/socket_generic.cc @@ -332,6 +332,35 @@ TEST_P(AllSocketPairTest, RecvmsgTimeoutSucceeds) { SyscallFailsWithErrno(EAGAIN)); } +TEST_P(AllSocketPairTest, SendTimeoutAllowsSend) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + struct timeval tv { + .tv_sec = 0, .tv_usec = 10 + }; + EXPECT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)), + SyscallSucceeds()); + + char buf[20] = {}; + ASSERT_THAT(RetryEINTR(send)(sockets->first_fd(), buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(buf))); +} + +TEST_P(AllSocketPairTest, SendmsgTimeoutAllowsSend) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + struct timeval tv { + .tv_sec = 0, .tv_usec = 10 + }; + EXPECT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)), + SyscallSucceeds()); + + char buf[20] = {}; + ASSERT_NO_FATAL_FAILURE(SendNullCmsg(sockets->first_fd(), buf, sizeof(buf))); +} + TEST_P(AllSocketPairTest, SoRcvTimeoIsSet) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -382,6 +411,87 @@ TEST_P(AllSocketPairTest, RecvmsgTimeoutOneSecondSucceeds) { SyscallFailsWithErrno(EAGAIN)); } +TEST_P(AllSocketPairTest, RecvTimeoutUsecTooLarge) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + struct timeval tv { + .tv_sec = 0, .tv_usec = 2000000 // 2 seconds. + }; + EXPECT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)), + SyscallFailsWithErrno(EDOM)); +} + +TEST_P(AllSocketPairTest, SendTimeoutUsecTooLarge) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + struct timeval tv { + .tv_sec = 0, .tv_usec = 2000000 // 2 seconds. + }; + EXPECT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)), + SyscallFailsWithErrno(EDOM)); +} + +TEST_P(AllSocketPairTest, RecvTimeoutUsecNeg) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + struct timeval tv { + .tv_sec = 0, .tv_usec = -1 + }; + EXPECT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)), + SyscallFailsWithErrno(EDOM)); +} + +TEST_P(AllSocketPairTest, SendTimeoutUsecNeg) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + struct timeval tv { + .tv_sec = 0, .tv_usec = -1 + }; + EXPECT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)), + SyscallFailsWithErrno(EDOM)); +} + +TEST_P(AllSocketPairTest, RecvTimeoutNegSec) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + struct timeval tv { + .tv_sec = -1, .tv_usec = 0 + }; + EXPECT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)), + SyscallSucceeds()); + + char buf[20] = {}; + EXPECT_THAT(RetryEINTR(recv)(sockets->first_fd(), buf, sizeof(buf), 0), + SyscallFailsWithErrno(EAGAIN)); +} + +TEST_P(AllSocketPairTest, RecvmsgTimeoutNegSec) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + struct timeval tv { + .tv_sec = -1, .tv_usec = 0 + }; + EXPECT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)), + SyscallSucceeds()); + + struct msghdr msg = {}; + char buf[20] = {}; + struct iovec iov; + iov.iov_base = buf; + iov.iov_len = sizeof(buf); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + EXPECT_THAT(RetryEINTR(recvmsg)(sockets->first_fd(), &msg, 0), + SyscallFailsWithErrno(EAGAIN)); +} + TEST_P(AllSocketPairTest, RecvWaitAll) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); diff --git a/test/syscalls/linux/socket_stream_blocking.cc b/test/syscalls/linux/socket_stream_blocking.cc index 3fbbe54d8..6cfadc9da 100644 --- a/test/syscalls/linux/socket_stream_blocking.cc +++ b/test/syscalls/linux/socket_stream_blocking.cc @@ -125,5 +125,27 @@ TEST_P(BlockingStreamSocketPairTest, RecvLessThanBufferWaitAll) { EXPECT_GE(after - before, kDuration); } +TEST_P(BlockingStreamSocketPairTest, SendTimeout) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + struct timeval tv { + .tv_sec = 0, .tv_usec = 10 + }; + EXPECT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)), + SyscallSucceeds()); + + char buf[100] = {}; + for (;;) { + int ret; + ASSERT_THAT( + ret = RetryEINTR(send)(sockets->first_fd(), buf, sizeof(buf), 0), + ::testing::AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(EAGAIN))); + if (ret == -1) { + break; + } + } +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_non_stream.cc b/test/syscalls/linux/socket_unix_non_stream.cc index 620397746..264b7fe6a 100644 --- a/test/syscalls/linux/socket_unix_non_stream.cc +++ b/test/syscalls/linux/socket_unix_non_stream.cc @@ -225,5 +225,27 @@ TEST_P(UnixNonStreamSocketPairTest, FragmentedRecvMsg) { EXPECT_EQ(0, memcmp(write_buf.data(), ptr, buffer_size)); } +TEST_P(UnixNonStreamSocketPairTest, SendTimeout) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + struct timeval tv { + .tv_sec = 0, .tv_usec = 10 + }; + EXPECT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)), + SyscallSucceeds()); + + char buf[100] = {}; + for (;;) { + int ret; + ASSERT_THAT( + ret = RetryEINTR(send)(sockets->first_fd(), buf, sizeof(buf), 0), + ::testing::AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(EAGAIN))); + if (ret == -1) { + break; + } + } +} + } // namespace testing } // namespace gvisor |