summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go53
-rw-r--r--pkg/tcpip/transport/icmp/endpoint_state.go33
-rw-r--r--pkg/tcpip/transport/icmp/icmp_state_autogen.go35
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go74
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go25
-rw-r--r--pkg/tcpip/transport/packet/packet_state_autogen.go35
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go76
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go33
-rw-r--r--pkg/tcpip/transport/raw/raw_state_autogen.go31
-rw-r--r--pkg/tcpip/transport/tcp/accept.go2
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go152
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go6
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go2
-rw-r--r--pkg/tcpip/transport/tcp/segment_queue.go4
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go68
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go34
-rw-r--r--pkg/tcpip/transport/udp/udp_state_autogen.go79
17 files changed, 318 insertions, 424 deletions
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 50991c3c0..33ed78f54 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -63,12 +63,11 @@ type endpoint struct {
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
- rcvMu sync.Mutex `state:"nosave"`
- rcvReady bool
- rcvList icmpPacketList
- rcvBufSizeMax int `state:".(int)"`
- rcvBufSize int
- rcvClosed bool
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvReady bool
+ rcvList icmpPacketList
+ rcvBufSize int
+ rcvClosed bool
// The following fields are protected by the mu mutex.
mu sync.RWMutex `state:"nosave"`
@@ -84,6 +83,10 @@ type endpoint struct {
// ops is used to get socket level options.
ops tcpip.SocketOptions
+
+ // frozen indicates if the packets should be delivered to the endpoint
+ // during restore.
+ frozen bool
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
@@ -93,19 +96,23 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
NetProto: netProto,
TransProto: transProto,
},
- waiterQueue: waiterQueue,
- rcvBufSizeMax: 32 * 1024,
- state: stateInitial,
- uniqueID: s.UniqueID(),
+ waiterQueue: waiterQueue,
+ state: stateInitial,
+ uniqueID: s.UniqueID(),
}
- ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits)
+ ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
ep.ops.SetSendBufferSize(32*1024, false /* notify */)
+ ep.ops.SetReceiveBufferSize(32*1024, false /* notify */)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
if err := s.Option(&ss); err == nil {
ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
}
+ var rs tcpip.ReceiveBufferSizeOption
+ if err := s.Option(&rs); err == nil {
+ ep.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */)
+ }
return ep, nil
}
@@ -371,12 +378,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
- case tcpip.ReceiveBufferSizeOption:
- e.rcvMu.Lock()
- v := e.rcvBufSizeMax
- e.rcvMu.Unlock()
- return v, nil
-
case tcpip.TTLOption:
e.rcvMu.Lock()
v := int(e.ttl)
@@ -774,7 +775,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
return
}
- if e.rcvBufSize >= e.rcvBufSizeMax {
+ rcvBufSize := e.ops.GetReceiveBufferSize()
+ if e.frozen || e.rcvBufSize >= int(rcvBufSize) {
e.rcvMu.Unlock()
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
@@ -843,3 +845,18 @@ func (*endpoint) LastError() tcpip.Error {
func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
return &e.ops
}
+
+// freeze prevents any more packets from being delivered to the endpoint.
+func (e *endpoint) freeze() {
+ e.mu.Lock()
+ e.frozen = true
+ e.mu.Unlock()
+}
+
+// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows
+// new packets to be delivered again.
+func (e *endpoint) thaw() {
+ e.mu.Lock()
+ e.frozen = false
+ e.mu.Unlock()
+}
diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go
index a3c6db5a8..28a56a2d5 100644
--- a/pkg/tcpip/transport/icmp/endpoint_state.go
+++ b/pkg/tcpip/transport/icmp/endpoint_state.go
@@ -36,40 +36,21 @@ func (p *icmpPacket) loadData(data buffer.VectorisedView) {
p.data = data
}
-// beforeSave is invoked by stateify.
-func (e *endpoint) beforeSave() {
- // Stop incoming packets from being handled (and mutate endpoint state).
- // The lock will be released after savercvBufSizeMax(), which would have
- // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming
- // packets.
- e.rcvMu.Lock()
-}
-
-// saveRcvBufSizeMax is invoked by stateify.
-func (e *endpoint) saveRcvBufSizeMax() int {
- max := e.rcvBufSizeMax
- // Make sure no new packets will be handled regardless of the lock.
- e.rcvBufSizeMax = 0
- // Release the lock acquired in beforeSave() so regular endpoint closing
- // logic can proceed after save.
- e.rcvMu.Unlock()
- return max
-}
-
-// loadRcvBufSizeMax is invoked by stateify.
-func (e *endpoint) loadRcvBufSizeMax(max int) {
- e.rcvBufSizeMax = max
-}
-
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
+// beforeSave is invoked by stateify.
+func (e *endpoint) beforeSave() {
+ e.freeze()
+}
+
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
+ e.thaw()
e.stack = s
- e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits)
+ e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
if e.state != stateBound && e.state != stateConnected {
return
diff --git a/pkg/tcpip/transport/icmp/icmp_state_autogen.go b/pkg/tcpip/transport/icmp/icmp_state_autogen.go
index 4175f0f8b..98c1b06a3 100644
--- a/pkg/tcpip/transport/icmp/icmp_state_autogen.go
+++ b/pkg/tcpip/transport/icmp/icmp_state_autogen.go
@@ -54,7 +54,6 @@ func (e *endpoint) StateFields() []string {
"uniqueID",
"rcvReady",
"rcvList",
- "rcvBufSizeMax",
"rcvBufSize",
"rcvClosed",
"shutdownFlags",
@@ -62,27 +61,27 @@ func (e *endpoint) StateFields() []string {
"ttl",
"owner",
"ops",
+ "frozen",
}
}
// +checklocksignore
func (e *endpoint) StateSave(stateSinkObject state.Sink) {
e.beforeSave()
- var rcvBufSizeMaxValue int = e.saveRcvBufSizeMax()
- stateSinkObject.SaveValue(6, rcvBufSizeMaxValue)
stateSinkObject.Save(0, &e.TransportEndpointInfo)
stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler)
stateSinkObject.Save(2, &e.waiterQueue)
stateSinkObject.Save(3, &e.uniqueID)
stateSinkObject.Save(4, &e.rcvReady)
stateSinkObject.Save(5, &e.rcvList)
- stateSinkObject.Save(7, &e.rcvBufSize)
- stateSinkObject.Save(8, &e.rcvClosed)
- stateSinkObject.Save(9, &e.shutdownFlags)
- stateSinkObject.Save(10, &e.state)
- stateSinkObject.Save(11, &e.ttl)
- stateSinkObject.Save(12, &e.owner)
- stateSinkObject.Save(13, &e.ops)
+ stateSinkObject.Save(6, &e.rcvBufSize)
+ stateSinkObject.Save(7, &e.rcvClosed)
+ stateSinkObject.Save(8, &e.shutdownFlags)
+ stateSinkObject.Save(9, &e.state)
+ stateSinkObject.Save(10, &e.ttl)
+ stateSinkObject.Save(11, &e.owner)
+ stateSinkObject.Save(12, &e.ops)
+ stateSinkObject.Save(13, &e.frozen)
}
// +checklocksignore
@@ -93,14 +92,14 @@ func (e *endpoint) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(3, &e.uniqueID)
stateSourceObject.Load(4, &e.rcvReady)
stateSourceObject.Load(5, &e.rcvList)
- stateSourceObject.Load(7, &e.rcvBufSize)
- stateSourceObject.Load(8, &e.rcvClosed)
- stateSourceObject.Load(9, &e.shutdownFlags)
- stateSourceObject.Load(10, &e.state)
- stateSourceObject.Load(11, &e.ttl)
- stateSourceObject.Load(12, &e.owner)
- stateSourceObject.Load(13, &e.ops)
- stateSourceObject.LoadValue(6, new(int), func(y interface{}) { e.loadRcvBufSizeMax(y.(int)) })
+ stateSourceObject.Load(6, &e.rcvBufSize)
+ stateSourceObject.Load(7, &e.rcvClosed)
+ stateSourceObject.Load(8, &e.shutdownFlags)
+ stateSourceObject.Load(9, &e.state)
+ stateSourceObject.Load(10, &e.ttl)
+ stateSourceObject.Load(11, &e.owner)
+ stateSourceObject.Load(12, &e.ops)
+ stateSourceObject.Load(13, &e.frozen)
stateSourceObject.AfterLoad(e.afterLoad)
}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 52ed9560c..496eca581 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -72,11 +72,10 @@ type endpoint struct {
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
- rcvMu sync.Mutex `state:"nosave"`
- rcvList packetList
- rcvBufSizeMax int `state:".(int)"`
- rcvBufSize int
- rcvClosed bool
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvList packetList
+ rcvBufSize int
+ rcvClosed bool
// The following fields are protected by mu.
mu sync.RWMutex `state:"nosave"`
@@ -91,6 +90,10 @@ type endpoint struct {
// ops is used to get socket level options.
ops tcpip.SocketOptions
+
+ // frozen indicates if the packets should be delivered to the endpoint
+ // during restore.
+ frozen bool
}
// NewEndpoint returns a new packet endpoint.
@@ -100,12 +103,12 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
TransportEndpointInfo: stack.TransportEndpointInfo{
NetProto: netProto,
},
- cooked: cooked,
- netProto: netProto,
- waiterQueue: waiterQueue,
- rcvBufSizeMax: 32 * 1024,
+ cooked: cooked,
+ netProto: netProto,
+ waiterQueue: waiterQueue,
}
- ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits)
+ ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
+ ep.ops.SetReceiveBufferSize(32*1024, false /* notify */)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
@@ -113,9 +116,9 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
}
- var rs stack.ReceiveBufferSizeOption
+ var rs tcpip.ReceiveBufferSizeOption
if err := s.Option(&rs); err == nil {
- ep.rcvBufSizeMax = rs.Default
+ ep.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */)
}
if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil {
@@ -316,28 +319,7 @@ func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
- switch opt {
- case tcpip.ReceiveBufferSizeOption:
- // Make sure the receive buffer size is within the min and max
- // allowed.
- var rs stack.ReceiveBufferSizeOption
- if err := ep.stack.Option(&rs); err != nil {
- panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err))
- }
- if v > rs.Max {
- v = rs.Max
- }
- if v < rs.Min {
- v = rs.Min
- }
- ep.rcvMu.Lock()
- ep.rcvBufSizeMax = v
- ep.rcvMu.Unlock()
- return nil
-
- default:
- return &tcpip.ErrUnknownProtocolOption{}
- }
+ return &tcpip.ErrUnknownProtocolOption{}
}
func (ep *endpoint) LastError() tcpip.Error {
@@ -374,12 +356,6 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
ep.rcvMu.Unlock()
return v, nil
- case tcpip.ReceiveBufferSizeOption:
- ep.rcvMu.Lock()
- v := ep.rcvBufSizeMax
- ep.rcvMu.Unlock()
- return v, nil
-
default:
return -1, &tcpip.ErrUnknownProtocolOption{}
}
@@ -397,7 +373,8 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
return
}
- if ep.rcvBufSize >= ep.rcvBufSizeMax {
+ rcvBufSize := ep.ops.GetReceiveBufferSize()
+ if ep.frozen || ep.rcvBufSize >= int(rcvBufSize) {
ep.rcvMu.Unlock()
ep.stack.Stats().DroppedPackets.Increment()
ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
@@ -513,3 +490,18 @@ func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {}
func (ep *endpoint) SocketOptions() *tcpip.SocketOptions {
return &ep.ops
}
+
+// freeze prevents any more packets from being delivered to the endpoint.
+func (ep *endpoint) freeze() {
+ ep.mu.Lock()
+ ep.frozen = true
+ ep.mu.Unlock()
+}
+
+// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows
+// new packets to be delivered again.
+func (ep *endpoint) thaw() {
+ ep.mu.Lock()
+ ep.frozen = false
+ ep.mu.Unlock()
+}
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
index ece662c0d..5bd860d20 100644
--- a/pkg/tcpip/transport/packet/endpoint_state.go
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -38,33 +38,14 @@ func (p *packet) loadData(data buffer.VectorisedView) {
// beforeSave is invoked by stateify.
func (ep *endpoint) beforeSave() {
- // Stop incoming packets from being handled (and mutate endpoint state).
- // The lock will be released after saveRcvBufSizeMax(), which would have
- // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming
- // packets.
- ep.rcvMu.Lock()
-}
-
-// saveRcvBufSizeMax is invoked by stateify.
-func (ep *endpoint) saveRcvBufSizeMax() int {
- max := ep.rcvBufSizeMax
- // Make sure no new packets will be handled regardless of the lock.
- ep.rcvBufSizeMax = 0
- // Release the lock acquired in beforeSave() so regular endpoint closing
- // logic can proceed after save.
- ep.rcvMu.Unlock()
- return max
-}
-
-// loadRcvBufSizeMax is invoked by stateify.
-func (ep *endpoint) loadRcvBufSizeMax(max int) {
- ep.rcvBufSizeMax = max
+ ep.freeze()
}
// afterLoad is invoked by stateify.
func (ep *endpoint) afterLoad() {
+ ep.thaw()
ep.stack = stack.StackFromEnv
- ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits)
+ ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
// TODO(gvisor.dev/173): Once bind is supported, choose the right NIC.
if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil {
diff --git a/pkg/tcpip/transport/packet/packet_state_autogen.go b/pkg/tcpip/transport/packet/packet_state_autogen.go
index f8b82e575..b354c87b1 100644
--- a/pkg/tcpip/transport/packet/packet_state_autogen.go
+++ b/pkg/tcpip/transport/packet/packet_state_autogen.go
@@ -57,7 +57,6 @@ func (ep *endpoint) StateFields() []string {
"waiterQueue",
"cooked",
"rcvList",
- "rcvBufSizeMax",
"rcvBufSize",
"rcvClosed",
"closed",
@@ -65,27 +64,27 @@ func (ep *endpoint) StateFields() []string {
"boundNIC",
"lastError",
"ops",
+ "frozen",
}
}
// +checklocksignore
func (ep *endpoint) StateSave(stateSinkObject state.Sink) {
ep.beforeSave()
- var rcvBufSizeMaxValue int = ep.saveRcvBufSizeMax()
- stateSinkObject.SaveValue(6, rcvBufSizeMaxValue)
stateSinkObject.Save(0, &ep.TransportEndpointInfo)
stateSinkObject.Save(1, &ep.DefaultSocketOptionsHandler)
stateSinkObject.Save(2, &ep.netProto)
stateSinkObject.Save(3, &ep.waiterQueue)
stateSinkObject.Save(4, &ep.cooked)
stateSinkObject.Save(5, &ep.rcvList)
- stateSinkObject.Save(7, &ep.rcvBufSize)
- stateSinkObject.Save(8, &ep.rcvClosed)
- stateSinkObject.Save(9, &ep.closed)
- stateSinkObject.Save(10, &ep.bound)
- stateSinkObject.Save(11, &ep.boundNIC)
- stateSinkObject.Save(12, &ep.lastError)
- stateSinkObject.Save(13, &ep.ops)
+ stateSinkObject.Save(6, &ep.rcvBufSize)
+ stateSinkObject.Save(7, &ep.rcvClosed)
+ stateSinkObject.Save(8, &ep.closed)
+ stateSinkObject.Save(9, &ep.bound)
+ stateSinkObject.Save(10, &ep.boundNIC)
+ stateSinkObject.Save(11, &ep.lastError)
+ stateSinkObject.Save(12, &ep.ops)
+ stateSinkObject.Save(13, &ep.frozen)
}
// +checklocksignore
@@ -96,14 +95,14 @@ func (ep *endpoint) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(3, &ep.waiterQueue)
stateSourceObject.Load(4, &ep.cooked)
stateSourceObject.Load(5, &ep.rcvList)
- stateSourceObject.Load(7, &ep.rcvBufSize)
- stateSourceObject.Load(8, &ep.rcvClosed)
- stateSourceObject.Load(9, &ep.closed)
- stateSourceObject.Load(10, &ep.bound)
- stateSourceObject.Load(11, &ep.boundNIC)
- stateSourceObject.Load(12, &ep.lastError)
- stateSourceObject.Load(13, &ep.ops)
- stateSourceObject.LoadValue(6, new(int), func(y interface{}) { ep.loadRcvBufSizeMax(y.(int)) })
+ stateSourceObject.Load(6, &ep.rcvBufSize)
+ stateSourceObject.Load(7, &ep.rcvClosed)
+ stateSourceObject.Load(8, &ep.closed)
+ stateSourceObject.Load(9, &ep.bound)
+ stateSourceObject.Load(10, &ep.boundNIC)
+ stateSourceObject.Load(11, &ep.lastError)
+ stateSourceObject.Load(12, &ep.ops)
+ stateSourceObject.Load(13, &ep.frozen)
stateSourceObject.AfterLoad(ep.afterLoad)
}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index e27a249cd..10453a42a 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -26,7 +26,6 @@
package raw
import (
- "fmt"
"io"
"gvisor.dev/gvisor/pkg/sync"
@@ -69,11 +68,10 @@ type endpoint struct {
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
- rcvMu sync.Mutex `state:"nosave"`
- rcvList rawPacketList
- rcvBufSize int
- rcvBufSizeMax int `state:".(int)"`
- rcvClosed bool
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvList rawPacketList
+ rcvBufSize int
+ rcvClosed bool
// The following fields are protected by mu.
mu sync.RWMutex `state:"nosave"`
@@ -89,6 +87,10 @@ type endpoint struct {
// ops is used to get socket level options.
ops tcpip.SocketOptions
+
+ // frozen indicates if the packets should be delivered to the endpoint
+ // during restore.
+ frozen bool
}
// NewEndpoint returns a raw endpoint for the given protocols.
@@ -107,13 +109,13 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
NetProto: netProto,
TransProto: transProto,
},
- waiterQueue: waiterQueue,
- rcvBufSizeMax: 32 * 1024,
- associated: associated,
+ waiterQueue: waiterQueue,
+ associated: associated,
}
- e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits)
+ e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
e.ops.SetHeaderIncluded(!associated)
e.ops.SetSendBufferSize(32*1024, false /* notify */)
+ e.ops.SetReceiveBufferSize(32*1024, false /* notify */)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
@@ -121,16 +123,16 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
}
- var rs stack.ReceiveBufferSizeOption
+ var rs tcpip.ReceiveBufferSizeOption
if err := s.Option(&rs); err == nil {
- e.rcvBufSizeMax = rs.Default
+ e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */)
}
// Unassociated endpoints are write-only and users call Write() with IP
// headers included. Because they're write-only, We don't need to
// register with the stack.
if !associated {
- e.rcvBufSizeMax = 0
+ e.ops.SetReceiveBufferSize(0, false)
e.waiterQueue = nil
return e, nil
}
@@ -511,30 +513,8 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
}
}
-// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
- switch opt {
- case tcpip.ReceiveBufferSizeOption:
- // Make sure the receive buffer size is within the min and max
- // allowed.
- var rs stack.ReceiveBufferSizeOption
- if err := e.stack.Option(&rs); err != nil {
- panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err))
- }
- if v > rs.Max {
- v = rs.Max
- }
- if v < rs.Min {
- v = rs.Min
- }
- e.rcvMu.Lock()
- e.rcvBufSizeMax = v
- e.rcvMu.Unlock()
- return nil
-
- default:
- return &tcpip.ErrUnknownProtocolOption{}
- }
+ return &tcpip.ErrUnknownProtocolOption{}
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
@@ -555,12 +535,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
- case tcpip.ReceiveBufferSizeOption:
- e.rcvMu.Lock()
- v := e.rcvBufSizeMax
- e.rcvMu.Unlock()
- return v, nil
-
default:
return -1, &tcpip.ErrUnknownProtocolOption{}
}
@@ -587,7 +561,8 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
- if e.rcvBufSize >= e.rcvBufSizeMax {
+ rcvBufSize := e.ops.GetReceiveBufferSize()
+ if e.frozen || e.rcvBufSize >= int(rcvBufSize) {
e.rcvMu.Unlock()
e.mu.RUnlock()
e.stack.Stats().DroppedPackets.Increment()
@@ -690,3 +665,18 @@ func (*endpoint) LastError() tcpip.Error {
func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
return &e.ops
}
+
+// freeze prevents any more packets from being delivered to the endpoint.
+func (e *endpoint) freeze() {
+ e.mu.Lock()
+ e.frozen = true
+ e.mu.Unlock()
+}
+
+// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows
+// new packets to be delivered again.
+func (e *endpoint) thaw() {
+ e.mu.Lock()
+ e.frozen = false
+ e.mu.Unlock()
+}
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index 263ec5146..5d6f2709c 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -36,40 +36,21 @@ func (p *rawPacket) loadData(data buffer.VectorisedView) {
p.data = data
}
-// beforeSave is invoked by stateify.
-func (e *endpoint) beforeSave() {
- // Stop incoming packets from being handled (and mutate endpoint state).
- // The lock will be released after saveRcvBufSizeMax(), which would have
- // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming
- // packets.
- e.rcvMu.Lock()
-}
-
-// saveRcvBufSizeMax is invoked by stateify.
-func (e *endpoint) saveRcvBufSizeMax() int {
- max := e.rcvBufSizeMax
- // Make sure no new packets will be handled regardless of the lock.
- e.rcvBufSizeMax = 0
- // Release the lock acquired in beforeSave() so regular endpoint closing
- // logic can proceed after save.
- e.rcvMu.Unlock()
- return max
-}
-
-// loadRcvBufSizeMax is invoked by stateify.
-func (e *endpoint) loadRcvBufSizeMax(max int) {
- e.rcvBufSizeMax = max
-}
-
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
+// beforeSave is invoked by stateify.
+func (e *endpoint) beforeSave() {
+ e.freeze()
+}
+
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
+ e.thaw()
e.stack = s
- e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits)
+ e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
// If the endpoint is connected, re-connect.
if e.connected {
diff --git a/pkg/tcpip/transport/raw/raw_state_autogen.go b/pkg/tcpip/transport/raw/raw_state_autogen.go
index 3bcfc7c61..2bcd983a2 100644
--- a/pkg/tcpip/transport/raw/raw_state_autogen.go
+++ b/pkg/tcpip/transport/raw/raw_state_autogen.go
@@ -54,33 +54,32 @@ func (e *endpoint) StateFields() []string {
"associated",
"rcvList",
"rcvBufSize",
- "rcvBufSizeMax",
"rcvClosed",
"closed",
"connected",
"bound",
"owner",
"ops",
+ "frozen",
}
}
// +checklocksignore
func (e *endpoint) StateSave(stateSinkObject state.Sink) {
e.beforeSave()
- var rcvBufSizeMaxValue int = e.saveRcvBufSizeMax()
- stateSinkObject.SaveValue(6, rcvBufSizeMaxValue)
stateSinkObject.Save(0, &e.TransportEndpointInfo)
stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler)
stateSinkObject.Save(2, &e.waiterQueue)
stateSinkObject.Save(3, &e.associated)
stateSinkObject.Save(4, &e.rcvList)
stateSinkObject.Save(5, &e.rcvBufSize)
- stateSinkObject.Save(7, &e.rcvClosed)
- stateSinkObject.Save(8, &e.closed)
- stateSinkObject.Save(9, &e.connected)
- stateSinkObject.Save(10, &e.bound)
- stateSinkObject.Save(11, &e.owner)
- stateSinkObject.Save(12, &e.ops)
+ stateSinkObject.Save(6, &e.rcvClosed)
+ stateSinkObject.Save(7, &e.closed)
+ stateSinkObject.Save(8, &e.connected)
+ stateSinkObject.Save(9, &e.bound)
+ stateSinkObject.Save(10, &e.owner)
+ stateSinkObject.Save(11, &e.ops)
+ stateSinkObject.Save(12, &e.frozen)
}
// +checklocksignore
@@ -91,13 +90,13 @@ func (e *endpoint) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(3, &e.associated)
stateSourceObject.Load(4, &e.rcvList)
stateSourceObject.Load(5, &e.rcvBufSize)
- stateSourceObject.Load(7, &e.rcvClosed)
- stateSourceObject.Load(8, &e.closed)
- stateSourceObject.Load(9, &e.connected)
- stateSourceObject.Load(10, &e.bound)
- stateSourceObject.Load(11, &e.owner)
- stateSourceObject.Load(12, &e.ops)
- stateSourceObject.LoadValue(6, new(int), func(y interface{}) { e.loadRcvBufSizeMax(y.(int)) })
+ stateSourceObject.Load(6, &e.rcvClosed)
+ stateSourceObject.Load(7, &e.closed)
+ stateSourceObject.Load(8, &e.connected)
+ stateSourceObject.Load(9, &e.bound)
+ stateSourceObject.Load(10, &e.owner)
+ stateSourceObject.Load(11, &e.ops)
+ stateSourceObject.Load(12, &e.frozen)
stateSourceObject.AfterLoad(e.afterLoad)
}
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 664cb9420..d4bd4e80e 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -219,7 +219,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header
n.boundNICID = s.nicID
n.route = route
n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.netProto}
- n.rcvQueueInfo.RcvBufSize = int(l.rcvWnd)
+ n.ops.SetReceiveBufferSize(int64(l.rcvWnd), false /* notify */)
n.amss = calculateAdvertisedMSS(n.userMSS, n.route)
n.setEndpointState(StateConnecting)
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 884332828..f25dc781a 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -822,11 +822,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
windowClamp: DefaultReceiveBufferSize,
maxSynRetries: DefaultSynRetries,
}
- e.rcvQueueInfo.RcvBufSize = DefaultReceiveBufferSize
- e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits)
+ e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits)
e.ops.SetMulticastLoop(true)
e.ops.SetQuickAck(true)
e.ops.SetSendBufferSize(DefaultSendBufferSize, false /* notify */)
+ e.ops.SetReceiveBufferSize(DefaultReceiveBufferSize, false /* notify */)
var ss tcpip.TCPSendBufferSizeRangeOption
if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
@@ -835,7 +835,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
var rs tcpip.TCPReceiveBufferSizeRangeOption
if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
- e.rcvQueueInfo.RcvBufSize = rs.Default
+ e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */)
}
var cs tcpip.CongestionControlOption
@@ -1228,11 +1228,12 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
// We do not adjust downwards as that can cause the receiver to
// reject valid data that might already be in flight as the
// acceptable window will shrink.
- if rcvWnd > e.rcvQueueInfo.RcvBufSize {
- availBefore := wndFromSpace(e.receiveBufferAvailableLocked())
- e.rcvQueueInfo.RcvBufSize = rcvWnd
- availAfter := wndFromSpace(e.receiveBufferAvailableLocked())
- if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
+ rcvBufSize := int(e.ops.GetReceiveBufferSize())
+ if rcvWnd > rcvBufSize {
+ availBefore := wndFromSpace(e.receiveBufferAvailableLocked(rcvBufSize))
+ e.ops.SetReceiveBufferSize(int64(rcvWnd), false /* notify */)
+ availAfter := wndFromSpace(e.receiveBufferAvailableLocked(rcvWnd))
+ if crossed, above := e.windowCrossedACKThresholdLocked(availAfter-availBefore, rcvBufSize); crossed && above {
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
}
@@ -1424,7 +1425,7 @@ func (e *endpoint) commitRead(done int) *segment {
// enough buffer space, to either fit an aMSS or half a receive buffer
// (whichever smaller), then notify the protocol goroutine to send a
// window update.
- if crossed, above := e.windowCrossedACKThresholdLocked(memDelta); crossed && above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(memDelta, int(e.ops.GetReceiveBufferSize())); crossed && above {
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
}
@@ -1556,9 +1557,9 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
// selectWindowLocked returns the new window without checking for shrinking or scaling
// applied.
// Precondition: e.mu and e.rcvQueueMu must be held.
-func (e *endpoint) selectWindowLocked() (wnd seqnum.Size) {
- wndFromAvailable := wndFromSpace(e.receiveBufferAvailableLocked())
- maxWindow := wndFromSpace(e.rcvQueueInfo.RcvBufSize)
+func (e *endpoint) selectWindowLocked(rcvBufSize int) (wnd seqnum.Size) {
+ wndFromAvailable := wndFromSpace(e.receiveBufferAvailableLocked(rcvBufSize))
+ maxWindow := wndFromSpace(rcvBufSize)
wndFromUsedBytes := maxWindow - e.rcvQueueInfo.RcvBufUsed
// We take the lesser of the wndFromAvailable and wndFromUsedBytes because in
@@ -1580,7 +1581,7 @@ func (e *endpoint) selectWindowLocked() (wnd seqnum.Size) {
// selectWindow invokes selectWindowLocked after acquiring e.rcvQueueMu.
func (e *endpoint) selectWindow() (wnd seqnum.Size) {
e.rcvQueueInfo.rcvQueueMu.Lock()
- wnd = e.selectWindowLocked()
+ wnd = e.selectWindowLocked(int(e.ops.GetReceiveBufferSize()))
e.rcvQueueInfo.rcvQueueMu.Unlock()
return wnd
}
@@ -1600,8 +1601,8 @@ func (e *endpoint) selectWindow() (wnd seqnum.Size) {
// otherwise.
//
// Precondition: e.mu and e.rcvQueueMu must be held.
-func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) {
- newAvail := int(e.selectWindowLocked())
+func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int, rcvBufSize int) (crossed bool, above bool) {
+ newAvail := int(e.selectWindowLocked(rcvBufSize))
oldAvail := newAvail - deltaBefore
if oldAvail < 0 {
oldAvail = 0
@@ -1610,7 +1611,7 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo
// rcvBufFraction is the inverse of the fraction of receive buffer size that
// is used to decide if the available buffer space is now above it.
const rcvBufFraction = 2
- if wndThreshold := wndFromSpace(e.rcvQueueInfo.RcvBufSize / rcvBufFraction); threshold > wndThreshold {
+ if wndThreshold := wndFromSpace(rcvBufSize / rcvBufFraction); threshold > wndThreshold {
threshold = wndThreshold
}
switch {
@@ -1661,6 +1662,37 @@ func (e *endpoint) getSendBufferSize() int {
return int(e.ops.GetSendBufferSize())
}
+// OnSetReceiveBufferSize implements tcpip.SocketOptionsHandler.OnSetReceiveBufferSize.
+func (e *endpoint) OnSetReceiveBufferSize(rcvBufSz, oldSz int64) (newSz int64) {
+ e.LockUser()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+
+ // Make sure the receive buffer size allows us to send a
+ // non-zero window size.
+ scale := uint8(0)
+ if e.rcv != nil {
+ scale = e.rcv.RcvWndScale
+ }
+ if rcvBufSz>>scale == 0 {
+ rcvBufSz = 1 << scale
+ }
+
+ availBefore := wndFromSpace(e.receiveBufferAvailableLocked(int(oldSz)))
+ availAfter := wndFromSpace(e.receiveBufferAvailableLocked(int(rcvBufSz)))
+ e.rcvQueueInfo.RcvAutoParams.Disabled = true
+
+ // Immediately send an ACK to uncork the sender silly window
+ // syndrome prevetion, when our available space grows above aMSS
+ // or half receive buffer, whichever smaller.
+ if crossed, above := e.windowCrossedACKThresholdLocked(availAfter-availBefore, int(rcvBufSz)); crossed && above {
+ e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
+ }
+
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
+ e.UnlockUser()
+ return rcvBufSz
+}
+
// SetSockOptInt sets a socket option.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
// Lower 2 bits represents ECN bits. RFC 3168, section 23.1
@@ -1704,56 +1736,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
return &tcpip.ErrNotSupported{}
}
- case tcpip.ReceiveBufferSizeOption:
- // Make sure the receive buffer size is within the min and max
- // allowed.
- var rs tcpip.TCPReceiveBufferSizeRangeOption
- if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil {
- panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &rs, err))
- }
-
- if v > rs.Max {
- v = rs.Max
- }
-
- if v < math.MaxInt32/SegOverheadFactor {
- v *= SegOverheadFactor
- if v < rs.Min {
- v = rs.Min
- }
- } else {
- v = math.MaxInt32
- }
-
- e.LockUser()
- e.rcvQueueInfo.rcvQueueMu.Lock()
-
- // Make sure the receive buffer size allows us to send a
- // non-zero window size.
- scale := uint8(0)
- if e.rcv != nil {
- scale = e.rcv.RcvWndScale
- }
- if v>>scale == 0 {
- v = 1 << scale
- }
-
- availBefore := wndFromSpace(e.receiveBufferAvailableLocked())
- e.rcvQueueInfo.RcvBufSize = v
- availAfter := wndFromSpace(e.receiveBufferAvailableLocked())
-
- e.rcvQueueInfo.RcvAutoParams.Disabled = true
-
- // Immediately send an ACK to uncork the sender silly window
- // syndrome prevetion, when our available space grows above aMSS
- // or half receive buffer, whichever smaller.
- if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
- e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
- }
-
- e.rcvQueueInfo.rcvQueueMu.Unlock()
- e.UnlockUser()
-
case tcpip.TTLOption:
e.LockUser()
e.ttl = uint8(v)
@@ -1939,12 +1921,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
- case tcpip.ReceiveBufferSizeOption:
- e.rcvQueueInfo.rcvQueueMu.Lock()
- v := e.rcvQueueInfo.RcvBufSize
- e.rcvQueueInfo.rcvQueueMu.Unlock()
- return v, nil
-
case tcpip.TTLOption:
e.LockUser()
v := int(e.ttl)
@@ -2780,15 +2756,15 @@ func (e *endpoint) readyToRead(s *segment) {
// receiveBufferAvailableLocked calculates how many bytes are still available
// in the receive buffer.
// rcvQueueMu must be held when this function is called.
-func (e *endpoint) receiveBufferAvailableLocked() int {
+func (e *endpoint) receiveBufferAvailableLocked(rcvBufSize int) int {
// We may use more bytes than the buffer size when the receive buffer
// shrinks.
memUsed := e.receiveMemUsed()
- if memUsed >= e.rcvQueueInfo.RcvBufSize {
+ if memUsed >= rcvBufSize {
return 0
}
- return e.rcvQueueInfo.RcvBufSize - memUsed
+ return rcvBufSize - memUsed
}
// receiveBufferAvailable calculates how many bytes are still available in the
@@ -2796,7 +2772,7 @@ func (e *endpoint) receiveBufferAvailableLocked() int {
// receive buffer/pending and segment queue.
func (e *endpoint) receiveBufferAvailable() int {
e.rcvQueueInfo.rcvQueueMu.Lock()
- available := e.receiveBufferAvailableLocked()
+ available := e.receiveBufferAvailableLocked(int(e.ops.GetReceiveBufferSize()))
e.rcvQueueInfo.rcvQueueMu.Unlock()
return available
}
@@ -2809,14 +2785,6 @@ func (e *endpoint) receiveBufferUsed() int {
return used
}
-// receiveBufferSize returns the current size of the receive buffer.
-func (e *endpoint) receiveBufferSize() int {
- e.rcvQueueInfo.rcvQueueMu.Lock()
- size := e.rcvQueueInfo.RcvBufSize
- e.rcvQueueInfo.rcvQueueMu.Unlock()
- return size
-}
-
// receiveMemUsed returns the total memory in use by segments held by this
// endpoint.
func (e *endpoint) receiveMemUsed() int {
@@ -2845,7 +2813,7 @@ func (e *endpoint) maxReceiveBufferSize() int {
// receiveBuffer otherwise we use the max permissible receive buffer size to
// compute the scale.
func (e *endpoint) rcvWndScaleForHandshake() int {
- bufSizeForScale := e.receiveBufferSize()
+ bufSizeForScale := e.ops.GetReceiveBufferSize()
e.rcvQueueInfo.rcvQueueMu.Lock()
autoTuningDisabled := e.rcvQueueInfo.RcvAutoParams.Disabled
@@ -3074,3 +3042,17 @@ func (e *endpoint) allowOutOfWindowAck() bool {
e.lastOutOfWindowAckTime = now
return true
}
+
+// GetTCPReceiveBufferLimits is used to get send buffer size limits for TCP.
+func GetTCPReceiveBufferLimits(s tcpip.StackHandler) tcpip.ReceiveBufferSizeOption {
+ var ss tcpip.TCPReceiveBufferSizeRangeOption
+ if err := s.TransportProtocolOption(header.TCPProtocolNumber, &ss); err != nil {
+ panic(fmt.Sprintf("s.TransportProtocolOption(%d, %#v) = %s", header.TCPProtocolNumber, ss, err))
+ }
+
+ return tcpip.ReceiveBufferSizeOption{
+ Min: ss.Min,
+ Default: ss.Default,
+ Max: ss.Max,
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 034eacd72..6e9777fe4 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -165,7 +165,7 @@ func (e *endpoint) afterLoad() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
- e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits)
+ e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits)
e.segmentQueue.thaw()
epState := e.origEndpointState
switch epState {
@@ -180,8 +180,8 @@ func (e *endpoint) Resume(s *stack.Stack) {
var rs tcpip.TCPReceiveBufferSizeRangeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
- if e.rcvQueueInfo.RcvBufSize < rs.Min || e.rcvQueueInfo.RcvBufSize > rs.Max {
- panic(fmt.Sprintf("endpoint.rcvQueueInfo.RcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvQueueInfo.RcvBufSize, rs.Min, rs.Max))
+ if rcvBufSize := e.ops.GetReceiveBufferSize(); rcvBufSize < int64(rs.Min) || rcvBufSize > int64(rs.Max) {
+ panic(fmt.Sprintf("endpoint rcvBufSize %d is outside the min and max allowed [%d, %d]", rcvBufSize, rs.Min, rs.Max))
}
}
}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index fc11b4ba9..ee2c08cd6 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -466,7 +466,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) {
// segments. This ensures that we always leave some space for the inorder
// segments to arrive allowing pending segments to be processed and
// delivered to the user.
- if r.ep.receiveBufferAvailable() > 0 && r.PendingBufUsed < r.ep.receiveBufferSize()>>2 {
+ if rcvBufSize := r.ep.ops.GetReceiveBufferSize(); rcvBufSize > 0 && r.PendingBufUsed < int(rcvBufSize)>>2 {
r.ep.rcvQueueInfo.rcvQueueMu.Lock()
r.PendingBufUsed += s.segMemSize()
r.ep.rcvQueueInfo.rcvQueueMu.Unlock()
diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go
index 54545a1b1..d0d1b0b8a 100644
--- a/pkg/tcpip/transport/tcp/segment_queue.go
+++ b/pkg/tcpip/transport/tcp/segment_queue.go
@@ -52,12 +52,12 @@ func (q *segmentQueue) empty() bool {
func (q *segmentQueue) enqueue(s *segment) bool {
// q.ep.receiveBufferParams() must be called without holding q.mu to
// avoid lock order inversion.
- bufSz := q.ep.receiveBufferSize()
+ bufSz := q.ep.ops.GetReceiveBufferSize()
used := q.ep.receiveMemUsed()
q.mu.Lock()
// Allow zero sized segments (ACK/FIN/RSTs etc even if the segment queue
// is currently full).
- allow := (used <= bufSz || s.payloadSize() == 0) && !q.frozen
+ allow := (used <= int(bufSz) || s.payloadSize() == 0) && !q.frozen
if allow {
q.list.PushBack(s)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index f26c7ca10..c9f2f3efc 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -15,7 +15,6 @@
package udp
import (
- "fmt"
"io"
"sync/atomic"
@@ -89,12 +88,11 @@ type endpoint struct {
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
- rcvMu sync.Mutex `state:"nosave"`
- rcvReady bool
- rcvList udpPacketList
- rcvBufSizeMax int `state:".(int)"`
- rcvBufSize int
- rcvClosed bool
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvReady bool
+ rcvList udpPacketList
+ rcvBufSize int
+ rcvClosed bool
// The following fields are protected by the mu mutex.
mu sync.RWMutex `state:"nosave"`
@@ -144,6 +142,10 @@ type endpoint struct {
// ops is used to get socket level options.
ops tcpip.SocketOptions
+
+ // frozen indicates if the packets should be delivered to the endpoint
+ // during restore.
+ frozen bool
}
// +stateify savable
@@ -173,14 +175,14 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
//
// Linux defaults to TTL=1.
multicastTTL: 1,
- rcvBufSizeMax: 32 * 1024,
multicastMemberships: make(map[multicastMembership]struct{}),
state: StateInitial,
uniqueID: s.UniqueID(),
}
- e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits)
+ e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
e.ops.SetMulticastLoop(true)
e.ops.SetSendBufferSize(32*1024, false /* notify */)
+ e.ops.SetReceiveBufferSize(32*1024, false /* notify */)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
@@ -188,9 +190,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
}
- var rs stack.ReceiveBufferSizeOption
+ var rs tcpip.ReceiveBufferSizeOption
if err := s.Option(&rs); err == nil {
- e.rcvBufSizeMax = rs.Default
+ e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */)
}
return e
@@ -622,26 +624,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
e.mu.Lock()
e.sendTOS = uint8(v)
e.mu.Unlock()
-
- case tcpip.ReceiveBufferSizeOption:
- // Make sure the receive buffer size is within the min and max
- // allowed.
- var rs stack.ReceiveBufferSizeOption
- if err := e.stack.Option(&rs); err != nil {
- panic(fmt.Sprintf("e.stack.Option(%#v) = %s", rs, err))
- }
-
- if v < rs.Min {
- v = rs.Min
- }
- if v > rs.Max {
- v = rs.Max
- }
-
- e.mu.Lock()
- e.rcvBufSizeMax = v
- e.mu.Unlock()
- return nil
}
return nil
@@ -802,12 +784,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
- case tcpip.ReceiveBufferSizeOption:
- e.rcvMu.Lock()
- v := e.rcvBufSizeMax
- e.rcvMu.Unlock()
- return v, nil
-
case tcpip.TTLOption:
e.mu.Lock()
v := int(e.ttl)
@@ -1310,7 +1286,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
return
}
- if e.rcvBufSize >= e.rcvBufSizeMax {
+ rcvBufSize := e.ops.GetReceiveBufferSize()
+ if e.frozen || e.rcvBufSize >= int(rcvBufSize) {
e.rcvMu.Unlock()
e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
@@ -1444,3 +1421,18 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
return &e.ops
}
+
+// freeze prevents any more packets from being delivered to the endpoint.
+func (e *endpoint) freeze() {
+ e.mu.Lock()
+ e.frozen = true
+ e.mu.Unlock()
+}
+
+// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows
+// new packets to be delivered again.
+func (e *endpoint) thaw() {
+ e.mu.Lock()
+ e.frozen = false
+ e.mu.Unlock()
+}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 21a6aa460..4aba68b21 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -37,43 +37,25 @@ func (u *udpPacket) loadData(data buffer.VectorisedView) {
u.data = data
}
-// beforeSave is invoked by stateify.
-func (e *endpoint) beforeSave() {
- // Stop incoming packets from being handled (and mutate endpoint state).
- // The lock will be released after savercvBufSizeMax(), which would have
- // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming
- // packets.
- e.rcvMu.Lock()
-}
-
-// saveRcvBufSizeMax is invoked by stateify.
-func (e *endpoint) saveRcvBufSizeMax() int {
- max := e.rcvBufSizeMax
- // Make sure no new packets will be handled regardless of the lock.
- e.rcvBufSizeMax = 0
- // Release the lock acquired in beforeSave() so regular endpoint closing
- // logic can proceed after save.
- e.rcvMu.Unlock()
- return max
-}
-
-// loadRcvBufSizeMax is invoked by stateify.
-func (e *endpoint) loadRcvBufSizeMax(max int) {
- e.rcvBufSizeMax = max
-}
-
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
+// beforeSave is invoked by stateify.
+func (e *endpoint) beforeSave() {
+ e.freeze()
+}
+
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
+ e.thaw()
+
e.mu.Lock()
defer e.mu.Unlock()
e.stack = s
- e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits)
+ e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
for m := range e.multicastMemberships {
if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil {
diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go
index 5d7859a9e..092daa0b8 100644
--- a/pkg/tcpip/transport/udp/udp_state_autogen.go
+++ b/pkg/tcpip/transport/udp/udp_state_autogen.go
@@ -63,7 +63,6 @@ func (e *endpoint) StateFields() []string {
"uniqueID",
"rcvReady",
"rcvList",
- "rcvBufSizeMax",
"rcvBufSize",
"rcvClosed",
"state",
@@ -82,38 +81,38 @@ func (e *endpoint) StateFields() []string {
"effectiveNetProtos",
"owner",
"ops",
+ "frozen",
}
}
// +checklocksignore
func (e *endpoint) StateSave(stateSinkObject state.Sink) {
e.beforeSave()
- var rcvBufSizeMaxValue int = e.saveRcvBufSizeMax()
- stateSinkObject.SaveValue(6, rcvBufSizeMaxValue)
stateSinkObject.Save(0, &e.TransportEndpointInfo)
stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler)
stateSinkObject.Save(2, &e.waiterQueue)
stateSinkObject.Save(3, &e.uniqueID)
stateSinkObject.Save(4, &e.rcvReady)
stateSinkObject.Save(5, &e.rcvList)
- stateSinkObject.Save(7, &e.rcvBufSize)
- stateSinkObject.Save(8, &e.rcvClosed)
- stateSinkObject.Save(9, &e.state)
- stateSinkObject.Save(10, &e.dstPort)
- stateSinkObject.Save(11, &e.ttl)
- stateSinkObject.Save(12, &e.multicastTTL)
- stateSinkObject.Save(13, &e.multicastAddr)
- stateSinkObject.Save(14, &e.multicastNICID)
- stateSinkObject.Save(15, &e.portFlags)
- stateSinkObject.Save(16, &e.lastError)
- stateSinkObject.Save(17, &e.boundBindToDevice)
- stateSinkObject.Save(18, &e.boundPortFlags)
- stateSinkObject.Save(19, &e.sendTOS)
- stateSinkObject.Save(20, &e.shutdownFlags)
- stateSinkObject.Save(21, &e.multicastMemberships)
- stateSinkObject.Save(22, &e.effectiveNetProtos)
- stateSinkObject.Save(23, &e.owner)
- stateSinkObject.Save(24, &e.ops)
+ stateSinkObject.Save(6, &e.rcvBufSize)
+ stateSinkObject.Save(7, &e.rcvClosed)
+ stateSinkObject.Save(8, &e.state)
+ stateSinkObject.Save(9, &e.dstPort)
+ stateSinkObject.Save(10, &e.ttl)
+ stateSinkObject.Save(11, &e.multicastTTL)
+ stateSinkObject.Save(12, &e.multicastAddr)
+ stateSinkObject.Save(13, &e.multicastNICID)
+ stateSinkObject.Save(14, &e.portFlags)
+ stateSinkObject.Save(15, &e.lastError)
+ stateSinkObject.Save(16, &e.boundBindToDevice)
+ stateSinkObject.Save(17, &e.boundPortFlags)
+ stateSinkObject.Save(18, &e.sendTOS)
+ stateSinkObject.Save(19, &e.shutdownFlags)
+ stateSinkObject.Save(20, &e.multicastMemberships)
+ stateSinkObject.Save(21, &e.effectiveNetProtos)
+ stateSinkObject.Save(22, &e.owner)
+ stateSinkObject.Save(23, &e.ops)
+ stateSinkObject.Save(24, &e.frozen)
}
// +checklocksignore
@@ -124,25 +123,25 @@ func (e *endpoint) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(3, &e.uniqueID)
stateSourceObject.Load(4, &e.rcvReady)
stateSourceObject.Load(5, &e.rcvList)
- stateSourceObject.Load(7, &e.rcvBufSize)
- stateSourceObject.Load(8, &e.rcvClosed)
- stateSourceObject.Load(9, &e.state)
- stateSourceObject.Load(10, &e.dstPort)
- stateSourceObject.Load(11, &e.ttl)
- stateSourceObject.Load(12, &e.multicastTTL)
- stateSourceObject.Load(13, &e.multicastAddr)
- stateSourceObject.Load(14, &e.multicastNICID)
- stateSourceObject.Load(15, &e.portFlags)
- stateSourceObject.Load(16, &e.lastError)
- stateSourceObject.Load(17, &e.boundBindToDevice)
- stateSourceObject.Load(18, &e.boundPortFlags)
- stateSourceObject.Load(19, &e.sendTOS)
- stateSourceObject.Load(20, &e.shutdownFlags)
- stateSourceObject.Load(21, &e.multicastMemberships)
- stateSourceObject.Load(22, &e.effectiveNetProtos)
- stateSourceObject.Load(23, &e.owner)
- stateSourceObject.Load(24, &e.ops)
- stateSourceObject.LoadValue(6, new(int), func(y interface{}) { e.loadRcvBufSizeMax(y.(int)) })
+ stateSourceObject.Load(6, &e.rcvBufSize)
+ stateSourceObject.Load(7, &e.rcvClosed)
+ stateSourceObject.Load(8, &e.state)
+ stateSourceObject.Load(9, &e.dstPort)
+ stateSourceObject.Load(10, &e.ttl)
+ stateSourceObject.Load(11, &e.multicastTTL)
+ stateSourceObject.Load(12, &e.multicastAddr)
+ stateSourceObject.Load(13, &e.multicastNICID)
+ stateSourceObject.Load(14, &e.portFlags)
+ stateSourceObject.Load(15, &e.lastError)
+ stateSourceObject.Load(16, &e.boundBindToDevice)
+ stateSourceObject.Load(17, &e.boundPortFlags)
+ stateSourceObject.Load(18, &e.sendTOS)
+ stateSourceObject.Load(19, &e.shutdownFlags)
+ stateSourceObject.Load(20, &e.multicastMemberships)
+ stateSourceObject.Load(21, &e.effectiveNetProtos)
+ stateSourceObject.Load(22, &e.owner)
+ stateSourceObject.Load(23, &e.ops)
+ stateSourceObject.Load(24, &e.frozen)
stateSourceObject.AfterLoad(e.afterLoad)
}