summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go15
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/socket.go15
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go18
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go3
-rw-r--r--test/syscalls/linux/accept_bind.cc36
5 files changed, 68 insertions, 19 deletions
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 0141e8a96..eff251cec 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -46,6 +46,9 @@ const maxOptLen = 1024 * 8
// buffers upto INT_MAX.
const maxControlLen = 10 * 1024 * 1024
+// maxListenBacklog is the maximum limit of listen backlog supported.
+const maxListenBacklog = 1024
+
// nameLenOffset is the offset from the start of the MessageHeader64 struct to
// the NameLen field.
const nameLenOffset = 8
@@ -361,7 +364,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
// Listen implements the linux syscall listen(2).
func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
- backlog := args[1].Int()
+ backlog := args[1].Uint()
// Get socket from the file descriptor.
file := t.GetFile(fd)
@@ -376,6 +379,16 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, syserror.ENOTSOCK
}
+ if backlog > maxListenBacklog {
+ // Linux treats incoming backlog as uint with a limit defined by
+ // sysctl_somaxconn.
+ // https://github.com/torvalds/linux/blob/7acac4b3196/net/socket.c#L1666
+ //
+ // We use the backlog to allocate a channel of that size, hence enforce
+ // a hard limit for the backlog.
+ backlog = maxListenBacklog
+ }
+
return 0, nil, s.Listen(t, int(backlog)).ToError()
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go
index 7cc0be892..936614eab 100644
--- a/pkg/sentry/syscalls/linux/vfs2/socket.go
+++ b/pkg/sentry/syscalls/linux/vfs2/socket.go
@@ -46,6 +46,9 @@ const maxOptLen = 1024 * 8
// buffers upto INT_MAX.
const maxControlLen = 10 * 1024 * 1024
+// maxListenBacklog is the maximum limit of listen backlog supported.
+const maxListenBacklog = 1024
+
// nameLenOffset is the offset from the start of the MessageHeader64 struct to
// the NameLen field.
const nameLenOffset = 8
@@ -365,7 +368,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
// Listen implements the linux syscall listen(2).
func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
- backlog := args[1].Int()
+ backlog := args[1].Uint()
// Get socket from the file descriptor.
file := t.GetFileVFS2(fd)
@@ -380,6 +383,16 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, syserror.ENOTSOCK
}
+ if backlog > maxListenBacklog {
+ // Linux treats incoming backlog as uint with a limit defined by
+ // sysctl_somaxconn.
+ // https://github.com/torvalds/linux/blob/7acac4b3196/net/socket.c#L1666
+ //
+ // We use the backlog to allocate a channel of that size, hence enforce
+ // a hard limit for the backlog.
+ backlog = maxListenBacklog
+ }
+
return 0, nil, s.Listen(t, int(backlog)).ToError()
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 9438056f9..5001d222e 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -2474,20 +2474,10 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error {
// Listen puts the endpoint in "listen" mode, which allows it to accept
// new connections.
func (e *endpoint) Listen(backlog int) tcpip.Error {
- if uint32(backlog) > MaxListenBacklog {
- // Linux treats incoming backlog as uint with a limit defined by
- // sysctl_somaxconn.
- // https://github.com/torvalds/linux/blob/7acac4b3196/net/socket.c#L1666
- //
- // We use the backlog to allocate a channel of that size, hence enforce
- // a hard limit for the backlog.
- backlog = MaxListenBacklog
- } else {
- // Accept one more than the configured listen backlog to keep in parity with
- // Linux. Ref, because of missing equality check here:
- // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/sock.h#L937
- backlog++
- }
+ // Accept one more than the configured listen backlog to keep in parity with
+ // Linux. Ref, because of missing equality check here:
+ // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/sock.h#L937
+ backlog++
err := e.listen(backlog)
if err != nil {
if !err.IgnoreStats() {
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 230fa6ebe..fe0d7f10f 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -68,9 +68,6 @@ const (
// DefaultSynRetries is the default value for the number of SYN retransmits
// before a connect is aborted.
DefaultSynRetries = 6
-
- // MaxListenBacklog is the maximum limit of listen backlog supported.
- MaxListenBacklog = 1024
)
const (
diff --git a/test/syscalls/linux/accept_bind.cc b/test/syscalls/linux/accept_bind.cc
index f65a14fb8..119a1466b 100644
--- a/test/syscalls/linux/accept_bind.cc
+++ b/test/syscalls/linux/accept_bind.cc
@@ -67,6 +67,42 @@ TEST_P(AllSocketPairTest, ListenDecreaseBacklog) {
SyscallSucceeds());
}
+TEST_P(AllSocketPairTest, ListenBacklogSizes_NoRandomSave) {
+ DisableSave ds;
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ int type;
+ socklen_t typelen = sizeof(type);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_TYPE, &type, &typelen),
+ SyscallSucceeds());
+
+ std::array<int, 3> backlogs = {-1, 0, 1};
+ for (auto& backlog : backlogs) {
+ ASSERT_THAT(listen(sockets->first_fd(), backlog), SyscallSucceeds());
+
+ int expected_accepts = backlog;
+ if (backlog < 0) {
+ expected_accepts = 1024;
+ }
+ for (int i = 0; i < expected_accepts; i++) {
+ SCOPED_TRACE(absl::StrCat("i=", i));
+ // Connect to the listening socket.
+ const FileDescriptor client =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, type, 0));
+ ASSERT_THAT(connect(client.get(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+ const FileDescriptor accepted = ASSERT_NO_ERRNO_AND_VALUE(
+ Accept(sockets->first_fd(), nullptr, nullptr));
+ }
+ }
+}
+
TEST_P(AllSocketPairTest, ListenWithoutBind) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
ASSERT_THAT(listen(sockets->first_fd(), 0), SyscallFailsWithErrno(EINVAL));