summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sentry/socket/netstack/netstack.go23
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go27
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go1
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go8
-rw-r--r--pkg/tcpip/socketops.go61
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go4
-rw-r--r--pkg/tcpip/stack/transport_test.go17
-rw-r--r--pkg/tcpip/tcpip.go8
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go5
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go7
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go3
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go3
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go40
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go24
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go40
-rw-r--r--test/syscalls/linux/socket_generic.cc3
16 files changed, 150 insertions, 124 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index d48b92c66..bf6d8c5dc 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -1112,25 +1112,16 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.ReuseAddressOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReuseAddress()))
+ return &v, nil
case linux.SO_REUSEPORT:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.ReusePortOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReusePort()))
+ return &v, nil
case linux.SO_BINDTODEVICE:
var v tcpip.BindToDeviceOption
@@ -1869,7 +1860,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReuseAddressOption, v != 0))
+ ep.SocketOptions().SetReuseAddress(v != 0)
+ return nil
case linux.SO_REUSEPORT:
if len(optVal) < sizeOfInt32 {
@@ -1877,7 +1869,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReusePortOption, v != 0))
+ ep.SocketOptions().SetReusePort(v != 0)
+ return nil
case linux.SO_BINDTODEVICE:
n := bytes.IndexByte(optVal, 0)
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index 6d9e502bd..9f7aca305 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -118,28 +118,24 @@ var (
// NewConnectioned creates a new unbound connectionedEndpoint.
func NewConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) Endpoint {
- return &connectionedEndpoint{
+ return newConnectioned(ctx, stype, uid)
+}
+
+func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) *connectionedEndpoint {
+ ep := &connectionedEndpoint{
baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
id: uid.UniqueID(),
idGenerator: uid,
stype: stype,
}
+ ep.ops.InitHandler(ep)
+ return ep
}
// NewPair allocates a new pair of connected unix-domain connectionedEndpoints.
func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) {
- a := &connectionedEndpoint{
- baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
- id: uid.UniqueID(),
- idGenerator: uid,
- stype: stype,
- }
- b := &connectionedEndpoint{
- baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
- id: uid.UniqueID(),
- idGenerator: uid,
- stype: stype,
- }
+ a := newConnectioned(ctx, stype, uid)
+ b := newConnectioned(ctx, stype, uid)
q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit}
q1.InitRefs()
@@ -171,12 +167,14 @@ func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (E
// NewExternal creates a new externally backed Endpoint. It behaves like a
// socketpair.
func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint {
- return &connectionedEndpoint{
+ ep := &connectionedEndpoint{
baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected},
id: uid.UniqueID(),
idGenerator: uid,
stype: stype,
}
+ ep.ops.InitHandler(ep)
+ return ep
}
// ID implements ConnectingEndpoint.ID.
@@ -298,6 +296,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn
idGenerator: e.idGenerator,
stype: e.stype,
}
+ ne.ops.InitHandler(ne)
readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit}
readQueue.InitRefs()
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
index 1406971bc..0813ad87d 100644
--- a/pkg/sentry/socket/unix/transport/connectionless.go
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -44,6 +44,7 @@ func NewConnectionless(ctx context.Context) Endpoint {
q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit}
q.InitRefs()
ep.receiver = &queueReceiver{readQueue: &q}
+ ep.ops.InitHandler(ep)
return ep
}
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 0324dcd93..8482d1603 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -738,6 +738,7 @@ func (e *connectedEndpoint) CloseUnread() {
// +stateify savable
type baseEndpoint struct {
*waiter.Queue
+ tcpip.DefaultSocketOptionsHandler
// Mutex protects the below fields.
sync.Mutex `state:"nosave"`
@@ -756,6 +757,7 @@ type baseEndpoint struct {
// linger is used for SO_LINGER socket option.
linger tcpip.LingerOption
+ // ops is used to get socket level options.
ops tcpip.SocketOptions
}
@@ -856,11 +858,7 @@ func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
}
func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
- switch opt {
- case tcpip.ReuseAddressOption:
- default:
- log.Warningf("Unsupported socket option: %d", opt)
- }
+ log.Warningf("Unsupported socket option: %d", opt)
return nil
}
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index cc3d59d9d..99c3e9c45 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -18,11 +18,36 @@ import (
"sync/atomic"
)
+// SocketOptionsHandler holds methods that help define endpoint specific
+// behavior for socket options. These must be implemented by endpoints to get
+// notified when socket level options are set.
+type SocketOptionsHandler interface {
+ // OnReuseAddressSet is invoked when SO_REUSEADDR is set for an endpoint.
+ OnReuseAddressSet(v bool)
+
+ // OnReusePortSet is invoked when SO_REUSEPORT is set for an endpoint.
+ OnReusePortSet(v bool)
+}
+
+// DefaultSocketOptionsHandler is an embeddable type that implements no-op
+// implementations for SocketOptionsHandler methods.
+type DefaultSocketOptionsHandler struct{}
+
+var _ SocketOptionsHandler = (*DefaultSocketOptionsHandler)(nil)
+
+// OnReuseAddressSet implements SocketOptionsHandler.OnReuseAddressSet.
+func (*DefaultSocketOptionsHandler) OnReuseAddressSet(bool) {}
+
+// OnReusePortSet implements SocketOptionsHandler.OnReusePortSet.
+func (*DefaultSocketOptionsHandler) OnReusePortSet(bool) {}
+
// SocketOptions contains all the variables which store values for SOL_SOCKET
// level options.
//
// +stateify savable
type SocketOptions struct {
+ handler SocketOptionsHandler
+
// These fields are accessed and modified using atomic operations.
// broadcastEnabled determines whether datagram sockets are allowed to send
@@ -36,6 +61,20 @@ type SocketOptions struct {
// noChecksumEnabled determines whether UDP checksum is disabled while
// transmitting for this socket.
noChecksumEnabled uint32
+
+ // reuseAddressEnabled determines whether Bind() should allow reuse of local
+ // address.
+ reuseAddressEnabled uint32
+
+ // reusePortEnabled determines whether to permit multiple sockets to be bound
+ // to an identical socket address.
+ reusePortEnabled uint32
+}
+
+// InitHandler initializes the handler. This must be called before using the
+// socket options utility.
+func (so *SocketOptions) InitHandler(handler SocketOptionsHandler) {
+ so.handler = handler
}
func storeAtomicBool(addr *uint32, v bool) {
@@ -75,3 +114,25 @@ func (so *SocketOptions) GetNoChecksum() bool {
func (so *SocketOptions) SetNoChecksum(v bool) {
storeAtomicBool(&so.noChecksumEnabled, v)
}
+
+// GetReuseAddress gets value for SO_REUSEADDR option.
+func (so *SocketOptions) GetReuseAddress() bool {
+ return atomic.LoadUint32(&so.reuseAddressEnabled) != 0
+}
+
+// SetReuseAddress sets value for SO_REUSEADDR option.
+func (so *SocketOptions) SetReuseAddress(v bool) {
+ storeAtomicBool(&so.reuseAddressEnabled, v)
+ so.handler.OnReuseAddressSet(v)
+}
+
+// GetReusePort gets value for SO_REUSEPORT option.
+func (so *SocketOptions) GetReusePort() bool {
+ return atomic.LoadUint32(&so.reusePortEnabled) != 0
+}
+
+// SetReusePort sets value for SO_REUSEPORT option.
+func (so *SocketOptions) SetReusePort(v bool) {
+ storeAtomicBool(&so.reusePortEnabled, v)
+ so.handler.OnReusePortSet(v)
+}
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 41a8e5ad0..2cdb5ca79 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -307,9 +307,7 @@ func TestBindToDeviceDistribution(t *testing.T) {
}(ep)
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.ReusePortOption, endpoint.reuse); err != nil {
- t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err)
- }
+ ep.SocketOptions().SetReusePort(endpoint.reuse)
bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
if err := ep.SetSockOpt(&bindToDeviceOption); err != nil {
t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", bindToDeviceOption, bindToDeviceOption, i, err)
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 5b9043d85..fbac66993 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -38,6 +38,7 @@ const (
// use it.
type fakeTransportEndpoint struct {
stack.TransportEndpointInfo
+ tcpip.DefaultSocketOptionsHandler
proto *fakeTransportProtocol
peerAddr tcpip.Address
@@ -45,7 +46,7 @@ type fakeTransportEndpoint struct {
uniqueID uint64
// acceptQueue is non-nil iff bound.
- acceptQueue []fakeTransportEndpoint
+ acceptQueue []*fakeTransportEndpoint
// ops is used to set and get socket options.
ops tcpip.SocketOptions
@@ -65,7 +66,9 @@ func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions {
return &f.ops
}
func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
- return &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
+ ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
+ ep.ops.InitHandler(ep)
+ return ep
}
func (f *fakeTransportEndpoint) Abort() {
@@ -189,7 +192,7 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai
if len(f.acceptQueue) == 0 {
return nil, nil, nil
}
- a := &f.acceptQueue[0]
+ a := f.acceptQueue[0]
f.acceptQueue = f.acceptQueue[1:]
return a, nil, nil
}
@@ -206,7 +209,7 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
); err != nil {
return err
}
- f.acceptQueue = []fakeTransportEndpoint{}
+ f.acceptQueue = []*fakeTransportEndpoint{}
return nil
}
@@ -232,7 +235,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *
}
route.ResolveWith(pkt.SourceLinkAddress())
- f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{
+ ep := &fakeTransportEndpoint{
TransportEndpointInfo: stack.TransportEndpointInfo{
ID: f.ID,
NetProto: f.NetProto,
@@ -240,7 +243,9 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *
proto: f.proto,
peerAddr: route.RemoteAddress,
route: route,
- })
+ }
+ ep.ops.InitHandler(ep)
+ f.acceptQueue = append(f.acceptQueue, ep)
}
func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) {
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 09361360f..7ae36fde7 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -738,14 +738,6 @@ const (
// interface index and address.
ReceiveIPPacketInfoOption
- // ReuseAddressOption is used by SetSockOptBool/GetSockOptBool to
- // specify whether Bind() should allow reuse of local address.
- ReuseAddressOption
-
- // ReusePortOption is used by SetSockOptBool/GetSockOptBool to permit
- // multiple sockets to be bound to an identical socket address.
- ReusePortOption
-
// V6OnlyOption is used by SetSockOptBool/GetSockOptBool to specify
// whether an IPv6 socket is to be restricted to sending and receiving
// IPv6 packets only.
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 9d30329f5..8be791a00 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -510,10 +510,7 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("eps[%d].SetSockOptBool(tcpip.ReuseAddressOption, true): %s", len(eps), err)
- }
-
+ ep.SocketOptions().SetReuseAddress(true)
ep.SocketOptions().SetBroadcast(true)
bindAddr := tcpip.FullAddress{Port: localPort}
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index fe6514bcd..39560a9fa 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -49,6 +49,7 @@ const (
// +stateify savable
type endpoint struct {
stack.TransportEndpointInfo
+ tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and are
// immutable.
@@ -85,7 +86,7 @@ type endpoint struct {
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return &endpoint{
+ ep := &endpoint{
stack: s,
TransportEndpointInfo: stack.TransportEndpointInfo{
NetProto: netProto,
@@ -96,7 +97,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
sndBufSize: 32 * 1024,
state: stateInitial,
uniqueID: s.UniqueID(),
- }, nil
+ }
+ ep.ops.InitHandler(ep)
+ return ep, nil
}
// UniqueID implements stack.TransportEndpoint.UniqueID.
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 3bff3755a..35d1be792 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -60,6 +60,8 @@ type packet struct {
// +stateify savable
type endpoint struct {
stack.TransportEndpointInfo
+ tcpip.DefaultSocketOptionsHandler
+
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
@@ -107,6 +109,7 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
}
+ ep.ops.InitHandler(ep)
// Override with stack defaults.
var ss stack.SendBufferSizeOption
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 0a1e1fbb3..e64392f7b 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -58,6 +58,8 @@ type rawPacket struct {
// +stateify savable
type endpoint struct {
stack.TransportEndpointInfo
+ tcpip.DefaultSocketOptionsHandler
+
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
@@ -116,6 +118,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
associated: associated,
hdrIncluded: !associated,
}
+ e.ops.InitHandler(e)
// Override with stack defaults.
var ss stack.SendBufferSizeOption
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 36b915510..f893324c2 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -362,6 +362,7 @@ func (*EndpointInfo) IsEndpointInfo() {}
// +stateify savable
type endpoint struct {
EndpointInfo
+ tcpip.DefaultSocketOptionsHandler
// endpointEntry is used to queue endpoints for processing to the
// a given tcp processor goroutine.
@@ -884,6 +885,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
windowClamp: DefaultReceiveBufferSize,
maxSynRetries: DefaultSynRetries,
}
+ e.ops.InitHandler(e)
var ss tcpip.TCPSendBufferSizeRangeOption
if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
@@ -1627,6 +1629,20 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo
return false, false
}
+// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet.
+func (e *endpoint) OnReuseAddressSet(v bool) {
+ e.LockUser()
+ e.portFlags.TupleOnly = v
+ e.UnlockUser()
+}
+
+// OnReusePortSet implements tcpip.SocketOptionsHandler.OnReusePortSet.
+func (e *endpoint) OnReusePortSet(v bool) {
+ e.LockUser()
+ e.portFlags.LoadBalanced = v
+ e.UnlockUser()
+}
+
// SetSockOptBool sets a socket option.
func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
switch opt {
@@ -1666,16 +1682,6 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
}
atomic.StoreUint32(&e.slowAck, o)
- case tcpip.ReuseAddressOption:
- e.LockUser()
- e.portFlags.TupleOnly = v
- e.UnlockUser()
-
- case tcpip.ReusePortOption:
- e.LockUser()
- e.portFlags.LoadBalanced = v
- e.UnlockUser()
-
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
@@ -1995,20 +2001,6 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
v := atomic.LoadUint32(&e.slowAck) == 0
return v, nil
- case tcpip.ReuseAddressOption:
- e.LockUser()
- v := e.portFlags.TupleOnly
- e.UnlockUser()
-
- return v, nil
-
- case tcpip.ReusePortOption:
- e.LockUser()
- v := e.portFlags.LoadBalanced
- e.UnlockUser()
-
- return v, nil
-
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index c366a4cbc..9fa3aa740 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -4191,9 +4191,7 @@ func TestReusePort(t *testing.T) {
if err != nil {
t.Fatalf("NewEndpoint failed; %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
- }
+ c.EP.SocketOptions().SetReuseAddress(true)
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
@@ -4203,9 +4201,7 @@ func TestReusePort(t *testing.T) {
if err != nil {
t.Fatalf("NewEndpoint failed; %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
- }
+ c.EP.SocketOptions().SetReuseAddress(true)
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
@@ -4216,9 +4212,7 @@ func TestReusePort(t *testing.T) {
if err != nil {
t.Fatalf("NewEndpoint failed; %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
- }
+ c.EP.SocketOptions().SetReuseAddress(true)
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
@@ -4231,9 +4225,7 @@ func TestReusePort(t *testing.T) {
if err != nil {
t.Fatalf("NewEndpoint failed; %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
- }
+ c.EP.SocketOptions().SetReuseAddress(true)
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
@@ -4244,9 +4236,7 @@ func TestReusePort(t *testing.T) {
if err != nil {
t.Fatalf("NewEndpoint failed; %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
- }
+ c.EP.SocketOptions().SetReuseAddress(true)
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
@@ -4259,9 +4249,7 @@ func TestReusePort(t *testing.T) {
if err != nil {
t.Fatalf("NewEndpoint failed; %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
- }
+ c.EP.SocketOptions().SetReuseAddress(true)
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 5aa16bf35..e57833644 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -77,6 +77,7 @@ func (s EndpointState) String() string {
// +stateify savable
type endpoint struct {
stack.TransportEndpointInfo
+ tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and do not
// change throughout the lifetime of the endpoint.
@@ -194,6 +195,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
state: StateInitial,
uniqueID: s.UniqueID(),
}
+ e.ops.InitHandler(e)
// Override with stack defaults.
var ss stack.SendBufferSizeOption
@@ -574,6 +576,20 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
+// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet.
+func (e *endpoint) OnReuseAddressSet(v bool) {
+ e.mu.Lock()
+ e.portFlags.MostRecent = v
+ e.mu.Unlock()
+}
+
+// OnReusePortSet implements tcpip.SocketOptionsHandler.OnReusePortSet.
+func (e *endpoint) OnReusePortSet(v bool) {
+ e.mu.Lock()
+ e.portFlags.LoadBalanced = v
+ e.mu.Unlock()
+}
+
// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
switch opt {
@@ -602,16 +618,6 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
e.receiveIPPacketInfo = v
e.mu.Unlock()
- case tcpip.ReuseAddressOption:
- e.mu.Lock()
- e.portFlags.MostRecent = v
- e.mu.Unlock()
-
- case tcpip.ReusePortOption:
- e.mu.Lock()
- e.portFlags.LoadBalanced = v
- e.mu.Unlock()
-
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
@@ -875,20 +881,6 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
e.mu.RUnlock()
return v, nil
- case tcpip.ReuseAddressOption:
- e.mu.RLock()
- v := e.portFlags.MostRecent
- e.mu.RUnlock()
-
- return v, nil
-
- case tcpip.ReusePortOption:
- e.mu.RLock()
- v := e.portFlags.LoadBalanced
- e.mu.RUnlock()
-
- return v, nil
-
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc
index d17192c36..81e00b1dc 100644
--- a/test/syscalls/linux/socket_generic.cc
+++ b/test/syscalls/linux/socket_generic.cc
@@ -819,7 +819,8 @@ TEST_P(AllSocketPairTest, GetSockoptProtocol) {
}
TEST_P(AllSocketPairTest, SetAndGetBooleanSocketOptions) {
- int sock_opts[] = {SO_BROADCAST, SO_PASSCRED, SO_NO_CHECK};
+ int sock_opts[] = {SO_BROADCAST, SO_PASSCRED, SO_NO_CHECK, SO_REUSEADDR,
+ SO_REUSEPORT};
for (int sock_opt : sock_opts) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
int enable = -1;