diff options
author | Ayush Ranjan <ayushranjan@google.com> | 2020-11-18 14:34:49 -0800 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2020-11-18 14:36:41 -0800 |
commit | df37babd576ba4607e2fe69eb2c669aa2954b9cb (patch) | |
tree | 56ecfd74ce4a6234871e62ce7e726f47e341c81a | |
parent | c85bba03852412e297740bcff9fdc076f0feb58e (diff) |
[netstack] Move SO_REUSEPORT and SO_REUSEADDR option to SocketOptions.
This changes also introduces:
- `SocketOptionsHandler` interface which can be implemented by endpoints to
handle endpoint specific behavior on SetSockOpt. This is analogous to what
Linux does.
- `DefaultSocketOptionsHandler` which is a default implementation of the above.
This is embedded in all endpoints so that we don't have to uselessly
implement empty functions. Endpoints with specific behavior can override the
embedded method by manually defining its own implementation.
PiperOrigin-RevId: 343158301
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 23 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/connectioned.go | 27 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/connectionless.go | 1 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/unix.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/socketops.go | 61 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer_test.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 17 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/multicast_broadcast_test.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/transport/packet/endpoint.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 40 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 24 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 40 | ||||
-rw-r--r-- | test/syscalls/linux/socket_generic.cc | 3 |
16 files changed, 150 insertions, 124 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index d48b92c66..bf6d8c5dc 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1112,25 +1112,16 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReuseAddressOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReuseAddress())) + return &v, nil case linux.SO_REUSEPORT: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReusePortOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReusePort())) + return &v, nil case linux.SO_BINDTODEVICE: var v tcpip.BindToDeviceOption @@ -1869,7 +1860,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReuseAddressOption, v != 0)) + ep.SocketOptions().SetReuseAddress(v != 0) + return nil case linux.SO_REUSEPORT: if len(optVal) < sizeOfInt32 { @@ -1877,7 +1869,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReusePortOption, v != 0)) + ep.SocketOptions().SetReusePort(v != 0) + return nil case linux.SO_BINDTODEVICE: n := bytes.IndexByte(optVal, 0) diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 6d9e502bd..9f7aca305 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -118,28 +118,24 @@ var ( // NewConnectioned creates a new unbound connectionedEndpoint. func NewConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) Endpoint { - return &connectionedEndpoint{ + return newConnectioned(ctx, stype, uid) +} + +func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) *connectionedEndpoint { + ep := &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, id: uid.UniqueID(), idGenerator: uid, stype: stype, } + ep.ops.InitHandler(ep) + return ep } // NewPair allocates a new pair of connected unix-domain connectionedEndpoints. func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { - a := &connectionedEndpoint{ - baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, - id: uid.UniqueID(), - idGenerator: uid, - stype: stype, - } - b := &connectionedEndpoint{ - baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, - id: uid.UniqueID(), - idGenerator: uid, - stype: stype, - } + a := newConnectioned(ctx, stype, uid) + b := newConnectioned(ctx, stype, uid) q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit} q1.InitRefs() @@ -171,12 +167,14 @@ func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (E // NewExternal creates a new externally backed Endpoint. It behaves like a // socketpair. func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint { - return &connectionedEndpoint{ + ep := &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected}, id: uid.UniqueID(), idGenerator: uid, stype: stype, } + ep.ops.InitHandler(ep) + return ep } // ID implements ConnectingEndpoint.ID. @@ -298,6 +296,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn idGenerator: e.idGenerator, stype: e.stype, } + ne.ops.InitHandler(ne) readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit} readQueue.InitRefs() diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 1406971bc..0813ad87d 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -44,6 +44,7 @@ func NewConnectionless(ctx context.Context) Endpoint { q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit} q.InitRefs() ep.receiver = &queueReceiver{readQueue: &q} + ep.ops.InitHandler(ep) return ep } diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 0324dcd93..8482d1603 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -738,6 +738,7 @@ func (e *connectedEndpoint) CloseUnread() { // +stateify savable type baseEndpoint struct { *waiter.Queue + tcpip.DefaultSocketOptionsHandler // Mutex protects the below fields. sync.Mutex `state:"nosave"` @@ -756,6 +757,7 @@ type baseEndpoint struct { // linger is used for SO_LINGER socket option. linger tcpip.LingerOption + // ops is used to get socket level options. ops tcpip.SocketOptions } @@ -856,11 +858,7 @@ func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { } func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - switch opt { - case tcpip.ReuseAddressOption: - default: - log.Warningf("Unsupported socket option: %d", opt) - } + log.Warningf("Unsupported socket option: %d", opt) return nil } diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index cc3d59d9d..99c3e9c45 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -18,11 +18,36 @@ import ( "sync/atomic" ) +// SocketOptionsHandler holds methods that help define endpoint specific +// behavior for socket options. These must be implemented by endpoints to get +// notified when socket level options are set. +type SocketOptionsHandler interface { + // OnReuseAddressSet is invoked when SO_REUSEADDR is set for an endpoint. + OnReuseAddressSet(v bool) + + // OnReusePortSet is invoked when SO_REUSEPORT is set for an endpoint. + OnReusePortSet(v bool) +} + +// DefaultSocketOptionsHandler is an embeddable type that implements no-op +// implementations for SocketOptionsHandler methods. +type DefaultSocketOptionsHandler struct{} + +var _ SocketOptionsHandler = (*DefaultSocketOptionsHandler)(nil) + +// OnReuseAddressSet implements SocketOptionsHandler.OnReuseAddressSet. +func (*DefaultSocketOptionsHandler) OnReuseAddressSet(bool) {} + +// OnReusePortSet implements SocketOptionsHandler.OnReusePortSet. +func (*DefaultSocketOptionsHandler) OnReusePortSet(bool) {} + // SocketOptions contains all the variables which store values for SOL_SOCKET // level options. // // +stateify savable type SocketOptions struct { + handler SocketOptionsHandler + // These fields are accessed and modified using atomic operations. // broadcastEnabled determines whether datagram sockets are allowed to send @@ -36,6 +61,20 @@ type SocketOptions struct { // noChecksumEnabled determines whether UDP checksum is disabled while // transmitting for this socket. noChecksumEnabled uint32 + + // reuseAddressEnabled determines whether Bind() should allow reuse of local + // address. + reuseAddressEnabled uint32 + + // reusePortEnabled determines whether to permit multiple sockets to be bound + // to an identical socket address. + reusePortEnabled uint32 +} + +// InitHandler initializes the handler. This must be called before using the +// socket options utility. +func (so *SocketOptions) InitHandler(handler SocketOptionsHandler) { + so.handler = handler } func storeAtomicBool(addr *uint32, v bool) { @@ -75,3 +114,25 @@ func (so *SocketOptions) GetNoChecksum() bool { func (so *SocketOptions) SetNoChecksum(v bool) { storeAtomicBool(&so.noChecksumEnabled, v) } + +// GetReuseAddress gets value for SO_REUSEADDR option. +func (so *SocketOptions) GetReuseAddress() bool { + return atomic.LoadUint32(&so.reuseAddressEnabled) != 0 +} + +// SetReuseAddress sets value for SO_REUSEADDR option. +func (so *SocketOptions) SetReuseAddress(v bool) { + storeAtomicBool(&so.reuseAddressEnabled, v) + so.handler.OnReuseAddressSet(v) +} + +// GetReusePort gets value for SO_REUSEPORT option. +func (so *SocketOptions) GetReusePort() bool { + return atomic.LoadUint32(&so.reusePortEnabled) != 0 +} + +// SetReusePort sets value for SO_REUSEPORT option. +func (so *SocketOptions) SetReusePort(v bool) { + storeAtomicBool(&so.reusePortEnabled, v) + so.handler.OnReusePortSet(v) +} diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 41a8e5ad0..2cdb5ca79 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -307,9 +307,7 @@ func TestBindToDeviceDistribution(t *testing.T) { }(ep) defer ep.Close() - if err := ep.SetSockOptBool(tcpip.ReusePortOption, endpoint.reuse); err != nil { - t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err) - } + ep.SocketOptions().SetReusePort(endpoint.reuse) bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) if err := ep.SetSockOpt(&bindToDeviceOption); err != nil { t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", bindToDeviceOption, bindToDeviceOption, i, err) diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 5b9043d85..fbac66993 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -38,6 +38,7 @@ const ( // use it. type fakeTransportEndpoint struct { stack.TransportEndpointInfo + tcpip.DefaultSocketOptionsHandler proto *fakeTransportProtocol peerAddr tcpip.Address @@ -45,7 +46,7 @@ type fakeTransportEndpoint struct { uniqueID uint64 // acceptQueue is non-nil iff bound. - acceptQueue []fakeTransportEndpoint + acceptQueue []*fakeTransportEndpoint // ops is used to set and get socket options. ops tcpip.SocketOptions @@ -65,7 +66,9 @@ func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions { return &f.ops } func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { - return &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} + ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} + ep.ops.InitHandler(ep) + return ep } func (f *fakeTransportEndpoint) Abort() { @@ -189,7 +192,7 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai if len(f.acceptQueue) == 0 { return nil, nil, nil } - a := &f.acceptQueue[0] + a := f.acceptQueue[0] f.acceptQueue = f.acceptQueue[1:] return a, nil, nil } @@ -206,7 +209,7 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { ); err != nil { return err } - f.acceptQueue = []fakeTransportEndpoint{} + f.acceptQueue = []*fakeTransportEndpoint{} return nil } @@ -232,7 +235,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt * } route.ResolveWith(pkt.SourceLinkAddress()) - f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ + ep := &fakeTransportEndpoint{ TransportEndpointInfo: stack.TransportEndpointInfo{ ID: f.ID, NetProto: f.NetProto, @@ -240,7 +243,9 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt * proto: f.proto, peerAddr: route.RemoteAddress, route: route, - }) + } + ep.ops.InitHandler(ep) + f.acceptQueue = append(f.acceptQueue, ep) } func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 09361360f..7ae36fde7 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -738,14 +738,6 @@ const ( // interface index and address. ReceiveIPPacketInfoOption - // ReuseAddressOption is used by SetSockOptBool/GetSockOptBool to - // specify whether Bind() should allow reuse of local address. - ReuseAddressOption - - // ReusePortOption is used by SetSockOptBool/GetSockOptBool to permit - // multiple sockets to be bound to an identical socket address. - ReusePortOption - // V6OnlyOption is used by SetSockOptBool/GetSockOptBool to specify // whether an IPv6 socket is to be restricted to sending and receiving // IPv6 packets only. diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 9d30329f5..8be791a00 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -510,10 +510,7 @@ func TestReuseAddrAndBroadcast(t *testing.T) { } defer ep.Close() - if err := ep.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("eps[%d].SetSockOptBool(tcpip.ReuseAddressOption, true): %s", len(eps), err) - } - + ep.SocketOptions().SetReuseAddress(true) ep.SocketOptions().SetBroadcast(true) bindAddr := tcpip.FullAddress{Port: localPort} diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index fe6514bcd..39560a9fa 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -49,6 +49,7 @@ const ( // +stateify savable type endpoint struct { stack.TransportEndpointInfo + tcpip.DefaultSocketOptionsHandler // The following fields are initialized at creation time and are // immutable. @@ -85,7 +86,7 @@ type endpoint struct { } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return &endpoint{ + ep := &endpoint{ stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{ NetProto: netProto, @@ -96,7 +97,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt sndBufSize: 32 * 1024, state: stateInitial, uniqueID: s.UniqueID(), - }, nil + } + ep.ops.InitHandler(ep) + return ep, nil } // UniqueID implements stack.TransportEndpoint.UniqueID. diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 3bff3755a..35d1be792 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -60,6 +60,8 @@ type packet struct { // +stateify savable type endpoint struct { stack.TransportEndpointInfo + tcpip.DefaultSocketOptionsHandler + // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` @@ -107,6 +109,7 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, } + ep.ops.InitHandler(ep) // Override with stack defaults. var ss stack.SendBufferSizeOption diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 0a1e1fbb3..e64392f7b 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -58,6 +58,8 @@ type rawPacket struct { // +stateify savable type endpoint struct { stack.TransportEndpointInfo + tcpip.DefaultSocketOptionsHandler + // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` @@ -116,6 +118,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt associated: associated, hdrIncluded: !associated, } + e.ops.InitHandler(e) // Override with stack defaults. var ss stack.SendBufferSizeOption diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 36b915510..f893324c2 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -362,6 +362,7 @@ func (*EndpointInfo) IsEndpointInfo() {} // +stateify savable type endpoint struct { EndpointInfo + tcpip.DefaultSocketOptionsHandler // endpointEntry is used to queue endpoints for processing to the // a given tcp processor goroutine. @@ -884,6 +885,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue windowClamp: DefaultReceiveBufferSize, maxSynRetries: DefaultSynRetries, } + e.ops.InitHandler(e) var ss tcpip.TCPSendBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { @@ -1627,6 +1629,20 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo return false, false } +// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet. +func (e *endpoint) OnReuseAddressSet(v bool) { + e.LockUser() + e.portFlags.TupleOnly = v + e.UnlockUser() +} + +// OnReusePortSet implements tcpip.SocketOptionsHandler.OnReusePortSet. +func (e *endpoint) OnReusePortSet(v bool) { + e.LockUser() + e.portFlags.LoadBalanced = v + e.UnlockUser() +} + // SetSockOptBool sets a socket option. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { switch opt { @@ -1666,16 +1682,6 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { } atomic.StoreUint32(&e.slowAck, o) - case tcpip.ReuseAddressOption: - e.LockUser() - e.portFlags.TupleOnly = v - e.UnlockUser() - - case tcpip.ReusePortOption: - e.LockUser() - e.portFlags.LoadBalanced = v - e.UnlockUser() - case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -1995,20 +2001,6 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { v := atomic.LoadUint32(&e.slowAck) == 0 return v, nil - case tcpip.ReuseAddressOption: - e.LockUser() - v := e.portFlags.TupleOnly - e.UnlockUser() - - return v, nil - - case tcpip.ReusePortOption: - e.LockUser() - v := e.portFlags.LoadBalanced - e.UnlockUser() - - return v, nil - case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index c366a4cbc..9fa3aa740 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -4191,9 +4191,7 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } - if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) - } + c.EP.SocketOptions().SetReuseAddress(true) if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -4203,9 +4201,7 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } - if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) - } + c.EP.SocketOptions().SetReuseAddress(true) if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -4216,9 +4212,7 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } - if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) - } + c.EP.SocketOptions().SetReuseAddress(true) if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -4231,9 +4225,7 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } - if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) - } + c.EP.SocketOptions().SetReuseAddress(true) if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -4244,9 +4236,7 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } - if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) - } + c.EP.SocketOptions().SetReuseAddress(true) if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -4259,9 +4249,7 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } - if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) - } + c.EP.SocketOptions().SetReuseAddress(true) if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 5aa16bf35..e57833644 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -77,6 +77,7 @@ func (s EndpointState) String() string { // +stateify savable type endpoint struct { stack.TransportEndpointInfo + tcpip.DefaultSocketOptionsHandler // The following fields are initialized at creation time and do not // change throughout the lifetime of the endpoint. @@ -194,6 +195,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue state: StateInitial, uniqueID: s.UniqueID(), } + e.ops.InitHandler(e) // Override with stack defaults. var ss stack.SendBufferSizeOption @@ -574,6 +576,20 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } +// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet. +func (e *endpoint) OnReuseAddressSet(v bool) { + e.mu.Lock() + e.portFlags.MostRecent = v + e.mu.Unlock() +} + +// OnReusePortSet implements tcpip.SocketOptionsHandler.OnReusePortSet. +func (e *endpoint) OnReusePortSet(v bool) { + e.mu.Lock() + e.portFlags.LoadBalanced = v + e.mu.Unlock() +} + // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { switch opt { @@ -602,16 +618,6 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { e.receiveIPPacketInfo = v e.mu.Unlock() - case tcpip.ReuseAddressOption: - e.mu.Lock() - e.portFlags.MostRecent = v - e.mu.Unlock() - - case tcpip.ReusePortOption: - e.mu.Lock() - e.portFlags.LoadBalanced = v - e.mu.Unlock() - case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -875,20 +881,6 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { e.mu.RUnlock() return v, nil - case tcpip.ReuseAddressOption: - e.mu.RLock() - v := e.portFlags.MostRecent - e.mu.RUnlock() - - return v, nil - - case tcpip.ReusePortOption: - e.mu.RLock() - v := e.portFlags.LoadBalanced - e.mu.RUnlock() - - return v, nil - case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc index d17192c36..81e00b1dc 100644 --- a/test/syscalls/linux/socket_generic.cc +++ b/test/syscalls/linux/socket_generic.cc @@ -819,7 +819,8 @@ TEST_P(AllSocketPairTest, GetSockoptProtocol) { } TEST_P(AllSocketPairTest, SetAndGetBooleanSocketOptions) { - int sock_opts[] = {SO_BROADCAST, SO_PASSCRED, SO_NO_CHECK}; + int sock_opts[] = {SO_BROADCAST, SO_PASSCRED, SO_NO_CHECK, SO_REUSEADDR, + SO_REUSEPORT}; for (int sock_opt : sock_opts) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); int enable = -1; |