diff options
Diffstat (limited to 'pkg/tcpip/transport/tcp/endpoint.go')
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 47 |
1 files changed, 35 insertions, 12 deletions
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 4f4f4c65e..a2161e49d 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -315,11 +315,6 @@ func (*Stats) IsEndpointStats() {} // +stateify savable type EndpointInfo struct { stack.TransportEndpointInfo - - // HardError is meaningful only when state is stateError. It stores the - // error to be returned when read/write syscalls are called and the - // endpoint is in this state. HardError is protected by endpoint mu. - HardError *tcpip.Error `state:".(string)"` } // IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo @@ -386,6 +381,11 @@ type endpoint struct { waiterQueue *waiter.Queue `state:"wait"` uniqueID uint64 + // hardError is meaningful only when state is stateError. It stores the + // error to be returned when read/write syscalls are called and the + // endpoint is in this state. hardError is protected by endpoint mu. + hardError *tcpip.Error `state:".(string)"` + // lastError represents the last error that the endpoint reported; // access to it is protected by the following mutex. lastErrorMu sync.Mutex `state:"nosave"` @@ -1283,7 +1283,15 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -func (e *endpoint) LastError() *tcpip.Error { +// Preconditions: e.mu must be held to call this function. +func (e *endpoint) hardErrorLocked() *tcpip.Error { + err := e.hardError + e.hardError = nil + return err +} + +// Preconditions: e.mu must be held to call this function. +func (e *endpoint) lastErrorLocked() *tcpip.Error { e.lastErrorMu.Lock() defer e.lastErrorMu.Unlock() err := e.lastError @@ -1291,6 +1299,15 @@ func (e *endpoint) LastError() *tcpip.Error { return err } +func (e *endpoint) LastError() *tcpip.Error { + e.LockUser() + defer e.UnlockUser() + if err := e.hardErrorLocked(); err != nil { + return err + } + return e.lastErrorLocked() +} + // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.LockUser() @@ -1312,9 +1329,8 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, bufUsed := e.rcvBufUsed if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() - he := e.HardError if s == StateError { - return buffer.View{}, tcpip.ControlMessages{}, he + return buffer.View{}, tcpip.ControlMessages{}, e.hardErrorLocked() } e.stats.ReadErrors.NotConnected.Increment() return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected @@ -1370,9 +1386,13 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { // indicating the reason why it's not writable. // Caller must hold e.mu and e.sndBufMu func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { + // The endpoint cannot be written to if it's not connected. switch s := e.EndpointState(); { case s == StateError: - return 0, e.HardError + if err := e.hardErrorLocked(); err != nil { + return 0, err + } + return 0, tcpip.ErrClosedForSend case !s.connecting() && !s.connected(): return 0, tcpip.ErrClosedForSend case s.connecting(): @@ -1486,7 +1506,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // but has some pending unread data. if s := e.EndpointState(); !s.connected() && s != StateClose { if s == StateError { - return 0, tcpip.ControlMessages{}, e.HardError + return 0, tcpip.ControlMessages{}, e.hardErrorLocked() } e.stats.ReadErrors.InvalidEndpointState.Increment() return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState @@ -2243,7 +2263,10 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc return tcpip.ErrAlreadyConnecting case StateError: - return e.HardError + if err := e.hardErrorLocked(); err != nil { + return err + } + return tcpip.ErrConnectionAborted default: return tcpip.ErrInvalidEndpointState @@ -2417,7 +2440,7 @@ func (e *endpoint) startMainLoop(handshake bool) *tcpip.Error { e.lastErrorMu.Unlock() e.setEndpointState(StateError) - e.HardError = err + e.hardError = err // Call cleanupLocked to free up any reservations. e.cleanupLocked() |