diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/stack/addressable_endpoint_state.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/conntrack.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 74 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/dispatcher.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 140 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/forwarder.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 150 |
10 files changed, 220 insertions, 181 deletions
diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go index 0b51563cd..1261ad414 100644 --- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go +++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go @@ -126,7 +126,7 @@ func (m *mockMulticastGroupProtocol) sendQueuedReports() { // Precondition: m.mu must be read locked. func (m *mockMulticastGroupProtocol) Enabled() bool { if m.mu.TryLock() { - m.mu.Unlock() + m.mu.Unlock() // +checklocksforce: TryLock. m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled") } @@ -138,11 +138,11 @@ func (m *mockMulticastGroupProtocol) Enabled() bool { // Precondition: m.mu must be locked. func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) { if m.mu.TryLock() { - m.mu.Unlock() + m.mu.Unlock() // +checklocksforce: TryLock. m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) } if m.mu.TryRLock() { - m.mu.RUnlock() + m.mu.RUnlock() // +checklocksforce: TryLock. m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) } @@ -155,11 +155,11 @@ func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (boo // Precondition: m.mu must be locked. func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip.Error { if m.mu.TryLock() { - m.mu.Unlock() + m.mu.Unlock() // +checklocksforce: TryLock. m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) } if m.mu.TryRLock() { - m.mu.RUnlock() + m.mu.RUnlock() // +checklocksforce: TryLock. m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) } diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index ce9cebdaa..ae0bb4ace 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -249,7 +249,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address // or we are adding a new temporary or permanent address. // // The address MUST be write locked at this point. - defer addrState.mu.Unlock() + defer addrState.mu.Unlock() // +checklocksforce if permanent { if addrState.mu.kind.IsPermanent() { diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 782e74b24..068dab7ce 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -363,7 +363,7 @@ func (ct *ConnTrack) insertConn(conn *conn) { // Unlocking can happen in any order. ct.buckets[tupleBucket].mu.Unlock() if tupleBucket != replyBucket { - ct.buckets[replyBucket].mu.Unlock() + ct.buckets[replyBucket].mu.Unlock() // +checklocksforce } } @@ -626,7 +626,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo // Don't re-unlock if both tuples are in the same bucket. if differentBuckets { - ct.buckets[replyBucket].mu.Unlock() + ct.buckets[replyBucket].mu.Unlock() // +checklocksforce } return true diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index cb316d27a..f9a15efb2 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -213,6 +213,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // reacquire the mutex in exclusive mode. // // Returns true for retry if preparation should be retried. +// +checklocks:e.mu func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { switch e.state { case stateInitial: @@ -229,10 +230,8 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip } e.mu.RUnlock() - defer e.mu.RLock() - e.mu.Lock() - defer e.mu.Unlock() + defer e.mu.DowngradeLock() // The state changed when we released the shared locked and re-acquired // it in exclusive mode. Try again. diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index d807b13b7..aa413ad05 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -330,7 +330,9 @@ func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, } ep := h.ep - if err := h.complete(); err != nil { + // N.B. the endpoint is generated above by startHandshake, and will be + // returned locked. This first call is forced. + if err := h.complete(); err != nil { // +checklocksforce ep.stack.Stats().TCP.FailedConnectionAttempts.Increment() ep.stats.FailedConnectionAttempts.Increment() l.cleanupFailedHandshake(h) @@ -364,6 +366,7 @@ func (l *listenContext) closeAllPendingEndpoints() { } // Precondition: h.ep.mu must be held. +// +checklocks:h.ep.mu func (l *listenContext) cleanupFailedHandshake(h *handshake) { e := h.ep e.mu.Unlock() @@ -504,7 +507,9 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header } go func() { - if err := h.complete(); err != nil { + // Note that startHandshake returns a locked endpoint. The + // force call here just makes it so. + if err := h.complete(); err != nil { // +checklocksforce e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() ctx.cleanupFailedHandshake(h) diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index e39d1623d..93ed161f9 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -511,6 +511,7 @@ func (h *handshake) start() { } // complete completes the TCP 3-way handshake initiated by h.start(). +// +checklocks:h.ep.mu func (h *handshake) complete() tcpip.Error { // Set up the wakers. var s sleep.Sleeper @@ -1283,42 +1284,45 @@ func (e *endpoint) disableKeepaliveTimer() { e.keepalive.Unlock() } -// protocolMainLoop is the main loop of the TCP protocol. It runs in its own -// goroutine and is responsible for sending segments and handling received -// segments. -func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error { - e.mu.Lock() - var closeTimer tcpip.Timer - var closeWaker sleep.Waker - - epilogue := func() { - // e.mu is expected to be hold upon entering this section. - if e.snd != nil { - e.snd.resendTimer.cleanup() - e.snd.probeTimer.cleanup() - e.snd.reorderTimer.cleanup() - } +// protocolMainLoopDone is called at the end of protocolMainLoop. +// +checklocksrelease:e.mu +func (e *endpoint) protocolMainLoopDone(closeTimer tcpip.Timer, closeWaker *sleep.Waker) { + if e.snd != nil { + e.snd.resendTimer.cleanup() + e.snd.probeTimer.cleanup() + e.snd.reorderTimer.cleanup() + } - if closeTimer != nil { - closeTimer.Stop() - } + if closeTimer != nil { + closeTimer.Stop() + } - e.completeWorkerLocked() + e.completeWorkerLocked() - if e.drainDone != nil { - close(e.drainDone) - } + if e.drainDone != nil { + close(e.drainDone) + } - e.mu.Unlock() + e.mu.Unlock() - e.drainClosingSegmentQueue() + e.drainClosingSegmentQueue() - // When the protocol loop exits we should wake up our waiters. - e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) - } + // When the protocol loop exits we should wake up our waiters. + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) +} +// protocolMainLoop is the main loop of the TCP protocol. It runs in its own +// goroutine and is responsible for sending segments and handling received +// segments. +func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error { + var ( + closeTimer tcpip.Timer + closeWaker sleep.Waker + ) + + e.mu.Lock() if handshake { - if err := e.h.complete(); err != nil { + if err := e.h.complete(); err != nil { // +checklocksforce e.lastErrorMu.Lock() e.lastError = err e.lastErrorMu.Unlock() @@ -1327,8 +1331,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.hardError = err e.workerCleanup = true - // Lock released below. - epilogue() + e.protocolMainLoopDone(closeTimer, &closeWaker) return err } } @@ -1472,7 +1475,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // Only block the worker if the endpoint // is not in closed state or error state. close(e.drainDone) - e.mu.Unlock() + e.mu.Unlock() // +checklocksforce <-e.undrain e.mu.Lock() } @@ -1533,8 +1536,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ if err != nil { e.resetConnectionLocked(err) } - // Lock released below. - epilogue() } loop: @@ -1558,6 +1559,7 @@ loop: // just want to terminate the loop and cleanup the // endpoint. cleanupOnError(nil) + e.protocolMainLoopDone(closeTimer, &closeWaker) return nil case StateTimeWait: fallthrough @@ -1566,6 +1568,7 @@ loop: default: if err := funcs[v].f(); err != nil { cleanupOnError(err) + e.protocolMainLoopDone(closeTimer, &closeWaker) return nil } } @@ -1589,13 +1592,13 @@ loop: // Handle any StateError transition from StateTimeWait. if e.EndpointState() == StateError { cleanupOnError(nil) + e.protocolMainLoopDone(closeTimer, &closeWaker) return nil } e.transitionToStateCloseLocked() - // Lock released below. - epilogue() + e.protocolMainLoopDone(closeTimer, &closeWaker) // A new SYN was received during TIME_WAIT and we need to abort // the timewait and redirect the segment to the listener queue @@ -1665,6 +1668,7 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func() // should be executed after releasing the endpoint registrations. This is // done in cases where a new SYN is received during TIME_WAIT that carries // a sequence number larger than one see on the connection. +// +checklocks:e.mu func (e *endpoint) doTimeWait() (twReuse func()) { // Trigger a 2 * MSL time wait state. During this period // we will drop all incoming segments. diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index dff7cb89c..7d110516b 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -127,7 +127,7 @@ func (p *processor) start(wg *sync.WaitGroup) { case !ep.segmentQueue.empty(): p.epQ.enqueue(ep) } - ep.mu.Unlock() + ep.mu.Unlock() // +checklocksforce } else { ep.newSegmentWaker.Assert() } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 4acddc959..1ed4ba419 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -664,6 +664,7 @@ func calculateAdvertisedMSS(userMSS uint16, r *stack.Route) uint16 { // The assumption behind spinning here being that background packet processing // should not be holding the lock for long and spinning reduces latency as we // avoid an expensive sleep/wakeup of of the syscall goroutine). +// +checklocksacquire:e.mu func (e *endpoint) LockUser() { for { // Try first if the sock is locked then check if it's owned @@ -683,7 +684,7 @@ func (e *endpoint) LockUser() { continue } atomic.StoreUint32(&e.ownedByUser, 1) - return + return // +checklocksforce } } @@ -700,7 +701,7 @@ func (e *endpoint) LockUser() { // protocol goroutine altogether. // // Precondition: e.LockUser() must have been called before calling e.UnlockUser() -// +checklocks:e.mu +// +checklocksrelease:e.mu func (e *endpoint) UnlockUser() { // Lock segment queue before checking so that we avoid a race where // segments can be queued between the time we check if queue is empty @@ -736,12 +737,13 @@ func (e *endpoint) UnlockUser() { } // StopWork halts packet processing. Only to be used in tests. +// +checklocksacquire:e.mu func (e *endpoint) StopWork() { e.mu.Lock() } // ResumeWork resumes packet processing. Only to be used in tests. -// +checklocks:e.mu +// +checklocksrelease:e.mu func (e *endpoint) ResumeWork() { e.mu.Unlock() } @@ -1480,86 +1482,95 @@ func (e *endpoint) isEndpointWritableLocked() (int, tcpip.Error) { return avail, nil } -// Write writes data to the endpoint's peer. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { - // Linux completely ignores any address passed to sendto(2) for TCP sockets - // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More - // and opts.EndOfRecord are also ignored. +// readFromPayloader reads a slice from the Payloader. +// +checklocks:e.mu +// +checklocks:e.sndQueueInfo.sndQueueMu +func (e *endpoint) readFromPayloader(p tcpip.Payloader, opts tcpip.WriteOptions, avail int) ([]byte, tcpip.Error) { + // We can release locks while copying data. + // + // This is not possible if atomic is set, because we can't allow the + // available buffer space to be consumed by some other caller while we + // are copying data in. + if !opts.Atomic { + e.sndQueueInfo.sndQueueMu.Unlock() + defer e.sndQueueInfo.sndQueueMu.Lock() - e.LockUser() - defer e.UnlockUser() + e.UnlockUser() + defer e.LockUser() + } - nextSeg, n, err := func() (*segment, int, tcpip.Error) { - e.sndQueueInfo.sndQueueMu.Lock() - defer e.sndQueueInfo.sndQueueMu.Unlock() + // Fetch data. + if l := p.Len(); l < avail { + avail = l + } + if avail == 0 { + return nil, nil + } + v := make([]byte, avail) + n, err := p.Read(v) + if err != nil && err != io.EOF { + return nil, &tcpip.ErrBadBuffer{} + } + return v[:n], nil +} + +// queueSegment reads data from the payloader and returns a segment to be sent. +// +checklocks:e.mu +func (e *endpoint) queueSegment(p tcpip.Payloader, opts tcpip.WriteOptions) (*segment, int, tcpip.Error) { + e.sndQueueInfo.sndQueueMu.Lock() + defer e.sndQueueInfo.sndQueueMu.Unlock() + + avail, err := e.isEndpointWritableLocked() + if err != nil { + e.stats.WriteErrors.WriteClosed.Increment() + return nil, 0, err + } + v, err := e.readFromPayloader(p, opts, avail) + if err != nil { + return nil, 0, err + } + if !opts.Atomic { + // Since we released locks in between it's possible that the + // endpoint transitioned to a CLOSED/ERROR states so make + // sure endpoint is still writable before trying to write. avail, err := e.isEndpointWritableLocked() if err != nil { e.stats.WriteErrors.WriteClosed.Increment() return nil, 0, err } - v, err := func() ([]byte, tcpip.Error) { - // We can release locks while copying data. - // - // This is not possible if atomic is set, because we can't allow the - // available buffer space to be consumed by some other caller while we - // are copying data in. - if !opts.Atomic { - e.sndQueueInfo.sndQueueMu.Unlock() - defer e.sndQueueInfo.sndQueueMu.Lock() - - e.UnlockUser() - defer e.LockUser() - } - - // Fetch data. - if l := p.Len(); l < avail { - avail = l - } - if avail == 0 { - return nil, nil - } - v := make([]byte, avail) - n, err := p.Read(v) - if err != nil && err != io.EOF { - return nil, &tcpip.ErrBadBuffer{} - } - return v[:n], nil - }() - if len(v) == 0 || err != nil { - return nil, 0, err + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] } + } - if !opts.Atomic { - // Since we released locks in between it's possible that the - // endpoint transitioned to a CLOSED/ERROR states so make - // sure endpoint is still writable before trying to write. - avail, err := e.isEndpointWritableLocked() - if err != nil { - e.stats.WriteErrors.WriteClosed.Increment() - return nil, 0, err - } + // Add data to the send queue. + s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), v) + e.sndQueueInfo.SndBufUsed += len(v) + e.snd.writeList.PushBack(s) - // Discard any excess data copied in due to avail being reduced due - // to a simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] - } - } + return s, len(v), nil +} - // Add data to the send queue. - s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), v) - e.sndQueueInfo.SndBufUsed += len(v) - e.snd.writeList.PushBack(s) +// Write writes data to the endpoint's peer. +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { + // Linux completely ignores any address passed to sendto(2) for TCP sockets + // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More + // and opts.EndOfRecord are also ignored. + + e.LockUser() + defer e.UnlockUser() - return s, len(v), nil - }() // Return if either we didn't queue anything or if an error occurred while // attempting to queue data. + nextSeg, n, err := e.queueSegment(p, opts) if n == 0 || err != nil { return 0, err } + e.sendData(nextSeg) return int64(n), nil } @@ -2504,6 +2515,7 @@ func (e *endpoint) listen(backlog int) tcpip.Error { // startAcceptedLoop sets up required state and starts a goroutine with the // main loop for accepted connections. +// +checklocksrelease:e.mu func (e *endpoint) startAcceptedLoop() { e.workerRunning = true e.mu.Unlock() diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 65c86823a..2e709ed78 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -164,8 +164,9 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, return nil, err } - // Start the protocol goroutine. - ep.startAcceptedLoop() + // Start the protocol goroutine. Note that the endpoint is returned + // from performHandshake locked. + ep.startAcceptedLoop() // +checklocksforce return ep, nil } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index def9d7186..82a3f2287 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -364,6 +364,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // reacquire the mutex in exclusive mode. // // Returns true for retry if preparation should be retried. +// +checklocks:e.mu func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { switch e.EndpointState() { case StateInitial: @@ -380,10 +381,8 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip } e.mu.RUnlock() - defer e.mu.RLock() - e.mu.Lock() - defer e.mu.Unlock() + defer e.mu.DowngradeLock() // The state changed when we released the shared locked and re-acquired // it in exclusive mode. Try again. @@ -449,37 +448,20 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { - if err := e.LastError(); err != nil { - return 0, err - } - - // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) - if opts.More { - return 0, &tcpip.ErrInvalidOptionValue{} - } - - to := opts.To - +func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions) (udpPacketInfo, tcpip.Error) { e.mu.RLock() - lockReleased := false - defer func() { - if lockReleased { - return - } - e.mu.RUnlock() - }() + defer e.mu.RUnlock() // If we've shutdown with SHUT_WR we are in an invalid state for sending. if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - return 0, &tcpip.ErrClosedForSend{} + return udpPacketInfo{}, &tcpip.ErrClosedForSend{} } // Prepare for write. for { - retry, err := e.prepareForWrite(to) + retry, err := e.prepareForWrite(opts.To) if err != nil { - return 0, err + return udpPacketInfo{}, err } if !retry { @@ -489,34 +471,34 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp route := e.route dstPort := e.dstPort - if to != nil { + if opts.To != nil { // Reject destination address if it goes through a different // NIC than the endpoint was bound to. - nicID := to.NIC + nicID := opts.To.NIC if nicID == 0 { nicID = tcpip.NICID(e.ops.GetBindToDevice()) } if e.BindNICID != 0 { if nicID != 0 && nicID != e.BindNICID { - return 0, &tcpip.ErrNoRoute{} + return udpPacketInfo{}, &tcpip.ErrNoRoute{} } nicID = e.BindNICID } - if to.Port == 0 { + if opts.To.Port == 0 { // Port 0 is an invalid port to send to. - return 0, &tcpip.ErrInvalidEndpointState{} + return udpPacketInfo{}, &tcpip.ErrInvalidEndpointState{} } - dst, netProto, err := e.checkV4MappedLocked(*to) + dst, netProto, err := e.checkV4MappedLocked(*opts.To) if err != nil { - return 0, err + return udpPacketInfo{}, err } r, _, err := e.connectRoute(nicID, dst, netProto) if err != nil { - return 0, err + return udpPacketInfo{}, err } defer r.Release() @@ -525,12 +507,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() { - return 0, &tcpip.ErrBroadcastDisabled{} + return udpPacketInfo{}, &tcpip.ErrBroadcastDisabled{} } v := make([]byte, p.Len()) if _, err := io.ReadFull(p, v); err != nil { - return 0, &tcpip.ErrBadBuffer{} + return udpPacketInfo{}, &tcpip.ErrBadBuffer{} } if len(v) > header.UDPMaximumPacketSize { // Payload can't possibly fit in a packet. @@ -548,24 +530,39 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp v, ) } - return 0, &tcpip.ErrMessageTooLong{} + return udpPacketInfo{}, &tcpip.ErrMessageTooLong{} } ttl := e.ttl useDefaultTTL := ttl == 0 - if header.IsV4MulticastAddress(route.RemoteAddress()) || header.IsV6MulticastAddress(route.RemoteAddress()) { ttl = e.multicastTTL // Multicast allows a 0 TTL. useDefaultTTL = false } - localPort := e.ID.LocalPort - sendTOS := e.sendTOS - owner := e.owner - noChecksum := e.SocketOptions().GetNoChecksum() - lockReleased = true - e.mu.RUnlock() + return udpPacketInfo{ + route: route, + data: buffer.View(v), + localPort: e.ID.LocalPort, + remotePort: dstPort, + ttl: ttl, + useDefaultTTL: useDefaultTTL, + tos: e.sendTOS, + owner: e.owner, + noChecksum: e.SocketOptions().GetNoChecksum(), + }, nil +} + +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { + if err := e.LastError(); err != nil { + return 0, err + } + + // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) + if opts.More { + return 0, &tcpip.ErrInvalidOptionValue{} + } // Do not hold lock when sending as loopback is synchronous and if the UDP // datagram ends up generating an ICMP response then it can result in a @@ -577,10 +574,15 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp // // See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read // locking is prohibited. - if err := sendUDP(route, buffer.View(v).ToVectorisedView(), localPort, dstPort, ttl, useDefaultTTL, sendTOS, owner, noChecksum); err != nil { + u, err := e.buildUDPPacketInfo(p, opts) + if err != nil { return 0, err } - return int64(len(v)), nil + n, err := u.send() + if err != nil { + return 0, err + } + return int64(n), nil } // OnReuseAddressSet implements tcpip.SocketOptionsHandler. @@ -817,14 +819,30 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { return nil } -// sendUDP sends a UDP segment via the provided network endpoint and under the -// provided identity. -func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) tcpip.Error { +// udpPacketInfo contains all information required to send a UDP packet. +// +// This should be used as a value-only type, which exists in order to simplify +// return value syntax. It should not be exported or extended. +type udpPacketInfo struct { + route *stack.Route + data buffer.View + localPort uint16 + remotePort uint16 + ttl uint8 + useDefaultTTL bool + tos uint8 + owner tcpip.PacketOwner + noChecksum bool +} + +// send sends the given packet. +func (u *udpPacketInfo) send() (int, tcpip.Error) { + vv := u.data.ToVectorisedView() pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()), - Data: data, + ReserveHeaderBytes: header.UDPMinimumSize + int(u.route.MaxHeaderLength()), + Data: vv, }) - pkt.Owner = owner + pkt.Owner = u.owner // Initialize the UDP header. udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) @@ -832,8 +850,8 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u length := uint16(pkt.Size()) udp.Encode(&header.UDPFields{ - SrcPort: localPort, - DstPort: remotePort, + SrcPort: u.localPort, + DstPort: u.remotePort, Length: length, }) @@ -841,30 +859,30 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u // On IPv4, UDP checksum is optional, and a zero value indicates the // transmitter skipped the checksum generation (RFC768). // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). - if r.RequiresTXTransportChecksum() && - (!noChecksum || r.NetProto() == header.IPv6ProtocolNumber) { - xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) - for _, v := range data.Views() { + if u.route.RequiresTXTransportChecksum() && + (!u.noChecksum || u.route.NetProto() == header.IPv6ProtocolNumber) { + xsum := u.route.PseudoHeaderChecksum(ProtocolNumber, length) + for _, v := range vv.Views() { xsum = header.Checksum(v, xsum) } udp.SetChecksum(^udp.CalculateChecksum(xsum)) } - if useDefaultTTL { - ttl = r.DefaultTTL() + if u.useDefaultTTL { + u.ttl = u.route.DefaultTTL() } - if err := r.WritePacket(stack.NetworkHeaderParams{ + if err := u.route.WritePacket(stack.NetworkHeaderParams{ Protocol: ProtocolNumber, - TTL: ttl, - TOS: tos, + TTL: u.ttl, + TOS: u.tos, }, pkt); err != nil { - r.Stats().UDP.PacketSendErrors.Increment() - return err + u.route.Stats().UDP.PacketSendErrors.Increment() + return 0, err } // Track count of packets sent. - r.Stats().UDP.PacketsSent.Increment() - return nil + u.route.Stats().UDP.PacketsSent.Increment() + return len(u.data), nil } // checkV4MappedLocked determines the effective network protocol and converts |