summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go27
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go85
-rw-r--r--test/syscalls/linux/BUILD4
-rw-r--r--test/syscalls/linux/partial_bad_buffer.cc110
4 files changed, 187 insertions, 39 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index 8cb5c823f..0f2cd05fc 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -429,6 +429,11 @@ func (i *ioSequencePayload) Size() int {
return int(i.src.NumBytes())
}
+// DropFirst drops the first n bytes from underlying src.
+func (i *ioSequencePayload) DropFirst(n int) {
+ i.src = i.src.DropFirst(int(n))
+}
+
// Write implements fs.FileOperations.Write.
func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
f := &ioSequencePayload{ctx: ctx, src: src}
@@ -2026,28 +2031,22 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
addr = &addrBuf
}
- v := buffer.NewView(int(src.NumBytes()))
-
- // Copy all the data into the buffer.
- if _, err := src.CopyIn(t, v); err != nil {
- return 0, syserr.FromError(err)
- }
-
opts := tcpip.WriteOptions{
To: addr,
More: flags&linux.MSG_MORE != 0,
EndOfRecord: flags&linux.MSG_EOR != 0,
}
- n, resCh, err := s.Endpoint.Write(tcpip.SlicePayload(v), opts)
+ v := &ioSequencePayload{t, src}
+ n, resCh, err := s.Endpoint.Write(v, opts)
if resCh != nil {
if err := t.Block(resCh); err != nil {
return 0, syserr.FromError(err)
}
- n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts)
+ n, _, err = s.Endpoint.Write(v, opts)
}
dontWait := flags&linux.MSG_DONTWAIT != 0
- if err == nil && (n >= uintptr(len(v)) || dontWait) {
+ if err == nil && (n >= uintptr(v.Size()) || dontWait) {
// Complete write.
return int(n), nil
}
@@ -2061,18 +2060,18 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
s.EventRegister(&e, waiter.EventOut)
defer s.EventUnregister(&e)
- v.TrimFront(int(n))
+ v.DropFirst(int(n))
total := n
for {
- n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts)
- v.TrimFront(int(n))
+ n, _, err = s.Endpoint.Write(v, opts)
+ v.DropFirst(int(n))
total += n
if err != nil && err != tcpip.ErrWouldBlock && total == 0 {
return 0, syserr.TranslateNetstackError(err)
}
- if err == nil && len(v) == 0 || err != nil && err != tcpip.ErrWouldBlock {
+ if err == nil && v.Size() == 0 || err != nil && err != tcpip.ErrWouldBlock {
return int(total), nil
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index e67169111..7c42a830a 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -878,60 +878,95 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
return v, nil
}
-// Write writes data to the endpoint's peer.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
- // Linux completely ignores any address passed to sendto(2) for TCP sockets
- // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
- // and opts.EndOfRecord are also ignored.
-
- e.mu.RLock()
- defer e.mu.RUnlock()
-
+// isEndpointWritableLocked checks if a given endpoint is writable
+// and also returns the number of bytes that can be written at this
+// moment. If the endpoint is not writable then it returns an error
+// indicating the reason why it's not writable.
+// Caller must hold e.mu and e.sndBufMu
+func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
// The endpoint cannot be written to if it's not connected.
if !e.state.connected() {
switch e.state {
case StateError:
- return 0, nil, e.hardError
+ return 0, e.hardError
default:
- return 0, nil, tcpip.ErrClosedForSend
+ return 0, tcpip.ErrClosedForSend
}
}
- // Nothing to do if the buffer is empty.
- if p.Size() == 0 {
- return 0, nil, nil
- }
-
- e.sndBufMu.Lock()
-
// Check if the connection has already been closed for sends.
if e.sndClosed {
- e.sndBufMu.Unlock()
- return 0, nil, tcpip.ErrClosedForSend
+ return 0, tcpip.ErrClosedForSend
}
- // Check against the limit.
avail := e.sndBufSize - e.sndBufUsed
if avail <= 0 {
+ return 0, tcpip.ErrWouldBlock
+ }
+ return avail, nil
+}
+
+// Write writes data to the endpoint's peer.
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
+ // Linux completely ignores any address passed to sendto(2) for TCP sockets
+ // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
+ // and opts.EndOfRecord are also ignored.
+
+ e.mu.RLock()
+ e.sndBufMu.Lock()
+
+ avail, err := e.isEndpointWritableLocked()
+ if err != nil {
e.sndBufMu.Unlock()
- return 0, nil, tcpip.ErrWouldBlock
+ e.mu.RUnlock()
+ return 0, nil, err
+ }
+
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+
+ // Nothing to do if the buffer is empty.
+ if p.Size() == 0 {
+ return 0, nil, nil
}
+ // Copy in memory without holding sndBufMu so that worker goroutine can
+ // make progress independent of this operation.
v, perr := p.Get(avail)
if perr != nil {
- e.sndBufMu.Unlock()
return 0, nil, perr
}
- l := len(v)
- s := newSegmentFromView(&e.route, e.id, v)
+ e.mu.RLock()
+ e.sndBufMu.Lock()
+
+ // Because we released the lock before copying, check state again
+ // to make sure the endpoint is still in a valid state for a
+ // write.
+ avail, err = e.isEndpointWritableLocked()
+ if err != nil {
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+ return 0, nil, err
+ }
+
+ // Discard any excess data copied in due to avail being reduced due to a
+ // simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
+ }
// Add data to the send queue.
+ l := len(v)
+ s := newSegmentFromView(&e.route, e.id, v)
e.sndBufUsed += l
e.sndBufInQueue += seqnum.Size(l)
e.sndQueue.PushBack(s)
e.sndBufMu.Unlock()
+ // Release the endpoint lock to prevent deadlocks due to lock
+ // order inversion when acquiring workMu.
+ e.mu.RUnlock()
if e.workMu.TryLock() {
// Do the work inline.
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 16666e772..d28ce4ba1 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -1252,10 +1252,14 @@ cc_binary(
srcs = ["partial_bad_buffer.cc"],
linkstatic = 1,
deps = [
+ "//test/syscalls/linux:socket_test_util",
+ "//test/util:file_descriptor",
"//test/util:fs_util",
+ "//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
+ "@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],
)
diff --git a/test/syscalls/linux/partial_bad_buffer.cc b/test/syscalls/linux/partial_bad_buffer.cc
index 83b1ad4e4..33822ee57 100644
--- a/test/syscalls/linux/partial_bad_buffer.cc
+++ b/test/syscalls/linux/partial_bad_buffer.cc
@@ -14,13 +14,20 @@
#include <errno.h>
#include <fcntl.h>
+#include <netinet/in.h>
+#include <netinet/tcp.h>
#include <sys/mman.h>
+#include <sys/socket.h>
#include <sys/syscall.h>
#include <sys/uio.h>
#include <unistd.h>
#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
#include "test/util/fs_util.h"
+#include "test/util/posix_error.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
@@ -299,6 +306,109 @@ TEST_F(PartialBadBufferTest, WriteEfaultIsntPartial) {
EXPECT_STREQ(buf, kMessage);
}
+PosixErrorOr<sockaddr_storage> InetLoopbackAddr(int family) {
+ struct sockaddr_storage addr;
+ memset(&addr, 0, sizeof(addr));
+ addr.ss_family = family;
+ switch (family) {
+ case AF_INET:
+ reinterpret_cast<struct sockaddr_in*>(&addr)->sin_addr.s_addr =
+ htonl(INADDR_LOOPBACK);
+ break;
+ case AF_INET6:
+ reinterpret_cast<struct sockaddr_in6*>(&addr)->sin6_addr =
+ in6addr_loopback;
+ break;
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+ return addr;
+}
+
+// SendMsgTCP verifies that calling sendmsg with a bad address returns an
+// EFAULT. It also verifies that passing a buffer which is made up of 2
+// pages one valid and one guard page succeeds as long as the write is
+// for exactly the size of 1 page.
+TEST_F(PartialBadBufferTest, SendMsgTCP) {
+ auto listen_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
+
+ // Initialize address to the loopback one.
+ sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(AF_INET));
+ socklen_t addrlen = sizeof(addr);
+
+ // Bind to some port then start listening.
+ ASSERT_THAT(bind(listen_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(listen_socket.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the address we're listening on, then connect to it. We need to do this
+ // because we're allowing the stack to pick a port for us.
+ ASSERT_THAT(getsockname(listen_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+
+ auto send_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
+
+ ASSERT_THAT(
+ RetryEINTR(connect)(send_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ // Accept the connection.
+ auto recv_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_socket.get(), nullptr, nullptr));
+
+ // TODO(gvisor.dev/issue/674): Update this once Netstack matches linux
+ // behaviour on a setsockopt of SO_RCVBUF/SO_SNDBUF.
+ //
+ // Set SO_SNDBUF for socket to exactly kPageSize+1.
+ //
+ // gVisor does not double the value passed in SO_SNDBUF like linux does so we
+ // just increase it by 1 byte here for gVisor so that we can test writing 1
+ // byte past the valid page and check that it triggers an EFAULT
+ // correctly. Otherwise in gVisor the sendmsg call will just return with no
+ // error with kPageSize bytes written successfully.
+ const uint32_t buf_size = kPageSize + 1;
+ ASSERT_THAT(setsockopt(send_socket.get(), SOL_SOCKET, SO_SNDBUF, &buf_size,
+ sizeof(buf_size)),
+ SyscallSucceedsWithValue(0));
+
+ struct msghdr hdr = {};
+ struct iovec iov = {};
+ iov.iov_base = bad_buffer_;
+ iov.iov_len = kPageSize;
+ hdr.msg_iov = &iov;
+ hdr.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0),
+ SyscallFailsWithErrno(EFAULT));
+
+ // Now assert that writing kPageSize from addr_ succeeds.
+ iov.iov_base = addr_;
+ ASSERT_THAT(RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0),
+ SyscallSucceedsWithValue(kPageSize));
+ // Read all the data out so that we drain the socket SND_BUF on the sender.
+ std::vector<char> buffer(kPageSize);
+ ASSERT_THAT(RetryEINTR(read)(recv_socket.get(), buffer.data(), kPageSize),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Sleep for a shortwhile to ensure that we have time to process the
+ // ACKs. This is not strictly required unless running under gotsan which is a
+ // lot slower and can result in the next write to write only 1 byte instead of
+ // our intended kPageSize + 1.
+ absl::SleepFor(absl::Milliseconds(50));
+
+ // Now assert that writing > kPageSize results in EFAULT.
+ iov.iov_len = kPageSize + 1;
+ ASSERT_THAT(RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0),
+ SyscallFailsWithErrno(EFAULT));
+}
+
} // namespace
} // namespace testing