diff options
-rw-r--r-- | pkg/tcpip/transport/ping/endpoint.go | 26 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 15 |
3 files changed, 31 insertions, 12 deletions
diff --git a/pkg/tcpip/transport/ping/endpoint.go b/pkg/tcpip/transport/ping/endpoint.go index a22684de9..f097ac057 100644 --- a/pkg/tcpip/transport/ping/endpoint.go +++ b/pkg/tcpip/transport/ping/endpoint.go @@ -71,12 +71,14 @@ type endpoint struct { // The following fields are protected by the mu mutex. mu sync.RWMutex `state:"nosave"` sndBufSize int - id stack.TransportEndpointID - state endpointState - bindNICID tcpip.NICID - bindAddr tcpip.Address - regNICID tcpip.NICID - route stack.Route `state:"manual"` + // shutdownFlags represent the current shutdown state of the endpoint. + shutdownFlags tcpip.ShutdownFlags + id stack.TransportEndpointID + state endpointState + bindNICID tcpip.NICID + bindAddr tcpip.Address + regNICID tcpip.NICID + route stack.Route `state:"manual"` } func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { @@ -93,7 +95,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite // associated with it. func (e *endpoint) Close() { e.mu.Lock() - + e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, ProtocolNumber4, e.id) @@ -205,6 +207,11 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc e.mu.RLock() defer e.mu.RUnlock() + // If we've shutdown with SHUT_WR we are in an invalid state for sending. + if e.shutdownFlags&tcpip.ShutdownWrite != 0 { + return 0, tcpip.ErrClosedForSend + } + // Prepare for write. for { retry, err := e.prepareForWrite(to) @@ -465,8 +472,9 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { // Shutdown closes the read and/or write end of the endpoint connection // to its peer. func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { - e.mu.RLock() - defer e.mu.RUnlock() + e.mu.Lock() + defer e.mu.Unlock() + e.shutdownFlags |= flags if e.state != stateConnected { return tcpip.ErrNotConnected diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 8bfb68f91..e1b71e423 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1024,7 +1024,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { } default: - return tcpip.ErrInvalidEndpointState + return tcpip.ErrNotConnected } return nil diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 03fb76f92..b2d7f9779 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -80,6 +80,9 @@ type endpoint struct { dstPort uint16 v6only bool + // shutdownFlags represent the current shutdown state of the endpoint. + shutdownFlags tcpip.ShutdownFlags + // effectiveNetProtos contains the network protocols actually in use. In // most cases it will only contain "netProto", but in cases like IPv6 // endpoints with v6only set to false, this could include multiple @@ -124,6 +127,7 @@ func NewConnectedEndpoint(stack *stack.Stack, r *stack.Route, id stack.Transport // associated with it. func (e *endpoint) Close() { e.mu.Lock() + e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: @@ -236,6 +240,11 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc e.mu.RLock() defer e.mu.RUnlock() + // If we've shutdown with SHUT_WR we are in an invalid state for sending. + if e.shutdownFlags&tcpip.ShutdownWrite != 0 { + return 0, tcpip.ErrClosedForSend + } + // Prepare for write. for { retry, err := e.prepareForWrite(to) @@ -562,13 +571,15 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { // Shutdown closes the read and/or write end of the endpoint connection // to its peer. func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { - e.mu.RLock() - defer e.mu.RUnlock() + e.mu.Lock() + defer e.mu.Unlock() if e.state != stateConnected { return tcpip.ErrNotConnected } + e.shutdownFlags |= flags + if flags&tcpip.ShutdownRead != 0 { e.rcvMu.Lock() wasClosed := e.rcvClosed |