summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sentry/socket/netstack/netstack.go10
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go23
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go7
-rw-r--r--test/syscalls/linux/socket_ip_tcp_generic.cc31
4 files changed, 52 insertions, 19 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 4d0e33696..921464f5d 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -1409,8 +1409,12 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- lingerTimeout := primitive.Int32(time.Duration(v) / time.Second)
+ var lingerTimeout primitive.Int32
+ if v >= 0 {
+ lingerTimeout = primitive.Int32(time.Duration(v) / time.Second)
+ } else {
+ lingerTimeout = -1
+ }
return &lingerTimeout, nil
case linux.TCP_DEFER_ACCEPT:
@@ -1967,7 +1971,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
return syserr.ErrInvalidArgument
}
- v := usermem.ByteOrder.Uint32(optVal)
+ v := int32(usermem.ByteOrder.Uint32(optVal))
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v))))
case linux.TCP_DEFER_ACCEPT:
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 4ba0ea1c0..9c0f4c9f4 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1775,15 +1775,24 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.TCPLingerTimeoutOption:
e.LockUser()
- if v < 0 {
+
+ switch {
+ case v < 0:
// Same as effectively disabling TCPLinger timeout.
- v = 0
- }
- // Cap it to MaxTCPLingerTimeout.
- stkTCPLingerTimeout := tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout)
- if v > stkTCPLingerTimeout {
- v = stkTCPLingerTimeout
+ v = -1
+ case v == 0:
+ // Same as the stack default.
+ var stackLingerTimeout tcpip.TCPLingerTimeoutOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &stackLingerTimeout); err != nil {
+ panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %+v) = %v", ProtocolNumber, &stackLingerTimeout, err))
+ }
+ v = stackLingerTimeout
+ case v > tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout):
+ // Cap it to Stack's default TCP_LINGER2 timeout.
+ v = tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout)
+ default:
}
+
e.tcpLingerTimeout = time.Duration(v)
e.UnlockUser()
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 55ae09a2f..9650bb06c 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -6206,12 +6206,13 @@ func TestTCPLingerTimeout(t *testing.T) {
tcpLingerTimeout time.Duration
want time.Duration
}{
- {"NegativeLingerTimeout", -123123, 0},
- {"ZeroLingerTimeout", 0, 0},
+ {"NegativeLingerTimeout", -123123, -1},
+ // Zero is treated same as the stack's default TCP_LINGER2 timeout.
+ {"ZeroLingerTimeout", 0, tcp.DefaultTCPLingerTimeout},
{"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second},
// Values > stack's TCPLingerTimeout are capped to the stack's
// value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds)
- {"AboveMaxLingerTimeout", 125 * time.Second, 120 * time.Second},
+ {"AboveMaxLingerTimeout", tcp.MaxTCPLingerTimeout + 5*time.Second, tcp.MaxTCPLingerTimeout},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc
index 53c076787..04356b780 100644
--- a/test/syscalls/linux/socket_ip_tcp_generic.cc
+++ b/test/syscalls/linux/socket_ip_tcp_generic.cc
@@ -819,18 +819,37 @@ TEST_P(TCPSocketPairTest, TCPLingerTimeoutDefault) {
EXPECT_EQ(get, kDefaultTCPLingerTimeout);
}
-TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutZeroOrLess) {
+TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutLessThanZero) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
- constexpr int kZero = 0;
- EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &kZero,
- sizeof(kZero)),
- SyscallSucceedsWithValue(0));
-
constexpr int kNegative = -1234;
EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2,
&kNegative, sizeof(kNegative)),
SyscallSucceedsWithValue(0));
+ int get = INT_MAX;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, -1);
+}
+
+TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutZero) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr int kZero = 0;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &kZero,
+ sizeof(kZero)),
+ SyscallSucceedsWithValue(0));
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_THAT(get,
+ AnyOf(Eq(kMaxTCPLingerTimeout), Eq(kOldMaxTCPLingerTimeout)));
}
TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutAboveMax) {