diff options
-rw-r--r-- | pkg/sentry/fs/host/socket.go | 16 | ||||
-rw-r--r-- | pkg/sentry/fsimpl/host/socket.go | 18 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 12 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/BUILD | 1 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/BUILD | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/connectioned.go | 30 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/connectioned_state.go | 5 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/connectionless.go | 13 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/connectionless_state.go | 20 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/queue.go | 9 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/unix.go | 77 | ||||
-rw-r--r-- | pkg/tcpip/socketops.go | 73 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 7 | ||||
-rw-r--r-- | test/syscalls/linux/BUILD | 3 | ||||
-rw-r--r-- | test/syscalls/linux/socket_generic.cc | 2 | ||||
-rw-r--r-- | test/syscalls/linux/socket_unix_dgram.cc | 35 | ||||
-rw-r--r-- | test/syscalls/linux/socket_unix_seqpacket.cc | 35 | ||||
-rw-r--r-- | test/syscalls/linux/socket_unix_stream.cc | 80 |
19 files changed, 340 insertions, 104 deletions
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 07b4fb70f..2b58fc52c 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -16,6 +16,7 @@ package host import ( "fmt" + "sync/atomic" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -206,7 +207,7 @@ func (c *ConnectedEndpoint) Send(ctx context.Context, data [][]byte, controlMess // only as much of the message as fits in the send buffer. truncate := c.stype == linux.SOCK_STREAM - n, totalLen, err := fdWriteVec(c.file.FD(), data, c.sndbuf, truncate) + n, totalLen, err := fdWriteVec(c.file.FD(), data, c.SendMaxQueueSize(), truncate) if n < totalLen && err == nil { // The host only returns a short write if it would otherwise // block (and only for stream sockets). @@ -282,7 +283,7 @@ func (c *ConnectedEndpoint) Recv(ctx context.Context, data [][]byte, creds bool, // N.B. Unix sockets don't have a receive buffer, the send buffer // serves both purposes. - rl, ml, cl, cTrunc, err := fdReadVec(c.file.FD(), data, []byte(cm), peek, c.sndbuf) + rl, ml, cl, cTrunc, err := fdReadVec(c.file.FD(), data, []byte(cm), peek, c.RecvMaxQueueSize()) if rl > 0 && err != nil { // We got some data, so all we need to do on error is return // the data that we got. Short reads are fine, no need to @@ -363,14 +364,14 @@ func (c *ConnectedEndpoint) RecvQueuedSize() int64 { // SendMaxQueueSize implements transport.Receiver.SendMaxQueueSize. func (c *ConnectedEndpoint) SendMaxQueueSize() int64 { - return int64(c.sndbuf) + return atomic.LoadInt64(&c.sndbuf) } // RecvMaxQueueSize implements transport.Receiver.RecvMaxQueueSize. func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 { // N.B. Unix sockets don't use the receive buffer. We'll claim it is // the same size as the send buffer. - return int64(c.sndbuf) + return atomic.LoadInt64(&c.sndbuf) } // Release implements transport.ConnectedEndpoint.Release and transport.Receiver.Release. @@ -381,4 +382,11 @@ func (c *ConnectedEndpoint) Release(ctx context.Context) { // CloseUnread implements transport.ConnectedEndpoint.CloseUnread. func (c *ConnectedEndpoint) CloseUnread() {} +// SetSendBufferSize implements transport.ConnectedEndpoint.SetSendBufferSize. +func (c *ConnectedEndpoint) SetSendBufferSize(v int64) (newSz int64) { + // gVisor does not permit setting of SO_SNDBUF for host backed unix domain + // sockets. + return atomic.LoadInt64(&c.sndbuf) +} + // LINT.ThenChange(../../fsimpl/host/socket.go) diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go index 72aa535f8..6763f5b0c 100644 --- a/pkg/sentry/fsimpl/host/socket.go +++ b/pkg/sentry/fsimpl/host/socket.go @@ -16,6 +16,7 @@ package host import ( "fmt" + "sync/atomic" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -111,7 +112,7 @@ func (c *ConnectedEndpoint) init() *syserr.Error { } c.stype = linux.SockType(stype) - c.sndbuf = int64(sndbuf) + atomic.StoreInt64(&c.sndbuf, int64(sndbuf)) return nil } @@ -150,7 +151,7 @@ func (c *ConnectedEndpoint) Send(ctx context.Context, data [][]byte, controlMess // only as much of the message as fits in the send buffer. truncate := c.stype == linux.SOCK_STREAM - n, totalLen, err := fdWriteVec(c.fd, data, c.sndbuf, truncate) + n, totalLen, err := fdWriteVec(c.fd, data, c.SendMaxQueueSize(), truncate) if n < totalLen && err == nil { // The host only returns a short write if it would otherwise // block (and only for stream sockets). @@ -226,7 +227,7 @@ func (c *ConnectedEndpoint) Recv(ctx context.Context, data [][]byte, creds bool, // N.B. Unix sockets don't have a receive buffer, the send buffer // serves both purposes. - rl, ml, cl, cTrunc, err := fdReadVec(c.fd, data, []byte(cm), peek, c.sndbuf) + rl, ml, cl, cTrunc, err := fdReadVec(c.fd, data, []byte(cm), peek, c.RecvMaxQueueSize()) if rl > 0 && err != nil { // We got some data, so all we need to do on error is return // the data that we got. Short reads are fine, no need to @@ -300,14 +301,14 @@ func (c *ConnectedEndpoint) RecvQueuedSize() int64 { // SendMaxQueueSize implements transport.Receiver.SendMaxQueueSize. func (c *ConnectedEndpoint) SendMaxQueueSize() int64 { - return int64(c.sndbuf) + return atomic.LoadInt64(&c.sndbuf) } // RecvMaxQueueSize implements transport.Receiver.RecvMaxQueueSize. func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 { // N.B. Unix sockets don't use the receive buffer. We'll claim it is // the same size as the send buffer. - return int64(c.sndbuf) + return atomic.LoadInt64(&c.sndbuf) } func (c *ConnectedEndpoint) destroyLocked() { @@ -327,6 +328,13 @@ func (c *ConnectedEndpoint) Release(ctx context.Context) { // CloseUnread implements transport.ConnectedEndpoint.CloseUnread. func (c *ConnectedEndpoint) CloseUnread() {} +// SetSendBufferSize implements transport.ConnectedEndpoint.SetSendBufferSize. +func (c *ConnectedEndpoint) SetSendBufferSize(v int64) (newSz int64) { + // gVisor does not permit setting of SO_SNDBUF for host backed unix domain + // sockets. + return atomic.LoadInt64(&c.sndbuf) +} + // SCMConnectedEndpoint represents an endpoint backed by a host fd that was // passed through a gofer Unix socket. It resembles ConnectedEndpoint, with the // following differences: diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 69693f263..cee8120ab 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -855,10 +855,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - size, err := ep.SocketOptions().GetSendBufferSize() - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } + size := ep.SocketOptions().GetSendBufferSize() if size > math.MaxInt32 { size = math.MaxInt32 @@ -1647,13 +1644,6 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam return syserr.ErrInvalidArgument } - family, _, _ := s.Type() - // TODO(gvisor.dev/issue/5132): We currently do not support - // setting this option for unix sockets. - if family == linux.AF_UNIX { - return nil - } - v := usermem.ByteOrder.Uint32(optVal) ep.SocketOptions().SetSendBufferSize(int64(v), true) return nil diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index cce0acc33..acf2ab8e7 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -51,6 +51,7 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/sockfs", + "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/kernel/time", "//pkg/sentry/socket", diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 3ebbd28b0..0d11bb251 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -32,6 +32,7 @@ go_library( "connectioned.go", "connectioned_state.go", "connectionless.go", + "connectionless_state.go", "queue.go", "queue_refs.go", "transport_message_list.go", @@ -45,6 +46,7 @@ go_library( "//pkg/log", "//pkg/refs", "//pkg/refsvfs2", + "//pkg/sentry/inet", "//pkg/sync", "//pkg/syserr", "//pkg/tcpip", diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index fc5b823b0..809c95429 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -128,7 +128,9 @@ func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProv idGenerator: uid, stype: stype, } - ep.ops.InitHandler(ep, nil, nil) + + ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */) + ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits) return ep } @@ -137,9 +139,9 @@ func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (E a := newConnectioned(ctx, stype, uid) b := newConnectioned(ctx, stype, uid) - q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit} + q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: defaultBufferSize} q1.InitRefs() - q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: initialLimit} + q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: defaultBufferSize} q2.InitRefs() if stype == linux.SOCK_STREAM { @@ -173,7 +175,8 @@ func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider idGenerator: uid, stype: stype, } - ep.ops.InitHandler(ep, nil, nil) + ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits) + ep.ops.SetSendBufferSize(connected.SendMaxQueueSize(), false /* notify */) return ep } @@ -296,16 +299,18 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn idGenerator: e.idGenerator, stype: e.stype, } - ne.ops.InitHandler(ne, nil, nil) + ne.ops.InitHandler(ne, &stackHandler{}, getSendBufferLimits) + ne.ops.SetSendBufferSize(defaultBufferSize, false /* notify */) - readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit} + readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: defaultBufferSize} readQueue.InitRefs() ne.connected = &connectedEndpoint{ endpoint: ce, writeQueue: readQueue, } - writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: initialLimit} + // Make sure the accepted endpoint inherits this listening socket's SO_SNDBUF. + writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: e.ops.GetSendBufferSize()} writeQueue.InitRefs() if e.stype == linux.SOCK_STREAM { ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}} @@ -357,6 +362,9 @@ func (e *connectionedEndpoint) Connect(ctx context.Context, server BoundEndpoint returnConnect := func(r Receiver, ce ConnectedEndpoint) { e.receiver = r e.connected = ce + // Make sure the newly created connected endpoint's write queue is updated + // to reflect this endpoint's send buffer size. + e.connected.SetSendBufferSize(e.ops.GetSendBufferSize()) } return server.BidirectionalConnect(ctx, e, returnConnect) @@ -495,3 +503,11 @@ func (e *connectionedEndpoint) State() uint32 { } return linux.SS_UNCONNECTED } + +// OnSetSendBufferSize implements tcpip.SocketOptionsHandler.OnSetSendBufferSize. +func (e *connectionedEndpoint) OnSetSendBufferSize(v int64) (newSz int64) { + if e.Connected() { + return e.baseEndpoint.connected.SetSendBufferSize(v) + } + return v +} diff --git a/pkg/sentry/socket/unix/transport/connectioned_state.go b/pkg/sentry/socket/unix/transport/connectioned_state.go index 7e02a5db8..590b0bd01 100644 --- a/pkg/sentry/socket/unix/transport/connectioned_state.go +++ b/pkg/sentry/socket/unix/transport/connectioned_state.go @@ -51,3 +51,8 @@ func (e *connectionedEndpoint) loadAcceptedChan(acceptedSlice []*connectionedEnd } } } + +// afterLoad is invoked by stateify. +func (e *connectionedEndpoint) afterLoad() { + e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits) +} diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 20fa8b874..0be78480c 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -41,10 +41,11 @@ var ( // NewConnectionless creates a new unbound dgram endpoint. func NewConnectionless(ctx context.Context) Endpoint { ep := &connectionlessEndpoint{baseEndpoint{Queue: &waiter.Queue{}}} - q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit} + q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: defaultBufferSize} q.InitRefs() ep.receiver = &queueReceiver{readQueue: &q} - ep.ops.InitHandler(ep, nil, nil) + ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */) + ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits) return ep } @@ -217,3 +218,11 @@ func (e *connectionlessEndpoint) State() uint32 { return linux.SS_DISCONNECTING } } + +// OnSetSendBufferSize implements tcpip.SocketOptionsHandler.OnSetSendBufferSize. +func (e *connectionlessEndpoint) OnSetSendBufferSize(v int64) (newSz int64) { + if e.Connected() { + return e.baseEndpoint.connected.SetSendBufferSize(v) + } + return v +} diff --git a/pkg/sentry/socket/unix/transport/connectionless_state.go b/pkg/sentry/socket/unix/transport/connectionless_state.go new file mode 100644 index 000000000..2ef337ec8 --- /dev/null +++ b/pkg/sentry/socket/unix/transport/connectionless_state.go @@ -0,0 +1,20 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transport + +// afterLoad is invoked by stateify. +func (e *connectionlessEndpoint) afterLoad() { + e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits) +} diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go index 342def28f..698a9a82c 100644 --- a/pkg/sentry/socket/unix/transport/queue.go +++ b/pkg/sentry/socket/unix/transport/queue.go @@ -237,9 +237,18 @@ func (q *queue) QueuedSize() int64 { // MaxQueueSize returns the maximum number of bytes storable in the queue. func (q *queue) MaxQueueSize() int64 { + q.mu.Lock() + defer q.mu.Unlock() return q.limit } +// SetMaxQueueSize sets the maximum number of bytes storable in the queue. +func (q *queue) SetMaxQueueSize(v int64) { + q.mu.Lock() + defer q.mu.Unlock() + q.limit = v +} + // CloseUnread sets flag to indicate that the peer is closed (not shutdown) // with unread data. So if read on this queue shall return ECONNRESET error. func (q *queue) CloseUnread() { diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 70227bbd2..ceada54a8 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -26,8 +26,16 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// initialLimit is the starting limit for the socket buffers. -const initialLimit = 16 * 1024 +const ( + // The minimum size of the send/receive buffers. + minimumBufferSize = 4 << 10 // 4 KiB (match default in linux) + + // The default size of the send/receive buffers. + defaultBufferSize = 208 << 10 // 208 KiB (default in linux for net.core.wmem_default) + + // The maximum permitted size for the send/receive buffers. + maxBufferSize = 4 << 20 // 4 MiB 4 MiB (default in linux for net.core.wmem_max) +) // A RightsControlMessage is a control message containing FDs. // @@ -627,6 +635,10 @@ type ConnectedEndpoint interface { // CloseUnread sets the fact that this end is closed with unread data to // the peer socket. CloseUnread() + + // SetSendBufferSize is called when the endpoint's send buffer size is + // changed. + SetSendBufferSize(v int64) (newSz int64) } // +stateify savable @@ -722,6 +734,14 @@ func (e *connectedEndpoint) CloseUnread() { e.writeQueue.CloseUnread() } +// SetSendBufferSize implements ConnectedEndpoint.SetSendBufferSize. +// SetSendBufferSize sets the send buffer size for the write queue to the +// specified value. +func (e *connectedEndpoint) SetSendBufferSize(v int64) (newSz int64) { + e.writeQueue.SetMaxQueueSize(v) + return v +} + // baseEndpoint is an embeddable unix endpoint base used in both the connected and connectionless // unix domain socket Endpoint implementations. // @@ -849,27 +869,6 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { return nil } -// IsUnixSocket implements tcpip.SocketOptionsHandler.IsUnixSocket. -func (e *baseEndpoint) IsUnixSocket() bool { - return true -} - -// GetSendBufferSize implements tcpip.SocketOptionsHandler.GetSendBufferSize. -func (e *baseEndpoint) GetSendBufferSize() (int64, tcpip.Error) { - e.Lock() - defer e.Unlock() - - if !e.Connected() { - return -1, &tcpip.ErrNotConnected{} - } - - v := e.connected.SendMaxQueueSize() - if v < 0 { - return -1, &tcpip.ErrQueueSizeNotSupported{} - } - return v, nil -} - func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: @@ -987,3 +986,35 @@ func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { func (*baseEndpoint) Release(context.Context) { // Binding a baseEndpoint doesn't take a reference. } + +// stackHandler is just a stub implementation of tcpip.StackHandler to provide +// when initializing socketoptions. +type stackHandler struct { +} + +// Option implements tcpip.StackHandler. +func (h *stackHandler) Option(option interface{}) tcpip.Error { + panic("unimplemented") +} + +// TransportProtocolOption implements tcpip.StackHandler. +func (h *stackHandler) TransportProtocolOption(proto tcpip.TransportProtocolNumber, option tcpip.GettableTransportProtocolOption) tcpip.Error { + panic("unimplemented") +} + +// getSendBufferLimits implements tcpip.GetSendBufferLimits. +// +// AF_UNIX sockets buffer sizes are not tied to the networking stack/namespace +// in linux but are bound by net.core.(wmem|rmem)_(max|default). +// +// In gVisor net.core sysctls today are not exposed or if exposed are currently +// tied to the networking stack in use. This makes it complicated for AF_UNIX +// when we are in a new namespace w/ no networking stack. As a result for now we +// define default/max values here in the unix socket implementation itself. +func getSendBufferLimits(tcpip.StackHandler) tcpip.SendBufferSizeOption { + return tcpip.SendBufferSizeOption{ + Min: minimumBufferSize, + Default: defaultBufferSize, + Max: maxBufferSize, + } +} diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index 1e00144a5..dc37e61a4 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -54,11 +54,10 @@ type SocketOptionsHandler interface { // HasNIC is invoked to check if the NIC is valid for SO_BINDTODEVICE. HasNIC(v int32) bool - // GetSendBufferSize is invoked to get the SO_SNDBUFSIZE. - GetSendBufferSize() (int64, Error) - - // IsUnixSocket is invoked to check if the socket is of unix domain. - IsUnixSocket() bool + // OnSetSendBufferSize is invoked when the send buffer size for an endpoint is + // changed. The handler is invoked with the new value for the socket send + // buffer size. It also returns the newly set value. + OnSetSendBufferSize(v int64) (newSz int64) } // DefaultSocketOptionsHandler is an embeddable type that implements no-op @@ -95,14 +94,9 @@ func (*DefaultSocketOptionsHandler) HasNIC(int32) bool { return false } -// GetSendBufferSize implements SocketOptionsHandler.GetSendBufferSize. -func (*DefaultSocketOptionsHandler) GetSendBufferSize() (int64, Error) { - return 0, nil -} - -// IsUnixSocket implements SocketOptionsHandler.IsUnixSocket. -func (*DefaultSocketOptionsHandler) IsUnixSocket() bool { - return false +// OnSetSendBufferSize implements SocketOptionsHandler.OnSetSendBufferSize. +func (*DefaultSocketOptionsHandler) OnSetSendBufferSize(v int64) (newSz int64) { + return v } // StackHandler holds methods to access the stack options. These must be @@ -600,42 +594,41 @@ func (so *SocketOptions) SetBindToDevice(bindToDevice int32) Error { } // GetSendBufferSize gets value for SO_SNDBUF option. -func (so *SocketOptions) GetSendBufferSize() (int64, Error) { - if so.handler.IsUnixSocket() { - return so.handler.GetSendBufferSize() - } - return atomic.LoadInt64(&so.sendBufferSize), nil +func (so *SocketOptions) GetSendBufferSize() int64 { + return atomic.LoadInt64(&so.sendBufferSize) } // SetSendBufferSize sets value for SO_SNDBUF option. notify indicates if the // stack handler should be invoked to set the send buffer size. func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) { - if so.handler.IsUnixSocket() { + v := sendBufferSize + + if !notify { + atomic.StoreInt64(&so.sendBufferSize, v) return } - v := sendBufferSize - if notify { - // TODO(b/176170271): Notify waiters after size has grown. - // Make sure the send buffer size is within the min and max - // allowed. - ss := so.getSendBufferLimits(so.stackHandler) - min := int64(ss.Min) - max := int64(ss.Max) - // Validate the send buffer size with min and max values. - // Multiply it by factor of 2. - if v > max { - v = max - } + // Make sure the send buffer size is within the min and max + // allowed. + ss := so.getSendBufferLimits(so.stackHandler) + min := int64(ss.Min) + max := int64(ss.Max) + // Validate the send buffer size with min and max values. + // Multiply it by factor of 2. + if v > max { + v = max + } - if v < math.MaxInt32/PacketOverheadFactor { - v *= PacketOverheadFactor - if v < min { - v = min - } - } else { - v = math.MaxInt32 + if v < math.MaxInt32/PacketOverheadFactor { + v *= PacketOverheadFactor + if v < min { + v = min } + } else { + v = math.MaxInt32 } - atomic.StoreInt64(&so.sendBufferSize, v) + + // Notify endpoint about change in buffer size. + newSz := so.handler.OnSetSendBufferSize(v) + atomic.StoreInt64(&so.sendBufferSize, newSz) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 4e5a6089f..8c5be0586 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1698,11 +1698,7 @@ func (e *endpoint) OnCorkOptionSet(v bool) { } func (e *endpoint) getSendBufferSize() int { - sndBufSize, err := e.ops.GetSendBufferSize() - if err != nil { - panic(fmt.Sprintf("e.ops.GetSendBufferSize() = %s", err)) - } - return int(sndBufSize) + return int(e.ops.GetSendBufferSize()) } // SetSockOptInt sets a socket option. diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index cd3c4a027..0128c1f7e 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -4393,12 +4393,7 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { t.Helper() - s, err := ep.SocketOptions().GetSendBufferSize() - if err != nil { - t.Fatalf("GetSendBufferSize failed: %s", err) - } - - if int(s) != v { + if s := ep.SocketOptions().GetSendBufferSize(); int(s) != v { t.Fatalf("got send buffer size = %d, want = %d", s, v) } } diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 499027e7c..42fc363a2 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -2346,6 +2346,7 @@ cc_library( deps = [ ":socket_test_util", ":unix_domain_socket_test_util", + "@com_google_absl//absl/time", gtest, "//test/util:test_util", ], @@ -2360,6 +2361,7 @@ cc_library( deps = [ ":socket_test_util", ":unix_domain_socket_test_util", + "@com_google_absl//absl/time", gtest, "//test/util:test_util", ], @@ -2678,6 +2680,7 @@ cc_binary( deps = [ ":socket_test_util", ":unix_domain_socket_test_util", + "@com_google_absl//absl/time", gtest, "//test/util:test_main", "//test/util:test_util", diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc index de0b8bb11..f70047a09 100644 --- a/test/syscalls/linux/socket_generic.cc +++ b/test/syscalls/linux/socket_generic.cc @@ -379,7 +379,7 @@ TEST_P(AllSocketPairTest, RcvBufSucceeds) { EXPECT_GT(size, 0); } -TEST_P(AllSocketPairTest, SndBufSucceeds) { +TEST_P(AllSocketPairTest, GetSndBufSucceeds) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); int size = 0; socklen_t size_size = sizeof(size); diff --git a/test/syscalls/linux/socket_unix_dgram.cc b/test/syscalls/linux/socket_unix_dgram.cc index af0df4fb4..5b0844493 100644 --- a/test/syscalls/linux/socket_unix_dgram.cc +++ b/test/syscalls/linux/socket_unix_dgram.cc @@ -18,6 +18,8 @@ #include <sys/un.h> #include "gtest/gtest.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" @@ -39,6 +41,39 @@ TEST_P(DgramUnixSocketPairTest, WriteOneSideClosed) { SyscallFailsWithErrno(ECONNREFUSED)); } +TEST_P(DgramUnixSocketPairTest, IncreasedSocketSendBufUnblocksWrites) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + int sock = sockets->first_fd(); + int buf_size = 0; + socklen_t buf_size_len = sizeof(buf_size); + ASSERT_THAT(getsockopt(sock, SOL_SOCKET, SO_SNDBUF, &buf_size, &buf_size_len), + SyscallSucceeds()); + int opts; + ASSERT_THAT(opts = fcntl(sock, F_GETFL), SyscallSucceeds()); + opts |= O_NONBLOCK; + ASSERT_THAT(fcntl(sock, F_SETFL, opts), SyscallSucceeds()); + + std::vector<char> buf(buf_size / 4); + // Write till the socket buffer is full. + while (RetryEINTR(send)(sock, buf.data(), buf.size(), 0) != -1) { + // Sleep to give linux a chance to move data from the send buffer to the + // receive buffer. + absl::SleepFor(absl::Milliseconds(10)); // 10ms. + } + // The last error should have been EWOULDBLOCK. + ASSERT_EQ(errno, EWOULDBLOCK); + + // Now increase the socket send buffer. + buf_size = buf_size * 2; + ASSERT_THAT( + setsockopt(sock, SOL_SOCKET, SO_SNDBUF, &buf_size, sizeof(buf_size)), + SyscallSucceeds()); + + // The send should succeed again. + ASSERT_THAT(RetryEINTR(send)(sock, buf.data(), buf.size(), 0), + SyscallSucceeds()); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_unix_seqpacket.cc b/test/syscalls/linux/socket_unix_seqpacket.cc index 6d03df4d9..eb373373d 100644 --- a/test/syscalls/linux/socket_unix_seqpacket.cc +++ b/test/syscalls/linux/socket_unix_seqpacket.cc @@ -18,6 +18,8 @@ #include <sys/un.h> #include "gtest/gtest.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" @@ -61,6 +63,39 @@ TEST_P(SeqpacketUnixSocketPairTest, Sendto) { SyscallSucceedsWithValue(3)); } +TEST_P(SeqpacketUnixSocketPairTest, IncreasedSocketSendBufUnblocksWrites) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + int sock = sockets->first_fd(); + int buf_size = 0; + socklen_t buf_size_len = sizeof(buf_size); + ASSERT_THAT(getsockopt(sock, SOL_SOCKET, SO_SNDBUF, &buf_size, &buf_size_len), + SyscallSucceeds()); + int opts; + ASSERT_THAT(opts = fcntl(sock, F_GETFL), SyscallSucceeds()); + opts |= O_NONBLOCK; + ASSERT_THAT(fcntl(sock, F_SETFL, opts), SyscallSucceeds()); + + std::vector<char> buf(buf_size / 4); + // Write till the socket buffer is full. + while (RetryEINTR(send)(sock, buf.data(), buf.size(), 0) != -1) { + // Sleep to give linux a chance to move data from the send buffer to the + // receive buffer. + absl::SleepFor(absl::Milliseconds(10)); // 10ms. + } + // The last error should have been EWOULDBLOCK. + ASSERT_EQ(errno, EWOULDBLOCK); + + // Now increase the socket send buffer. + buf_size = buf_size * 2; + ASSERT_THAT( + setsockopt(sock, SOL_SOCKET, SO_SNDBUF, &buf_size, sizeof(buf_size)), + SyscallSucceeds()); + + // The send should succeed again. + ASSERT_THAT(RetryEINTR(send)(sock, buf.data(), buf.size(), 0), + SyscallSucceeds()); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_unix_stream.cc b/test/syscalls/linux/socket_unix_stream.cc index ad9c4bf37..3ff810914 100644 --- a/test/syscalls/linux/socket_unix_stream.cc +++ b/test/syscalls/linux/socket_unix_stream.cc @@ -17,6 +17,8 @@ #include <sys/un.h> #include "gtest/gtest.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" @@ -134,6 +136,84 @@ TEST_P(StreamUnixSocketPairTest, GetSocketAcceptConn) { EXPECT_EQ(got, 0); } +TEST_P(StreamUnixSocketPairTest, SetSocketSendBuf) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + auto s = sockets->first_fd(); + int max = 0; + int min = 0; + { + // Discover maxmimum buffer size by setting to a really large value. + constexpr int kRcvBufSz = INT_MAX; + ASSERT_THAT( + setsockopt(s, SOL_SOCKET, SO_SNDBUF, &kRcvBufSz, sizeof(kRcvBufSz)), + SyscallSucceeds()); + + max = 0; + socklen_t max_len = sizeof(max); + ASSERT_THAT(getsockopt(s, SOL_SOCKET, SO_SNDBUF, &max, &max_len), + SyscallSucceeds()); + } + + { + // Discover minimum buffer size by setting it to zero. + constexpr int kRcvBufSz = 0; + ASSERT_THAT( + setsockopt(s, SOL_SOCKET, SO_SNDBUF, &kRcvBufSz, sizeof(kRcvBufSz)), + SyscallSucceeds()); + + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s, SOL_SOCKET, SO_SNDBUF, &min, &min_len), + SyscallSucceeds()); + } + + int quarter_sz = min + (max - min) / 4; + ASSERT_THAT( + setsockopt(s, SOL_SOCKET, SO_SNDBUF, &quarter_sz, sizeof(quarter_sz)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s, SOL_SOCKET, SO_SNDBUF, &val, &val_len), + SyscallSucceeds()); + + // Linux doubles the value set by SO_SNDBUF/SO_SNDBUF. + quarter_sz *= 2; + ASSERT_EQ(quarter_sz, val); +} + +TEST_P(StreamUnixSocketPairTest, IncreasedSocketSendBufUnblocksWrites) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + int sock = sockets->first_fd(); + int buf_size = 0; + socklen_t buf_size_len = sizeof(buf_size); + ASSERT_THAT(getsockopt(sock, SOL_SOCKET, SO_SNDBUF, &buf_size, &buf_size_len), + SyscallSucceeds()); + int opts; + ASSERT_THAT(opts = fcntl(sock, F_GETFL), SyscallSucceeds()); + opts |= O_NONBLOCK; + ASSERT_THAT(fcntl(sock, F_SETFL, opts), SyscallSucceeds()); + + std::vector<char> buf(buf_size / 4); + // Write till the socket buffer is full. + while (RetryEINTR(send)(sock, buf.data(), buf.size(), 0) != -1) { + // Sleep to give linux a chance to move data from the send buffer to the + // receive buffer. + absl::SleepFor(absl::Milliseconds(10)); // 10ms. + } + // The last error should have been EWOULDBLOCK. + ASSERT_EQ(errno, EWOULDBLOCK); + + // Now increase the socket send buffer. + buf_size = buf_size * 2; + ASSERT_THAT( + setsockopt(sock, SOL_SOCKET, SO_SNDBUF, &buf_size, sizeof(buf_size)), + SyscallSucceeds()); + + // The send should succeed again. + ASSERT_THAT(RetryEINTR(send)(sock, buf.data(), buf.size(), 0), + SyscallSucceeds()); +} + INSTANTIATE_TEST_SUITE_P( AllUnixDomainSockets, StreamUnixSocketPairTest, ::testing::ValuesIn(IncludeReversals(VecCat<SocketPairKind>( |