diff options
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 53 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint_state.go | 33 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/icmp_state_autogen.go | 35 | ||||
-rw-r--r-- | pkg/tcpip/transport/packet/endpoint.go | 74 | ||||
-rw-r--r-- | pkg/tcpip/transport/packet/endpoint_state.go | 25 | ||||
-rw-r--r-- | pkg/tcpip/transport/packet/packet_state_autogen.go | 35 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 76 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint_state.go | 33 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/raw_state_autogen.go | 31 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 152 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint_state.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/rcv.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/segment_queue.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 68 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint_state.go | 34 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_state_autogen.go | 79 |
17 files changed, 318 insertions, 424 deletions
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 50991c3c0..33ed78f54 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -63,12 +63,11 @@ type endpoint struct { // The following fields are used to manage the receive queue, and are // protected by rcvMu. - rcvMu sync.Mutex `state:"nosave"` - rcvReady bool - rcvList icmpPacketList - rcvBufSizeMax int `state:".(int)"` - rcvBufSize int - rcvClosed bool + rcvMu sync.Mutex `state:"nosave"` + rcvReady bool + rcvList icmpPacketList + rcvBufSize int + rcvClosed bool // The following fields are protected by the mu mutex. mu sync.RWMutex `state:"nosave"` @@ -84,6 +83,10 @@ type endpoint struct { // ops is used to get socket level options. ops tcpip.SocketOptions + + // frozen indicates if the packets should be delivered to the endpoint + // during restore. + frozen bool } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { @@ -93,19 +96,23 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt NetProto: netProto, TransProto: transProto, }, - waiterQueue: waiterQueue, - rcvBufSizeMax: 32 * 1024, - state: stateInitial, - uniqueID: s.UniqueID(), + waiterQueue: waiterQueue, + state: stateInitial, + uniqueID: s.UniqueID(), } - ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) ep.ops.SetSendBufferSize(32*1024, false /* notify */) + ep.ops.SetReceiveBufferSize(32*1024, false /* notify */) // Override with stack defaults. var ss tcpip.SendBufferSizeOption if err := s.Option(&ss); err == nil { ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } + var rs tcpip.ReceiveBufferSizeOption + if err := s.Option(&rs); err == nil { + ep.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) + } return ep, nil } @@ -371,12 +378,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - v := e.rcvBufSizeMax - e.rcvMu.Unlock() - return v, nil - case tcpip.TTLOption: e.rcvMu.Lock() v := int(e.ttl) @@ -774,7 +775,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB return } - if e.rcvBufSize >= e.rcvBufSizeMax { + rcvBufSize := e.ops.GetReceiveBufferSize() + if e.frozen || e.rcvBufSize >= int(rcvBufSize) { e.rcvMu.Unlock() e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() @@ -843,3 +845,18 @@ func (*endpoint) LastError() tcpip.Error { func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } + +// freeze prevents any more packets from being delivered to the endpoint. +func (e *endpoint) freeze() { + e.mu.Lock() + e.frozen = true + e.mu.Unlock() +} + +// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows +// new packets to be delivered again. +func (e *endpoint) thaw() { + e.mu.Lock() + e.frozen = false + e.mu.Unlock() +} diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index a3c6db5a8..28a56a2d5 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -36,40 +36,21 @@ func (p *icmpPacket) loadData(data buffer.VectorisedView) { p.data = data } -// beforeSave is invoked by stateify. -func (e *endpoint) beforeSave() { - // Stop incoming packets from being handled (and mutate endpoint state). - // The lock will be released after savercvBufSizeMax(), which would have - // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming - // packets. - e.rcvMu.Lock() -} - -// saveRcvBufSizeMax is invoked by stateify. -func (e *endpoint) saveRcvBufSizeMax() int { - max := e.rcvBufSizeMax - // Make sure no new packets will be handled regardless of the lock. - e.rcvBufSizeMax = 0 - // Release the lock acquired in beforeSave() so regular endpoint closing - // logic can proceed after save. - e.rcvMu.Unlock() - return max -} - -// loadRcvBufSizeMax is invoked by stateify. -func (e *endpoint) loadRcvBufSizeMax(max int) { - e.rcvBufSizeMax = max -} - // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { stack.StackFromEnv.RegisterRestoredEndpoint(e) } +// beforeSave is invoked by stateify. +func (e *endpoint) beforeSave() { + e.freeze() +} + // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.thaw() e.stack = s - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) if e.state != stateBound && e.state != stateConnected { return diff --git a/pkg/tcpip/transport/icmp/icmp_state_autogen.go b/pkg/tcpip/transport/icmp/icmp_state_autogen.go index 4175f0f8b..98c1b06a3 100644 --- a/pkg/tcpip/transport/icmp/icmp_state_autogen.go +++ b/pkg/tcpip/transport/icmp/icmp_state_autogen.go @@ -54,7 +54,6 @@ func (e *endpoint) StateFields() []string { "uniqueID", "rcvReady", "rcvList", - "rcvBufSizeMax", "rcvBufSize", "rcvClosed", "shutdownFlags", @@ -62,27 +61,27 @@ func (e *endpoint) StateFields() []string { "ttl", "owner", "ops", + "frozen", } } // +checklocksignore func (e *endpoint) StateSave(stateSinkObject state.Sink) { e.beforeSave() - var rcvBufSizeMaxValue int = e.saveRcvBufSizeMax() - stateSinkObject.SaveValue(6, rcvBufSizeMaxValue) stateSinkObject.Save(0, &e.TransportEndpointInfo) stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler) stateSinkObject.Save(2, &e.waiterQueue) stateSinkObject.Save(3, &e.uniqueID) stateSinkObject.Save(4, &e.rcvReady) stateSinkObject.Save(5, &e.rcvList) - stateSinkObject.Save(7, &e.rcvBufSize) - stateSinkObject.Save(8, &e.rcvClosed) - stateSinkObject.Save(9, &e.shutdownFlags) - stateSinkObject.Save(10, &e.state) - stateSinkObject.Save(11, &e.ttl) - stateSinkObject.Save(12, &e.owner) - stateSinkObject.Save(13, &e.ops) + stateSinkObject.Save(6, &e.rcvBufSize) + stateSinkObject.Save(7, &e.rcvClosed) + stateSinkObject.Save(8, &e.shutdownFlags) + stateSinkObject.Save(9, &e.state) + stateSinkObject.Save(10, &e.ttl) + stateSinkObject.Save(11, &e.owner) + stateSinkObject.Save(12, &e.ops) + stateSinkObject.Save(13, &e.frozen) } // +checklocksignore @@ -93,14 +92,14 @@ func (e *endpoint) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(3, &e.uniqueID) stateSourceObject.Load(4, &e.rcvReady) stateSourceObject.Load(5, &e.rcvList) - stateSourceObject.Load(7, &e.rcvBufSize) - stateSourceObject.Load(8, &e.rcvClosed) - stateSourceObject.Load(9, &e.shutdownFlags) - stateSourceObject.Load(10, &e.state) - stateSourceObject.Load(11, &e.ttl) - stateSourceObject.Load(12, &e.owner) - stateSourceObject.Load(13, &e.ops) - stateSourceObject.LoadValue(6, new(int), func(y interface{}) { e.loadRcvBufSizeMax(y.(int)) }) + stateSourceObject.Load(6, &e.rcvBufSize) + stateSourceObject.Load(7, &e.rcvClosed) + stateSourceObject.Load(8, &e.shutdownFlags) + stateSourceObject.Load(9, &e.state) + stateSourceObject.Load(10, &e.ttl) + stateSourceObject.Load(11, &e.owner) + stateSourceObject.Load(12, &e.ops) + stateSourceObject.Load(13, &e.frozen) stateSourceObject.AfterLoad(e.afterLoad) } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 52ed9560c..496eca581 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -72,11 +72,10 @@ type endpoint struct { // The following fields are used to manage the receive queue and are // protected by rcvMu. - rcvMu sync.Mutex `state:"nosave"` - rcvList packetList - rcvBufSizeMax int `state:".(int)"` - rcvBufSize int - rcvClosed bool + rcvMu sync.Mutex `state:"nosave"` + rcvList packetList + rcvBufSize int + rcvClosed bool // The following fields are protected by mu. mu sync.RWMutex `state:"nosave"` @@ -91,6 +90,10 @@ type endpoint struct { // ops is used to get socket level options. ops tcpip.SocketOptions + + // frozen indicates if the packets should be delivered to the endpoint + // during restore. + frozen bool } // NewEndpoint returns a new packet endpoint. @@ -100,12 +103,12 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb TransportEndpointInfo: stack.TransportEndpointInfo{ NetProto: netProto, }, - cooked: cooked, - netProto: netProto, - waiterQueue: waiterQueue, - rcvBufSizeMax: 32 * 1024, + cooked: cooked, + netProto: netProto, + waiterQueue: waiterQueue, } - ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) + ep.ops.SetReceiveBufferSize(32*1024, false /* notify */) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -113,9 +116,9 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } - var rs stack.ReceiveBufferSizeOption + var rs tcpip.ReceiveBufferSizeOption if err := s.Option(&rs); err == nil { - ep.rcvBufSizeMax = rs.Default + ep.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) } if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil { @@ -316,28 +319,7 @@ func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { - switch opt { - case tcpip.ReceiveBufferSizeOption: - // Make sure the receive buffer size is within the min and max - // allowed. - var rs stack.ReceiveBufferSizeOption - if err := ep.stack.Option(&rs); err != nil { - panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err)) - } - if v > rs.Max { - v = rs.Max - } - if v < rs.Min { - v = rs.Min - } - ep.rcvMu.Lock() - ep.rcvBufSizeMax = v - ep.rcvMu.Unlock() - return nil - - default: - return &tcpip.ErrUnknownProtocolOption{} - } + return &tcpip.ErrUnknownProtocolOption{} } func (ep *endpoint) LastError() tcpip.Error { @@ -374,12 +356,6 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { ep.rcvMu.Unlock() return v, nil - case tcpip.ReceiveBufferSizeOption: - ep.rcvMu.Lock() - v := ep.rcvBufSizeMax - ep.rcvMu.Unlock() - return v, nil - default: return -1, &tcpip.ErrUnknownProtocolOption{} } @@ -397,7 +373,8 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, return } - if ep.rcvBufSize >= ep.rcvBufSizeMax { + rcvBufSize := ep.ops.GetReceiveBufferSize() + if ep.frozen || ep.rcvBufSize >= int(rcvBufSize) { ep.rcvMu.Unlock() ep.stack.Stats().DroppedPackets.Increment() ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() @@ -513,3 +490,18 @@ func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {} func (ep *endpoint) SocketOptions() *tcpip.SocketOptions { return &ep.ops } + +// freeze prevents any more packets from being delivered to the endpoint. +func (ep *endpoint) freeze() { + ep.mu.Lock() + ep.frozen = true + ep.mu.Unlock() +} + +// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows +// new packets to be delivered again. +func (ep *endpoint) thaw() { + ep.mu.Lock() + ep.frozen = false + ep.mu.Unlock() +} diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index ece662c0d..5bd860d20 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -38,33 +38,14 @@ func (p *packet) loadData(data buffer.VectorisedView) { // beforeSave is invoked by stateify. func (ep *endpoint) beforeSave() { - // Stop incoming packets from being handled (and mutate endpoint state). - // The lock will be released after saveRcvBufSizeMax(), which would have - // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming - // packets. - ep.rcvMu.Lock() -} - -// saveRcvBufSizeMax is invoked by stateify. -func (ep *endpoint) saveRcvBufSizeMax() int { - max := ep.rcvBufSizeMax - // Make sure no new packets will be handled regardless of the lock. - ep.rcvBufSizeMax = 0 - // Release the lock acquired in beforeSave() so regular endpoint closing - // logic can proceed after save. - ep.rcvMu.Unlock() - return max -} - -// loadRcvBufSizeMax is invoked by stateify. -func (ep *endpoint) loadRcvBufSizeMax(max int) { - ep.rcvBufSizeMax = max + ep.freeze() } // afterLoad is invoked by stateify. func (ep *endpoint) afterLoad() { + ep.thaw() ep.stack = stack.StackFromEnv - ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC. if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil { diff --git a/pkg/tcpip/transport/packet/packet_state_autogen.go b/pkg/tcpip/transport/packet/packet_state_autogen.go index f8b82e575..b354c87b1 100644 --- a/pkg/tcpip/transport/packet/packet_state_autogen.go +++ b/pkg/tcpip/transport/packet/packet_state_autogen.go @@ -57,7 +57,6 @@ func (ep *endpoint) StateFields() []string { "waiterQueue", "cooked", "rcvList", - "rcvBufSizeMax", "rcvBufSize", "rcvClosed", "closed", @@ -65,27 +64,27 @@ func (ep *endpoint) StateFields() []string { "boundNIC", "lastError", "ops", + "frozen", } } // +checklocksignore func (ep *endpoint) StateSave(stateSinkObject state.Sink) { ep.beforeSave() - var rcvBufSizeMaxValue int = ep.saveRcvBufSizeMax() - stateSinkObject.SaveValue(6, rcvBufSizeMaxValue) stateSinkObject.Save(0, &ep.TransportEndpointInfo) stateSinkObject.Save(1, &ep.DefaultSocketOptionsHandler) stateSinkObject.Save(2, &ep.netProto) stateSinkObject.Save(3, &ep.waiterQueue) stateSinkObject.Save(4, &ep.cooked) stateSinkObject.Save(5, &ep.rcvList) - stateSinkObject.Save(7, &ep.rcvBufSize) - stateSinkObject.Save(8, &ep.rcvClosed) - stateSinkObject.Save(9, &ep.closed) - stateSinkObject.Save(10, &ep.bound) - stateSinkObject.Save(11, &ep.boundNIC) - stateSinkObject.Save(12, &ep.lastError) - stateSinkObject.Save(13, &ep.ops) + stateSinkObject.Save(6, &ep.rcvBufSize) + stateSinkObject.Save(7, &ep.rcvClosed) + stateSinkObject.Save(8, &ep.closed) + stateSinkObject.Save(9, &ep.bound) + stateSinkObject.Save(10, &ep.boundNIC) + stateSinkObject.Save(11, &ep.lastError) + stateSinkObject.Save(12, &ep.ops) + stateSinkObject.Save(13, &ep.frozen) } // +checklocksignore @@ -96,14 +95,14 @@ func (ep *endpoint) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(3, &ep.waiterQueue) stateSourceObject.Load(4, &ep.cooked) stateSourceObject.Load(5, &ep.rcvList) - stateSourceObject.Load(7, &ep.rcvBufSize) - stateSourceObject.Load(8, &ep.rcvClosed) - stateSourceObject.Load(9, &ep.closed) - stateSourceObject.Load(10, &ep.bound) - stateSourceObject.Load(11, &ep.boundNIC) - stateSourceObject.Load(12, &ep.lastError) - stateSourceObject.Load(13, &ep.ops) - stateSourceObject.LoadValue(6, new(int), func(y interface{}) { ep.loadRcvBufSizeMax(y.(int)) }) + stateSourceObject.Load(6, &ep.rcvBufSize) + stateSourceObject.Load(7, &ep.rcvClosed) + stateSourceObject.Load(8, &ep.closed) + stateSourceObject.Load(9, &ep.bound) + stateSourceObject.Load(10, &ep.boundNIC) + stateSourceObject.Load(11, &ep.lastError) + stateSourceObject.Load(12, &ep.ops) + stateSourceObject.Load(13, &ep.frozen) stateSourceObject.AfterLoad(ep.afterLoad) } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index e27a249cd..10453a42a 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -26,7 +26,6 @@ package raw import ( - "fmt" "io" "gvisor.dev/gvisor/pkg/sync" @@ -69,11 +68,10 @@ type endpoint struct { // The following fields are used to manage the receive queue and are // protected by rcvMu. - rcvMu sync.Mutex `state:"nosave"` - rcvList rawPacketList - rcvBufSize int - rcvBufSizeMax int `state:".(int)"` - rcvClosed bool + rcvMu sync.Mutex `state:"nosave"` + rcvList rawPacketList + rcvBufSize int + rcvClosed bool // The following fields are protected by mu. mu sync.RWMutex `state:"nosave"` @@ -89,6 +87,10 @@ type endpoint struct { // ops is used to get socket level options. ops tcpip.SocketOptions + + // frozen indicates if the packets should be delivered to the endpoint + // during restore. + frozen bool } // NewEndpoint returns a raw endpoint for the given protocols. @@ -107,13 +109,13 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt NetProto: netProto, TransProto: transProto, }, - waiterQueue: waiterQueue, - rcvBufSizeMax: 32 * 1024, - associated: associated, + waiterQueue: waiterQueue, + associated: associated, } - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) e.ops.SetHeaderIncluded(!associated) e.ops.SetSendBufferSize(32*1024, false /* notify */) + e.ops.SetReceiveBufferSize(32*1024, false /* notify */) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -121,16 +123,16 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } - var rs stack.ReceiveBufferSizeOption + var rs tcpip.ReceiveBufferSizeOption if err := s.Option(&rs); err == nil { - e.rcvBufSizeMax = rs.Default + e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) } // Unassociated endpoints are write-only and users call Write() with IP // headers included. Because they're write-only, We don't need to // register with the stack. if !associated { - e.rcvBufSizeMax = 0 + e.ops.SetReceiveBufferSize(0, false) e.waiterQueue = nil return e, nil } @@ -511,30 +513,8 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { } } -// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { - switch opt { - case tcpip.ReceiveBufferSizeOption: - // Make sure the receive buffer size is within the min and max - // allowed. - var rs stack.ReceiveBufferSizeOption - if err := e.stack.Option(&rs); err != nil { - panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err)) - } - if v > rs.Max { - v = rs.Max - } - if v < rs.Min { - v = rs.Min - } - e.rcvMu.Lock() - e.rcvBufSizeMax = v - e.rcvMu.Unlock() - return nil - - default: - return &tcpip.ErrUnknownProtocolOption{} - } + return &tcpip.ErrUnknownProtocolOption{} } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. @@ -555,12 +535,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - v := e.rcvBufSizeMax - e.rcvMu.Unlock() - return v, nil - default: return -1, &tcpip.ErrUnknownProtocolOption{} } @@ -587,7 +561,8 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - if e.rcvBufSize >= e.rcvBufSizeMax { + rcvBufSize := e.ops.GetReceiveBufferSize() + if e.frozen || e.rcvBufSize >= int(rcvBufSize) { e.rcvMu.Unlock() e.mu.RUnlock() e.stack.Stats().DroppedPackets.Increment() @@ -690,3 +665,18 @@ func (*endpoint) LastError() tcpip.Error { func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } + +// freeze prevents any more packets from being delivered to the endpoint. +func (e *endpoint) freeze() { + e.mu.Lock() + e.frozen = true + e.mu.Unlock() +} + +// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows +// new packets to be delivered again. +func (e *endpoint) thaw() { + e.mu.Lock() + e.frozen = false + e.mu.Unlock() +} diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 263ec5146..5d6f2709c 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -36,40 +36,21 @@ func (p *rawPacket) loadData(data buffer.VectorisedView) { p.data = data } -// beforeSave is invoked by stateify. -func (e *endpoint) beforeSave() { - // Stop incoming packets from being handled (and mutate endpoint state). - // The lock will be released after saveRcvBufSizeMax(), which would have - // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming - // packets. - e.rcvMu.Lock() -} - -// saveRcvBufSizeMax is invoked by stateify. -func (e *endpoint) saveRcvBufSizeMax() int { - max := e.rcvBufSizeMax - // Make sure no new packets will be handled regardless of the lock. - e.rcvBufSizeMax = 0 - // Release the lock acquired in beforeSave() so regular endpoint closing - // logic can proceed after save. - e.rcvMu.Unlock() - return max -} - -// loadRcvBufSizeMax is invoked by stateify. -func (e *endpoint) loadRcvBufSizeMax(max int) { - e.rcvBufSizeMax = max -} - // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { stack.StackFromEnv.RegisterRestoredEndpoint(e) } +// beforeSave is invoked by stateify. +func (e *endpoint) beforeSave() { + e.freeze() +} + // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.thaw() e.stack = s - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) // If the endpoint is connected, re-connect. if e.connected { diff --git a/pkg/tcpip/transport/raw/raw_state_autogen.go b/pkg/tcpip/transport/raw/raw_state_autogen.go index 3bcfc7c61..2bcd983a2 100644 --- a/pkg/tcpip/transport/raw/raw_state_autogen.go +++ b/pkg/tcpip/transport/raw/raw_state_autogen.go @@ -54,33 +54,32 @@ func (e *endpoint) StateFields() []string { "associated", "rcvList", "rcvBufSize", - "rcvBufSizeMax", "rcvClosed", "closed", "connected", "bound", "owner", "ops", + "frozen", } } // +checklocksignore func (e *endpoint) StateSave(stateSinkObject state.Sink) { e.beforeSave() - var rcvBufSizeMaxValue int = e.saveRcvBufSizeMax() - stateSinkObject.SaveValue(6, rcvBufSizeMaxValue) stateSinkObject.Save(0, &e.TransportEndpointInfo) stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler) stateSinkObject.Save(2, &e.waiterQueue) stateSinkObject.Save(3, &e.associated) stateSinkObject.Save(4, &e.rcvList) stateSinkObject.Save(5, &e.rcvBufSize) - stateSinkObject.Save(7, &e.rcvClosed) - stateSinkObject.Save(8, &e.closed) - stateSinkObject.Save(9, &e.connected) - stateSinkObject.Save(10, &e.bound) - stateSinkObject.Save(11, &e.owner) - stateSinkObject.Save(12, &e.ops) + stateSinkObject.Save(6, &e.rcvClosed) + stateSinkObject.Save(7, &e.closed) + stateSinkObject.Save(8, &e.connected) + stateSinkObject.Save(9, &e.bound) + stateSinkObject.Save(10, &e.owner) + stateSinkObject.Save(11, &e.ops) + stateSinkObject.Save(12, &e.frozen) } // +checklocksignore @@ -91,13 +90,13 @@ func (e *endpoint) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(3, &e.associated) stateSourceObject.Load(4, &e.rcvList) stateSourceObject.Load(5, &e.rcvBufSize) - stateSourceObject.Load(7, &e.rcvClosed) - stateSourceObject.Load(8, &e.closed) - stateSourceObject.Load(9, &e.connected) - stateSourceObject.Load(10, &e.bound) - stateSourceObject.Load(11, &e.owner) - stateSourceObject.Load(12, &e.ops) - stateSourceObject.LoadValue(6, new(int), func(y interface{}) { e.loadRcvBufSizeMax(y.(int)) }) + stateSourceObject.Load(6, &e.rcvClosed) + stateSourceObject.Load(7, &e.closed) + stateSourceObject.Load(8, &e.connected) + stateSourceObject.Load(9, &e.bound) + stateSourceObject.Load(10, &e.owner) + stateSourceObject.Load(11, &e.ops) + stateSourceObject.Load(12, &e.frozen) stateSourceObject.AfterLoad(e.afterLoad) } diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 664cb9420..d4bd4e80e 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -219,7 +219,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header n.boundNICID = s.nicID n.route = route n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.netProto} - n.rcvQueueInfo.RcvBufSize = int(l.rcvWnd) + n.ops.SetReceiveBufferSize(int64(l.rcvWnd), false /* notify */) n.amss = calculateAdvertisedMSS(n.userMSS, n.route) n.setEndpointState(StateConnecting) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 884332828..f25dc781a 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -822,11 +822,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue windowClamp: DefaultReceiveBufferSize, maxSynRetries: DefaultSynRetries, } - e.rcvQueueInfo.RcvBufSize = DefaultReceiveBufferSize - e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits) + e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits) e.ops.SetMulticastLoop(true) e.ops.SetQuickAck(true) e.ops.SetSendBufferSize(DefaultSendBufferSize, false /* notify */) + e.ops.SetReceiveBufferSize(DefaultReceiveBufferSize, false /* notify */) var ss tcpip.TCPSendBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { @@ -835,7 +835,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue var rs tcpip.TCPReceiveBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { - e.rcvQueueInfo.RcvBufSize = rs.Default + e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) } var cs tcpip.CongestionControlOption @@ -1228,11 +1228,12 @@ func (e *endpoint) ModerateRecvBuf(copied int) { // We do not adjust downwards as that can cause the receiver to // reject valid data that might already be in flight as the // acceptable window will shrink. - if rcvWnd > e.rcvQueueInfo.RcvBufSize { - availBefore := wndFromSpace(e.receiveBufferAvailableLocked()) - e.rcvQueueInfo.RcvBufSize = rcvWnd - availAfter := wndFromSpace(e.receiveBufferAvailableLocked()) - if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { + rcvBufSize := int(e.ops.GetReceiveBufferSize()) + if rcvWnd > rcvBufSize { + availBefore := wndFromSpace(e.receiveBufferAvailableLocked(rcvBufSize)) + e.ops.SetReceiveBufferSize(int64(rcvWnd), false /* notify */) + availAfter := wndFromSpace(e.receiveBufferAvailableLocked(rcvWnd)) + if crossed, above := e.windowCrossedACKThresholdLocked(availAfter-availBefore, rcvBufSize); crossed && above { e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } } @@ -1424,7 +1425,7 @@ func (e *endpoint) commitRead(done int) *segment { // enough buffer space, to either fit an aMSS or half a receive buffer // (whichever smaller), then notify the protocol goroutine to send a // window update. - if crossed, above := e.windowCrossedACKThresholdLocked(memDelta); crossed && above { + if crossed, above := e.windowCrossedACKThresholdLocked(memDelta, int(e.ops.GetReceiveBufferSize())); crossed && above { e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } } @@ -1556,9 +1557,9 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp // selectWindowLocked returns the new window without checking for shrinking or scaling // applied. // Precondition: e.mu and e.rcvQueueMu must be held. -func (e *endpoint) selectWindowLocked() (wnd seqnum.Size) { - wndFromAvailable := wndFromSpace(e.receiveBufferAvailableLocked()) - maxWindow := wndFromSpace(e.rcvQueueInfo.RcvBufSize) +func (e *endpoint) selectWindowLocked(rcvBufSize int) (wnd seqnum.Size) { + wndFromAvailable := wndFromSpace(e.receiveBufferAvailableLocked(rcvBufSize)) + maxWindow := wndFromSpace(rcvBufSize) wndFromUsedBytes := maxWindow - e.rcvQueueInfo.RcvBufUsed // We take the lesser of the wndFromAvailable and wndFromUsedBytes because in @@ -1580,7 +1581,7 @@ func (e *endpoint) selectWindowLocked() (wnd seqnum.Size) { // selectWindow invokes selectWindowLocked after acquiring e.rcvQueueMu. func (e *endpoint) selectWindow() (wnd seqnum.Size) { e.rcvQueueInfo.rcvQueueMu.Lock() - wnd = e.selectWindowLocked() + wnd = e.selectWindowLocked(int(e.ops.GetReceiveBufferSize())) e.rcvQueueInfo.rcvQueueMu.Unlock() return wnd } @@ -1600,8 +1601,8 @@ func (e *endpoint) selectWindow() (wnd seqnum.Size) { // otherwise. // // Precondition: e.mu and e.rcvQueueMu must be held. -func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) { - newAvail := int(e.selectWindowLocked()) +func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int, rcvBufSize int) (crossed bool, above bool) { + newAvail := int(e.selectWindowLocked(rcvBufSize)) oldAvail := newAvail - deltaBefore if oldAvail < 0 { oldAvail = 0 @@ -1610,7 +1611,7 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo // rcvBufFraction is the inverse of the fraction of receive buffer size that // is used to decide if the available buffer space is now above it. const rcvBufFraction = 2 - if wndThreshold := wndFromSpace(e.rcvQueueInfo.RcvBufSize / rcvBufFraction); threshold > wndThreshold { + if wndThreshold := wndFromSpace(rcvBufSize / rcvBufFraction); threshold > wndThreshold { threshold = wndThreshold } switch { @@ -1661,6 +1662,37 @@ func (e *endpoint) getSendBufferSize() int { return int(e.ops.GetSendBufferSize()) } +// OnSetReceiveBufferSize implements tcpip.SocketOptionsHandler.OnSetReceiveBufferSize. +func (e *endpoint) OnSetReceiveBufferSize(rcvBufSz, oldSz int64) (newSz int64) { + e.LockUser() + e.rcvQueueInfo.rcvQueueMu.Lock() + + // Make sure the receive buffer size allows us to send a + // non-zero window size. + scale := uint8(0) + if e.rcv != nil { + scale = e.rcv.RcvWndScale + } + if rcvBufSz>>scale == 0 { + rcvBufSz = 1 << scale + } + + availBefore := wndFromSpace(e.receiveBufferAvailableLocked(int(oldSz))) + availAfter := wndFromSpace(e.receiveBufferAvailableLocked(int(rcvBufSz))) + e.rcvQueueInfo.RcvAutoParams.Disabled = true + + // Immediately send an ACK to uncork the sender silly window + // syndrome prevetion, when our available space grows above aMSS + // or half receive buffer, whichever smaller. + if crossed, above := e.windowCrossedACKThresholdLocked(availAfter-availBefore, int(rcvBufSz)); crossed && above { + e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) + } + + e.rcvQueueInfo.rcvQueueMu.Unlock() + e.UnlockUser() + return rcvBufSz +} + // SetSockOptInt sets a socket option. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 @@ -1704,56 +1736,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { return &tcpip.ErrNotSupported{} } - case tcpip.ReceiveBufferSizeOption: - // Make sure the receive buffer size is within the min and max - // allowed. - var rs tcpip.TCPReceiveBufferSizeRangeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { - panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &rs, err)) - } - - if v > rs.Max { - v = rs.Max - } - - if v < math.MaxInt32/SegOverheadFactor { - v *= SegOverheadFactor - if v < rs.Min { - v = rs.Min - } - } else { - v = math.MaxInt32 - } - - e.LockUser() - e.rcvQueueInfo.rcvQueueMu.Lock() - - // Make sure the receive buffer size allows us to send a - // non-zero window size. - scale := uint8(0) - if e.rcv != nil { - scale = e.rcv.RcvWndScale - } - if v>>scale == 0 { - v = 1 << scale - } - - availBefore := wndFromSpace(e.receiveBufferAvailableLocked()) - e.rcvQueueInfo.RcvBufSize = v - availAfter := wndFromSpace(e.receiveBufferAvailableLocked()) - - e.rcvQueueInfo.RcvAutoParams.Disabled = true - - // Immediately send an ACK to uncork the sender silly window - // syndrome prevetion, when our available space grows above aMSS - // or half receive buffer, whichever smaller. - if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { - e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) - } - - e.rcvQueueInfo.rcvQueueMu.Unlock() - e.UnlockUser() - case tcpip.TTLOption: e.LockUser() e.ttl = uint8(v) @@ -1939,12 +1921,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() - case tcpip.ReceiveBufferSizeOption: - e.rcvQueueInfo.rcvQueueMu.Lock() - v := e.rcvQueueInfo.RcvBufSize - e.rcvQueueInfo.rcvQueueMu.Unlock() - return v, nil - case tcpip.TTLOption: e.LockUser() v := int(e.ttl) @@ -2780,15 +2756,15 @@ func (e *endpoint) readyToRead(s *segment) { // receiveBufferAvailableLocked calculates how many bytes are still available // in the receive buffer. // rcvQueueMu must be held when this function is called. -func (e *endpoint) receiveBufferAvailableLocked() int { +func (e *endpoint) receiveBufferAvailableLocked(rcvBufSize int) int { // We may use more bytes than the buffer size when the receive buffer // shrinks. memUsed := e.receiveMemUsed() - if memUsed >= e.rcvQueueInfo.RcvBufSize { + if memUsed >= rcvBufSize { return 0 } - return e.rcvQueueInfo.RcvBufSize - memUsed + return rcvBufSize - memUsed } // receiveBufferAvailable calculates how many bytes are still available in the @@ -2796,7 +2772,7 @@ func (e *endpoint) receiveBufferAvailableLocked() int { // receive buffer/pending and segment queue. func (e *endpoint) receiveBufferAvailable() int { e.rcvQueueInfo.rcvQueueMu.Lock() - available := e.receiveBufferAvailableLocked() + available := e.receiveBufferAvailableLocked(int(e.ops.GetReceiveBufferSize())) e.rcvQueueInfo.rcvQueueMu.Unlock() return available } @@ -2809,14 +2785,6 @@ func (e *endpoint) receiveBufferUsed() int { return used } -// receiveBufferSize returns the current size of the receive buffer. -func (e *endpoint) receiveBufferSize() int { - e.rcvQueueInfo.rcvQueueMu.Lock() - size := e.rcvQueueInfo.RcvBufSize - e.rcvQueueInfo.rcvQueueMu.Unlock() - return size -} - // receiveMemUsed returns the total memory in use by segments held by this // endpoint. func (e *endpoint) receiveMemUsed() int { @@ -2845,7 +2813,7 @@ func (e *endpoint) maxReceiveBufferSize() int { // receiveBuffer otherwise we use the max permissible receive buffer size to // compute the scale. func (e *endpoint) rcvWndScaleForHandshake() int { - bufSizeForScale := e.receiveBufferSize() + bufSizeForScale := e.ops.GetReceiveBufferSize() e.rcvQueueInfo.rcvQueueMu.Lock() autoTuningDisabled := e.rcvQueueInfo.RcvAutoParams.Disabled @@ -3074,3 +3042,17 @@ func (e *endpoint) allowOutOfWindowAck() bool { e.lastOutOfWindowAckTime = now return true } + +// GetTCPReceiveBufferLimits is used to get send buffer size limits for TCP. +func GetTCPReceiveBufferLimits(s tcpip.StackHandler) tcpip.ReceiveBufferSizeOption { + var ss tcpip.TCPReceiveBufferSizeRangeOption + if err := s.TransportProtocolOption(header.TCPProtocolNumber, &ss); err != nil { + panic(fmt.Sprintf("s.TransportProtocolOption(%d, %#v) = %s", header.TCPProtocolNumber, ss, err)) + } + + return tcpip.ReceiveBufferSizeOption{ + Min: ss.Min, + Default: ss.Default, + Max: ss.Max, + } +} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 034eacd72..6e9777fe4 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -165,7 +165,7 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.stack = s - e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits) + e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits) e.segmentQueue.thaw() epState := e.origEndpointState switch epState { @@ -180,8 +180,8 @@ func (e *endpoint) Resume(s *stack.Stack) { var rs tcpip.TCPReceiveBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { - if e.rcvQueueInfo.RcvBufSize < rs.Min || e.rcvQueueInfo.RcvBufSize > rs.Max { - panic(fmt.Sprintf("endpoint.rcvQueueInfo.RcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvQueueInfo.RcvBufSize, rs.Min, rs.Max)) + if rcvBufSize := e.ops.GetReceiveBufferSize(); rcvBufSize < int64(rs.Min) || rcvBufSize > int64(rs.Max) { + panic(fmt.Sprintf("endpoint rcvBufSize %d is outside the min and max allowed [%d, %d]", rcvBufSize, rs.Min, rs.Max)) } } } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index fc11b4ba9..ee2c08cd6 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -466,7 +466,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) { // segments. This ensures that we always leave some space for the inorder // segments to arrive allowing pending segments to be processed and // delivered to the user. - if r.ep.receiveBufferAvailable() > 0 && r.PendingBufUsed < r.ep.receiveBufferSize()>>2 { + if rcvBufSize := r.ep.ops.GetReceiveBufferSize(); rcvBufSize > 0 && r.PendingBufUsed < int(rcvBufSize)>>2 { r.ep.rcvQueueInfo.rcvQueueMu.Lock() r.PendingBufUsed += s.segMemSize() r.ep.rcvQueueInfo.rcvQueueMu.Unlock() diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go index 54545a1b1..d0d1b0b8a 100644 --- a/pkg/tcpip/transport/tcp/segment_queue.go +++ b/pkg/tcpip/transport/tcp/segment_queue.go @@ -52,12 +52,12 @@ func (q *segmentQueue) empty() bool { func (q *segmentQueue) enqueue(s *segment) bool { // q.ep.receiveBufferParams() must be called without holding q.mu to // avoid lock order inversion. - bufSz := q.ep.receiveBufferSize() + bufSz := q.ep.ops.GetReceiveBufferSize() used := q.ep.receiveMemUsed() q.mu.Lock() // Allow zero sized segments (ACK/FIN/RSTs etc even if the segment queue // is currently full). - allow := (used <= bufSz || s.payloadSize() == 0) && !q.frozen + allow := (used <= int(bufSz) || s.payloadSize() == 0) && !q.frozen if allow { q.list.PushBack(s) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index f26c7ca10..c9f2f3efc 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,7 +15,6 @@ package udp import ( - "fmt" "io" "sync/atomic" @@ -89,12 +88,11 @@ type endpoint struct { // The following fields are used to manage the receive queue, and are // protected by rcvMu. - rcvMu sync.Mutex `state:"nosave"` - rcvReady bool - rcvList udpPacketList - rcvBufSizeMax int `state:".(int)"` - rcvBufSize int - rcvClosed bool + rcvMu sync.Mutex `state:"nosave"` + rcvReady bool + rcvList udpPacketList + rcvBufSize int + rcvClosed bool // The following fields are protected by the mu mutex. mu sync.RWMutex `state:"nosave"` @@ -144,6 +142,10 @@ type endpoint struct { // ops is used to get socket level options. ops tcpip.SocketOptions + + // frozen indicates if the packets should be delivered to the endpoint + // during restore. + frozen bool } // +stateify savable @@ -173,14 +175,14 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // // Linux defaults to TTL=1. multicastTTL: 1, - rcvBufSizeMax: 32 * 1024, multicastMemberships: make(map[multicastMembership]struct{}), state: StateInitial, uniqueID: s.UniqueID(), } - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) e.ops.SetMulticastLoop(true) e.ops.SetSendBufferSize(32*1024, false /* notify */) + e.ops.SetReceiveBufferSize(32*1024, false /* notify */) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -188,9 +190,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } - var rs stack.ReceiveBufferSizeOption + var rs tcpip.ReceiveBufferSizeOption if err := s.Option(&rs); err == nil { - e.rcvBufSizeMax = rs.Default + e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) } return e @@ -622,26 +624,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { e.mu.Lock() e.sendTOS = uint8(v) e.mu.Unlock() - - case tcpip.ReceiveBufferSizeOption: - // Make sure the receive buffer size is within the min and max - // allowed. - var rs stack.ReceiveBufferSizeOption - if err := e.stack.Option(&rs); err != nil { - panic(fmt.Sprintf("e.stack.Option(%#v) = %s", rs, err)) - } - - if v < rs.Min { - v = rs.Min - } - if v > rs.Max { - v = rs.Max - } - - e.mu.Lock() - e.rcvBufSizeMax = v - e.mu.Unlock() - return nil } return nil @@ -802,12 +784,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - v := e.rcvBufSizeMax - e.rcvMu.Unlock() - return v, nil - case tcpip.TTLOption: e.mu.Lock() v := int(e.ttl) @@ -1310,7 +1286,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB return } - if e.rcvBufSize >= e.rcvBufSizeMax { + rcvBufSize := e.ops.GetReceiveBufferSize() + if e.frozen || e.rcvBufSize >= int(rcvBufSize) { e.rcvMu.Unlock() e.stack.Stats().UDP.ReceiveBufferErrors.Increment() e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() @@ -1444,3 +1421,18 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } + +// freeze prevents any more packets from being delivered to the endpoint. +func (e *endpoint) freeze() { + e.mu.Lock() + e.frozen = true + e.mu.Unlock() +} + +// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows +// new packets to be delivered again. +func (e *endpoint) thaw() { + e.mu.Lock() + e.frozen = false + e.mu.Unlock() +} diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 21a6aa460..4aba68b21 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -37,43 +37,25 @@ func (u *udpPacket) loadData(data buffer.VectorisedView) { u.data = data } -// beforeSave is invoked by stateify. -func (e *endpoint) beforeSave() { - // Stop incoming packets from being handled (and mutate endpoint state). - // The lock will be released after savercvBufSizeMax(), which would have - // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming - // packets. - e.rcvMu.Lock() -} - -// saveRcvBufSizeMax is invoked by stateify. -func (e *endpoint) saveRcvBufSizeMax() int { - max := e.rcvBufSizeMax - // Make sure no new packets will be handled regardless of the lock. - e.rcvBufSizeMax = 0 - // Release the lock acquired in beforeSave() so regular endpoint closing - // logic can proceed after save. - e.rcvMu.Unlock() - return max -} - -// loadRcvBufSizeMax is invoked by stateify. -func (e *endpoint) loadRcvBufSizeMax(max int) { - e.rcvBufSizeMax = max -} - // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { stack.StackFromEnv.RegisterRestoredEndpoint(e) } +// beforeSave is invoked by stateify. +func (e *endpoint) beforeSave() { + e.freeze() +} + // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.thaw() + e.mu.Lock() defer e.mu.Unlock() e.stack = s - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) for m := range e.multicastMemberships { if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil { diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go index 5d7859a9e..092daa0b8 100644 --- a/pkg/tcpip/transport/udp/udp_state_autogen.go +++ b/pkg/tcpip/transport/udp/udp_state_autogen.go @@ -63,7 +63,6 @@ func (e *endpoint) StateFields() []string { "uniqueID", "rcvReady", "rcvList", - "rcvBufSizeMax", "rcvBufSize", "rcvClosed", "state", @@ -82,38 +81,38 @@ func (e *endpoint) StateFields() []string { "effectiveNetProtos", "owner", "ops", + "frozen", } } // +checklocksignore func (e *endpoint) StateSave(stateSinkObject state.Sink) { e.beforeSave() - var rcvBufSizeMaxValue int = e.saveRcvBufSizeMax() - stateSinkObject.SaveValue(6, rcvBufSizeMaxValue) stateSinkObject.Save(0, &e.TransportEndpointInfo) stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler) stateSinkObject.Save(2, &e.waiterQueue) stateSinkObject.Save(3, &e.uniqueID) stateSinkObject.Save(4, &e.rcvReady) stateSinkObject.Save(5, &e.rcvList) - stateSinkObject.Save(7, &e.rcvBufSize) - stateSinkObject.Save(8, &e.rcvClosed) - stateSinkObject.Save(9, &e.state) - stateSinkObject.Save(10, &e.dstPort) - stateSinkObject.Save(11, &e.ttl) - stateSinkObject.Save(12, &e.multicastTTL) - stateSinkObject.Save(13, &e.multicastAddr) - stateSinkObject.Save(14, &e.multicastNICID) - stateSinkObject.Save(15, &e.portFlags) - stateSinkObject.Save(16, &e.lastError) - stateSinkObject.Save(17, &e.boundBindToDevice) - stateSinkObject.Save(18, &e.boundPortFlags) - stateSinkObject.Save(19, &e.sendTOS) - stateSinkObject.Save(20, &e.shutdownFlags) - stateSinkObject.Save(21, &e.multicastMemberships) - stateSinkObject.Save(22, &e.effectiveNetProtos) - stateSinkObject.Save(23, &e.owner) - stateSinkObject.Save(24, &e.ops) + stateSinkObject.Save(6, &e.rcvBufSize) + stateSinkObject.Save(7, &e.rcvClosed) + stateSinkObject.Save(8, &e.state) + stateSinkObject.Save(9, &e.dstPort) + stateSinkObject.Save(10, &e.ttl) + stateSinkObject.Save(11, &e.multicastTTL) + stateSinkObject.Save(12, &e.multicastAddr) + stateSinkObject.Save(13, &e.multicastNICID) + stateSinkObject.Save(14, &e.portFlags) + stateSinkObject.Save(15, &e.lastError) + stateSinkObject.Save(16, &e.boundBindToDevice) + stateSinkObject.Save(17, &e.boundPortFlags) + stateSinkObject.Save(18, &e.sendTOS) + stateSinkObject.Save(19, &e.shutdownFlags) + stateSinkObject.Save(20, &e.multicastMemberships) + stateSinkObject.Save(21, &e.effectiveNetProtos) + stateSinkObject.Save(22, &e.owner) + stateSinkObject.Save(23, &e.ops) + stateSinkObject.Save(24, &e.frozen) } // +checklocksignore @@ -124,25 +123,25 @@ func (e *endpoint) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(3, &e.uniqueID) stateSourceObject.Load(4, &e.rcvReady) stateSourceObject.Load(5, &e.rcvList) - stateSourceObject.Load(7, &e.rcvBufSize) - stateSourceObject.Load(8, &e.rcvClosed) - stateSourceObject.Load(9, &e.state) - stateSourceObject.Load(10, &e.dstPort) - stateSourceObject.Load(11, &e.ttl) - stateSourceObject.Load(12, &e.multicastTTL) - stateSourceObject.Load(13, &e.multicastAddr) - stateSourceObject.Load(14, &e.multicastNICID) - stateSourceObject.Load(15, &e.portFlags) - stateSourceObject.Load(16, &e.lastError) - stateSourceObject.Load(17, &e.boundBindToDevice) - stateSourceObject.Load(18, &e.boundPortFlags) - stateSourceObject.Load(19, &e.sendTOS) - stateSourceObject.Load(20, &e.shutdownFlags) - stateSourceObject.Load(21, &e.multicastMemberships) - stateSourceObject.Load(22, &e.effectiveNetProtos) - stateSourceObject.Load(23, &e.owner) - stateSourceObject.Load(24, &e.ops) - stateSourceObject.LoadValue(6, new(int), func(y interface{}) { e.loadRcvBufSizeMax(y.(int)) }) + stateSourceObject.Load(6, &e.rcvBufSize) + stateSourceObject.Load(7, &e.rcvClosed) + stateSourceObject.Load(8, &e.state) + stateSourceObject.Load(9, &e.dstPort) + stateSourceObject.Load(10, &e.ttl) + stateSourceObject.Load(11, &e.multicastTTL) + stateSourceObject.Load(12, &e.multicastAddr) + stateSourceObject.Load(13, &e.multicastNICID) + stateSourceObject.Load(14, &e.portFlags) + stateSourceObject.Load(15, &e.lastError) + stateSourceObject.Load(16, &e.boundBindToDevice) + stateSourceObject.Load(17, &e.boundPortFlags) + stateSourceObject.Load(18, &e.sendTOS) + stateSourceObject.Load(19, &e.shutdownFlags) + stateSourceObject.Load(20, &e.multicastMemberships) + stateSourceObject.Load(21, &e.effectiveNetProtos) + stateSourceObject.Load(22, &e.owner) + stateSourceObject.Load(23, &e.ops) + stateSourceObject.Load(24, &e.frozen) stateSourceObject.AfterLoad(e.afterLoad) } |