From df37babd576ba4607e2fe69eb2c669aa2954b9cb Mon Sep 17 00:00:00 2001 From: Ayush Ranjan Date: Wed, 18 Nov 2020 14:34:49 -0800 Subject: [netstack] Move SO_REUSEPORT and SO_REUSEADDR option to SocketOptions. This changes also introduces: - `SocketOptionsHandler` interface which can be implemented by endpoints to handle endpoint specific behavior on SetSockOpt. This is analogous to what Linux does. - `DefaultSocketOptionsHandler` which is a default implementation of the above. This is embedded in all endpoints so that we don't have to uselessly implement empty functions. Endpoints with specific behavior can override the embedded method by manually defining its own implementation. PiperOrigin-RevId: 343158301 --- pkg/sentry/socket/netstack/netstack.go | 23 +++++++----------- pkg/sentry/socket/unix/transport/connectioned.go | 27 +++++++++++----------- pkg/sentry/socket/unix/transport/connectionless.go | 1 + pkg/sentry/socket/unix/transport/unix.go | 8 +++---- 4 files changed, 25 insertions(+), 34 deletions(-) (limited to 'pkg/sentry') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index d48b92c66..bf6d8c5dc 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1112,25 +1112,16 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReuseAddressOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReuseAddress())) + return &v, nil case linux.SO_REUSEPORT: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReusePortOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReusePort())) + return &v, nil case linux.SO_BINDTODEVICE: var v tcpip.BindToDeviceOption @@ -1869,7 +1860,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReuseAddressOption, v != 0)) + ep.SocketOptions().SetReuseAddress(v != 0) + return nil case linux.SO_REUSEPORT: if len(optVal) < sizeOfInt32 { @@ -1877,7 +1869,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReusePortOption, v != 0)) + ep.SocketOptions().SetReusePort(v != 0) + return nil case linux.SO_BINDTODEVICE: n := bytes.IndexByte(optVal, 0) diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 6d9e502bd..9f7aca305 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -118,28 +118,24 @@ var ( // NewConnectioned creates a new unbound connectionedEndpoint. func NewConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) Endpoint { - return &connectionedEndpoint{ + return newConnectioned(ctx, stype, uid) +} + +func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) *connectionedEndpoint { + ep := &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, id: uid.UniqueID(), idGenerator: uid, stype: stype, } + ep.ops.InitHandler(ep) + return ep } // NewPair allocates a new pair of connected unix-domain connectionedEndpoints. func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { - a := &connectionedEndpoint{ - baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, - id: uid.UniqueID(), - idGenerator: uid, - stype: stype, - } - b := &connectionedEndpoint{ - baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, - id: uid.UniqueID(), - idGenerator: uid, - stype: stype, - } + a := newConnectioned(ctx, stype, uid) + b := newConnectioned(ctx, stype, uid) q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit} q1.InitRefs() @@ -171,12 +167,14 @@ func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (E // NewExternal creates a new externally backed Endpoint. It behaves like a // socketpair. func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint { - return &connectionedEndpoint{ + ep := &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected}, id: uid.UniqueID(), idGenerator: uid, stype: stype, } + ep.ops.InitHandler(ep) + return ep } // ID implements ConnectingEndpoint.ID. @@ -298,6 +296,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn idGenerator: e.idGenerator, stype: e.stype, } + ne.ops.InitHandler(ne) readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit} readQueue.InitRefs() diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 1406971bc..0813ad87d 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -44,6 +44,7 @@ func NewConnectionless(ctx context.Context) Endpoint { q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit} q.InitRefs() ep.receiver = &queueReceiver{readQueue: &q} + ep.ops.InitHandler(ep) return ep } diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 0324dcd93..8482d1603 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -738,6 +738,7 @@ func (e *connectedEndpoint) CloseUnread() { // +stateify savable type baseEndpoint struct { *waiter.Queue + tcpip.DefaultSocketOptionsHandler // Mutex protects the below fields. sync.Mutex `state:"nosave"` @@ -756,6 +757,7 @@ type baseEndpoint struct { // linger is used for SO_LINGER socket option. linger tcpip.LingerOption + // ops is used to get socket level options. ops tcpip.SocketOptions } @@ -856,11 +858,7 @@ func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { } func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - switch opt { - case tcpip.ReuseAddressOption: - default: - log.Warningf("Unsupported socket option: %d", opt) - } + log.Warningf("Unsupported socket option: %d", opt) return nil } -- cgit v1.2.3