diff options
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 23 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 7 | ||||
-rw-r--r-- | test/syscalls/linux/socket_ip_tcp_generic.cc | 31 |
4 files changed, 52 insertions, 19 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 4d0e33696..921464f5d 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1409,8 +1409,12 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - lingerTimeout := primitive.Int32(time.Duration(v) / time.Second) + var lingerTimeout primitive.Int32 + if v >= 0 { + lingerTimeout = primitive.Int32(time.Duration(v) / time.Second) + } else { + lingerTimeout = -1 + } return &lingerTimeout, nil case linux.TCP_DEFER_ACCEPT: @@ -1967,7 +1971,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * return syserr.ErrInvalidArgument } - v := usermem.ByteOrder.Uint32(optVal) + v := int32(usermem.ByteOrder.Uint32(optVal)) return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v)))) case linux.TCP_DEFER_ACCEPT: diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 4ba0ea1c0..9c0f4c9f4 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1775,15 +1775,24 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { case tcpip.TCPLingerTimeoutOption: e.LockUser() - if v < 0 { + + switch { + case v < 0: // Same as effectively disabling TCPLinger timeout. - v = 0 - } - // Cap it to MaxTCPLingerTimeout. - stkTCPLingerTimeout := tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout) - if v > stkTCPLingerTimeout { - v = stkTCPLingerTimeout + v = -1 + case v == 0: + // Same as the stack default. + var stackLingerTimeout tcpip.TCPLingerTimeoutOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &stackLingerTimeout); err != nil { + panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %+v) = %v", ProtocolNumber, &stackLingerTimeout, err)) + } + v = stackLingerTimeout + case v > tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout): + // Cap it to Stack's default TCP_LINGER2 timeout. + v = tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout) + default: } + e.tcpLingerTimeout = time.Duration(v) e.UnlockUser() diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 55ae09a2f..9650bb06c 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -6206,12 +6206,13 @@ func TestTCPLingerTimeout(t *testing.T) { tcpLingerTimeout time.Duration want time.Duration }{ - {"NegativeLingerTimeout", -123123, 0}, - {"ZeroLingerTimeout", 0, 0}, + {"NegativeLingerTimeout", -123123, -1}, + // Zero is treated same as the stack's default TCP_LINGER2 timeout. + {"ZeroLingerTimeout", 0, tcp.DefaultTCPLingerTimeout}, {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second}, // Values > stack's TCPLingerTimeout are capped to the stack's // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds) - {"AboveMaxLingerTimeout", 125 * time.Second, 120 * time.Second}, + {"AboveMaxLingerTimeout", tcp.MaxTCPLingerTimeout + 5*time.Second, tcp.MaxTCPLingerTimeout}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index 53c076787..04356b780 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -819,18 +819,37 @@ TEST_P(TCPSocketPairTest, TCPLingerTimeoutDefault) { EXPECT_EQ(get, kDefaultTCPLingerTimeout); } -TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutZeroOrLess) { +TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutLessThanZero) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - constexpr int kZero = 0; - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &kZero, - sizeof(kZero)), - SyscallSucceedsWithValue(0)); - constexpr int kNegative = -1234; EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &kNegative, sizeof(kNegative)), SyscallSucceedsWithValue(0)); + int get = INT_MAX; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, -1); +} + +TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutZero) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kZero = 0; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &kZero, + sizeof(kZero)), + SyscallSucceedsWithValue(0)); + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_THAT(get, + AnyOf(Eq(kMaxTCPLingerTimeout), Eq(kOldMaxTCPLingerTimeout))); } TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutAboveMax) { |