diff options
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 53 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint_state.go | 33 | ||||
-rw-r--r-- | pkg/tcpip/transport/packet/endpoint.go | 74 | ||||
-rw-r--r-- | pkg/tcpip/transport/packet/endpoint_state.go | 25 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 76 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint_state.go | 33 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 152 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint_state.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/rcv.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/segment_queue.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 49 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/testing/context/context.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 68 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint_state.go | 34 |
15 files changed, 242 insertions, 373 deletions
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 50991c3c0..33ed78f54 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -63,12 +63,11 @@ type endpoint struct { // The following fields are used to manage the receive queue, and are // protected by rcvMu. - rcvMu sync.Mutex `state:"nosave"` - rcvReady bool - rcvList icmpPacketList - rcvBufSizeMax int `state:".(int)"` - rcvBufSize int - rcvClosed bool + rcvMu sync.Mutex `state:"nosave"` + rcvReady bool + rcvList icmpPacketList + rcvBufSize int + rcvClosed bool // The following fields are protected by the mu mutex. mu sync.RWMutex `state:"nosave"` @@ -84,6 +83,10 @@ type endpoint struct { // ops is used to get socket level options. ops tcpip.SocketOptions + + // frozen indicates if the packets should be delivered to the endpoint + // during restore. + frozen bool } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { @@ -93,19 +96,23 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt NetProto: netProto, TransProto: transProto, }, - waiterQueue: waiterQueue, - rcvBufSizeMax: 32 * 1024, - state: stateInitial, - uniqueID: s.UniqueID(), + waiterQueue: waiterQueue, + state: stateInitial, + uniqueID: s.UniqueID(), } - ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) ep.ops.SetSendBufferSize(32*1024, false /* notify */) + ep.ops.SetReceiveBufferSize(32*1024, false /* notify */) // Override with stack defaults. var ss tcpip.SendBufferSizeOption if err := s.Option(&ss); err == nil { ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } + var rs tcpip.ReceiveBufferSizeOption + if err := s.Option(&rs); err == nil { + ep.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) + } return ep, nil } @@ -371,12 +378,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - v := e.rcvBufSizeMax - e.rcvMu.Unlock() - return v, nil - case tcpip.TTLOption: e.rcvMu.Lock() v := int(e.ttl) @@ -774,7 +775,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB return } - if e.rcvBufSize >= e.rcvBufSizeMax { + rcvBufSize := e.ops.GetReceiveBufferSize() + if e.frozen || e.rcvBufSize >= int(rcvBufSize) { e.rcvMu.Unlock() e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() @@ -843,3 +845,18 @@ func (*endpoint) LastError() tcpip.Error { func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } + +// freeze prevents any more packets from being delivered to the endpoint. +func (e *endpoint) freeze() { + e.mu.Lock() + e.frozen = true + e.mu.Unlock() +} + +// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows +// new packets to be delivered again. +func (e *endpoint) thaw() { + e.mu.Lock() + e.frozen = false + e.mu.Unlock() +} diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index a3c6db5a8..28a56a2d5 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -36,40 +36,21 @@ func (p *icmpPacket) loadData(data buffer.VectorisedView) { p.data = data } -// beforeSave is invoked by stateify. -func (e *endpoint) beforeSave() { - // Stop incoming packets from being handled (and mutate endpoint state). - // The lock will be released after savercvBufSizeMax(), which would have - // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming - // packets. - e.rcvMu.Lock() -} - -// saveRcvBufSizeMax is invoked by stateify. -func (e *endpoint) saveRcvBufSizeMax() int { - max := e.rcvBufSizeMax - // Make sure no new packets will be handled regardless of the lock. - e.rcvBufSizeMax = 0 - // Release the lock acquired in beforeSave() so regular endpoint closing - // logic can proceed after save. - e.rcvMu.Unlock() - return max -} - -// loadRcvBufSizeMax is invoked by stateify. -func (e *endpoint) loadRcvBufSizeMax(max int) { - e.rcvBufSizeMax = max -} - // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { stack.StackFromEnv.RegisterRestoredEndpoint(e) } +// beforeSave is invoked by stateify. +func (e *endpoint) beforeSave() { + e.freeze() +} + // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.thaw() e.stack = s - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) if e.state != stateBound && e.state != stateConnected { return diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 52ed9560c..496eca581 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -72,11 +72,10 @@ type endpoint struct { // The following fields are used to manage the receive queue and are // protected by rcvMu. - rcvMu sync.Mutex `state:"nosave"` - rcvList packetList - rcvBufSizeMax int `state:".(int)"` - rcvBufSize int - rcvClosed bool + rcvMu sync.Mutex `state:"nosave"` + rcvList packetList + rcvBufSize int + rcvClosed bool // The following fields are protected by mu. mu sync.RWMutex `state:"nosave"` @@ -91,6 +90,10 @@ type endpoint struct { // ops is used to get socket level options. ops tcpip.SocketOptions + + // frozen indicates if the packets should be delivered to the endpoint + // during restore. + frozen bool } // NewEndpoint returns a new packet endpoint. @@ -100,12 +103,12 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb TransportEndpointInfo: stack.TransportEndpointInfo{ NetProto: netProto, }, - cooked: cooked, - netProto: netProto, - waiterQueue: waiterQueue, - rcvBufSizeMax: 32 * 1024, + cooked: cooked, + netProto: netProto, + waiterQueue: waiterQueue, } - ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) + ep.ops.SetReceiveBufferSize(32*1024, false /* notify */) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -113,9 +116,9 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } - var rs stack.ReceiveBufferSizeOption + var rs tcpip.ReceiveBufferSizeOption if err := s.Option(&rs); err == nil { - ep.rcvBufSizeMax = rs.Default + ep.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) } if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil { @@ -316,28 +319,7 @@ func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { - switch opt { - case tcpip.ReceiveBufferSizeOption: - // Make sure the receive buffer size is within the min and max - // allowed. - var rs stack.ReceiveBufferSizeOption - if err := ep.stack.Option(&rs); err != nil { - panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err)) - } - if v > rs.Max { - v = rs.Max - } - if v < rs.Min { - v = rs.Min - } - ep.rcvMu.Lock() - ep.rcvBufSizeMax = v - ep.rcvMu.Unlock() - return nil - - default: - return &tcpip.ErrUnknownProtocolOption{} - } + return &tcpip.ErrUnknownProtocolOption{} } func (ep *endpoint) LastError() tcpip.Error { @@ -374,12 +356,6 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { ep.rcvMu.Unlock() return v, nil - case tcpip.ReceiveBufferSizeOption: - ep.rcvMu.Lock() - v := ep.rcvBufSizeMax - ep.rcvMu.Unlock() - return v, nil - default: return -1, &tcpip.ErrUnknownProtocolOption{} } @@ -397,7 +373,8 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, return } - if ep.rcvBufSize >= ep.rcvBufSizeMax { + rcvBufSize := ep.ops.GetReceiveBufferSize() + if ep.frozen || ep.rcvBufSize >= int(rcvBufSize) { ep.rcvMu.Unlock() ep.stack.Stats().DroppedPackets.Increment() ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() @@ -513,3 +490,18 @@ func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {} func (ep *endpoint) SocketOptions() *tcpip.SocketOptions { return &ep.ops } + +// freeze prevents any more packets from being delivered to the endpoint. +func (ep *endpoint) freeze() { + ep.mu.Lock() + ep.frozen = true + ep.mu.Unlock() +} + +// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows +// new packets to be delivered again. +func (ep *endpoint) thaw() { + ep.mu.Lock() + ep.frozen = false + ep.mu.Unlock() +} diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index ece662c0d..5bd860d20 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -38,33 +38,14 @@ func (p *packet) loadData(data buffer.VectorisedView) { // beforeSave is invoked by stateify. func (ep *endpoint) beforeSave() { - // Stop incoming packets from being handled (and mutate endpoint state). - // The lock will be released after saveRcvBufSizeMax(), which would have - // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming - // packets. - ep.rcvMu.Lock() -} - -// saveRcvBufSizeMax is invoked by stateify. -func (ep *endpoint) saveRcvBufSizeMax() int { - max := ep.rcvBufSizeMax - // Make sure no new packets will be handled regardless of the lock. - ep.rcvBufSizeMax = 0 - // Release the lock acquired in beforeSave() so regular endpoint closing - // logic can proceed after save. - ep.rcvMu.Unlock() - return max -} - -// loadRcvBufSizeMax is invoked by stateify. -func (ep *endpoint) loadRcvBufSizeMax(max int) { - ep.rcvBufSizeMax = max + ep.freeze() } // afterLoad is invoked by stateify. func (ep *endpoint) afterLoad() { + ep.thaw() ep.stack = stack.StackFromEnv - ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC. if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil { diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index e27a249cd..10453a42a 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -26,7 +26,6 @@ package raw import ( - "fmt" "io" "gvisor.dev/gvisor/pkg/sync" @@ -69,11 +68,10 @@ type endpoint struct { // The following fields are used to manage the receive queue and are // protected by rcvMu. - rcvMu sync.Mutex `state:"nosave"` - rcvList rawPacketList - rcvBufSize int - rcvBufSizeMax int `state:".(int)"` - rcvClosed bool + rcvMu sync.Mutex `state:"nosave"` + rcvList rawPacketList + rcvBufSize int + rcvClosed bool // The following fields are protected by mu. mu sync.RWMutex `state:"nosave"` @@ -89,6 +87,10 @@ type endpoint struct { // ops is used to get socket level options. ops tcpip.SocketOptions + + // frozen indicates if the packets should be delivered to the endpoint + // during restore. + frozen bool } // NewEndpoint returns a raw endpoint for the given protocols. @@ -107,13 +109,13 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt NetProto: netProto, TransProto: transProto, }, - waiterQueue: waiterQueue, - rcvBufSizeMax: 32 * 1024, - associated: associated, + waiterQueue: waiterQueue, + associated: associated, } - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) e.ops.SetHeaderIncluded(!associated) e.ops.SetSendBufferSize(32*1024, false /* notify */) + e.ops.SetReceiveBufferSize(32*1024, false /* notify */) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -121,16 +123,16 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } - var rs stack.ReceiveBufferSizeOption + var rs tcpip.ReceiveBufferSizeOption if err := s.Option(&rs); err == nil { - e.rcvBufSizeMax = rs.Default + e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) } // Unassociated endpoints are write-only and users call Write() with IP // headers included. Because they're write-only, We don't need to // register with the stack. if !associated { - e.rcvBufSizeMax = 0 + e.ops.SetReceiveBufferSize(0, false) e.waiterQueue = nil return e, nil } @@ -511,30 +513,8 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { } } -// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { - switch opt { - case tcpip.ReceiveBufferSizeOption: - // Make sure the receive buffer size is within the min and max - // allowed. - var rs stack.ReceiveBufferSizeOption - if err := e.stack.Option(&rs); err != nil { - panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err)) - } - if v > rs.Max { - v = rs.Max - } - if v < rs.Min { - v = rs.Min - } - e.rcvMu.Lock() - e.rcvBufSizeMax = v - e.rcvMu.Unlock() - return nil - - default: - return &tcpip.ErrUnknownProtocolOption{} - } + return &tcpip.ErrUnknownProtocolOption{} } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. @@ -555,12 +535,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - v := e.rcvBufSizeMax - e.rcvMu.Unlock() - return v, nil - default: return -1, &tcpip.ErrUnknownProtocolOption{} } @@ -587,7 +561,8 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - if e.rcvBufSize >= e.rcvBufSizeMax { + rcvBufSize := e.ops.GetReceiveBufferSize() + if e.frozen || e.rcvBufSize >= int(rcvBufSize) { e.rcvMu.Unlock() e.mu.RUnlock() e.stack.Stats().DroppedPackets.Increment() @@ -690,3 +665,18 @@ func (*endpoint) LastError() tcpip.Error { func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } + +// freeze prevents any more packets from being delivered to the endpoint. +func (e *endpoint) freeze() { + e.mu.Lock() + e.frozen = true + e.mu.Unlock() +} + +// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows +// new packets to be delivered again. +func (e *endpoint) thaw() { + e.mu.Lock() + e.frozen = false + e.mu.Unlock() +} diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 263ec5146..5d6f2709c 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -36,40 +36,21 @@ func (p *rawPacket) loadData(data buffer.VectorisedView) { p.data = data } -// beforeSave is invoked by stateify. -func (e *endpoint) beforeSave() { - // Stop incoming packets from being handled (and mutate endpoint state). - // The lock will be released after saveRcvBufSizeMax(), which would have - // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming - // packets. - e.rcvMu.Lock() -} - -// saveRcvBufSizeMax is invoked by stateify. -func (e *endpoint) saveRcvBufSizeMax() int { - max := e.rcvBufSizeMax - // Make sure no new packets will be handled regardless of the lock. - e.rcvBufSizeMax = 0 - // Release the lock acquired in beforeSave() so regular endpoint closing - // logic can proceed after save. - e.rcvMu.Unlock() - return max -} - -// loadRcvBufSizeMax is invoked by stateify. -func (e *endpoint) loadRcvBufSizeMax(max int) { - e.rcvBufSizeMax = max -} - // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { stack.StackFromEnv.RegisterRestoredEndpoint(e) } +// beforeSave is invoked by stateify. +func (e *endpoint) beforeSave() { + e.freeze() +} + // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.thaw() e.stack = s - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) // If the endpoint is connected, re-connect. if e.connected { diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 664cb9420..d4bd4e80e 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -219,7 +219,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header n.boundNICID = s.nicID n.route = route n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.netProto} - n.rcvQueueInfo.RcvBufSize = int(l.rcvWnd) + n.ops.SetReceiveBufferSize(int64(l.rcvWnd), false /* notify */) n.amss = calculateAdvertisedMSS(n.userMSS, n.route) n.setEndpointState(StateConnecting) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 884332828..f25dc781a 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -822,11 +822,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue windowClamp: DefaultReceiveBufferSize, maxSynRetries: DefaultSynRetries, } - e.rcvQueueInfo.RcvBufSize = DefaultReceiveBufferSize - e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits) + e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits) e.ops.SetMulticastLoop(true) e.ops.SetQuickAck(true) e.ops.SetSendBufferSize(DefaultSendBufferSize, false /* notify */) + e.ops.SetReceiveBufferSize(DefaultReceiveBufferSize, false /* notify */) var ss tcpip.TCPSendBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { @@ -835,7 +835,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue var rs tcpip.TCPReceiveBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { - e.rcvQueueInfo.RcvBufSize = rs.Default + e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) } var cs tcpip.CongestionControlOption @@ -1228,11 +1228,12 @@ func (e *endpoint) ModerateRecvBuf(copied int) { // We do not adjust downwards as that can cause the receiver to // reject valid data that might already be in flight as the // acceptable window will shrink. - if rcvWnd > e.rcvQueueInfo.RcvBufSize { - availBefore := wndFromSpace(e.receiveBufferAvailableLocked()) - e.rcvQueueInfo.RcvBufSize = rcvWnd - availAfter := wndFromSpace(e.receiveBufferAvailableLocked()) - if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { + rcvBufSize := int(e.ops.GetReceiveBufferSize()) + if rcvWnd > rcvBufSize { + availBefore := wndFromSpace(e.receiveBufferAvailableLocked(rcvBufSize)) + e.ops.SetReceiveBufferSize(int64(rcvWnd), false /* notify */) + availAfter := wndFromSpace(e.receiveBufferAvailableLocked(rcvWnd)) + if crossed, above := e.windowCrossedACKThresholdLocked(availAfter-availBefore, rcvBufSize); crossed && above { e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } } @@ -1424,7 +1425,7 @@ func (e *endpoint) commitRead(done int) *segment { // enough buffer space, to either fit an aMSS or half a receive buffer // (whichever smaller), then notify the protocol goroutine to send a // window update. - if crossed, above := e.windowCrossedACKThresholdLocked(memDelta); crossed && above { + if crossed, above := e.windowCrossedACKThresholdLocked(memDelta, int(e.ops.GetReceiveBufferSize())); crossed && above { e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } } @@ -1556,9 +1557,9 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp // selectWindowLocked returns the new window without checking for shrinking or scaling // applied. // Precondition: e.mu and e.rcvQueueMu must be held. -func (e *endpoint) selectWindowLocked() (wnd seqnum.Size) { - wndFromAvailable := wndFromSpace(e.receiveBufferAvailableLocked()) - maxWindow := wndFromSpace(e.rcvQueueInfo.RcvBufSize) +func (e *endpoint) selectWindowLocked(rcvBufSize int) (wnd seqnum.Size) { + wndFromAvailable := wndFromSpace(e.receiveBufferAvailableLocked(rcvBufSize)) + maxWindow := wndFromSpace(rcvBufSize) wndFromUsedBytes := maxWindow - e.rcvQueueInfo.RcvBufUsed // We take the lesser of the wndFromAvailable and wndFromUsedBytes because in @@ -1580,7 +1581,7 @@ func (e *endpoint) selectWindowLocked() (wnd seqnum.Size) { // selectWindow invokes selectWindowLocked after acquiring e.rcvQueueMu. func (e *endpoint) selectWindow() (wnd seqnum.Size) { e.rcvQueueInfo.rcvQueueMu.Lock() - wnd = e.selectWindowLocked() + wnd = e.selectWindowLocked(int(e.ops.GetReceiveBufferSize())) e.rcvQueueInfo.rcvQueueMu.Unlock() return wnd } @@ -1600,8 +1601,8 @@ func (e *endpoint) selectWindow() (wnd seqnum.Size) { // otherwise. // // Precondition: e.mu and e.rcvQueueMu must be held. -func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) { - newAvail := int(e.selectWindowLocked()) +func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int, rcvBufSize int) (crossed bool, above bool) { + newAvail := int(e.selectWindowLocked(rcvBufSize)) oldAvail := newAvail - deltaBefore if oldAvail < 0 { oldAvail = 0 @@ -1610,7 +1611,7 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo // rcvBufFraction is the inverse of the fraction of receive buffer size that // is used to decide if the available buffer space is now above it. const rcvBufFraction = 2 - if wndThreshold := wndFromSpace(e.rcvQueueInfo.RcvBufSize / rcvBufFraction); threshold > wndThreshold { + if wndThreshold := wndFromSpace(rcvBufSize / rcvBufFraction); threshold > wndThreshold { threshold = wndThreshold } switch { @@ -1661,6 +1662,37 @@ func (e *endpoint) getSendBufferSize() int { return int(e.ops.GetSendBufferSize()) } +// OnSetReceiveBufferSize implements tcpip.SocketOptionsHandler.OnSetReceiveBufferSize. +func (e *endpoint) OnSetReceiveBufferSize(rcvBufSz, oldSz int64) (newSz int64) { + e.LockUser() + e.rcvQueueInfo.rcvQueueMu.Lock() + + // Make sure the receive buffer size allows us to send a + // non-zero window size. + scale := uint8(0) + if e.rcv != nil { + scale = e.rcv.RcvWndScale + } + if rcvBufSz>>scale == 0 { + rcvBufSz = 1 << scale + } + + availBefore := wndFromSpace(e.receiveBufferAvailableLocked(int(oldSz))) + availAfter := wndFromSpace(e.receiveBufferAvailableLocked(int(rcvBufSz))) + e.rcvQueueInfo.RcvAutoParams.Disabled = true + + // Immediately send an ACK to uncork the sender silly window + // syndrome prevetion, when our available space grows above aMSS + // or half receive buffer, whichever smaller. + if crossed, above := e.windowCrossedACKThresholdLocked(availAfter-availBefore, int(rcvBufSz)); crossed && above { + e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) + } + + e.rcvQueueInfo.rcvQueueMu.Unlock() + e.UnlockUser() + return rcvBufSz +} + // SetSockOptInt sets a socket option. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 @@ -1704,56 +1736,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { return &tcpip.ErrNotSupported{} } - case tcpip.ReceiveBufferSizeOption: - // Make sure the receive buffer size is within the min and max - // allowed. - var rs tcpip.TCPReceiveBufferSizeRangeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { - panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &rs, err)) - } - - if v > rs.Max { - v = rs.Max - } - - if v < math.MaxInt32/SegOverheadFactor { - v *= SegOverheadFactor - if v < rs.Min { - v = rs.Min - } - } else { - v = math.MaxInt32 - } - - e.LockUser() - e.rcvQueueInfo.rcvQueueMu.Lock() - - // Make sure the receive buffer size allows us to send a - // non-zero window size. - scale := uint8(0) - if e.rcv != nil { - scale = e.rcv.RcvWndScale - } - if v>>scale == 0 { - v = 1 << scale - } - - availBefore := wndFromSpace(e.receiveBufferAvailableLocked()) - e.rcvQueueInfo.RcvBufSize = v - availAfter := wndFromSpace(e.receiveBufferAvailableLocked()) - - e.rcvQueueInfo.RcvAutoParams.Disabled = true - - // Immediately send an ACK to uncork the sender silly window - // syndrome prevetion, when our available space grows above aMSS - // or half receive buffer, whichever smaller. - if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { - e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) - } - - e.rcvQueueInfo.rcvQueueMu.Unlock() - e.UnlockUser() - case tcpip.TTLOption: e.LockUser() e.ttl = uint8(v) @@ -1939,12 +1921,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() - case tcpip.ReceiveBufferSizeOption: - e.rcvQueueInfo.rcvQueueMu.Lock() - v := e.rcvQueueInfo.RcvBufSize - e.rcvQueueInfo.rcvQueueMu.Unlock() - return v, nil - case tcpip.TTLOption: e.LockUser() v := int(e.ttl) @@ -2780,15 +2756,15 @@ func (e *endpoint) readyToRead(s *segment) { // receiveBufferAvailableLocked calculates how many bytes are still available // in the receive buffer. // rcvQueueMu must be held when this function is called. -func (e *endpoint) receiveBufferAvailableLocked() int { +func (e *endpoint) receiveBufferAvailableLocked(rcvBufSize int) int { // We may use more bytes than the buffer size when the receive buffer // shrinks. memUsed := e.receiveMemUsed() - if memUsed >= e.rcvQueueInfo.RcvBufSize { + if memUsed >= rcvBufSize { return 0 } - return e.rcvQueueInfo.RcvBufSize - memUsed + return rcvBufSize - memUsed } // receiveBufferAvailable calculates how many bytes are still available in the @@ -2796,7 +2772,7 @@ func (e *endpoint) receiveBufferAvailableLocked() int { // receive buffer/pending and segment queue. func (e *endpoint) receiveBufferAvailable() int { e.rcvQueueInfo.rcvQueueMu.Lock() - available := e.receiveBufferAvailableLocked() + available := e.receiveBufferAvailableLocked(int(e.ops.GetReceiveBufferSize())) e.rcvQueueInfo.rcvQueueMu.Unlock() return available } @@ -2809,14 +2785,6 @@ func (e *endpoint) receiveBufferUsed() int { return used } -// receiveBufferSize returns the current size of the receive buffer. -func (e *endpoint) receiveBufferSize() int { - e.rcvQueueInfo.rcvQueueMu.Lock() - size := e.rcvQueueInfo.RcvBufSize - e.rcvQueueInfo.rcvQueueMu.Unlock() - return size -} - // receiveMemUsed returns the total memory in use by segments held by this // endpoint. func (e *endpoint) receiveMemUsed() int { @@ -2845,7 +2813,7 @@ func (e *endpoint) maxReceiveBufferSize() int { // receiveBuffer otherwise we use the max permissible receive buffer size to // compute the scale. func (e *endpoint) rcvWndScaleForHandshake() int { - bufSizeForScale := e.receiveBufferSize() + bufSizeForScale := e.ops.GetReceiveBufferSize() e.rcvQueueInfo.rcvQueueMu.Lock() autoTuningDisabled := e.rcvQueueInfo.RcvAutoParams.Disabled @@ -3074,3 +3042,17 @@ func (e *endpoint) allowOutOfWindowAck() bool { e.lastOutOfWindowAckTime = now return true } + +// GetTCPReceiveBufferLimits is used to get send buffer size limits for TCP. +func GetTCPReceiveBufferLimits(s tcpip.StackHandler) tcpip.ReceiveBufferSizeOption { + var ss tcpip.TCPReceiveBufferSizeRangeOption + if err := s.TransportProtocolOption(header.TCPProtocolNumber, &ss); err != nil { + panic(fmt.Sprintf("s.TransportProtocolOption(%d, %#v) = %s", header.TCPProtocolNumber, ss, err)) + } + + return tcpip.ReceiveBufferSizeOption{ + Min: ss.Min, + Default: ss.Default, + Max: ss.Max, + } +} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 034eacd72..6e9777fe4 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -165,7 +165,7 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.stack = s - e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits) + e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits) e.segmentQueue.thaw() epState := e.origEndpointState switch epState { @@ -180,8 +180,8 @@ func (e *endpoint) Resume(s *stack.Stack) { var rs tcpip.TCPReceiveBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { - if e.rcvQueueInfo.RcvBufSize < rs.Min || e.rcvQueueInfo.RcvBufSize > rs.Max { - panic(fmt.Sprintf("endpoint.rcvQueueInfo.RcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvQueueInfo.RcvBufSize, rs.Min, rs.Max)) + if rcvBufSize := e.ops.GetReceiveBufferSize(); rcvBufSize < int64(rs.Min) || rcvBufSize > int64(rs.Max) { + panic(fmt.Sprintf("endpoint rcvBufSize %d is outside the min and max allowed [%d, %d]", rcvBufSize, rs.Min, rs.Max)) } } } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index fc11b4ba9..ee2c08cd6 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -466,7 +466,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) { // segments. This ensures that we always leave some space for the inorder // segments to arrive allowing pending segments to be processed and // delivered to the user. - if r.ep.receiveBufferAvailable() > 0 && r.PendingBufUsed < r.ep.receiveBufferSize()>>2 { + if rcvBufSize := r.ep.ops.GetReceiveBufferSize(); rcvBufSize > 0 && r.PendingBufUsed < int(rcvBufSize)>>2 { r.ep.rcvQueueInfo.rcvQueueMu.Lock() r.PendingBufUsed += s.segMemSize() r.ep.rcvQueueInfo.rcvQueueMu.Unlock() diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go index 54545a1b1..d0d1b0b8a 100644 --- a/pkg/tcpip/transport/tcp/segment_queue.go +++ b/pkg/tcpip/transport/tcp/segment_queue.go @@ -52,12 +52,12 @@ func (q *segmentQueue) empty() bool { func (q *segmentQueue) enqueue(s *segment) bool { // q.ep.receiveBufferParams() must be called without holding q.mu to // avoid lock order inversion. - bufSz := q.ep.receiveBufferSize() + bufSz := q.ep.ops.GetReceiveBufferSize() used := q.ep.receiveMemUsed() q.mu.Lock() // Allow zero sized segments (ACK/FIN/RSTs etc even if the segment queue // is currently full). - allow := (used <= bufSz || s.payloadSize() == 0) && !q.frozen + allow := (used <= int(bufSz) || s.payloadSize() == 0) && !q.frozen if allow { q.list.PushBack(s) diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 74e11ab84..9f29a48fb 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -930,10 +930,7 @@ func TestUserSuppliedMSSOnConnect(t *testing.T) { } // Get expected window size. - rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) - if err != nil { - t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption): %s", err) - } + rcvBufSize := c.EP.SocketOptions().GetReceiveBufferSize() ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort} @@ -2080,9 +2077,7 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { // Bump up the receive buffer size such that, when the receive window grows, // the scaled window exceeds maxUint16. - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, opt.Max); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed: %s", opt.Max, err) - } + c.EP.SocketOptions().SetReceiveBufferSize(int64(opt.Max), true) // Keep the payload size < segment overhead and such that it is a multiple // of the window scaled value. This enables the test to perform equality @@ -2202,9 +2197,7 @@ func TestNoWindowShrinking(t *testing.T) { initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale initialLastAcceptableSeq := iss.Add(seqnum.Size(initialWnd)) // Now shrink the receive buffer to half its original size. - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufSize/2); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err) - } + c.EP.SocketOptions().SetReceiveBufferSize(int64(rcvBufSize/2), true) data := generateRandomPayload(t, rcvBufSize) // Send a payload of half the size of rcvBufSize. @@ -2460,9 +2453,7 @@ func TestScaledWindowAccept(t *testing.T) { defer ep.Close() // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err) - } + ep.SocketOptions().SetReceiveBufferSize(65535*3, true) if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) @@ -2534,9 +2525,7 @@ func TestNonScaledWindowAccept(t *testing.T) { defer ep.Close() // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err) - } + ep.SocketOptions().SetReceiveBufferSize(65535*3, true) if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) @@ -3129,9 +3118,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { // Set the buffer size to a deterministic size so that we can check the // window scaling option. const rcvBufferSize = 0x20000 - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) - } + ep.SocketOptions().SetReceiveBufferSize(rcvBufferSize, true) if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) @@ -3270,9 +3257,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { // window scaling option. const rcvBufferSize = 0x20000 const wndScale = 3 - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) - } + c.EP.SocketOptions().SetReceiveBufferSize(rcvBufferSize, true) // Start connection attempt. we, ch := waiter.NewChannelEntry(nil) @@ -4496,11 +4481,7 @@ func TestReusePort(t *testing.T) { func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { t.Helper() - s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) - if err != nil { - t.Fatalf("GetSockOpt failed: %s", err) - } - + s := ep.SocketOptions().GetReceiveBufferSize() if int(s) != v { t.Fatalf("got receive buffer size = %d, want = %d", s, v) } @@ -4606,10 +4587,7 @@ func TestMinMaxBufferSizes(t *testing.T) { } // Set values below the min/2. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 99); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 199) failed: %s", err) - } - + ep.SocketOptions().SetReceiveBufferSize(99, true) checkRecvBufferSize(t, ep, 200) ep.SocketOptions().SetSendBufferSize(149, true) @@ -4617,15 +4595,11 @@ func TestMinMaxBufferSizes(t *testing.T) { checkSendBufferSize(t, ep, 300) // Set values above the max. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) - } - + ep.SocketOptions().SetReceiveBufferSize(1+tcp.DefaultReceiveBufferSize*20, true) // Values above max are capped at max and then doubled. checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2) ep.SocketOptions().SetSendBufferSize(1+tcp.DefaultSendBufferSize*30, true) - // Values above max are capped at max and then doubled. checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2) } @@ -7633,8 +7607,7 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { // Increasing the buffer from should generate an ACK, // since window grew from small value to larger equal MSS - c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2) - + c.EP.SocketOptions().SetReceiveBufferSize(rcvBuf*2, true) checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index e73f90bb0..7578d64ec 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -757,9 +757,7 @@ func (c *Context) Create(epRcvBuf int) { } if epRcvBuf != -1 { - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, epRcvBuf); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) - } + c.EP.SocketOptions().SetReceiveBufferSize(int64(epRcvBuf), true /* notify */) } } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index f26c7ca10..c9f2f3efc 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,7 +15,6 @@ package udp import ( - "fmt" "io" "sync/atomic" @@ -89,12 +88,11 @@ type endpoint struct { // The following fields are used to manage the receive queue, and are // protected by rcvMu. - rcvMu sync.Mutex `state:"nosave"` - rcvReady bool - rcvList udpPacketList - rcvBufSizeMax int `state:".(int)"` - rcvBufSize int - rcvClosed bool + rcvMu sync.Mutex `state:"nosave"` + rcvReady bool + rcvList udpPacketList + rcvBufSize int + rcvClosed bool // The following fields are protected by the mu mutex. mu sync.RWMutex `state:"nosave"` @@ -144,6 +142,10 @@ type endpoint struct { // ops is used to get socket level options. ops tcpip.SocketOptions + + // frozen indicates if the packets should be delivered to the endpoint + // during restore. + frozen bool } // +stateify savable @@ -173,14 +175,14 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // // Linux defaults to TTL=1. multicastTTL: 1, - rcvBufSizeMax: 32 * 1024, multicastMemberships: make(map[multicastMembership]struct{}), state: StateInitial, uniqueID: s.UniqueID(), } - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) e.ops.SetMulticastLoop(true) e.ops.SetSendBufferSize(32*1024, false /* notify */) + e.ops.SetReceiveBufferSize(32*1024, false /* notify */) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -188,9 +190,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } - var rs stack.ReceiveBufferSizeOption + var rs tcpip.ReceiveBufferSizeOption if err := s.Option(&rs); err == nil { - e.rcvBufSizeMax = rs.Default + e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) } return e @@ -622,26 +624,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { e.mu.Lock() e.sendTOS = uint8(v) e.mu.Unlock() - - case tcpip.ReceiveBufferSizeOption: - // Make sure the receive buffer size is within the min and max - // allowed. - var rs stack.ReceiveBufferSizeOption - if err := e.stack.Option(&rs); err != nil { - panic(fmt.Sprintf("e.stack.Option(%#v) = %s", rs, err)) - } - - if v < rs.Min { - v = rs.Min - } - if v > rs.Max { - v = rs.Max - } - - e.mu.Lock() - e.rcvBufSizeMax = v - e.mu.Unlock() - return nil } return nil @@ -802,12 +784,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - v := e.rcvBufSizeMax - e.rcvMu.Unlock() - return v, nil - case tcpip.TTLOption: e.mu.Lock() v := int(e.ttl) @@ -1310,7 +1286,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB return } - if e.rcvBufSize >= e.rcvBufSizeMax { + rcvBufSize := e.ops.GetReceiveBufferSize() + if e.frozen || e.rcvBufSize >= int(rcvBufSize) { e.rcvMu.Unlock() e.stack.Stats().UDP.ReceiveBufferErrors.Increment() e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() @@ -1444,3 +1421,18 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } + +// freeze prevents any more packets from being delivered to the endpoint. +func (e *endpoint) freeze() { + e.mu.Lock() + e.frozen = true + e.mu.Unlock() +} + +// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows +// new packets to be delivered again. +func (e *endpoint) thaw() { + e.mu.Lock() + e.frozen = false + e.mu.Unlock() +} diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 21a6aa460..4aba68b21 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -37,43 +37,25 @@ func (u *udpPacket) loadData(data buffer.VectorisedView) { u.data = data } -// beforeSave is invoked by stateify. -func (e *endpoint) beforeSave() { - // Stop incoming packets from being handled (and mutate endpoint state). - // The lock will be released after savercvBufSizeMax(), which would have - // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming - // packets. - e.rcvMu.Lock() -} - -// saveRcvBufSizeMax is invoked by stateify. -func (e *endpoint) saveRcvBufSizeMax() int { - max := e.rcvBufSizeMax - // Make sure no new packets will be handled regardless of the lock. - e.rcvBufSizeMax = 0 - // Release the lock acquired in beforeSave() so regular endpoint closing - // logic can proceed after save. - e.rcvMu.Unlock() - return max -} - -// loadRcvBufSizeMax is invoked by stateify. -func (e *endpoint) loadRcvBufSizeMax(max int) { - e.rcvBufSizeMax = max -} - // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { stack.StackFromEnv.RegisterRestoredEndpoint(e) } +// beforeSave is invoked by stateify. +func (e *endpoint) beforeSave() { + e.freeze() +} + // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.thaw() + e.mu.Lock() defer e.mu.Unlock() e.stack = s - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) for m := range e.multicastMemberships { if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil { |