diff options
author | Nayana Bidari <nybidari@google.com> | 2021-01-26 08:23:49 -0800 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-01-26 08:25:34 -0800 |
commit | daf0d3f6ca3aad6f3f9ab4d762546c6dee78fa57 (patch) | |
tree | 84ebfd054c94040e516bb48703e507660b742841 | |
parent | 3946075403a93907138f13e61bdba075aeabfecf (diff) |
Move SO_SNDBUF to socketops.
This CL moves {S,G}etsockopt of SO_SNDBUF from all endpoints to socketops. For
unix sockets, we do not support setting of this option.
PiperOrigin-RevId: 353871484
24 files changed, 246 insertions, 235 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 3f34668cf..b4d0651b8 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -846,7 +846,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - size, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) + size, err := ep.SocketOptions().GetSendBufferSize() if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1615,8 +1615,21 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam return syserr.ErrInvalidArgument } + family, skType, skProto := s.Type() + // TODO(gvisor.dev/issue/5132): We currently do not support + // setting this option for unix sockets. + if family == linux.AF_UNIX { + return nil + } + + getBufferLimits := tcpip.GetStackSendBufferLimits + if isTCPSocket(skType, skProto) { + getBufferLimits = tcp.GetTCPSendBufferLimits + } + v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.SendBufferSizeOption, int(v))) + ep.SocketOptions().SetSendBufferSize(int64(v), true, getBufferLimits) + return nil case linux.SO_RCVBUF: if len(optVal) < sizeOfInt32 { diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 9f7aca305..dbb7f7c31 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -128,7 +128,7 @@ func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProv idGenerator: uid, stype: stype, } - ep.ops.InitHandler(ep) + ep.ops.InitHandler(ep, nil) return ep } @@ -173,7 +173,7 @@ func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider idGenerator: uid, stype: stype, } - ep.ops.InitHandler(ep) + ep.ops.InitHandler(ep, nil) return ep } @@ -296,7 +296,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn idGenerator: e.idGenerator, stype: e.stype, } - ne.ops.InitHandler(ne) + ne.ops.InitHandler(ne, nil) readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit} readQueue.InitRefs() diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 0813ad87d..895d2322e 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -44,7 +44,7 @@ func NewConnectionless(ctx context.Context) Endpoint { q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit} q.InitRefs() ep.receiver = &queueReceiver{readQueue: &q} - ep.ops.InitHandler(ep) + ep.ops.InitHandler(ep, nil) return ep } diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 099a56281..0e3889c6d 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -842,7 +842,6 @@ func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { switch opt { - case tcpip.SendBufferSizeOption: case tcpip.ReceiveBufferSizeOption: default: log.Warningf("Unsupported socket option: %d", opt) @@ -850,6 +849,27 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } +// IsUnixSocket implements tcpip.SocketOptionsHandler.IsUnixSocket. +func (e *baseEndpoint) IsUnixSocket() bool { + return true +} + +// GetSendBufferSize implements tcpip.SocketOptionsHandler.GetSendBufferSize. +func (e *baseEndpoint) GetSendBufferSize() (int64, *tcpip.Error) { + e.Lock() + defer e.Unlock() + + if !e.Connected() { + return -1, tcpip.ErrNotConnected + } + + v := e.connected.SendMaxQueueSize() + if v < 0 { + return -1, tcpip.ErrQueueSizeNotSupported + } + return v, nil +} + func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: @@ -879,19 +899,6 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } return int(v), nil - case tcpip.SendBufferSizeOption: - e.Lock() - if !e.Connected() { - e.Unlock() - return -1, tcpip.ErrNotConnected - } - v := e.connected.SendMaxQueueSize() - e.Unlock() - if v < 0 { - return -1, tcpip.ErrQueueSizeNotSupported - } - return int(v), nil - case tcpip.ReceiveBufferSizeOption: e.Lock() if e.receiver == nil { diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index f3ad40fdf..4149b91da 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -15,11 +15,16 @@ package tcpip import ( + "math" "sync/atomic" "gvisor.dev/gvisor/pkg/sync" ) +// PacketOverheadFactor is used to multiply the value provided by the user on a +// SetSockOpt for setting the send/receive buffer sizes sockets. +const PacketOverheadFactor = 2 + // SocketOptionsHandler holds methods that help define endpoint specific // behavior for socket level socket options. These must be implemented by // endpoints to get notified when socket level options are set. @@ -48,6 +53,12 @@ type SocketOptionsHandler interface { // HasNIC is invoked to check if the NIC is valid for SO_BINDTODEVICE. HasNIC(v int32) bool + + // GetSendBufferSize is invoked to get the SO_SNDBUFSIZE. + GetSendBufferSize() (int64, *Error) + + // IsUnixSocket is invoked to check if the socket is of unix domain. + IsUnixSocket() bool } // DefaultSocketOptionsHandler is an embeddable type that implements no-op @@ -84,6 +95,27 @@ func (*DefaultSocketOptionsHandler) HasNIC(int32) bool { return false } +// GetSendBufferSize implements SocketOptionsHandler.GetSendBufferSize. +func (*DefaultSocketOptionsHandler) GetSendBufferSize() (int64, *Error) { + return 0, nil +} + +// IsUnixSocket implements SocketOptionsHandler.IsUnixSocket. +func (*DefaultSocketOptionsHandler) IsUnixSocket() bool { + return false +} + +// StackHandler holds methods to access the stack options. These must be +// implemented by the stack. +type StackHandler interface { + // Option allows retrieving stack wide options. + Option(option interface{}) *Error + + // TransportProtocolOption allows retrieving individual protocol level + // option values. + TransportProtocolOption(proto TransportProtocolNumber, option GettableTransportProtocolOption) *Error +} + // SocketOptions contains all the variables which store values for SOL_SOCKET, // SOL_IP, SOL_IPV6 and SOL_TCP level options. // @@ -91,6 +123,9 @@ func (*DefaultSocketOptionsHandler) HasNIC(int32) bool { type SocketOptions struct { handler SocketOptionsHandler + // StackHandler is initialized at the creation time and will not change. + stackHandler StackHandler `state:"manual"` + // These fields are accessed and modified using atomic operations. // broadcastEnabled determines whether datagram sockets are allowed to @@ -170,6 +205,9 @@ type SocketOptions struct { // bindToDevice determines the device to which the socket is bound. bindToDevice int32 + // sendBufferSize determines the send buffer size for this socket. + sendBufferSize int64 + // mu protects the access to the below fields. mu sync.Mutex `state:"nosave"` @@ -180,8 +218,9 @@ type SocketOptions struct { // InitHandler initializes the handler. This must be called before using the // socket options utility. -func (so *SocketOptions) InitHandler(handler SocketOptionsHandler) { +func (so *SocketOptions) InitHandler(handler SocketOptionsHandler, stack StackHandler) { so.handler = handler + so.stackHandler = stack } func storeAtomicBool(addr *uint32, v bool) { @@ -518,3 +557,41 @@ func (so *SocketOptions) SetBindToDevice(bindToDevice int32) *Error { atomic.StoreInt32(&so.bindToDevice, bindToDevice) return nil } + +// GetSendBufferSize gets value for SO_SNDBUF option. +func (so *SocketOptions) GetSendBufferSize() (int64, *Error) { + if so.handler.IsUnixSocket() { + return so.handler.GetSendBufferSize() + } + return atomic.LoadInt64(&so.sendBufferSize), nil +} + +// SetSendBufferSize sets value for SO_SNDBUF option. notify indicates if the +// stack handler should be invoked to set the send buffer size. +func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool, getBufferLimits GetSendBufferLimits) { + v := sendBufferSize + ss := getBufferLimits(so.stackHandler) + + if notify { + // TODO(b/176170271): Notify waiters after size has grown. + // Make sure the send buffer size is within the min and max + // allowed. + min := int64(ss.Min) + max := int64(ss.Max) + // Validate the send buffer size with min and max values. + // Multiply it by factor of 2. + if v > max { + v = max + } + + if v < math.MaxInt32/PacketOverheadFactor { + v *= PacketOverheadFactor + if v < min { + v = min + } + } else { + v = math.MaxInt32 + } + } + atomic.StoreInt64(&so.sendBufferSize, v) +} diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index ee05c6013..fc7b9ea56 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -444,7 +444,7 @@ type Stack struct { // sendBufferSize holds the min/default/max send buffer sizes for // endpoints other than TCP. - sendBufferSize SendBufferSizeOption + sendBufferSize tcpip.SendBufferSizeOption // receiveBufferSize holds the min/default/max receive buffer sizes for // endpoints other than TCP. @@ -646,7 +646,7 @@ func New(opts Options) *Stack { uniqueIDGenerator: opts.UniqueID, nudDisp: opts.NUDDisp, randomGenerator: mathrand.New(randSrc), - sendBufferSize: SendBufferSizeOption{ + sendBufferSize: tcpip.SendBufferSizeOption{ Min: MinBufferSize, Default: DefaultBufferSize, Max: DefaultMaxBufferSize, diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go index 0b093e6c5..92e70f94e 100644 --- a/pkg/tcpip/stack/stack_options.go +++ b/pkg/tcpip/stack/stack_options.go @@ -14,7 +14,9 @@ package stack -import "gvisor.dev/gvisor/pkg/tcpip" +import ( + "gvisor.dev/gvisor/pkg/tcpip" +) const ( // MinBufferSize is the smallest size of a receive or send buffer. @@ -29,14 +31,6 @@ const ( DefaultMaxBufferSize = 4 << 20 // 4 MiB ) -// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to -// get/set the default, min and max send buffer sizes. -type SendBufferSizeOption struct { - Min int - Default int - Max int -} - // ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to // get/set the default, min and max receive buffer sizes. type ReceiveBufferSizeOption struct { @@ -48,7 +42,7 @@ type ReceiveBufferSizeOption struct { // SetOption allows setting stack wide options. func (s *Stack) SetOption(option interface{}) *tcpip.Error { switch v := option.(type) { - case SendBufferSizeOption: + case tcpip.SendBufferSizeOption: // Make sure we don't allow lowering the buffer below minimum // required for stack to work. if v.Min < MinBufferSize { @@ -88,7 +82,7 @@ func (s *Stack) SetOption(option interface{}) *tcpip.Error { // Option allows retrieving stack wide options. func (s *Stack) Option(option interface{}) *tcpip.Error { switch v := option.(type) { - case *SendBufferSizeOption: + case *tcpip.SendBufferSizeOption: s.mu.RLock() *v = s.sendBufferSize s.mu.RUnlock() diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index b94568a8e..0f02f1d53 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -3336,21 +3336,21 @@ func TestStackSendBufferSizeOption(t *testing.T) { const sMin = stack.MinBufferSize testCases := []struct { name string - ss stack.SendBufferSizeOption + ss tcpip.SendBufferSizeOption err *tcpip.Error }{ // Invalid configurations. - {"min_below_zero", stack.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"min_zero", stack.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"default_below_min", stack.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, - {"default_above_max", stack.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"max_below_min", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + {"min_below_zero", tcpip.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"min_zero", tcpip.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"default_below_min", tcpip.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + {"default_above_max", tcpip.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"max_below_min", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, // Valid Configurations - {"in_ascending_order", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, - {"all_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, - {"min_default_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, - {"default_max_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, + {"in_ascending_order", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, + {"all_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, + {"min_default_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, + {"default_max_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -3359,7 +3359,7 @@ func TestStackSendBufferSizeOption(t *testing.T) { if err := s.SetOption(tc.ss); err != tc.err { t.Fatalf("s.SetOption(%+v) = %v, want: %v", tc.ss, err, tc.err) } - var ss stack.SendBufferSizeOption + var ss tcpip.SendBufferSizeOption if tc.err == nil { if err := s.Option(&ss); err != nil { t.Fatalf("s.Option(%+v) = %v, want: nil", ss, err) diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index dbf8b4db1..f1ac8b777 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -68,9 +68,9 @@ func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions { return &f.ops } -func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { - ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} - ep.ops.InitHandler(ep) +func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, s *stack.Stack) tcpip.Endpoint { + ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: s.UniqueID()} + ep.ops.InitHandler(ep, s) return ep } @@ -234,7 +234,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt * peerAddr: route.RemoteAddress, route: route, } - ep.ops.InitHandler(ep) + ep.ops.InitHandler(ep, f.proto.stack) f.acceptQueue = append(f.acceptQueue, ep) } @@ -282,7 +282,7 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { } func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newFakeTransportEndpoint(f, netProto, f.stack.UniqueID()), nil + return newFakeTransportEndpoint(f, netProto, f.stack), nil } func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 32e4f02ca..4e2c4f906 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -826,10 +826,6 @@ const ( // number of unread bytes in the input buffer should be returned. ReceiveQueueSizeOption - // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to - // specify the send buffer size option. - SendBufferSizeOption - // ReceiveBufferSizeOption is used by SetSockOptInt/GetSockOptInt to // specify the receive buffer size option. ReceiveBufferSizeOption @@ -1234,6 +1230,31 @@ type IPPacketInfo struct { DestinationAddr Address } +// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to +// get/set the default, min and max send buffer sizes. +type SendBufferSizeOption struct { + // Min is the minimum size for send buffer. + Min int + + // Default is the default size for send buffer. + Default int + + // Max is the maximum size for send buffer. + Max int +} + +// GetSendBufferLimits is used to get the send buffer size limits. +type GetSendBufferLimits func(StackHandler) SendBufferSizeOption + +// GetStackSendBufferLimits is used to get default, min and max send buffer size. +func GetStackSendBufferLimits(so StackHandler) SendBufferSizeOption { + var ss SendBufferSizeOption + if err := so.Option(&ss); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) + } + return ss +} + // Route is a row in the routing table. It specifies through which NIC (and // gateway) sets of packets should be routed. A row is considered viable if the // masked target address matches the destination address in the row. diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index af00ed548..85b497365 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -69,8 +69,7 @@ type endpoint struct { rcvClosed bool // The following fields are protected by the mu mutex. - mu sync.RWMutex `state:"nosave"` - sndBufSize int + mu sync.RWMutex `state:"nosave"` // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags state endpointState @@ -94,11 +93,17 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt }, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, - sndBufSize: 32 * 1024, state: stateInitial, uniqueID: s.UniqueID(), } - ep.ops.InitHandler(ep) + ep.ops.InitHandler(ep, ep.stack) + ep.ops.SetSendBufferSize(32*1024, false /* notify */, tcpip.GetStackSendBufferLimits) + + // Override with stack defaults. + var ss tcpip.SendBufferSizeOption + if err := s.Option(&ss); err == nil { + ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */, tcpip.GetStackSendBufferLimits) + } return ep, nil } @@ -363,11 +368,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } e.rcvMu.Unlock() return v, nil - case tcpip.SendBufferSizeOption: - e.mu.Lock() - v := e.sndBufSize - e.mu.Unlock() - return v, nil case tcpip.ReceiveBufferSizeOption: e.rcvMu.Lock() diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index 9d263c0ec..9335cbc5a 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -69,6 +69,7 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.stack = s + e.ops.InitHandler(e, e.stack) if e.state != stateBound && e.state != stateConnected { return diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 6fd116a98..9ca9c9780 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -79,13 +79,11 @@ type endpoint struct { rcvClosed bool // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - sndBufSizeMax int - closed bool - stats tcpip.TransportEndpointStats `state:"nosave"` - bound bool - boundNIC tcpip.NICID + mu sync.RWMutex `state:"nosave"` + closed bool + stats tcpip.TransportEndpointStats `state:"nosave"` + bound bool + boundNIC tcpip.NICID // lastErrorMu protects lastError. lastErrorMu sync.Mutex `state:"nosave"` @@ -106,14 +104,13 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb netProto: netProto, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, - sndBufSize: 32 * 1024, } - ep.ops.InitHandler(ep) + ep.ops.InitHandler(ep, ep.stack) // Override with stack defaults. - var ss stack.SendBufferSizeOption + var ss tcpip.SendBufferSizeOption if err := s.Option(&ss); err == nil { - ep.sndBufSizeMax = ss.Default + ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */, tcpip.GetStackSendBufferLimits) } var rs stack.ReceiveBufferSizeOption @@ -320,24 +317,6 @@ func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { switch opt { - case tcpip.SendBufferSizeOption: - // Make sure the send buffer size is within the min and max - // allowed. - var ss stack.SendBufferSizeOption - if err := ep.stack.Option(&ss); err != nil { - panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) - } - if v > ss.Max { - v = ss.Max - } - if v < ss.Min { - v = ss.Min - } - ep.mu.Lock() - ep.sndBufSizeMax = v - ep.mu.Unlock() - return nil - case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. @@ -395,12 +374,6 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { ep.rcvMu.Unlock() return v, nil - case tcpip.SendBufferSizeOption: - ep.mu.Lock() - v := ep.sndBufSizeMax - ep.mu.Unlock() - return v, nil - case tcpip.ReceiveBufferSizeOption: ep.rcvMu.Lock() v := ep.rcvBufSizeMax diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index e2fa96d17..61221be71 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -63,8 +63,8 @@ func (ep *endpoint) loadRcvBufSizeMax(max int) { // afterLoad is invoked by stateify. func (ep *endpoint) afterLoad() { - // StackFromEnv is a stack used specifically for save/restore. ep.stack = stack.StackFromEnv + ep.ops.InitHandler(ep, ep.stack) // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC. if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil { diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 2dacf5a64..f038097e4 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -76,12 +76,10 @@ type endpoint struct { rcvClosed bool // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - sndBufSizeMax int - closed bool - connected bool - bound bool + mu sync.RWMutex `state:"nosave"` + closed bool + connected bool + bound bool // route is the route to a remote network endpoint. It is set via // Connect(), and is valid only when conneted is true. route *stack.Route `state:"manual"` @@ -112,16 +110,16 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt }, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, - sndBufSizeMax: 32 * 1024, associated: associated, } - e.ops.InitHandler(e) + e.ops.InitHandler(e, e.stack) e.ops.SetHeaderIncluded(!associated) + e.ops.SetSendBufferSize(32*1024, false /* notify */, tcpip.GetStackSendBufferLimits) // Override with stack defaults. - var ss stack.SendBufferSizeOption + var ss tcpip.SendBufferSizeOption if err := s.Option(&ss); err == nil { - e.sndBufSizeMax = ss.Default + e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */, tcpip.GetStackSendBufferLimits) } var rs stack.ReceiveBufferSizeOption @@ -511,24 +509,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { switch opt { - case tcpip.SendBufferSizeOption: - // Make sure the send buffer size is within the min and max - // allowed. - var ss stack.SendBufferSizeOption - if err := e.stack.Option(&ss); err != nil { - panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) - } - if v > ss.Max { - v = ss.Max - } - if v < ss.Min { - v = ss.Min - } - e.mu.Lock() - e.sndBufSizeMax = v - e.mu.Unlock() - return nil - case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. @@ -570,12 +550,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.SendBufferSizeOption: - e.mu.Lock() - v := e.sndBufSizeMax - e.mu.Unlock() - return v, nil - case tcpip.ReceiveBufferSizeOption: e.rcvMu.Lock() v := e.rcvBufSizeMax diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 4a7e1c039..e11ab38d5 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -69,6 +69,7 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.stack = s + e.ops.InitHandler(e, e.stack) // If the endpoint is connected, re-connect. if e.connected { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 05c431e83..8e43fec81 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -557,7 +557,6 @@ type endpoint struct { // When the send side is closed, the protocol goroutine is notified via // sndCloseWaker, and sndClosed is set to true. sndBufMu sync.Mutex `state:"nosave"` - sndBufSize int sndBufUsed int sndClosed bool sndBufInQueue seqnum.Size @@ -869,7 +868,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue waiterQueue: waiterQueue, state: StateInitial, rcvBufSize: DefaultReceiveBufferSize, - sndBufSize: DefaultSendBufferSize, sndMTU: int(math.MaxInt32), keepalive: keepalive{ // Linux defaults. @@ -882,13 +880,14 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue windowClamp: DefaultReceiveBufferSize, maxSynRetries: DefaultSynRetries, } - e.ops.InitHandler(e) + e.ops.InitHandler(e, e.stack) e.ops.SetMulticastLoop(true) e.ops.SetQuickAck(true) + e.ops.SetSendBufferSize(DefaultSendBufferSize, false /* notify */, GetTCPSendBufferLimits) var ss tcpip.TCPSendBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { - e.sndBufSize = ss.Default + e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */, GetTCPSendBufferLimits) } var rs tcpip.TCPReceiveBufferSizeRangeOption @@ -967,7 +966,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // Determine if the endpoint is writable if requested. if (mask & waiter.EventOut) != 0 { e.sndBufMu.Lock() - if e.sndClosed || e.sndBufUsed < e.sndBufSize { + sndBufSize := e.getSendBufferSize() + if e.sndClosed || e.sndBufUsed < sndBufSize { result |= waiter.EventOut } e.sndBufMu.Unlock() @@ -1499,7 +1499,8 @@ func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { return 0, tcpip.ErrClosedForSend } - avail := e.sndBufSize - e.sndBufUsed + sndBufSize := e.getSendBufferSize() + avail := sndBufSize - e.sndBufUsed if avail <= 0 { return 0, tcpip.ErrWouldBlock } @@ -1692,6 +1693,14 @@ func (e *endpoint) OnCorkOptionSet(v bool) { } } +func (e *endpoint) getSendBufferSize() int { + sndBufSize, err := e.ops.GetSendBufferSize() + if err != nil { + panic(fmt.Sprintf("e.ops.GetSendBufferSize() = %s", err)) + } + return int(sndBufSize) +} + // SetSockOptInt sets a socket option. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 @@ -1785,31 +1794,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.rcvListMu.Unlock() e.UnlockUser() - case tcpip.SendBufferSizeOption: - // Make sure the send buffer size is within the min and max - // allowed. - var ss tcpip.TCPSendBufferSizeRangeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err != nil { - panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &ss, err)) - } - - if v > ss.Max { - v = ss.Max - } - - if v < math.MaxInt32/SegOverheadFactor { - v *= SegOverheadFactor - if v < ss.Min { - v = ss.Min - } - } else { - v = math.MaxInt32 - } - - e.sndBufMu.Lock() - e.sndBufSize = v - e.sndBufMu.Unlock() - case tcpip.TTLOption: e.LockUser() e.ttl = uint8(v) @@ -1995,12 +1979,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() - case tcpip.SendBufferSizeOption: - e.sndBufMu.Lock() - v := e.sndBufSize - e.sndBufMu.Unlock() - return v, nil - case tcpip.ReceiveBufferSizeOption: e.rcvListMu.Lock() v := e.rcvBufSize @@ -2749,13 +2727,14 @@ func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt // updateSndBufferUsage is called by the protocol goroutine when room opens up // in the send buffer. The number of newly available bytes is v. func (e *endpoint) updateSndBufferUsage(v int) { + sendBufferSize := e.getSendBufferSize() e.sndBufMu.Lock() - notify := e.sndBufUsed >= e.sndBufSize>>1 + notify := e.sndBufUsed >= sendBufferSize>>1 e.sndBufUsed -= v - // We only notify when there is half the sndBufSize available after + // We only notify when there is half the sendBufferSize available after // a full buffer event occurs. This ensures that we don't wake up // writers to queue just 1-2 segments and go back to sleep. - notify = notify && e.sndBufUsed < e.sndBufSize>>1 + notify = notify && e.sndBufUsed < int(sendBufferSize)>>1 e.sndBufMu.Unlock() if notify { @@ -2967,8 +2946,9 @@ func (e *endpoint) completeState() stack.TCPEndpointState { s.SACK.ReceivedBlocks, s.SACK.MaxSACKED = e.scoreboard.Copy() // Copy endpoint send state. + sndBufSize := e.getSendBufferSize() e.sndBufMu.Lock() - s.SndBufSize = e.sndBufSize + s.SndBufSize = sndBufSize s.SndBufUsed = e.sndBufUsed s.SndClosed = e.sndClosed s.SndBufInQueue = e.sndBufInQueue @@ -3113,3 +3093,17 @@ func (e *endpoint) Wait() { func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } + +// GetTCPSendBufferLimits is used to get send buffer size limits for TCP. +func GetTCPSendBufferLimits(s tcpip.StackHandler) tcpip.SendBufferSizeOption { + var ss tcpip.TCPSendBufferSizeRangeOption + if err := s.TransportProtocolOption(header.TCPProtocolNumber, &ss); err != nil { + panic(fmt.Sprintf("s.TransportProtocolOption(%d, %#v) = %s", header.TCPProtocolNumber, ss, err)) + } + + return tcpip.SendBufferSizeOption{ + Min: ss.Min, + Default: ss.Default, + Max: ss.Max, + } +} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index ba67176b5..19c1dc67a 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -179,14 +179,16 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.stack = s + e.ops.InitHandler(e, e.stack) e.segmentQueue.thaw() epState := e.origEndpointState switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss tcpip.TCPSendBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { - if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { - panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) + sendBufferSize := e.getSendBufferSize() + if sendBufferSize < ss.Min || sendBufferSize > ss.Max { + panic(fmt.Sprintf("endpoint sendBufferSize %d is outside the min and max allowed [%d, %d]", sendBufferSize, ss.Min, ss.Max)) } } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index db9dc4e25..754f6b92a 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -4365,9 +4365,9 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { t.Helper() - s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) + s, err := ep.SocketOptions().GetSendBufferSize() if err != nil { - t.Fatalf("GetSockOpt failed: %s", err) + t.Fatalf("GetSendBufferSize failed: %s", err) } if int(s) != v { @@ -4473,9 +4473,7 @@ func TestMinMaxBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, 200) - if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 149); err != nil { - t.Fatalf("SetSockOptInt(SendBufferSizeOption, 299) failed: %s", err) - } + ep.SocketOptions().SetSendBufferSize(149, true, tcp.GetTCPSendBufferLimits) checkSendBufferSize(t, ep, 300) @@ -4487,9 +4485,7 @@ func TestMinMaxBufferSizes(t *testing.T) { // Values above max are capped at max and then doubled. checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2) - if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil { - t.Fatalf("SetSockOptInt(SendBufferSizeOption) failed: %s", err) - } + ep.SocketOptions().SetSendBufferSize(1+tcp.DefaultSendBufferSize*30, true, tcp.GetTCPSendBufferLimits) // Values above max are capped at max and then doubled. checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 8544fcb08..ee7e43b97 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -97,9 +97,7 @@ type endpoint struct { rcvClosed bool // The following fields are protected by the mu mutex. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - sndBufSizeMax int + mu sync.RWMutex `state:"nosave"` // state must be read/set using the EndpointState()/setEndpointState() // methods. state EndpointState @@ -176,18 +174,18 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // Linux defaults to TTL=1. multicastTTL: 1, rcvBufSizeMax: 32 * 1024, - sndBufSizeMax: 32 * 1024, multicastMemberships: make(map[multicastMembership]struct{}), state: StateInitial, uniqueID: s.UniqueID(), } - e.ops.InitHandler(e) + e.ops.InitHandler(e, e.stack) e.ops.SetMulticastLoop(true) + e.ops.SetSendBufferSize(32*1024, false /* notify */, tcpip.GetStackSendBufferLimits) // Override with stack defaults. - var ss stack.SendBufferSizeOption + var ss tcpip.SendBufferSizeOption if err := s.Option(&ss); err == nil { - e.sndBufSizeMax = ss.Default + e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */, tcpip.GetStackSendBufferLimits) } var rs stack.ReceiveBufferSizeOption @@ -632,25 +630,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.rcvBufSizeMax = v e.mu.Unlock() return nil - case tcpip.SendBufferSizeOption: - // Make sure the send buffer size is within the min and max - // allowed. - var ss stack.SendBufferSizeOption - if err := e.stack.Option(&ss); err != nil { - panic(fmt.Sprintf("e.stack.Option(%#v) = %s", ss, err)) - } - - if v < ss.Min { - v = ss.Min - } - if v > ss.Max { - v = ss.Max - } - - e.mu.Lock() - e.sndBufSizeMax = v - e.mu.Unlock() - return nil } return nil @@ -811,12 +790,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.SendBufferSizeOption: - e.mu.Lock() - v := e.sndBufSizeMax - e.mu.Unlock() - return v, nil - case tcpip.ReceiveBufferSizeOption: e.rcvMu.Lock() v := e.rcvBufSizeMax diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 13b72dc88..903397f1c 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -91,6 +91,7 @@ func (e *endpoint) Resume(s *stack.Stack) { defer e.mu.Unlock() e.stack = s + e.ops.InitHandler(e, e.stack) for m := range e.multicastMemberships { if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil { diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc index 2ed4f6f9c..d25be0e30 100644 --- a/test/syscalls/linux/packet_socket_raw.cc +++ b/test/syscalls/linux/packet_socket_raw.cc @@ -548,13 +548,7 @@ TEST_P(RawPacketTest, SetSocketSendBuf) { ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len), SyscallSucceeds()); - // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF. - // TODO(gvisor.dev/issue/2926): Remove the gvisor special casing when Netstack - // matches linux behavior. - if (!IsRunningOnGvisor()) { - quarter_sz *= 2; - } - + quarter_sz *= 2; ASSERT_EQ(quarter_sz, val); } diff --git a/test/syscalls/linux/raw_socket.cc b/test/syscalls/linux/raw_socket.cc index 955bcee4b..32924466f 100644 --- a/test/syscalls/linux/raw_socket.cc +++ b/test/syscalls/linux/raw_socket.cc @@ -621,13 +621,7 @@ TEST_P(RawSocketTest, SetSocketSendBuf) { ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len), SyscallSucceeds()); - // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF. - // TODO(gvisor.dev/issue/2926): Remove the gvisor special casing when Netstack - // matches linux behavior. - if (!IsRunningOnGvisor()) { - quarter_sz *= 2; - } - + quarter_sz *= 2; ASSERT_EQ(quarter_sz, val); } diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index e557572a7..8eec31a46 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -2513,11 +2513,7 @@ TEST_P(IPv4UDPUnboundSocketTest, SetSocketSendBuf) { ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &val, &val_len), SyscallSucceeds()); - // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF. - if (!IsRunningOnGvisor()) { - quarter_sz *= 2; - } - + quarter_sz *= 2; ASSERT_EQ(quarter_sz, val); } |