summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
authorBhasker Hariharan <bhaskerh@google.com>2020-08-25 07:15:50 -0700
committerAndrei Vagin <avagin@gmail.com>2020-09-09 17:53:10 -0700
commit086f085660b73e8ead7ca0bfef5835a6aaad8866 (patch)
tree9becb4457f63a694deeacccfcba1b4f3f5fc8dfa /pkg
parent2ddd883a9459ab0f5ef22b81b8efbd1733fda035 (diff)
Fix TCP_LINGER2 behavior to match linux.
We still deviate a bit from linux in how long we will actually wait in FIN-WAIT-2. Linux seems to cap it with TIME_WAIT_LEN and it's not completely obvious as to why it's done that way. For now I think we can ignore that and fix it if it really is an issue. PiperOrigin-RevId: 328324922
Diffstat (limited to 'pkg')
-rw-r--r--pkg/sentry/socket/netstack/netstack.go10
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go23
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go7
3 files changed, 27 insertions, 13 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 4d0e33696..921464f5d 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -1409,8 +1409,12 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- lingerTimeout := primitive.Int32(time.Duration(v) / time.Second)
+ var lingerTimeout primitive.Int32
+ if v >= 0 {
+ lingerTimeout = primitive.Int32(time.Duration(v) / time.Second)
+ } else {
+ lingerTimeout = -1
+ }
return &lingerTimeout, nil
case linux.TCP_DEFER_ACCEPT:
@@ -1967,7 +1971,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
return syserr.ErrInvalidArgument
}
- v := usermem.ByteOrder.Uint32(optVal)
+ v := int32(usermem.ByteOrder.Uint32(optVal))
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v))))
case linux.TCP_DEFER_ACCEPT:
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 4ba0ea1c0..9c0f4c9f4 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1775,15 +1775,24 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.TCPLingerTimeoutOption:
e.LockUser()
- if v < 0 {
+
+ switch {
+ case v < 0:
// Same as effectively disabling TCPLinger timeout.
- v = 0
- }
- // Cap it to MaxTCPLingerTimeout.
- stkTCPLingerTimeout := tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout)
- if v > stkTCPLingerTimeout {
- v = stkTCPLingerTimeout
+ v = -1
+ case v == 0:
+ // Same as the stack default.
+ var stackLingerTimeout tcpip.TCPLingerTimeoutOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &stackLingerTimeout); err != nil {
+ panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %+v) = %v", ProtocolNumber, &stackLingerTimeout, err))
+ }
+ v = stackLingerTimeout
+ case v > tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout):
+ // Cap it to Stack's default TCP_LINGER2 timeout.
+ v = tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout)
+ default:
}
+
e.tcpLingerTimeout = time.Duration(v)
e.UnlockUser()
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 55ae09a2f..9650bb06c 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -6206,12 +6206,13 @@ func TestTCPLingerTimeout(t *testing.T) {
tcpLingerTimeout time.Duration
want time.Duration
}{
- {"NegativeLingerTimeout", -123123, 0},
- {"ZeroLingerTimeout", 0, 0},
+ {"NegativeLingerTimeout", -123123, -1},
+ // Zero is treated same as the stack's default TCP_LINGER2 timeout.
+ {"ZeroLingerTimeout", 0, tcp.DefaultTCPLingerTimeout},
{"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second},
// Values > stack's TCPLingerTimeout are capped to the stack's
// value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds)
- {"AboveMaxLingerTimeout", 125 * time.Second, 120 * time.Second},
+ {"AboveMaxLingerTimeout", tcp.MaxTCPLingerTimeout + 5*time.Second, tcp.MaxTCPLingerTimeout},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {