summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
authorBhasker Hariharan <bhaskerh@google.com>2021-02-09 21:52:50 -0800
committergVisor bot <gvisor-bot@google.com>2021-02-09 21:55:16 -0800
commit298c129cc151e197db35a927f9676cc40ec80d5c (patch)
treeab9ef4a5992e53a5020522018b1ea48b8a86bcbf /pkg
parent2de36e44ed753c4cef2f9d71499fad6d87cb8b86 (diff)
Add support for setting SO_SNDBUF for unix domain sockets.
The limits for snd/rcv buffers for unix domain socket is controlled by the following sysctls on linux - net.core.rmem_default - net.core.rmem_max - net.core.wmem_default - net.core.wmem_max Today in gVisor we do not expose these sysctls but we do support setting the equivalent in netstack via stack.Options() method. But AF_UNIX sockets in gVisor can be used without netstack, with hostinet or even without any networking stack at all. Which means ideally these sysctls need to live as globals in gVisor. But rather than make this a big change for now we hardcode the limits in the AF_UNIX implementation itself (which in itself is better than where we were before) where it SO_SNDBUF was hardcoded to 16KiB. Further we bump the initial limit to a default value of 208 KiB to match linux from the paltry 16 KiB we use today. Updates #5132 PiperOrigin-RevId: 356665498
Diffstat (limited to 'pkg')
-rw-r--r--pkg/sentry/fs/host/socket.go16
-rw-r--r--pkg/sentry/fsimpl/host/socket.go18
-rw-r--r--pkg/sentry/socket/netstack/netstack.go12
-rw-r--r--pkg/sentry/socket/unix/BUILD1
-rw-r--r--pkg/sentry/socket/unix/transport/BUILD2
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go30
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned_state.go5
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go13
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless_state.go20
-rw-r--r--pkg/sentry/socket/unix/transport/queue.go9
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go77
-rw-r--r--pkg/tcpip/socketops.go73
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go6
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go7
14 files changed, 186 insertions, 103 deletions
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
index 07b4fb70f..2b58fc52c 100644
--- a/pkg/sentry/fs/host/socket.go
+++ b/pkg/sentry/fs/host/socket.go
@@ -16,6 +16,7 @@ package host
import (
"fmt"
+ "sync/atomic"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -206,7 +207,7 @@ func (c *ConnectedEndpoint) Send(ctx context.Context, data [][]byte, controlMess
// only as much of the message as fits in the send buffer.
truncate := c.stype == linux.SOCK_STREAM
- n, totalLen, err := fdWriteVec(c.file.FD(), data, c.sndbuf, truncate)
+ n, totalLen, err := fdWriteVec(c.file.FD(), data, c.SendMaxQueueSize(), truncate)
if n < totalLen && err == nil {
// The host only returns a short write if it would otherwise
// block (and only for stream sockets).
@@ -282,7 +283,7 @@ func (c *ConnectedEndpoint) Recv(ctx context.Context, data [][]byte, creds bool,
// N.B. Unix sockets don't have a receive buffer, the send buffer
// serves both purposes.
- rl, ml, cl, cTrunc, err := fdReadVec(c.file.FD(), data, []byte(cm), peek, c.sndbuf)
+ rl, ml, cl, cTrunc, err := fdReadVec(c.file.FD(), data, []byte(cm), peek, c.RecvMaxQueueSize())
if rl > 0 && err != nil {
// We got some data, so all we need to do on error is return
// the data that we got. Short reads are fine, no need to
@@ -363,14 +364,14 @@ func (c *ConnectedEndpoint) RecvQueuedSize() int64 {
// SendMaxQueueSize implements transport.Receiver.SendMaxQueueSize.
func (c *ConnectedEndpoint) SendMaxQueueSize() int64 {
- return int64(c.sndbuf)
+ return atomic.LoadInt64(&c.sndbuf)
}
// RecvMaxQueueSize implements transport.Receiver.RecvMaxQueueSize.
func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 {
// N.B. Unix sockets don't use the receive buffer. We'll claim it is
// the same size as the send buffer.
- return int64(c.sndbuf)
+ return atomic.LoadInt64(&c.sndbuf)
}
// Release implements transport.ConnectedEndpoint.Release and transport.Receiver.Release.
@@ -381,4 +382,11 @@ func (c *ConnectedEndpoint) Release(ctx context.Context) {
// CloseUnread implements transport.ConnectedEndpoint.CloseUnread.
func (c *ConnectedEndpoint) CloseUnread() {}
+// SetSendBufferSize implements transport.ConnectedEndpoint.SetSendBufferSize.
+func (c *ConnectedEndpoint) SetSendBufferSize(v int64) (newSz int64) {
+ // gVisor does not permit setting of SO_SNDBUF for host backed unix domain
+ // sockets.
+ return atomic.LoadInt64(&c.sndbuf)
+}
+
// LINT.ThenChange(../../fsimpl/host/socket.go)
diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go
index 72aa535f8..6763f5b0c 100644
--- a/pkg/sentry/fsimpl/host/socket.go
+++ b/pkg/sentry/fsimpl/host/socket.go
@@ -16,6 +16,7 @@ package host
import (
"fmt"
+ "sync/atomic"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -111,7 +112,7 @@ func (c *ConnectedEndpoint) init() *syserr.Error {
}
c.stype = linux.SockType(stype)
- c.sndbuf = int64(sndbuf)
+ atomic.StoreInt64(&c.sndbuf, int64(sndbuf))
return nil
}
@@ -150,7 +151,7 @@ func (c *ConnectedEndpoint) Send(ctx context.Context, data [][]byte, controlMess
// only as much of the message as fits in the send buffer.
truncate := c.stype == linux.SOCK_STREAM
- n, totalLen, err := fdWriteVec(c.fd, data, c.sndbuf, truncate)
+ n, totalLen, err := fdWriteVec(c.fd, data, c.SendMaxQueueSize(), truncate)
if n < totalLen && err == nil {
// The host only returns a short write if it would otherwise
// block (and only for stream sockets).
@@ -226,7 +227,7 @@ func (c *ConnectedEndpoint) Recv(ctx context.Context, data [][]byte, creds bool,
// N.B. Unix sockets don't have a receive buffer, the send buffer
// serves both purposes.
- rl, ml, cl, cTrunc, err := fdReadVec(c.fd, data, []byte(cm), peek, c.sndbuf)
+ rl, ml, cl, cTrunc, err := fdReadVec(c.fd, data, []byte(cm), peek, c.RecvMaxQueueSize())
if rl > 0 && err != nil {
// We got some data, so all we need to do on error is return
// the data that we got. Short reads are fine, no need to
@@ -300,14 +301,14 @@ func (c *ConnectedEndpoint) RecvQueuedSize() int64 {
// SendMaxQueueSize implements transport.Receiver.SendMaxQueueSize.
func (c *ConnectedEndpoint) SendMaxQueueSize() int64 {
- return int64(c.sndbuf)
+ return atomic.LoadInt64(&c.sndbuf)
}
// RecvMaxQueueSize implements transport.Receiver.RecvMaxQueueSize.
func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 {
// N.B. Unix sockets don't use the receive buffer. We'll claim it is
// the same size as the send buffer.
- return int64(c.sndbuf)
+ return atomic.LoadInt64(&c.sndbuf)
}
func (c *ConnectedEndpoint) destroyLocked() {
@@ -327,6 +328,13 @@ func (c *ConnectedEndpoint) Release(ctx context.Context) {
// CloseUnread implements transport.ConnectedEndpoint.CloseUnread.
func (c *ConnectedEndpoint) CloseUnread() {}
+// SetSendBufferSize implements transport.ConnectedEndpoint.SetSendBufferSize.
+func (c *ConnectedEndpoint) SetSendBufferSize(v int64) (newSz int64) {
+ // gVisor does not permit setting of SO_SNDBUF for host backed unix domain
+ // sockets.
+ return atomic.LoadInt64(&c.sndbuf)
+}
+
// SCMConnectedEndpoint represents an endpoint backed by a host fd that was
// passed through a gofer Unix socket. It resembles ConnectedEndpoint, with the
// following differences:
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 69693f263..cee8120ab 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -855,10 +855,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- size, err := ep.SocketOptions().GetSendBufferSize()
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
+ size := ep.SocketOptions().GetSendBufferSize()
if size > math.MaxInt32 {
size = math.MaxInt32
@@ -1647,13 +1644,6 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
return syserr.ErrInvalidArgument
}
- family, _, _ := s.Type()
- // TODO(gvisor.dev/issue/5132): We currently do not support
- // setting this option for unix sockets.
- if family == linux.AF_UNIX {
- return nil
- }
-
v := usermem.ByteOrder.Uint32(optVal)
ep.SocketOptions().SetSendBufferSize(int64(v), true)
return nil
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index cce0acc33..acf2ab8e7 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -51,6 +51,7 @@ go_library(
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/lock",
"//pkg/sentry/fsimpl/sockfs",
+ "//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/time",
"//pkg/sentry/socket",
diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD
index 3ebbd28b0..0d11bb251 100644
--- a/pkg/sentry/socket/unix/transport/BUILD
+++ b/pkg/sentry/socket/unix/transport/BUILD
@@ -32,6 +32,7 @@ go_library(
"connectioned.go",
"connectioned_state.go",
"connectionless.go",
+ "connectionless_state.go",
"queue.go",
"queue_refs.go",
"transport_message_list.go",
@@ -45,6 +46,7 @@ go_library(
"//pkg/log",
"//pkg/refs",
"//pkg/refsvfs2",
+ "//pkg/sentry/inet",
"//pkg/sync",
"//pkg/syserr",
"//pkg/tcpip",
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index fc5b823b0..809c95429 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -128,7 +128,9 @@ func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProv
idGenerator: uid,
stype: stype,
}
- ep.ops.InitHandler(ep, nil, nil)
+
+ ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */)
+ ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits)
return ep
}
@@ -137,9 +139,9 @@ func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (E
a := newConnectioned(ctx, stype, uid)
b := newConnectioned(ctx, stype, uid)
- q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit}
+ q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: defaultBufferSize}
q1.InitRefs()
- q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: initialLimit}
+ q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: defaultBufferSize}
q2.InitRefs()
if stype == linux.SOCK_STREAM {
@@ -173,7 +175,8 @@ func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider
idGenerator: uid,
stype: stype,
}
- ep.ops.InitHandler(ep, nil, nil)
+ ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits)
+ ep.ops.SetSendBufferSize(connected.SendMaxQueueSize(), false /* notify */)
return ep
}
@@ -296,16 +299,18 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn
idGenerator: e.idGenerator,
stype: e.stype,
}
- ne.ops.InitHandler(ne, nil, nil)
+ ne.ops.InitHandler(ne, &stackHandler{}, getSendBufferLimits)
+ ne.ops.SetSendBufferSize(defaultBufferSize, false /* notify */)
- readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit}
+ readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: defaultBufferSize}
readQueue.InitRefs()
ne.connected = &connectedEndpoint{
endpoint: ce,
writeQueue: readQueue,
}
- writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: initialLimit}
+ // Make sure the accepted endpoint inherits this listening socket's SO_SNDBUF.
+ writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: e.ops.GetSendBufferSize()}
writeQueue.InitRefs()
if e.stype == linux.SOCK_STREAM {
ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}}
@@ -357,6 +362,9 @@ func (e *connectionedEndpoint) Connect(ctx context.Context, server BoundEndpoint
returnConnect := func(r Receiver, ce ConnectedEndpoint) {
e.receiver = r
e.connected = ce
+ // Make sure the newly created connected endpoint's write queue is updated
+ // to reflect this endpoint's send buffer size.
+ e.connected.SetSendBufferSize(e.ops.GetSendBufferSize())
}
return server.BidirectionalConnect(ctx, e, returnConnect)
@@ -495,3 +503,11 @@ func (e *connectionedEndpoint) State() uint32 {
}
return linux.SS_UNCONNECTED
}
+
+// OnSetSendBufferSize implements tcpip.SocketOptionsHandler.OnSetSendBufferSize.
+func (e *connectionedEndpoint) OnSetSendBufferSize(v int64) (newSz int64) {
+ if e.Connected() {
+ return e.baseEndpoint.connected.SetSendBufferSize(v)
+ }
+ return v
+}
diff --git a/pkg/sentry/socket/unix/transport/connectioned_state.go b/pkg/sentry/socket/unix/transport/connectioned_state.go
index 7e02a5db8..590b0bd01 100644
--- a/pkg/sentry/socket/unix/transport/connectioned_state.go
+++ b/pkg/sentry/socket/unix/transport/connectioned_state.go
@@ -51,3 +51,8 @@ func (e *connectionedEndpoint) loadAcceptedChan(acceptedSlice []*connectionedEnd
}
}
}
+
+// afterLoad is invoked by stateify.
+func (e *connectionedEndpoint) afterLoad() {
+ e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits)
+}
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
index 20fa8b874..0be78480c 100644
--- a/pkg/sentry/socket/unix/transport/connectionless.go
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -41,10 +41,11 @@ var (
// NewConnectionless creates a new unbound dgram endpoint.
func NewConnectionless(ctx context.Context) Endpoint {
ep := &connectionlessEndpoint{baseEndpoint{Queue: &waiter.Queue{}}}
- q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit}
+ q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: defaultBufferSize}
q.InitRefs()
ep.receiver = &queueReceiver{readQueue: &q}
- ep.ops.InitHandler(ep, nil, nil)
+ ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */)
+ ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits)
return ep
}
@@ -217,3 +218,11 @@ func (e *connectionlessEndpoint) State() uint32 {
return linux.SS_DISCONNECTING
}
}
+
+// OnSetSendBufferSize implements tcpip.SocketOptionsHandler.OnSetSendBufferSize.
+func (e *connectionlessEndpoint) OnSetSendBufferSize(v int64) (newSz int64) {
+ if e.Connected() {
+ return e.baseEndpoint.connected.SetSendBufferSize(v)
+ }
+ return v
+}
diff --git a/pkg/sentry/socket/unix/transport/connectionless_state.go b/pkg/sentry/socket/unix/transport/connectionless_state.go
new file mode 100644
index 000000000..2ef337ec8
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/connectionless_state.go
@@ -0,0 +1,20 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+// afterLoad is invoked by stateify.
+func (e *connectionlessEndpoint) afterLoad() {
+ e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits)
+}
diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go
index 342def28f..698a9a82c 100644
--- a/pkg/sentry/socket/unix/transport/queue.go
+++ b/pkg/sentry/socket/unix/transport/queue.go
@@ -237,9 +237,18 @@ func (q *queue) QueuedSize() int64 {
// MaxQueueSize returns the maximum number of bytes storable in the queue.
func (q *queue) MaxQueueSize() int64 {
+ q.mu.Lock()
+ defer q.mu.Unlock()
return q.limit
}
+// SetMaxQueueSize sets the maximum number of bytes storable in the queue.
+func (q *queue) SetMaxQueueSize(v int64) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ q.limit = v
+}
+
// CloseUnread sets flag to indicate that the peer is closed (not shutdown)
// with unread data. So if read on this queue shall return ECONNRESET error.
func (q *queue) CloseUnread() {
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 70227bbd2..ceada54a8 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -26,8 +26,16 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-// initialLimit is the starting limit for the socket buffers.
-const initialLimit = 16 * 1024
+const (
+ // The minimum size of the send/receive buffers.
+ minimumBufferSize = 4 << 10 // 4 KiB (match default in linux)
+
+ // The default size of the send/receive buffers.
+ defaultBufferSize = 208 << 10 // 208 KiB (default in linux for net.core.wmem_default)
+
+ // The maximum permitted size for the send/receive buffers.
+ maxBufferSize = 4 << 20 // 4 MiB 4 MiB (default in linux for net.core.wmem_max)
+)
// A RightsControlMessage is a control message containing FDs.
//
@@ -627,6 +635,10 @@ type ConnectedEndpoint interface {
// CloseUnread sets the fact that this end is closed with unread data to
// the peer socket.
CloseUnread()
+
+ // SetSendBufferSize is called when the endpoint's send buffer size is
+ // changed.
+ SetSendBufferSize(v int64) (newSz int64)
}
// +stateify savable
@@ -722,6 +734,14 @@ func (e *connectedEndpoint) CloseUnread() {
e.writeQueue.CloseUnread()
}
+// SetSendBufferSize implements ConnectedEndpoint.SetSendBufferSize.
+// SetSendBufferSize sets the send buffer size for the write queue to the
+// specified value.
+func (e *connectedEndpoint) SetSendBufferSize(v int64) (newSz int64) {
+ e.writeQueue.SetMaxQueueSize(v)
+ return v
+}
+
// baseEndpoint is an embeddable unix endpoint base used in both the connected and connectionless
// unix domain socket Endpoint implementations.
//
@@ -849,27 +869,6 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
return nil
}
-// IsUnixSocket implements tcpip.SocketOptionsHandler.IsUnixSocket.
-func (e *baseEndpoint) IsUnixSocket() bool {
- return true
-}
-
-// GetSendBufferSize implements tcpip.SocketOptionsHandler.GetSendBufferSize.
-func (e *baseEndpoint) GetSendBufferSize() (int64, tcpip.Error) {
- e.Lock()
- defer e.Unlock()
-
- if !e.Connected() {
- return -1, &tcpip.ErrNotConnected{}
- }
-
- v := e.connected.SendMaxQueueSize()
- if v < 0 {
- return -1, &tcpip.ErrQueueSizeNotSupported{}
- }
- return v, nil
-}
-
func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
@@ -987,3 +986,35 @@ func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
func (*baseEndpoint) Release(context.Context) {
// Binding a baseEndpoint doesn't take a reference.
}
+
+// stackHandler is just a stub implementation of tcpip.StackHandler to provide
+// when initializing socketoptions.
+type stackHandler struct {
+}
+
+// Option implements tcpip.StackHandler.
+func (h *stackHandler) Option(option interface{}) tcpip.Error {
+ panic("unimplemented")
+}
+
+// TransportProtocolOption implements tcpip.StackHandler.
+func (h *stackHandler) TransportProtocolOption(proto tcpip.TransportProtocolNumber, option tcpip.GettableTransportProtocolOption) tcpip.Error {
+ panic("unimplemented")
+}
+
+// getSendBufferLimits implements tcpip.GetSendBufferLimits.
+//
+// AF_UNIX sockets buffer sizes are not tied to the networking stack/namespace
+// in linux but are bound by net.core.(wmem|rmem)_(max|default).
+//
+// In gVisor net.core sysctls today are not exposed or if exposed are currently
+// tied to the networking stack in use. This makes it complicated for AF_UNIX
+// when we are in a new namespace w/ no networking stack. As a result for now we
+// define default/max values here in the unix socket implementation itself.
+func getSendBufferLimits(tcpip.StackHandler) tcpip.SendBufferSizeOption {
+ return tcpip.SendBufferSizeOption{
+ Min: minimumBufferSize,
+ Default: defaultBufferSize,
+ Max: maxBufferSize,
+ }
+}
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index 1e00144a5..dc37e61a4 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -54,11 +54,10 @@ type SocketOptionsHandler interface {
// HasNIC is invoked to check if the NIC is valid for SO_BINDTODEVICE.
HasNIC(v int32) bool
- // GetSendBufferSize is invoked to get the SO_SNDBUFSIZE.
- GetSendBufferSize() (int64, Error)
-
- // IsUnixSocket is invoked to check if the socket is of unix domain.
- IsUnixSocket() bool
+ // OnSetSendBufferSize is invoked when the send buffer size for an endpoint is
+ // changed. The handler is invoked with the new value for the socket send
+ // buffer size. It also returns the newly set value.
+ OnSetSendBufferSize(v int64) (newSz int64)
}
// DefaultSocketOptionsHandler is an embeddable type that implements no-op
@@ -95,14 +94,9 @@ func (*DefaultSocketOptionsHandler) HasNIC(int32) bool {
return false
}
-// GetSendBufferSize implements SocketOptionsHandler.GetSendBufferSize.
-func (*DefaultSocketOptionsHandler) GetSendBufferSize() (int64, Error) {
- return 0, nil
-}
-
-// IsUnixSocket implements SocketOptionsHandler.IsUnixSocket.
-func (*DefaultSocketOptionsHandler) IsUnixSocket() bool {
- return false
+// OnSetSendBufferSize implements SocketOptionsHandler.OnSetSendBufferSize.
+func (*DefaultSocketOptionsHandler) OnSetSendBufferSize(v int64) (newSz int64) {
+ return v
}
// StackHandler holds methods to access the stack options. These must be
@@ -600,42 +594,41 @@ func (so *SocketOptions) SetBindToDevice(bindToDevice int32) Error {
}
// GetSendBufferSize gets value for SO_SNDBUF option.
-func (so *SocketOptions) GetSendBufferSize() (int64, Error) {
- if so.handler.IsUnixSocket() {
- return so.handler.GetSendBufferSize()
- }
- return atomic.LoadInt64(&so.sendBufferSize), nil
+func (so *SocketOptions) GetSendBufferSize() int64 {
+ return atomic.LoadInt64(&so.sendBufferSize)
}
// SetSendBufferSize sets value for SO_SNDBUF option. notify indicates if the
// stack handler should be invoked to set the send buffer size.
func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) {
- if so.handler.IsUnixSocket() {
+ v := sendBufferSize
+
+ if !notify {
+ atomic.StoreInt64(&so.sendBufferSize, v)
return
}
- v := sendBufferSize
- if notify {
- // TODO(b/176170271): Notify waiters after size has grown.
- // Make sure the send buffer size is within the min and max
- // allowed.
- ss := so.getSendBufferLimits(so.stackHandler)
- min := int64(ss.Min)
- max := int64(ss.Max)
- // Validate the send buffer size with min and max values.
- // Multiply it by factor of 2.
- if v > max {
- v = max
- }
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ ss := so.getSendBufferLimits(so.stackHandler)
+ min := int64(ss.Min)
+ max := int64(ss.Max)
+ // Validate the send buffer size with min and max values.
+ // Multiply it by factor of 2.
+ if v > max {
+ v = max
+ }
- if v < math.MaxInt32/PacketOverheadFactor {
- v *= PacketOverheadFactor
- if v < min {
- v = min
- }
- } else {
- v = math.MaxInt32
+ if v < math.MaxInt32/PacketOverheadFactor {
+ v *= PacketOverheadFactor
+ if v < min {
+ v = min
}
+ } else {
+ v = math.MaxInt32
}
- atomic.StoreInt64(&so.sendBufferSize, v)
+
+ // Notify endpoint about change in buffer size.
+ newSz := so.handler.OnSetSendBufferSize(v)
+ atomic.StoreInt64(&so.sendBufferSize, newSz)
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 4e5a6089f..8c5be0586 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1698,11 +1698,7 @@ func (e *endpoint) OnCorkOptionSet(v bool) {
}
func (e *endpoint) getSendBufferSize() int {
- sndBufSize, err := e.ops.GetSendBufferSize()
- if err != nil {
- panic(fmt.Sprintf("e.ops.GetSendBufferSize() = %s", err))
- }
- return int(sndBufSize)
+ return int(e.ops.GetSendBufferSize())
}
// SetSockOptInt sets a socket option.
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index cd3c4a027..0128c1f7e 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -4393,12 +4393,7 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
t.Helper()
- s, err := ep.SocketOptions().GetSendBufferSize()
- if err != nil {
- t.Fatalf("GetSendBufferSize failed: %s", err)
- }
-
- if int(s) != v {
+ if s := ep.SocketOptions().GetSendBufferSize(); int(s) != v {
t.Fatalf("got send buffer size = %d, want = %d", s, v)
}
}