diff options
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 20 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/unix.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/socketops.go | 40 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/transport/packet/endpoint.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 32 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 6 | ||||
-rw-r--r-- | test/syscalls/linux/socket_generic.cc | 4 |
12 files changed, 66 insertions, 84 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index bf6d8c5dc..e8a0103bf 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1162,13 +1162,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.KeepaliveEnabledOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetKeepAlive())) + return &v, nil case linux.SO_LINGER: if outLen < linux.SizeOfLinger { @@ -1231,12 +1226,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.AcceptConnOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetAcceptConn())) + return &v, nil default: socket.GetSockOptEmitUnimplementedEvent(t, name) @@ -1918,7 +1909,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.KeepaliveEnabledOption, v != 0)) + ep.SocketOptions().SetKeepAlive(v != 0) + return nil case linux.SO_SNDTIMEO: if len(optVal) < linux.SizeOfTimeval { diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 8482d1603..4abea90cc 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -873,14 +873,8 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { } func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - switch opt { - case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption: - return false, nil - - default: - log.Warningf("Unsupported socket option: %d", opt) - return false, tcpip.ErrUnknownProtocolOption - } + log.Warningf("Unsupported socket option: %d", opt) + return false, tcpip.ErrUnknownProtocolOption } func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index 99c3e9c45..1b1188ee5 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -19,14 +19,23 @@ import ( ) // SocketOptionsHandler holds methods that help define endpoint specific -// behavior for socket options. These must be implemented by endpoints to get -// notified when socket level options are set. +// behavior for socket options. +// These must be implemented by endpoints to: +// - Get notified when socket level options are set. +// - Provide endpoint specific socket options. type SocketOptionsHandler interface { // OnReuseAddressSet is invoked when SO_REUSEADDR is set for an endpoint. OnReuseAddressSet(v bool) // OnReusePortSet is invoked when SO_REUSEPORT is set for an endpoint. OnReusePortSet(v bool) + + // OnKeepAliveSet is invoked when SO_KEEPALIVE is set for an endpoint. + OnKeepAliveSet(v bool) + + // IsListening is invoked to fetch SO_ACCEPTCONN option value for an + // endpoint. It is used to indicate if the socket is a listening socket. + IsListening() bool } // DefaultSocketOptionsHandler is an embeddable type that implements no-op @@ -41,6 +50,12 @@ func (*DefaultSocketOptionsHandler) OnReuseAddressSet(bool) {} // OnReusePortSet implements SocketOptionsHandler.OnReusePortSet. func (*DefaultSocketOptionsHandler) OnReusePortSet(bool) {} +// OnKeepAliveSet implements SocketOptionsHandler.OnKeepAliveSet. +func (*DefaultSocketOptionsHandler) OnKeepAliveSet(bool) {} + +// IsListening implements SocketOptionsHandler.IsListening. +func (*DefaultSocketOptionsHandler) IsListening() bool { return false } + // SocketOptions contains all the variables which store values for SOL_SOCKET // level options. // @@ -69,6 +84,10 @@ type SocketOptions struct { // reusePortEnabled determines whether to permit multiple sockets to be bound // to an identical socket address. reusePortEnabled uint32 + + // keepAliveEnabled determines whether TCP keepalive is enabled for this + // socket. + keepAliveEnabled uint32 } // InitHandler initializes the handler. This must be called before using the @@ -136,3 +155,20 @@ func (so *SocketOptions) SetReusePort(v bool) { storeAtomicBool(&so.reusePortEnabled, v) so.handler.OnReusePortSet(v) } + +// GetKeepAlive gets value for SO_KEEPALIVE option. +func (so *SocketOptions) GetKeepAlive() bool { + return atomic.LoadUint32(&so.keepAliveEnabled) != 0 +} + +// SetKeepAlive sets value for SO_KEEPALIVE option. +func (so *SocketOptions) SetKeepAlive(v bool) { + storeAtomicBool(&so.keepAliveEnabled, v) + so.handler.OnKeepAliveSet(v) +} + +// GetAcceptConn gets value for SO_ACCEPTCONN option. +func (so *SocketOptions) GetAcceptConn() bool { + // This option is completely endpoint dependent and unsettable. + return so.handler.IsListening() +} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 7ae36fde7..4c367b6a6 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -708,10 +708,6 @@ const ( // TCP, it determines if the Nagle algorithm is on or off. DelayOption - // KeepaliveEnabledOption is used by SetSockOptBool/GetSockOptBool to - // specify whether TCP keepalive is enabled for this socket. - KeepaliveEnabledOption - // MulticastLoopOption is used by SetSockOptBool/GetSockOptBool to // specify whether multicast packets sent over a non-loopback interface // will be looped back. @@ -747,10 +743,6 @@ const ( // endpoint that all packets being written have an IP header and the // endpoint should not attach an IP header. IPHdrIncludedOption - - // AcceptConnOption is used by GetSockOptBool to indicate if the - // socket is a listening socket. - AcceptConnOption ) // SockOptInt represents socket options which values have the int type. diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 39560a9fa..59ec54ca0 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -383,13 +383,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - switch opt { - case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption: - return false, nil - - default: - return false, tcpip.ErrUnknownProtocolOption - } + return false, tcpip.ErrUnknownProtocolOption } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 35d1be792..e2c7a0d62 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -395,12 +395,7 @@ func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (*endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - switch opt { - case tcpip.AcceptConnOption: - return false, nil - default: - return false, tcpip.ErrNotSupported - } + return false, tcpip.ErrNotSupported } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index e64392f7b..b0b53b181 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -607,9 +607,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { - case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption: - return false, nil - case tcpip.IPHdrIncludedOption: e.mu.Lock() v := e.hdrIncluded diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 6661e8915..88a632019 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -1284,7 +1284,7 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { userTimeout := e.userTimeout e.keepalive.Lock() - if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() { + if !e.SocketOptions().GetKeepAlive() || !e.keepalive.timer.checkExpiration() { e.keepalive.Unlock() return nil } @@ -1321,7 +1321,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) { } // Start the keepalive timer IFF it's enabled and there is no pending // data to send. - if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt { + if !e.SocketOptions().GetKeepAlive() || e.snd == nil || e.snd.sndUna != e.snd.sndNxt { e.keepalive.timer.disable() e.keepalive.Unlock() return diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index f893324c2..9c9f1ab87 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -851,7 +851,6 @@ func (e *endpoint) recentTimestamp() uint32 { // +stateify savable type keepalive struct { sync.Mutex `state:"nosave"` - enabled bool idle time.Duration interval time.Duration count int @@ -1643,6 +1642,11 @@ func (e *endpoint) OnReusePortSet(v bool) { e.UnlockUser() } +// OnKeepAliveSet implements tcpip.SocketOptionsHandler.OnKeepAliveSet. +func (e *endpoint) OnKeepAliveSet(v bool) { + e.notifyProtocolGoroutine(notifyKeepaliveChanged) +} + // SetSockOptBool sets a socket option. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { switch opt { @@ -1669,12 +1673,6 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { e.sndWaker.Assert() } - case tcpip.KeepaliveEnabledOption: - e.keepalive.Lock() - e.keepalive.enabled = v - e.keepalive.Unlock() - e.notifyProtocolGoroutine(notifyKeepaliveChanged) - case tcpip.QuickAckOption: o := uint32(1) if v { @@ -1980,6 +1978,13 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { return e.rcvBufUsed, nil } +// IsListening implements tcpip.SocketOptionsHandler.IsListening. +func (e *endpoint) IsListening() bool { + e.LockUser() + defer e.UnlockUser() + return e.EndpointState() == StateListen +} + // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { @@ -1990,13 +1995,6 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { case tcpip.DelayOption: return atomic.LoadUint32(&e.delay) != 0, nil - case tcpip.KeepaliveEnabledOption: - e.keepalive.Lock() - v := e.keepalive.enabled - e.keepalive.Unlock() - - return v, nil - case tcpip.QuickAckOption: v := atomic.LoadUint32(&e.slowAck) == 0 return v, nil @@ -2016,12 +2014,6 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { case tcpip.MulticastLoopOption: return true, nil - case tcpip.AcceptConnOption: - e.LockUser() - defer e.UnlockUser() - - return e.EndpointState() == StateListen, nil - default: return false, tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 9fa3aa740..7124a715d 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -4984,9 +4984,7 @@ func TestKeepalive(t *testing.T) { if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5); err != nil { t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5): %s", err) } - if err := c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true); err != nil { - t.Fatalf("c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true): %s", err) - } + c.EP.SocketOptions().SetKeepAlive(true) // 5 unacked keepalives are sent. ACK each one, and check that the // connection stays alive after 5. @@ -7236,9 +7234,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10); err != nil { t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10): %s", err) } - if err := c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true); err != nil { - t.Fatalf("c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true): %s", err) - } + c.EP.SocketOptions().SetKeepAlive(true) // Set userTimeout to be the duration to be 1 keepalive // probes. Which means that after the first probe is sent diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index e57833644..a7a405dcb 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -849,9 +849,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { - case tcpip.KeepaliveEnabledOption: - return false, nil - case tcpip.MulticastLoopOption: e.mu.RLock() v := e.multicastLoop @@ -893,9 +890,6 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { return v, nil - case tcpip.AcceptConnOption: - return false, nil - default: return false, tcpip.ErrUnknownProtocolOption } diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc index 81e00b1dc..70cc86b16 100644 --- a/test/syscalls/linux/socket_generic.cc +++ b/test/syscalls/linux/socket_generic.cc @@ -819,8 +819,8 @@ TEST_P(AllSocketPairTest, GetSockoptProtocol) { } TEST_P(AllSocketPairTest, SetAndGetBooleanSocketOptions) { - int sock_opts[] = {SO_BROADCAST, SO_PASSCRED, SO_NO_CHECK, SO_REUSEADDR, - SO_REUSEPORT}; + int sock_opts[] = {SO_BROADCAST, SO_PASSCRED, SO_NO_CHECK, + SO_REUSEADDR, SO_REUSEPORT, SO_KEEPALIVE}; for (int sock_opt : sock_opts) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); int enable = -1; |