diff options
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 29 |
1 files changed, 26 insertions, 3 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 57976d4e3..835dcc54e 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -429,7 +429,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c to := opts.To e.mu.RLock() - defer e.mu.RUnlock() + lockReleased := false + defer func() { + if lockReleased { + return + } + e.mu.RUnlock() + }() // If we've shutdown with SHUT_WR we are in an invalid state for sending. if e.shutdownFlags&tcpip.ShutdownWrite != 0 { @@ -475,7 +481,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c if e.state != StateConnected { err = tcpip.ErrInvalidEndpointState } - return + return ch, err } } else { // Reject destination address if it goes through a different @@ -541,7 +547,24 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c useDefaultTTL = false } - if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner, e.noChecksum); err != nil { + localPort := e.ID.LocalPort + sendTOS := e.sendTOS + owner := e.owner + noChecksum := e.noChecksum + lockReleased = true + e.mu.RUnlock() + + // Do not hold lock when sending as loopback is synchronous and if the UDP + // datagram ends up generating an ICMP response then it can result in a + // deadlock where the ICMP response handling ends up acquiring this endpoint's + // mutex using e.mu.RLock() in endpoint.HandleControlPacket which can cause a + // deadlock if another caller is trying to acquire e.mu in exclusive mode w/ + // e.mu.Lock(). Since e.mu.Lock() prevents any new read locks to ensure the + // lock can be eventually acquired. + // + // See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read + // locking is prohibited. + if err := sendUDP(route, buffer.View(v).ToVectorisedView(), localPort, dstPort, ttl, useDefaultTTL, sendTOS, owner, noChecksum); err != nil { return 0, nil, err } return int64(len(v)), nil, nil |