summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorIan Gudger <igudger@google.com>2018-12-14 16:12:51 -0800
committerShentubot <shentubot@google.com>2018-12-14 16:15:06 -0800
commite1dcf92ec5cf7d9bf58fb322f46f6ae2d98699d2 (patch)
tree61ed22a594dd96bb994d748e83358c5a51212ee5
parented930354ef46df9b6feece36e59ee644a7cdfa7f (diff)
Implement SO_SNDTIMEO
PiperOrigin-RevId: 225620490 Change-Id: Ia726107b3f58093a5f881634f90b071b32d2c269
-rw-r--r--pkg/sentry/fs/host/BUILD1
-rw-r--r--pkg/sentry/fs/host/socket_test.go3
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go35
-rw-r--r--pkg/sentry/socket/hostinet/socket.go9
-rw-r--r--pkg/sentry/socket/netlink/socket.go4
-rw-r--r--pkg/sentry/socket/rpcinet/socket.go41
-rw-r--r--pkg/sentry/socket/socket.go49
-rw-r--r--pkg/sentry/socket/unix/unix.go16
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go36
-rw-r--r--test/syscalls/linux/socket_generic.cc110
-rw-r--r--test/syscalls/linux/socket_stream_blocking.cc22
-rw-r--r--test/syscalls/linux/socket_unix_non_stream.cc22
12 files changed, 307 insertions, 41 deletions
diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD
index 89d7b2fe7..73d9cc71a 100644
--- a/pkg/sentry/fs/host/BUILD
+++ b/pkg/sentry/fs/host/BUILD
@@ -71,6 +71,7 @@ go_test(
"//pkg/sentry/context",
"//pkg/sentry/context/contexttest",
"//pkg/sentry/fs",
+ "//pkg/sentry/kernel/time",
"//pkg/sentry/socket",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usermem",
diff --git a/pkg/sentry/fs/host/socket_test.go b/pkg/sentry/fs/host/socket_test.go
index 17bf397ef..6ddf63a6a 100644
--- a/pkg/sentry/fs/host/socket_test.go
+++ b/pkg/sentry/fs/host/socket_test.go
@@ -21,6 +21,7 @@ import (
"gvisor.googlesource.com/gvisor/pkg/fd"
"gvisor.googlesource.com/gvisor/pkg/sentry/context/contexttest"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
"gvisor.googlesource.com/gvisor/pkg/sentry/socket"
"gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
@@ -142,7 +143,7 @@ func TestSocketSendMsgLen0(t *testing.T) {
defer sfile.DecRef()
s := sfile.FileOperations.(socket.Socket)
- n, terr := s.SendMsg(nil, usermem.BytesIOSequence(nil), []byte{}, 0, socket.ControlMessages{})
+ n, terr := s.SendMsg(nil, usermem.BytesIOSequence(nil), []byte{}, 0, false, ktime.Time{}, socket.ControlMessages{})
if n != 0 {
t.Fatalf("socket sendmsg() failed: %v wrote: %d", terr, n)
}
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index 19af7bc45..ab5d82183 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -30,6 +30,7 @@ import (
"strings"
"sync"
"syscall"
+ "time"
"gvisor.googlesource.com/gvisor/pkg/abi/linux"
"gvisor.googlesource.com/gvisor/pkg/binary"
@@ -137,12 +138,12 @@ type commonEndpoint interface {
//
// +stateify savable
type SocketOperations struct {
- socket.ReceiveTimeout
fsutil.PipeSeek `state:"nosave"`
fsutil.NotDirReaddir `state:"nosave"`
fsutil.NoFsync `state:"nosave"`
fsutil.NoopFlush `state:"nosave"`
fsutil.NoMMap `state:"nosave"`
+ socket.SendReceiveTimeout
*waiter.Queue
family int
@@ -643,7 +644,16 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
}
return syscall.Linger{}, nil
+ case linux.SO_SNDTIMEO:
+ // TODO: Linux allows shorter lengths for partial results.
+ if outLen < linux.SizeOfTimeval {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ return linux.NsecToTimeval(s.SendTimeout()), nil
+
case linux.SO_RCVTIMEO:
+ // TODO: Linux allows shorter lengths for partial results.
if outLen < linux.SizeOfTimeval {
return nil, syserr.ErrInvalidArgument
}
@@ -833,6 +843,19 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
v := usermem.ByteOrder.Uint32(optVal)
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.PasscredOption(v)))
+ case linux.SO_SNDTIMEO:
+ if len(optVal) < linux.SizeOfTimeval {
+ return syserr.ErrInvalidArgument
+ }
+
+ var v linux.Timeval
+ binary.Unmarshal(optVal[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
+ if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
+ return syserr.ErrDomain
+ }
+ s.SetSendTimeout(v.ToNsecCapped())
+ return nil
+
case linux.SO_RCVTIMEO:
if len(optVal) < linux.SizeOfTimeval {
return syserr.ErrInvalidArgument
@@ -840,6 +863,9 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
var v linux.Timeval
binary.Unmarshal(optVal[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
+ if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
+ return syserr.ErrDomain
+ }
s.SetRecvTimeout(v.ToNsecCapped())
return nil
@@ -1365,7 +1391,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
// SendMsg implements the linux syscall sendmsg(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
// Reject Unix control messages.
if !controlMessages.Unix.Empty() {
return 0, syserr.ErrInvalidArgument
@@ -1431,7 +1457,10 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
return int(total), nil
}
- if err := t.Block(ch); err != nil {
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ return int(total), syserr.ErrTryAgain
+ }
// handleIOError will consume errors from t.Block if needed.
return int(total), syserr.FromError(err)
}
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index e4e950fbb..34281cac0 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -46,12 +46,12 @@ const (
// socketOperations implements fs.FileOperations and socket.Socket for a socket
// implemented using a host socket.
type socketOperations struct {
- socket.ReceiveTimeout
fsutil.PipeSeek `state:"nosave"`
fsutil.NotDirReaddir `state:"nosave"`
fsutil.NoFsync `state:"nosave"`
fsutil.NoopFlush `state:"nosave"`
fsutil.NoMMap `state:"nosave"`
+ socket.SendReceiveTimeout
fd int // must be O_NONBLOCK
queue waiter.Queue
@@ -418,7 +418,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
// SendMsg implements socket.Socket.SendMsg.
-func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
// Whitelist flags.
if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 {
return 0, syserr.ErrInvalidArgument
@@ -468,7 +468,10 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
panic(fmt.Sprintf("CopyInTo: got (%d, %v), wanted (0, %v)", n, err, err))
}
if ch != nil {
- if err = t.Block(ch); err != nil {
+ if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = syserror.ErrWouldBlock
+ }
break
}
} else {
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index c4798839e..0a7d4772c 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -65,12 +65,12 @@ var netlinkSocketDevice = device.NewAnonDevice()
//
// +stateify savable
type Socket struct {
- socket.ReceiveTimeout
fsutil.PipeSeek `state:"nosave"`
fsutil.NotDirReaddir `state:"nosave"`
fsutil.NoFsync `state:"nosave"`
fsutil.NoopFlush `state:"nosave"`
fsutil.NoMMap `state:"nosave"`
+ socket.SendReceiveTimeout
// ports provides netlink port allocation.
ports *port.Manager
@@ -593,7 +593,7 @@ func (s *Socket) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte,
}
// SendMsg implements socket.Socket.SendMsg.
-func (s *Socket) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+func (s *Socket) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
return s.sendMsg(t, src, to, flags, controlMessages)
}
diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go
index 90844f10f..257bc2d71 100644
--- a/pkg/sentry/socket/rpcinet/socket.go
+++ b/pkg/sentry/socket/rpcinet/socket.go
@@ -17,6 +17,7 @@ package rpcinet
import (
"sync/atomic"
"syscall"
+ "time"
"gvisor.googlesource.com/gvisor/pkg/abi/linux"
"gvisor.googlesource.com/gvisor/pkg/binary"
@@ -44,12 +45,12 @@ import (
// socketOperations implements fs.FileOperations and socket.Socket for a socket
// implemented using a host socket.
type socketOperations struct {
- socket.ReceiveTimeout
fsutil.PipeSeek `state:"nosave"`
fsutil.NotDirReaddir `state:"nosave"`
fsutil.NoFsync `state:"nosave"`
fsutil.NoopFlush `state:"nosave"`
fsutil.NoMMap `state:"nosave"`
+ socket.SendReceiveTimeout
fd uint32 // must be O_NONBLOCK
wq *waiter.Queue
@@ -379,7 +380,8 @@ func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
// GetSockOpt implements socket.Socket.GetSockOpt.
func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outLen int) (interface{}, *syserr.Error) {
- // SO_RCVTIMEO is special because blocking is performed within the sentry.
+ // SO_RCVTIMEO and SO_SNDTIMEO are special because blocking is performed
+ // within the sentry.
if level == linux.SOL_SOCKET && name == linux.SO_RCVTIMEO {
if outLen < linux.SizeOfTimeval {
return nil, syserr.ErrInvalidArgument
@@ -387,6 +389,13 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outLe
return linux.NsecToTimeval(s.RecvTimeout()), nil
}
+ if level == linux.SOL_SOCKET && name == linux.SO_SNDTIMEO {
+ if outLen < linux.SizeOfTimeval {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ return linux.NsecToTimeval(s.SendTimeout()), nil
+ }
stack := t.NetworkContext().(*Stack)
id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetSockOpt{&pb.GetSockOptRequest{Fd: s.fd, Level: int64(level), Name: int64(name), Length: uint32(outLen)}}}, false /* ignoreResult */)
@@ -403,8 +412,9 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outLe
// SetSockOpt implements socket.Socket.SetSockOpt.
func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
// Because blocking actually happens within the sentry we need to inspect
- // this socket option to determine if it's a SO_RCVTIMEO, and if so, we will
- // save it and use it as the deadline for recv(2) related syscalls.
+ // this socket option to determine if it's a SO_RCVTIMEO or SO_SNDTIMEO,
+ // and if so, we will save it and use it as the deadline for recv(2)
+ // or send(2) related syscalls.
if level == linux.SOL_SOCKET && name == linux.SO_RCVTIMEO {
if len(opt) < linux.SizeOfTimeval {
return syserr.ErrInvalidArgument
@@ -412,9 +422,25 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [
var v linux.Timeval
binary.Unmarshal(opt[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
+ if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
+ return syserr.ErrDomain
+ }
s.SetRecvTimeout(v.ToNsecCapped())
return nil
}
+ if level == linux.SOL_SOCKET && name == linux.SO_SNDTIMEO {
+ if len(opt) < linux.SizeOfTimeval {
+ return syserr.ErrInvalidArgument
+ }
+
+ var v linux.Timeval
+ binary.Unmarshal(opt[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
+ if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
+ return syserr.ErrDomain
+ }
+ s.SetSendTimeout(v.ToNsecCapped())
+ return nil
+ }
stack := t.NetworkContext().(*Stack)
id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_SetSockOpt{&pb.SetSockOptRequest{Fd: s.fd, Level: int64(level), Name: int64(name), Opt: opt}}}, false /* ignoreResult */)
@@ -720,7 +746,7 @@ func rpcSendMsg(t *kernel.Task, req *pb.SyscallRequest_Sendmsg) (uint32, *syserr
}
// SendMsg implements socket.Socket.SendMsg.
-func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
// Whitelist flags.
if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 {
return 0, syserr.ErrInvalidArgument
@@ -787,7 +813,10 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
return int(totalWritten), nil
}
- if err := t.Block(ch); err != nil {
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ return int(totalWritten), syserr.ErrTryAgain
+ }
return int(totalWritten), syserr.FromError(err)
}
}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index f73127ea6..9d4aaeb9d 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -94,15 +94,23 @@ type Socket interface {
// ownership of the ControlMessage on error.
//
// If n > 0, err will either be nil or an error from t.Block.
- SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages ControlMessages) (n int, err *syserr.Error)
+ SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages ControlMessages) (n int, err *syserr.Error)
// SetRecvTimeout sets the timeout (in ns) for recv operations. Zero means
- // no timeout.
+ // no timeout, and negative means DONTWAIT.
SetRecvTimeout(nanoseconds int64)
// RecvTimeout gets the current timeout (in ns) for recv operations. Zero
- // means no timeout.
+ // means no timeout, and negative means DONTWAIT.
RecvTimeout() int64
+
+ // SetSendTimeout sets the timeout (in ns) for send operations. Zero means
+ // no timeout, and negative means DONTWAIT.
+ SetSendTimeout(nanoseconds int64)
+
+ // SendTimeout gets the current timeout (in ns) for send operations. Zero
+ // means no timeout, and negative means DONTWAIT.
+ SendTimeout() int64
}
// Provider is the interface implemented by providers of sockets for specific
@@ -192,30 +200,45 @@ func NewDirent(ctx context.Context, d *device.Device) *fs.Dirent {
return fs.NewDirent(inode, fmt.Sprintf("socket:[%d]", ino))
}
-// ReceiveTimeout stores a timeout for receive calls.
+// SendReceiveTimeout stores timeouts for send and receive calls.
//
// It is meant to be embedded into Socket implementations to help satisfy the
// interface.
//
-// Care must be taken when copying ReceiveTimeout as it contains atomic
+// Care must be taken when copying SendReceiveTimeout as it contains atomic
// variables.
//
// +stateify savable
-type ReceiveTimeout struct {
- // ns is length of the timeout in nanoseconds.
+type SendReceiveTimeout struct {
+ // send is length of the send timeout in nanoseconds.
+ //
+ // send must be accessed atomically.
+ send int64
+
+ // recv is length of the receive timeout in nanoseconds.
//
- // ns must be accessed atomically.
- ns int64
+ // recv must be accessed atomically.
+ recv int64
}
// SetRecvTimeout implements Socket.SetRecvTimeout.
-func (rt *ReceiveTimeout) SetRecvTimeout(nanoseconds int64) {
- atomic.StoreInt64(&rt.ns, nanoseconds)
+func (to *SendReceiveTimeout) SetRecvTimeout(nanoseconds int64) {
+ atomic.StoreInt64(&to.recv, nanoseconds)
}
// RecvTimeout implements Socket.RecvTimeout.
-func (rt *ReceiveTimeout) RecvTimeout() int64 {
- return atomic.LoadInt64(&rt.ns)
+func (to *SendReceiveTimeout) RecvTimeout() int64 {
+ return atomic.LoadInt64(&to.recv)
+}
+
+// SetSendTimeout implements Socket.SetSendTimeout.
+func (to *SendReceiveTimeout) SetSendTimeout(nanoseconds int64) {
+ atomic.StoreInt64(&to.send, nanoseconds)
+}
+
+// SendTimeout implements Socket.SendTimeout.
+func (to *SendReceiveTimeout) SendTimeout() int64 {
+ return atomic.LoadInt64(&to.send)
}
// GetSockOptEmitUnimplementedEvent emits unimplemented event if name is valid.
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 4c9dcbd61..da225eabb 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -45,15 +45,16 @@ import (
//
// +stateify savable
type SocketOperations struct {
- refs.AtomicRefCount
- socket.ReceiveTimeout
fsutil.PipeSeek `state:"nosave"`
fsutil.NotDirReaddir `state:"nosave"`
fsutil.NoFsync `state:"nosave"`
fsutil.NoopFlush `state:"nosave"`
fsutil.NoMMap `state:"nosave"`
- ep transport.Endpoint
- isPacket bool
+ refs.AtomicRefCount
+ socket.SendReceiveTimeout
+
+ ep transport.Endpoint
+ isPacket bool
}
// New creates a new unix socket.
@@ -367,7 +368,7 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
// SendMsg implements the linux syscall sendmsg(2) for unix sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
w := EndpointWriter{
Endpoint: s.ep,
Control: controlMessages.Unix,
@@ -404,7 +405,10 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
break
}
- if err := t.Block(ch); err != nil {
+ if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = syserror.ErrWouldBlock
+ }
break
}
}
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 1165d4566..3049fe6e5 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -612,9 +612,11 @@ func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
var haveDeadline bool
var deadline ktime.Time
- if dl := s.RecvTimeout(); dl != 0 {
+ if dl := s.RecvTimeout(); dl > 0 {
deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
}
n, err := recvSingleMsg(t, s, msgPtr, flags, haveDeadline, deadline)
@@ -671,10 +673,11 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
if !haveDeadline {
- dl := s.RecvTimeout()
- if dl != 0 {
+ if dl := s.RecvTimeout(); dl > 0 {
deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
}
}
@@ -821,10 +824,11 @@ func recvFrom(t *kernel.Task, fd kdefs.FD, bufPtr usermem.Addr, bufLen uint64, f
var haveDeadline bool
var deadline ktime.Time
-
- if dl := s.RecvTimeout(); dl != 0 {
+ if dl := s.RecvTimeout(); dl > 0 {
deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
}
n, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0)
@@ -1001,8 +1005,17 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme
return 0, err
}
+ var haveDeadline bool
+ var deadline ktime.Time
+ if dl := s.SendTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
// Call the syscall implementation.
- n, e := s.SendMsg(t, src, to, int(flags), socket.ControlMessages{Unix: controlMessages})
+ n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: controlMessages})
err = handleIOError(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendmsg", file)
if err != nil {
controlMessages.Release()
@@ -1052,8 +1065,17 @@ func sendTo(t *kernel.Task, fd kdefs.FD, bufPtr usermem.Addr, bufLen uint64, fla
return 0, err
}
+ var haveDeadline bool
+ var deadline ktime.Time
+ if dl := s.SendTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
// Call the syscall implementation.
- n, e := s.SendMsg(t, src, to, int(flags), socket.ControlMessages{Unix: control.New(t, s, nil)})
+ n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: control.New(t, s, nil)})
return uintptr(n), handleIOError(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendto", file)
}
diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc
index a9edbb950..c65b29112 100644
--- a/test/syscalls/linux/socket_generic.cc
+++ b/test/syscalls/linux/socket_generic.cc
@@ -332,6 +332,35 @@ TEST_P(AllSocketPairTest, RecvmsgTimeoutSucceeds) {
SyscallFailsWithErrno(EAGAIN));
}
+TEST_P(AllSocketPairTest, SendTimeoutAllowsSend) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = 0, .tv_usec = 10
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ char buf[20] = {};
+ ASSERT_THAT(RetryEINTR(send)(sockets->first_fd(), buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(buf)));
+}
+
+TEST_P(AllSocketPairTest, SendmsgTimeoutAllowsSend) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = 0, .tv_usec = 10
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ char buf[20] = {};
+ ASSERT_NO_FATAL_FAILURE(SendNullCmsg(sockets->first_fd(), buf, sizeof(buf)));
+}
+
TEST_P(AllSocketPairTest, SoRcvTimeoIsSet) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
@@ -382,6 +411,87 @@ TEST_P(AllSocketPairTest, RecvmsgTimeoutOneSecondSucceeds) {
SyscallFailsWithErrno(EAGAIN));
}
+TEST_P(AllSocketPairTest, RecvTimeoutUsecTooLarge) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = 0, .tv_usec = 2000000 // 2 seconds.
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),
+ SyscallFailsWithErrno(EDOM));
+}
+
+TEST_P(AllSocketPairTest, SendTimeoutUsecTooLarge) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = 0, .tv_usec = 2000000 // 2 seconds.
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)),
+ SyscallFailsWithErrno(EDOM));
+}
+
+TEST_P(AllSocketPairTest, RecvTimeoutUsecNeg) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = 0, .tv_usec = -1
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),
+ SyscallFailsWithErrno(EDOM));
+}
+
+TEST_P(AllSocketPairTest, SendTimeoutUsecNeg) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = 0, .tv_usec = -1
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)),
+ SyscallFailsWithErrno(EDOM));
+}
+
+TEST_P(AllSocketPairTest, RecvTimeoutNegSec) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = -1, .tv_usec = 0
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ char buf[20] = {};
+ EXPECT_THAT(RetryEINTR(recv)(sockets->first_fd(), buf, sizeof(buf), 0),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST_P(AllSocketPairTest, RecvmsgTimeoutNegSec) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = -1, .tv_usec = 0
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ struct msghdr msg = {};
+ char buf[20] = {};
+ struct iovec iov;
+ iov.iov_base = buf;
+ iov.iov_len = sizeof(buf);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ EXPECT_THAT(RetryEINTR(recvmsg)(sockets->first_fd(), &msg, 0),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
TEST_P(AllSocketPairTest, RecvWaitAll) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
diff --git a/test/syscalls/linux/socket_stream_blocking.cc b/test/syscalls/linux/socket_stream_blocking.cc
index 3fbbe54d8..6cfadc9da 100644
--- a/test/syscalls/linux/socket_stream_blocking.cc
+++ b/test/syscalls/linux/socket_stream_blocking.cc
@@ -125,5 +125,27 @@ TEST_P(BlockingStreamSocketPairTest, RecvLessThanBufferWaitAll) {
EXPECT_GE(after - before, kDuration);
}
+TEST_P(BlockingStreamSocketPairTest, SendTimeout) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = 0, .tv_usec = 10
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ char buf[100] = {};
+ for (;;) {
+ int ret;
+ ASSERT_THAT(
+ ret = RetryEINTR(send)(sockets->first_fd(), buf, sizeof(buf), 0),
+ ::testing::AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(EAGAIN)));
+ if (ret == -1) {
+ break;
+ }
+ }
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_non_stream.cc b/test/syscalls/linux/socket_unix_non_stream.cc
index 620397746..264b7fe6a 100644
--- a/test/syscalls/linux/socket_unix_non_stream.cc
+++ b/test/syscalls/linux/socket_unix_non_stream.cc
@@ -225,5 +225,27 @@ TEST_P(UnixNonStreamSocketPairTest, FragmentedRecvMsg) {
EXPECT_EQ(0, memcmp(write_buf.data(), ptr, buffer_size));
}
+TEST_P(UnixNonStreamSocketPairTest, SendTimeout) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval tv {
+ .tv_sec = 0, .tv_usec = 10
+ };
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ char buf[100] = {};
+ for (;;) {
+ int ret;
+ ASSERT_THAT(
+ ret = RetryEINTR(send)(sockets->first_fd(), buf, sizeof(buf), 0),
+ ::testing::AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(EAGAIN)));
+ if (ret == -1) {
+ break;
+ }
+ }
+}
+
} // namespace testing
} // namespace gvisor