diff options
-rw-r--r-- | pkg/abi/linux/tcp.go | 6 | ||||
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 62 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/unix.go | 13 | ||||
-rw-r--r-- | pkg/tcpip/transport/ping/endpoint.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 22 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 10 | ||||
-rw-r--r-- | test/syscalls/linux/socket_ip_tcp_generic.cc | 136 |
7 files changed, 243 insertions, 16 deletions
diff --git a/pkg/abi/linux/tcp.go b/pkg/abi/linux/tcp.go index 7586ada42..67908deb9 100644 --- a/pkg/abi/linux/tcp.go +++ b/pkg/abi/linux/tcp.go @@ -52,3 +52,9 @@ const ( TCP_ZEROCOPY_RECEIVE = 35 TCP_INQ = 36 ) + +// Socket constants from include/net/tcp.h. +const ( + MAX_TCP_KEEPIDLE = 32767 + MAX_TCP_KEEPINTVL = 32767 +) diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index ab5d82183..89580e83a 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -636,7 +636,13 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - return int32(0), nil + + var v tcpip.KeepaliveEnabledOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(v), nil case linux.SO_LINGER: if outLen < syscall.SizeofLinger { @@ -720,6 +726,30 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return int32(v), nil + case linux.TCP_KEEPIDLE: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.KeepaliveIdleOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(time.Duration(v) / time.Second), nil + + case linux.TCP_KEEPINTVL: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.KeepaliveIntervalOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(time.Duration(v) / time.Second), nil + case linux.TCP_INFO: var v tcpip.TCPInfoOption if err := ep.GetSockOpt(&v); err != nil { @@ -843,6 +873,14 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i v := usermem.ByteOrder.Uint32(optVal) return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.PasscredOption(v))) + case linux.SO_KEEPALIVE: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveEnabledOption(v))) + case linux.SO_SNDTIMEO: if len(optVal) < linux.SizeOfTimeval { return syserr.ErrInvalidArgument @@ -916,6 +954,28 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * v := usermem.ByteOrder.Uint32(optVal) return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.QuickAckOption(v))) + case linux.TCP_KEEPIDLE: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + if v < 1 || v > linux.MAX_TCP_KEEPIDLE { + return syserr.ErrInvalidArgument + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIdleOption(time.Second * time.Duration(v)))) + + case linux.TCP_KEEPINTVL: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + if v < 1 || v > linux.MAX_TCP_KEEPINTVL { + return syserr.ErrInvalidArgument + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v)))) + case linux.TCP_REPAIR_OPTIONS: t.Kernel().EmitUnimplementedEvent(t) diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index e98096d7b..12b1576bd 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -837,6 +837,7 @@ func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { case tcpip.ErrorOption: return nil + case *tcpip.SendQueueSizeOption: e.Lock() if !e.Connected() { @@ -850,6 +851,7 @@ func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { } *o = qs return nil + case *tcpip.ReceiveQueueSizeOption: e.Lock() if !e.Connected() { @@ -863,6 +865,7 @@ func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { } *o = qs return nil + case *tcpip.PasscredOption: if e.Passcred() { *o = tcpip.PasscredOption(1) @@ -870,6 +873,7 @@ func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = tcpip.PasscredOption(0) } return nil + case *tcpip.SendBufferSizeOption: e.Lock() if !e.Connected() { @@ -883,6 +887,7 @@ func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { } *o = qs return nil + case *tcpip.ReceiveBufferSizeOption: e.Lock() if e.receiver == nil { @@ -896,8 +901,14 @@ func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { } *o = qs return nil + + case *tcpip.KeepaliveEnabledOption: + *o = 0 + return nil + + default: + return tcpip.ErrUnknownProtocolOption } - return tcpip.ErrUnknownProtocolOption } // Shutdown closes the read and/or write end of the endpoint connection to its diff --git a/pkg/tcpip/transport/ping/endpoint.go b/pkg/tcpip/transport/ping/endpoint.go index 10d4d138e..d1b9b136c 100644 --- a/pkg/tcpip/transport/ping/endpoint.go +++ b/pkg/tcpip/transport/ping/endpoint.go @@ -358,9 +358,15 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = 1 } e.rcvMu.Unlock() - } + return nil - return tcpip.ErrUnknownProtocolOption + case *tcpip.KeepaliveEnabledOption: + *o = 0 + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } func sendPing4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 37d4c8f9e..c549132f0 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -662,7 +662,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } else { atomic.StoreUint32(&e.delay, 1) } - return nil case tcpip.CorkOption: @@ -674,7 +673,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } else { atomic.StoreUint32(&e.cork, 1) } - return nil case tcpip.ReuseAddressOption: @@ -689,7 +687,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } else { atomic.StoreUint32(&e.slowAck, 0) } - return nil case tcpip.ReceiveBufferSizeOption: @@ -754,7 +751,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.sndBufMu.Lock() e.sndBufSize = size e.sndBufMu.Unlock() - return nil case tcpip.V6OnlyOption: @@ -772,34 +768,39 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } e.v6only = v != 0 + return nil case tcpip.KeepaliveEnabledOption: e.keepalive.Lock() e.keepalive.enabled = v != 0 e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) + return nil case tcpip.KeepaliveIdleOption: e.keepalive.Lock() e.keepalive.idle = time.Duration(v) e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) + return nil case tcpip.KeepaliveIntervalOption: e.keepalive.Lock() e.keepalive.interval = time.Duration(v) e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) + return nil case tcpip.KeepaliveCountOption: e.keepalive.Lock() e.keepalive.count = int(v) e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) + return nil + default: + return nil } - - return nil } // readyReceiveSize returns the number of bytes ready to be received. @@ -908,7 +909,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { o.RTTVar = snd.rtt.rttvar snd.rtt.Unlock() } - return nil case *tcpip.KeepaliveEnabledOption: @@ -920,25 +920,29 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { if v { *o = 1 } + return nil case *tcpip.KeepaliveIdleOption: e.keepalive.Lock() *o = tcpip.KeepaliveIdleOption(e.keepalive.idle) e.keepalive.Unlock() + return nil case *tcpip.KeepaliveIntervalOption: e.keepalive.Lock() *o = tcpip.KeepaliveIntervalOption(e.keepalive.interval) e.keepalive.Unlock() + return nil case *tcpip.KeepaliveCountOption: e.keepalive.Lock() *o = tcpip.KeepaliveCountOption(e.keepalive.count) e.keepalive.Unlock() + return nil + default: + return tcpip.ErrUnknownProtocolOption } - - return tcpip.ErrUnknownProtocolOption } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 57b875680..67e9ca0ac 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -505,15 +505,21 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = 1 } e.rcvMu.Unlock() + return nil case *tcpip.MulticastTTLOption: e.mu.Lock() *o = tcpip.MulticastTTLOption(e.multicastTTL) e.mu.Unlock() return nil - } - return tcpip.ErrUnknownProtocolOption + case *tcpip.KeepaliveEnabledOption: + *o = 0 + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } // sendUDP sends a UDP segment via the provided network endpoint and under the diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index bb5a83c9a..81508263b 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -296,7 +296,7 @@ TEST_P(TCPSocketPairTest, TCPCorkDefault) { getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CORK, &get, &get_len), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, 0); + EXPECT_EQ(get, kSockOptOff); } TEST_P(TCPSocketPairTest, SetTCPCork) { @@ -388,5 +388,139 @@ TEST_P(TCPSocketPairTest, SetTCPQuickAck) { EXPECT_EQ(get, kSockOptOn); } +TEST_P(TCPSocketPairTest, SoKeepaliveDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_KEEPALIVE, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); +} + +TEST_P(TCPSocketPairTest, SetSoKeepalive) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_KEEPALIVE, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_KEEPALIVE, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOn); + + ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_KEEPALIVE, + &kSockOptOff, sizeof(kSockOptOff)), + SyscallSucceeds()); + + EXPECT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_KEEPALIVE, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); +} + +TEST_P(TCPSocketPairTest, TCPKeepidleDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPIDLE, &get, + &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 2 * 60 * 60); // 2 hours. +} + +TEST_P(TCPSocketPairTest, TCPKeepintvlDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPINTVL, &get, + &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 75); // 75 seconds. +} + +TEST_P(TCPSocketPairTest, SetTCPKeepidleZero) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kZero = 0; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPIDLE, &kZero, + sizeof(kZero)), + SyscallFailsWithErrno(EINVAL)); +} + +TEST_P(TCPSocketPairTest, SetTCPKeepintvlZero) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kZero = 0; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPINTVL, + &kZero, sizeof(kZero)), + SyscallFailsWithErrno(EINVAL)); +} + +// Copied from include/net/tcp.h. +constexpr int MAX_TCP_KEEPIDLE = 32767; +constexpr int MAX_TCP_KEEPINTVL = 32767; + +TEST_P(TCPSocketPairTest, SetTCPKeepidleAboveMax) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kAboveMax = MAX_TCP_KEEPIDLE + 1; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPIDLE, + &kAboveMax, sizeof(kAboveMax)), + SyscallFailsWithErrno(EINVAL)); +} + +TEST_P(TCPSocketPairTest, SetTCPKeepintvlAboveMax) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kAboveMax = MAX_TCP_KEEPINTVL + 1; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPINTVL, + &kAboveMax, sizeof(kAboveMax)), + SyscallFailsWithErrno(EINVAL)); +} + +TEST_P(TCPSocketPairTest, SetTCPKeepidleToMax) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPIDLE, + &MAX_TCP_KEEPIDLE, sizeof(MAX_TCP_KEEPIDLE)), + SyscallSucceedsWithValue(0)); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPIDLE, &get, + &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, MAX_TCP_KEEPIDLE); +} + +TEST_P(TCPSocketPairTest, SetTCPKeepintvlToMax) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPINTVL, + &MAX_TCP_KEEPINTVL, sizeof(MAX_TCP_KEEPINTVL)), + SyscallSucceedsWithValue(0)); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPINTVL, &get, + &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, MAX_TCP_KEEPINTVL); +} + } // namespace testing } // namespace gvisor |