diff options
author | Adin Scannell <ascannell@google.com> | 2021-07-01 15:05:28 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-07-01 15:07:56 -0700 |
commit | 16b751b6c610ec2c5a913cb8a818e9239ee7da71 (patch) | |
tree | 5596ea010c6afbbe79d1196197cd4bfc5d517e79 /pkg/tcpip | |
parent | 570ca571805d6939c4c24b6a88660eefaf558ae7 (diff) |
Mix checklocks and atomic analyzers.
This change makes the checklocks analyzer considerable more powerful, adding:
* The ability to traverse complex structures, e.g. to have multiple nested
fields as part of the annotation.
* The ability to resolve simple anonymous functions and closures, and perform
lock analysis across these invocations. This does not apply to closures that
are passed elsewhere, since it is not possible to know the context in which
they might be invoked.
* The ability to annotate return values in addition to receivers and other
parameters, with the same complex structures noted above.
* Ignoring locking semantics for "fresh" objects, i.e. objects that are
allocated in the local frame (typically a new-style function).
* Sanity checking of locking state across block transitions and returns, to
ensure that no unexpected locks are held.
Note that initially, most of these findings are excluded by a comprehensive
nogo.yaml. The findings that are included are fundamental lock violations.
The changes here should be relatively low risk, minor refactorings to either
include necessary annotations to simplify the code structure (in general
removing closures in favor of methods) so that the analyzer can be easily
track the lock state.
This change additional includes two changes to nogo itself:
* Sanity checking of all types to ensure that the binary and ast-derived
types have a consistent objectpath, to prevent the bug above from occurring
silently (and causing much confusion). This also requires a trick in
order to ensure that serialized facts are consumable downstream. This can
be removed with https://go-review.googlesource.com/c/tools/+/331789 merged.
* A minor refactoring to isolation the objdump settings in its own package.
This was originally used to implement the sanity check above, but this
information is now being passed another way. The minor refactor is preserved
however, since it cleans up the code slightly and is minimal risk.
PiperOrigin-RevId: 382613300
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/stack/addressable_endpoint_state.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/conntrack.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 74 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/dispatcher.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 140 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/forwarder.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 150 |
10 files changed, 220 insertions, 181 deletions
diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go index 0b51563cd..1261ad414 100644 --- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go +++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go @@ -126,7 +126,7 @@ func (m *mockMulticastGroupProtocol) sendQueuedReports() { // Precondition: m.mu must be read locked. func (m *mockMulticastGroupProtocol) Enabled() bool { if m.mu.TryLock() { - m.mu.Unlock() + m.mu.Unlock() // +checklocksforce: TryLock. m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled") } @@ -138,11 +138,11 @@ func (m *mockMulticastGroupProtocol) Enabled() bool { // Precondition: m.mu must be locked. func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) { if m.mu.TryLock() { - m.mu.Unlock() + m.mu.Unlock() // +checklocksforce: TryLock. m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) } if m.mu.TryRLock() { - m.mu.RUnlock() + m.mu.RUnlock() // +checklocksforce: TryLock. m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) } @@ -155,11 +155,11 @@ func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (boo // Precondition: m.mu must be locked. func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip.Error { if m.mu.TryLock() { - m.mu.Unlock() + m.mu.Unlock() // +checklocksforce: TryLock. m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) } if m.mu.TryRLock() { - m.mu.RUnlock() + m.mu.RUnlock() // +checklocksforce: TryLock. m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) } diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index ce9cebdaa..ae0bb4ace 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -249,7 +249,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address // or we are adding a new temporary or permanent address. // // The address MUST be write locked at this point. - defer addrState.mu.Unlock() + defer addrState.mu.Unlock() // +checklocksforce if permanent { if addrState.mu.kind.IsPermanent() { diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 782e74b24..068dab7ce 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -363,7 +363,7 @@ func (ct *ConnTrack) insertConn(conn *conn) { // Unlocking can happen in any order. ct.buckets[tupleBucket].mu.Unlock() if tupleBucket != replyBucket { - ct.buckets[replyBucket].mu.Unlock() + ct.buckets[replyBucket].mu.Unlock() // +checklocksforce } } @@ -626,7 +626,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo // Don't re-unlock if both tuples are in the same bucket. if differentBuckets { - ct.buckets[replyBucket].mu.Unlock() + ct.buckets[replyBucket].mu.Unlock() // +checklocksforce } return true diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index cb316d27a..f9a15efb2 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -213,6 +213,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // reacquire the mutex in exclusive mode. // // Returns true for retry if preparation should be retried. +// +checklocks:e.mu func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { switch e.state { case stateInitial: @@ -229,10 +230,8 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip } e.mu.RUnlock() - defer e.mu.RLock() - e.mu.Lock() - defer e.mu.Unlock() + defer e.mu.DowngradeLock() // The state changed when we released the shared locked and re-acquired // it in exclusive mode. Try again. diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index d807b13b7..aa413ad05 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -330,7 +330,9 @@ func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, } ep := h.ep - if err := h.complete(); err != nil { + // N.B. the endpoint is generated above by startHandshake, and will be + // returned locked. This first call is forced. + if err := h.complete(); err != nil { // +checklocksforce ep.stack.Stats().TCP.FailedConnectionAttempts.Increment() ep.stats.FailedConnectionAttempts.Increment() l.cleanupFailedHandshake(h) @@ -364,6 +366,7 @@ func (l *listenContext) closeAllPendingEndpoints() { } // Precondition: h.ep.mu must be held. +// +checklocks:h.ep.mu func (l *listenContext) cleanupFailedHandshake(h *handshake) { e := h.ep e.mu.Unlock() @@ -504,7 +507,9 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header } go func() { - if err := h.complete(); err != nil { + // Note that startHandshake returns a locked endpoint. The + // force call here just makes it so. + if err := h.complete(); err != nil { // +checklocksforce e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() ctx.cleanupFailedHandshake(h) diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index e39d1623d..93ed161f9 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -511,6 +511,7 @@ func (h *handshake) start() { } // complete completes the TCP 3-way handshake initiated by h.start(). +// +checklocks:h.ep.mu func (h *handshake) complete() tcpip.Error { // Set up the wakers. var s sleep.Sleeper @@ -1283,42 +1284,45 @@ func (e *endpoint) disableKeepaliveTimer() { e.keepalive.Unlock() } -// protocolMainLoop is the main loop of the TCP protocol. It runs in its own -// goroutine and is responsible for sending segments and handling received -// segments. -func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error { - e.mu.Lock() - var closeTimer tcpip.Timer - var closeWaker sleep.Waker - - epilogue := func() { - // e.mu is expected to be hold upon entering this section. - if e.snd != nil { - e.snd.resendTimer.cleanup() - e.snd.probeTimer.cleanup() - e.snd.reorderTimer.cleanup() - } +// protocolMainLoopDone is called at the end of protocolMainLoop. +// +checklocksrelease:e.mu +func (e *endpoint) protocolMainLoopDone(closeTimer tcpip.Timer, closeWaker *sleep.Waker) { + if e.snd != nil { + e.snd.resendTimer.cleanup() + e.snd.probeTimer.cleanup() + e.snd.reorderTimer.cleanup() + } - if closeTimer != nil { - closeTimer.Stop() - } + if closeTimer != nil { + closeTimer.Stop() + } - e.completeWorkerLocked() + e.completeWorkerLocked() - if e.drainDone != nil { - close(e.drainDone) - } + if e.drainDone != nil { + close(e.drainDone) + } - e.mu.Unlock() + e.mu.Unlock() - e.drainClosingSegmentQueue() + e.drainClosingSegmentQueue() - // When the protocol loop exits we should wake up our waiters. - e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) - } + // When the protocol loop exits we should wake up our waiters. + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) +} +// protocolMainLoop is the main loop of the TCP protocol. It runs in its own +// goroutine and is responsible for sending segments and handling received +// segments. +func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error { + var ( + closeTimer tcpip.Timer + closeWaker sleep.Waker + ) + + e.mu.Lock() if handshake { - if err := e.h.complete(); err != nil { + if err := e.h.complete(); err != nil { // +checklocksforce e.lastErrorMu.Lock() e.lastError = err e.lastErrorMu.Unlock() @@ -1327,8 +1331,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.hardError = err e.workerCleanup = true - // Lock released below. - epilogue() + e.protocolMainLoopDone(closeTimer, &closeWaker) return err } } @@ -1472,7 +1475,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // Only block the worker if the endpoint // is not in closed state or error state. close(e.drainDone) - e.mu.Unlock() + e.mu.Unlock() // +checklocksforce <-e.undrain e.mu.Lock() } @@ -1533,8 +1536,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ if err != nil { e.resetConnectionLocked(err) } - // Lock released below. - epilogue() } loop: @@ -1558,6 +1559,7 @@ loop: // just want to terminate the loop and cleanup the // endpoint. cleanupOnError(nil) + e.protocolMainLoopDone(closeTimer, &closeWaker) return nil case StateTimeWait: fallthrough @@ -1566,6 +1568,7 @@ loop: default: if err := funcs[v].f(); err != nil { cleanupOnError(err) + e.protocolMainLoopDone(closeTimer, &closeWaker) return nil } } @@ -1589,13 +1592,13 @@ loop: // Handle any StateError transition from StateTimeWait. if e.EndpointState() == StateError { cleanupOnError(nil) + e.protocolMainLoopDone(closeTimer, &closeWaker) return nil } e.transitionToStateCloseLocked() - // Lock released below. - epilogue() + e.protocolMainLoopDone(closeTimer, &closeWaker) // A new SYN was received during TIME_WAIT and we need to abort // the timewait and redirect the segment to the listener queue @@ -1665,6 +1668,7 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func() // should be executed after releasing the endpoint registrations. This is // done in cases where a new SYN is received during TIME_WAIT that carries // a sequence number larger than one see on the connection. +// +checklocks:e.mu func (e *endpoint) doTimeWait() (twReuse func()) { // Trigger a 2 * MSL time wait state. During this period // we will drop all incoming segments. diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index dff7cb89c..7d110516b 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -127,7 +127,7 @@ func (p *processor) start(wg *sync.WaitGroup) { case !ep.segmentQueue.empty(): p.epQ.enqueue(ep) } - ep.mu.Unlock() + ep.mu.Unlock() // +checklocksforce } else { ep.newSegmentWaker.Assert() } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 4acddc959..1ed4ba419 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -664,6 +664,7 @@ func calculateAdvertisedMSS(userMSS uint16, r *stack.Route) uint16 { // The assumption behind spinning here being that background packet processing // should not be holding the lock for long and spinning reduces latency as we // avoid an expensive sleep/wakeup of of the syscall goroutine). +// +checklocksacquire:e.mu func (e *endpoint) LockUser() { for { // Try first if the sock is locked then check if it's owned @@ -683,7 +684,7 @@ func (e *endpoint) LockUser() { continue } atomic.StoreUint32(&e.ownedByUser, 1) - return + return // +checklocksforce } } @@ -700,7 +701,7 @@ func (e *endpoint) LockUser() { // protocol goroutine altogether. // // Precondition: e.LockUser() must have been called before calling e.UnlockUser() -// +checklocks:e.mu +// +checklocksrelease:e.mu func (e *endpoint) UnlockUser() { // Lock segment queue before checking so that we avoid a race where // segments can be queued between the time we check if queue is empty @@ -736,12 +737,13 @@ func (e *endpoint) UnlockUser() { } // StopWork halts packet processing. Only to be used in tests. +// +checklocksacquire:e.mu func (e *endpoint) StopWork() { e.mu.Lock() } // ResumeWork resumes packet processing. Only to be used in tests. -// +checklocks:e.mu +// +checklocksrelease:e.mu func (e *endpoint) ResumeWork() { e.mu.Unlock() } @@ -1480,86 +1482,95 @@ func (e *endpoint) isEndpointWritableLocked() (int, tcpip.Error) { return avail, nil } -// Write writes data to the endpoint's peer. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { - // Linux completely ignores any address passed to sendto(2) for TCP sockets - // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More - // and opts.EndOfRecord are also ignored. +// readFromPayloader reads a slice from the Payloader. +// +checklocks:e.mu +// +checklocks:e.sndQueueInfo.sndQueueMu +func (e *endpoint) readFromPayloader(p tcpip.Payloader, opts tcpip.WriteOptions, avail int) ([]byte, tcpip.Error) { + // We can release locks while copying data. + // + // This is not possible if atomic is set, because we can't allow the + // available buffer space to be consumed by some other caller while we + // are copying data in. + if !opts.Atomic { + e.sndQueueInfo.sndQueueMu.Unlock() + defer e.sndQueueInfo.sndQueueMu.Lock() - e.LockUser() - defer e.UnlockUser() + e.UnlockUser() + defer e.LockUser() + } - nextSeg, n, err := func() (*segment, int, tcpip.Error) { - e.sndQueueInfo.sndQueueMu.Lock() - defer e.sndQueueInfo.sndQueueMu.Unlock() + // Fetch data. + if l := p.Len(); l < avail { + avail = l + } + if avail == 0 { + return nil, nil + } + v := make([]byte, avail) + n, err := p.Read(v) + if err != nil && err != io.EOF { + return nil, &tcpip.ErrBadBuffer{} + } + return v[:n], nil +} + +// queueSegment reads data from the payloader and returns a segment to be sent. +// +checklocks:e.mu +func (e *endpoint) queueSegment(p tcpip.Payloader, opts tcpip.WriteOptions) (*segment, int, tcpip.Error) { + e.sndQueueInfo.sndQueueMu.Lock() + defer e.sndQueueInfo.sndQueueMu.Unlock() + + avail, err := e.isEndpointWritableLocked() + if err != nil { + e.stats.WriteErrors.WriteClosed.Increment() + return nil, 0, err + } + v, err := e.readFromPayloader(p, opts, avail) + if err != nil { + return nil, 0, err + } + if !opts.Atomic { + // Since we released locks in between it's possible that the + // endpoint transitioned to a CLOSED/ERROR states so make + // sure endpoint is still writable before trying to write. avail, err := e.isEndpointWritableLocked() if err != nil { e.stats.WriteErrors.WriteClosed.Increment() return nil, 0, err } - v, err := func() ([]byte, tcpip.Error) { - // We can release locks while copying data. - // - // This is not possible if atomic is set, because we can't allow the - // available buffer space to be consumed by some other caller while we - // are copying data in. - if !opts.Atomic { - e.sndQueueInfo.sndQueueMu.Unlock() - defer e.sndQueueInfo.sndQueueMu.Lock() - - e.UnlockUser() - defer e.LockUser() - } - - // Fetch data. - if l := p.Len(); l < avail { - avail = l - } - if avail == 0 { - return nil, nil - } - v := make([]byte, avail) - n, err := p.Read(v) - if err != nil && err != io.EOF { - return nil, &tcpip.ErrBadBuffer{} - } - return v[:n], nil - }() - if len(v) == 0 || err != nil { - return nil, 0, err + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] } + } - if !opts.Atomic { - // Since we released locks in between it's possible that the - // endpoint transitioned to a CLOSED/ERROR states so make - // sure endpoint is still writable before trying to write. - avail, err := e.isEndpointWritableLocked() - if err != nil { - e.stats.WriteErrors.WriteClosed.Increment() - return nil, 0, err - } + // Add data to the send queue. + s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), v) + e.sndQueueInfo.SndBufUsed += len(v) + e.snd.writeList.PushBack(s) - // Discard any excess data copied in due to avail being reduced due - // to a simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] - } - } + return s, len(v), nil +} - // Add data to the send queue. - s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), v) - e.sndQueueInfo.SndBufUsed += len(v) - e.snd.writeList.PushBack(s) +// Write writes data to the endpoint's peer. +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { + // Linux completely ignores any address passed to sendto(2) for TCP sockets + // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More + // and opts.EndOfRecord are also ignored. + + e.LockUser() + defer e.UnlockUser() - return s, len(v), nil - }() // Return if either we didn't queue anything or if an error occurred while // attempting to queue data. + nextSeg, n, err := e.queueSegment(p, opts) if n == 0 || err != nil { return 0, err } + e.sendData(nextSeg) return int64(n), nil } @@ -2504,6 +2515,7 @@ func (e *endpoint) listen(backlog int) tcpip.Error { // startAcceptedLoop sets up required state and starts a goroutine with the // main loop for accepted connections. +// +checklocksrelease:e.mu func (e *endpoint) startAcceptedLoop() { e.workerRunning = true e.mu.Unlock() diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 65c86823a..2e709ed78 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -164,8 +164,9 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, return nil, err } - // Start the protocol goroutine. - ep.startAcceptedLoop() + // Start the protocol goroutine. Note that the endpoint is returned + // from performHandshake locked. + ep.startAcceptedLoop() // +checklocksforce return ep, nil } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index def9d7186..82a3f2287 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -364,6 +364,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // reacquire the mutex in exclusive mode. // // Returns true for retry if preparation should be retried. +// +checklocks:e.mu func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { switch e.EndpointState() { case StateInitial: @@ -380,10 +381,8 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip } e.mu.RUnlock() - defer e.mu.RLock() - e.mu.Lock() - defer e.mu.Unlock() + defer e.mu.DowngradeLock() // The state changed when we released the shared locked and re-acquired // it in exclusive mode. Try again. @@ -449,37 +448,20 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { - if err := e.LastError(); err != nil { - return 0, err - } - - // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) - if opts.More { - return 0, &tcpip.ErrInvalidOptionValue{} - } - - to := opts.To - +func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions) (udpPacketInfo, tcpip.Error) { e.mu.RLock() - lockReleased := false - defer func() { - if lockReleased { - return - } - e.mu.RUnlock() - }() + defer e.mu.RUnlock() // If we've shutdown with SHUT_WR we are in an invalid state for sending. if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - return 0, &tcpip.ErrClosedForSend{} + return udpPacketInfo{}, &tcpip.ErrClosedForSend{} } // Prepare for write. for { - retry, err := e.prepareForWrite(to) + retry, err := e.prepareForWrite(opts.To) if err != nil { - return 0, err + return udpPacketInfo{}, err } if !retry { @@ -489,34 +471,34 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp route := e.route dstPort := e.dstPort - if to != nil { + if opts.To != nil { // Reject destination address if it goes through a different // NIC than the endpoint was bound to. - nicID := to.NIC + nicID := opts.To.NIC if nicID == 0 { nicID = tcpip.NICID(e.ops.GetBindToDevice()) } if e.BindNICID != 0 { if nicID != 0 && nicID != e.BindNICID { - return 0, &tcpip.ErrNoRoute{} + return udpPacketInfo{}, &tcpip.ErrNoRoute{} } nicID = e.BindNICID } - if to.Port == 0 { + if opts.To.Port == 0 { // Port 0 is an invalid port to send to. - return 0, &tcpip.ErrInvalidEndpointState{} + return udpPacketInfo{}, &tcpip.ErrInvalidEndpointState{} } - dst, netProto, err := e.checkV4MappedLocked(*to) + dst, netProto, err := e.checkV4MappedLocked(*opts.To) if err != nil { - return 0, err + return udpPacketInfo{}, err } r, _, err := e.connectRoute(nicID, dst, netProto) if err != nil { - return 0, err + return udpPacketInfo{}, err } defer r.Release() @@ -525,12 +507,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() { - return 0, &tcpip.ErrBroadcastDisabled{} + return udpPacketInfo{}, &tcpip.ErrBroadcastDisabled{} } v := make([]byte, p.Len()) if _, err := io.ReadFull(p, v); err != nil { - return 0, &tcpip.ErrBadBuffer{} + return udpPacketInfo{}, &tcpip.ErrBadBuffer{} } if len(v) > header.UDPMaximumPacketSize { // Payload can't possibly fit in a packet. @@ -548,24 +530,39 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp v, ) } - return 0, &tcpip.ErrMessageTooLong{} + return udpPacketInfo{}, &tcpip.ErrMessageTooLong{} } ttl := e.ttl useDefaultTTL := ttl == 0 - if header.IsV4MulticastAddress(route.RemoteAddress()) || header.IsV6MulticastAddress(route.RemoteAddress()) { ttl = e.multicastTTL // Multicast allows a 0 TTL. useDefaultTTL = false } - localPort := e.ID.LocalPort - sendTOS := e.sendTOS - owner := e.owner - noChecksum := e.SocketOptions().GetNoChecksum() - lockReleased = true - e.mu.RUnlock() + return udpPacketInfo{ + route: route, + data: buffer.View(v), + localPort: e.ID.LocalPort, + remotePort: dstPort, + ttl: ttl, + useDefaultTTL: useDefaultTTL, + tos: e.sendTOS, + owner: e.owner, + noChecksum: e.SocketOptions().GetNoChecksum(), + }, nil +} + +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { + if err := e.LastError(); err != nil { + return 0, err + } + + // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) + if opts.More { + return 0, &tcpip.ErrInvalidOptionValue{} + } // Do not hold lock when sending as loopback is synchronous and if the UDP // datagram ends up generating an ICMP response then it can result in a @@ -577,10 +574,15 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp // // See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read // locking is prohibited. - if err := sendUDP(route, buffer.View(v).ToVectorisedView(), localPort, dstPort, ttl, useDefaultTTL, sendTOS, owner, noChecksum); err != nil { + u, err := e.buildUDPPacketInfo(p, opts) + if err != nil { return 0, err } - return int64(len(v)), nil + n, err := u.send() + if err != nil { + return 0, err + } + return int64(n), nil } // OnReuseAddressSet implements tcpip.SocketOptionsHandler. @@ -817,14 +819,30 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { return nil } -// sendUDP sends a UDP segment via the provided network endpoint and under the -// provided identity. -func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) tcpip.Error { +// udpPacketInfo contains all information required to send a UDP packet. +// +// This should be used as a value-only type, which exists in order to simplify +// return value syntax. It should not be exported or extended. +type udpPacketInfo struct { + route *stack.Route + data buffer.View + localPort uint16 + remotePort uint16 + ttl uint8 + useDefaultTTL bool + tos uint8 + owner tcpip.PacketOwner + noChecksum bool +} + +// send sends the given packet. +func (u *udpPacketInfo) send() (int, tcpip.Error) { + vv := u.data.ToVectorisedView() pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()), - Data: data, + ReserveHeaderBytes: header.UDPMinimumSize + int(u.route.MaxHeaderLength()), + Data: vv, }) - pkt.Owner = owner + pkt.Owner = u.owner // Initialize the UDP header. udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) @@ -832,8 +850,8 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u length := uint16(pkt.Size()) udp.Encode(&header.UDPFields{ - SrcPort: localPort, - DstPort: remotePort, + SrcPort: u.localPort, + DstPort: u.remotePort, Length: length, }) @@ -841,30 +859,30 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u // On IPv4, UDP checksum is optional, and a zero value indicates the // transmitter skipped the checksum generation (RFC768). // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). - if r.RequiresTXTransportChecksum() && - (!noChecksum || r.NetProto() == header.IPv6ProtocolNumber) { - xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) - for _, v := range data.Views() { + if u.route.RequiresTXTransportChecksum() && + (!u.noChecksum || u.route.NetProto() == header.IPv6ProtocolNumber) { + xsum := u.route.PseudoHeaderChecksum(ProtocolNumber, length) + for _, v := range vv.Views() { xsum = header.Checksum(v, xsum) } udp.SetChecksum(^udp.CalculateChecksum(xsum)) } - if useDefaultTTL { - ttl = r.DefaultTTL() + if u.useDefaultTTL { + u.ttl = u.route.DefaultTTL() } - if err := r.WritePacket(stack.NetworkHeaderParams{ + if err := u.route.WritePacket(stack.NetworkHeaderParams{ Protocol: ProtocolNumber, - TTL: ttl, - TOS: tos, + TTL: u.ttl, + TOS: u.tos, }, pkt); err != nil { - r.Stats().UDP.PacketSendErrors.Increment() - return err + u.route.Stats().UDP.PacketSendErrors.Increment() + return 0, err } // Track count of packets sent. - r.Stats().UDP.PacketsSent.Increment() - return nil + u.route.Stats().UDP.PacketsSent.Increment() + return len(u.data), nil } // checkV4MappedLocked determines the effective network protocol and converts |