summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
authorNayana Bidari <nybidari@google.com>2020-11-12 22:55:09 -0800
committergVisor bot <gvisor-bot@google.com>2020-11-12 22:57:00 -0800
commit5bb64ce1b8c42fcd96e44a5be05e17f34a83f840 (patch)
tree6dbfa200de6de2d6d53ddd8f3e7d35518ef22203 /pkg/tcpip/transport
parentbf392dcc7d7b1f256acfe8acd2758a77db3fc8a2 (diff)
Refactor SOL_SOCKET options
Store all the socket level options in a struct and call {Get/Set}SockOpt on this struct. This will avoid implementing socket level options on all endpoints. This CL contains implementing one socket level option for tcp and udp endpoints. PiperOrigin-RevId: 342203981
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go7
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go7
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go7
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go20
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go22
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go12
6 files changed, 39 insertions, 36 deletions
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 763cd8f84..440cb0352 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -79,6 +79,9 @@ type endpoint struct {
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
+
+ // ops is used to get socket level options.
+ ops tcpip.SocketOptions
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
@@ -853,3 +856,7 @@ func (*endpoint) Wait() {}
func (*endpoint) LastError() *tcpip.Error {
return nil
}
+
+func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
+ return &e.ops
+}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 31831a6d8..3bff3755a 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -89,6 +89,9 @@ type endpoint struct {
// lastErrorMu protects lastError.
lastErrorMu sync.Mutex `state:"nosave"`
lastError *tcpip.Error `state:".(string)"`
+
+ // ops is used to get socket level options.
+ ops tcpip.SocketOptions
}
// NewEndpoint returns a new packet endpoint.
@@ -549,3 +552,7 @@ func (ep *endpoint) Stats() tcpip.EndpointStats {
}
func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {}
+
+func (ep *endpoint) SocketOptions() *tcpip.SocketOptions {
+ return &ep.ops
+}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 7b6a87ba9..4ae1f92ab 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -89,6 +89,9 @@ type endpoint struct {
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
+
+ // ops is used to get socket level options.
+ ops tcpip.SocketOptions
}
// NewEndpoint returns a raw endpoint for the given protocols.
@@ -756,3 +759,7 @@ func (*endpoint) Wait() {}
func (*endpoint) LastError() *tcpip.Error {
return nil
}
+
+func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
+ return &e.ops
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 23b9de8c5..194d3a8a4 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -440,9 +440,6 @@ type endpoint struct {
ttl uint8
v6only bool
isConnectNotified bool
- // TCP should never broadcast but Linux nevertheless supports enabling/
- // disabling SO_BROADCAST, albeit as a NOOP.
- broadcast bool
// portFlags stores the current values of port related flags.
portFlags ports.Flags
@@ -685,6 +682,9 @@ type endpoint struct {
// linger is used for SO_LINGER socket option.
linger tcpip.LingerOption
+
+ // ops is used to get socket level options.
+ ops tcpip.SocketOptions
}
// UniqueID implements stack.TransportEndpoint.UniqueID.
@@ -1599,11 +1599,6 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo
func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
switch opt {
- case tcpip.BroadcastOption:
- e.LockUser()
- e.broadcast = v
- e.UnlockUser()
-
case tcpip.CorkOption:
e.LockUser()
if !v {
@@ -1950,11 +1945,6 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
switch opt {
- case tcpip.BroadcastOption:
- e.LockUser()
- v := e.broadcast
- e.UnlockUser()
- return v, nil
case tcpip.CorkOption:
return atomic.LoadUint32(&e.cork) != 0, nil
@@ -3130,3 +3120,7 @@ func (e *endpoint) Wait() {
<-notifyCh
}
}
+
+func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
+ return &e.ops
+}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 9bcb918bb..57976d4e3 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -108,7 +108,6 @@ type endpoint struct {
multicastLoop bool
portFlags ports.Flags
bindToDevice tcpip.NICID
- broadcast bool
noChecksum bool
lastErrorMu sync.Mutex `state:"nosave"`
@@ -157,6 +156,9 @@ type endpoint struct {
// linger is used for SO_LINGER socket option.
linger tcpip.LingerOption
+
+ // ops is used to get socket level options.
+ ops tcpip.SocketOptions
}
// +stateify savable
@@ -508,7 +510,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
resolve = route.Resolve
}
- if !e.broadcast && route.IsOutboundBroadcast() {
+ if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() {
return 0, nil, tcpip.ErrBroadcastDisabled
}
@@ -553,11 +555,6 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
switch opt {
- case tcpip.BroadcastOption:
- e.mu.Lock()
- e.broadcast = v
- e.mu.Unlock()
-
case tcpip.MulticastLoopOption:
e.mu.Lock()
e.multicastLoop = v
@@ -614,7 +611,6 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
e.v6only = v
}
-
return nil
}
@@ -830,12 +826,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
switch opt {
- case tcpip.BroadcastOption:
- e.mu.RLock()
- v := e.broadcast
- e.mu.RUnlock()
- return v, nil
-
case tcpip.KeepaliveEnabledOption:
return false, nil
@@ -1525,3 +1515,7 @@ func isBroadcastOrMulticast(a tcpip.Address) bool {
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
+
+func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
+ return &e.ops
+}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index df62177cd..764ad0857 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -364,9 +364,7 @@ func (c *testContext) createEndpointForFlow(flow testFlow) {
c.t.Fatalf("SetSockOptBool failed: %s", err)
}
} else if flow.isBroadcast() {
- if err := c.ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil {
- c.t.Fatalf("SetSockOptBool failed: %s", err)
- }
+ c.ep.SocketOptions().SetBroadcast(true)
}
}
@@ -2397,17 +2395,13 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt)
}
- if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil {
- t.Fatalf("got SetSockOptBool(BroadcastOption, true): %s", err)
- }
+ ep.SocketOptions().SetBroadcast(true)
if n, _, err := ep.Write(data, opts); err != nil {
t.Fatalf("got ep.Write(_, _) = (%d, _, %s), want = (_, _, nil)", n, err)
}
- if err := ep.SetSockOptBool(tcpip.BroadcastOption, false); err != nil {
- t.Fatalf("got SetSockOptBool(BroadcastOption, false): %s", err)
- }
+ ep.SocketOptions().SetBroadcast(false)
if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt {
t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt)