diff options
Diffstat (limited to 'pkg/tcpip/transport/raw')
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 76 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint_state.go | 33 |
2 files changed, 40 insertions, 69 deletions
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index e27a249cd..10453a42a 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -26,7 +26,6 @@ package raw import ( - "fmt" "io" "gvisor.dev/gvisor/pkg/sync" @@ -69,11 +68,10 @@ type endpoint struct { // The following fields are used to manage the receive queue and are // protected by rcvMu. - rcvMu sync.Mutex `state:"nosave"` - rcvList rawPacketList - rcvBufSize int - rcvBufSizeMax int `state:".(int)"` - rcvClosed bool + rcvMu sync.Mutex `state:"nosave"` + rcvList rawPacketList + rcvBufSize int + rcvClosed bool // The following fields are protected by mu. mu sync.RWMutex `state:"nosave"` @@ -89,6 +87,10 @@ type endpoint struct { // ops is used to get socket level options. ops tcpip.SocketOptions + + // frozen indicates if the packets should be delivered to the endpoint + // during restore. + frozen bool } // NewEndpoint returns a raw endpoint for the given protocols. @@ -107,13 +109,13 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt NetProto: netProto, TransProto: transProto, }, - waiterQueue: waiterQueue, - rcvBufSizeMax: 32 * 1024, - associated: associated, + waiterQueue: waiterQueue, + associated: associated, } - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) e.ops.SetHeaderIncluded(!associated) e.ops.SetSendBufferSize(32*1024, false /* notify */) + e.ops.SetReceiveBufferSize(32*1024, false /* notify */) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -121,16 +123,16 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } - var rs stack.ReceiveBufferSizeOption + var rs tcpip.ReceiveBufferSizeOption if err := s.Option(&rs); err == nil { - e.rcvBufSizeMax = rs.Default + e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) } // Unassociated endpoints are write-only and users call Write() with IP // headers included. Because they're write-only, We don't need to // register with the stack. if !associated { - e.rcvBufSizeMax = 0 + e.ops.SetReceiveBufferSize(0, false) e.waiterQueue = nil return e, nil } @@ -511,30 +513,8 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { } } -// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { - switch opt { - case tcpip.ReceiveBufferSizeOption: - // Make sure the receive buffer size is within the min and max - // allowed. - var rs stack.ReceiveBufferSizeOption - if err := e.stack.Option(&rs); err != nil { - panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err)) - } - if v > rs.Max { - v = rs.Max - } - if v < rs.Min { - v = rs.Min - } - e.rcvMu.Lock() - e.rcvBufSizeMax = v - e.rcvMu.Unlock() - return nil - - default: - return &tcpip.ErrUnknownProtocolOption{} - } + return &tcpip.ErrUnknownProtocolOption{} } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. @@ -555,12 +535,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - v := e.rcvBufSizeMax - e.rcvMu.Unlock() - return v, nil - default: return -1, &tcpip.ErrUnknownProtocolOption{} } @@ -587,7 +561,8 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - if e.rcvBufSize >= e.rcvBufSizeMax { + rcvBufSize := e.ops.GetReceiveBufferSize() + if e.frozen || e.rcvBufSize >= int(rcvBufSize) { e.rcvMu.Unlock() e.mu.RUnlock() e.stack.Stats().DroppedPackets.Increment() @@ -690,3 +665,18 @@ func (*endpoint) LastError() tcpip.Error { func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } + +// freeze prevents any more packets from being delivered to the endpoint. +func (e *endpoint) freeze() { + e.mu.Lock() + e.frozen = true + e.mu.Unlock() +} + +// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows +// new packets to be delivered again. +func (e *endpoint) thaw() { + e.mu.Lock() + e.frozen = false + e.mu.Unlock() +} diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 263ec5146..5d6f2709c 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -36,40 +36,21 @@ func (p *rawPacket) loadData(data buffer.VectorisedView) { p.data = data } -// beforeSave is invoked by stateify. -func (e *endpoint) beforeSave() { - // Stop incoming packets from being handled (and mutate endpoint state). - // The lock will be released after saveRcvBufSizeMax(), which would have - // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming - // packets. - e.rcvMu.Lock() -} - -// saveRcvBufSizeMax is invoked by stateify. -func (e *endpoint) saveRcvBufSizeMax() int { - max := e.rcvBufSizeMax - // Make sure no new packets will be handled regardless of the lock. - e.rcvBufSizeMax = 0 - // Release the lock acquired in beforeSave() so regular endpoint closing - // logic can proceed after save. - e.rcvMu.Unlock() - return max -} - -// loadRcvBufSizeMax is invoked by stateify. -func (e *endpoint) loadRcvBufSizeMax(max int) { - e.rcvBufSizeMax = max -} - // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { stack.StackFromEnv.RegisterRestoredEndpoint(e) } +// beforeSave is invoked by stateify. +func (e *endpoint) beforeSave() { + e.freeze() +} + // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.thaw() e.stack = s - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) // If the endpoint is connected, re-connect. if e.connected { |