summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/transport/ping/endpoint.go26
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go2
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go15
3 files changed, 31 insertions, 12 deletions
diff --git a/pkg/tcpip/transport/ping/endpoint.go b/pkg/tcpip/transport/ping/endpoint.go
index a22684de9..f097ac057 100644
--- a/pkg/tcpip/transport/ping/endpoint.go
+++ b/pkg/tcpip/transport/ping/endpoint.go
@@ -71,12 +71,14 @@ type endpoint struct {
// The following fields are protected by the mu mutex.
mu sync.RWMutex `state:"nosave"`
sndBufSize int
- id stack.TransportEndpointID
- state endpointState
- bindNICID tcpip.NICID
- bindAddr tcpip.Address
- regNICID tcpip.NICID
- route stack.Route `state:"manual"`
+ // shutdownFlags represent the current shutdown state of the endpoint.
+ shutdownFlags tcpip.ShutdownFlags
+ id stack.TransportEndpointID
+ state endpointState
+ bindNICID tcpip.NICID
+ bindAddr tcpip.Address
+ regNICID tcpip.NICID
+ route stack.Route `state:"manual"`
}
func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
@@ -93,7 +95,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite
// associated with it.
func (e *endpoint) Close() {
e.mu.Lock()
-
+ e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, ProtocolNumber4, e.id)
@@ -205,6 +207,11 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc
e.mu.RLock()
defer e.mu.RUnlock()
+ // If we've shutdown with SHUT_WR we are in an invalid state for sending.
+ if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
+ return 0, tcpip.ErrClosedForSend
+ }
+
// Prepare for write.
for {
retry, err := e.prepareForWrite(to)
@@ -465,8 +472,9 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
// Shutdown closes the read and/or write end of the endpoint connection
// to its peer.
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.shutdownFlags |= flags
if e.state != stateConnected {
return tcpip.ErrNotConnected
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 8bfb68f91..e1b71e423 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1024,7 +1024,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
}
default:
- return tcpip.ErrInvalidEndpointState
+ return tcpip.ErrNotConnected
}
return nil
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 03fb76f92..b2d7f9779 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -80,6 +80,9 @@ type endpoint struct {
dstPort uint16
v6only bool
+ // shutdownFlags represent the current shutdown state of the endpoint.
+ shutdownFlags tcpip.ShutdownFlags
+
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
// endpoints with v6only set to false, this could include multiple
@@ -124,6 +127,7 @@ func NewConnectedEndpoint(stack *stack.Stack, r *stack.Route, id stack.Transport
// associated with it.
func (e *endpoint) Close() {
e.mu.Lock()
+ e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
@@ -236,6 +240,11 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc
e.mu.RLock()
defer e.mu.RUnlock()
+ // If we've shutdown with SHUT_WR we are in an invalid state for sending.
+ if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
+ return 0, tcpip.ErrClosedForSend
+ }
+
// Prepare for write.
for {
retry, err := e.prepareForWrite(to)
@@ -562,13 +571,15 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
// Shutdown closes the read and/or write end of the endpoint connection
// to its peer.
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.mu.Lock()
+ defer e.mu.Unlock()
if e.state != stateConnected {
return tcpip.ErrNotConnected
}
+ e.shutdownFlags |= flags
+
if flags&tcpip.ShutdownRead != 0 {
e.rcvMu.Lock()
wasClosed := e.rcvClosed