diff options
Diffstat (limited to 'pkg/tcpip/transport/tcp/endpoint.go')
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 161 |
1 files changed, 116 insertions, 45 deletions
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index cc49c8272..ac927569a 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -27,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tmutex" @@ -361,6 +362,12 @@ type endpoint struct { // without hearing a response, the connection is closed. keepalive keepalive + // pendingAccepted is a synchronization primitive used to track number + // of connections that are queued up to be delivered to the accepted + // channel. We use this to ensure that all goroutines blocked on writing + // to the acceptedChan below terminate before we close acceptedChan. + pendingAccepted sync.WaitGroup `state:"nosave"` + // acceptedChan is used by a listening endpoint protocol goroutine to // send newly accepted connections to the endpoint so that they can be // read by Accept() calls. @@ -374,7 +381,11 @@ type endpoint struct { // The goroutine drain completion notification channel. drainDone chan struct{} `state:"nosave"` - // The goroutine undrain notification channel. + // The goroutine undrain notification channel. This is currently used as + // a way to block the worker goroutines. Today nothing closes/writes + // this channel and this causes any goroutines waiting on this to just + // block. This is used during save/restore to prevent worker goroutines + // from mutating state as it's being saved. undrain chan struct{} `state:"nosave"` // probe if not nil is invoked on every received segment. It is passed @@ -574,6 +585,34 @@ func (e *endpoint) Close() { e.mu.Unlock() } +// closePendingAcceptableConnections closes all connections that have completed +// handshake but not yet been delivered to the application. +func (e *endpoint) closePendingAcceptableConnectionsLocked() { + done := make(chan struct{}) + // Spin a goroutine up as ranging on e.acceptedChan will just block when + // there are no more connections in the channel. Using a non-blocking + // select does not work as it can potentially select the default case + // even when there are pending writes but that are not yet written to + // the channel. + go func() { + defer close(done) + for n := range e.acceptedChan { + n.mu.Lock() + n.resetConnectionLocked(tcpip.ErrConnectionAborted) + n.mu.Unlock() + n.Close() + } + }() + // pendingAccepted(see endpoint.deliverAccepted) tracks the number of + // endpoints which have completed handshake but are not yet written to + // the e.acceptedChan. We wait here till the goroutine above can drain + // all such connections from e.acceptedChan. + e.pendingAccepted.Wait() + close(e.acceptedChan) + <-done + e.acceptedChan = nil +} + // cleanupLocked frees all resources associated with the endpoint. It is called // after Close() is called and the worker goroutine (if any) is done with its // work. @@ -581,14 +620,7 @@ func (e *endpoint) cleanupLocked() { // Close all endpoints that might have been accepted by TCP but not by // the client. if e.acceptedChan != nil { - close(e.acceptedChan) - for n := range e.acceptedChan { - n.mu.Lock() - n.resetConnectionLocked(tcpip.ErrConnectionAborted) - n.mu.Unlock() - n.Close() - } - e.acceptedChan = nil + e.closePendingAcceptableConnectionsLocked() } e.workerCleanup = false @@ -683,6 +715,11 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvListMu.Unlock() } +// IPTables implements tcpip.Endpoint.IPTables. +func (e *endpoint) IPTables() (iptables.IPTables, error) { + return e.stack.IPTables(), nil +} + // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.mu.RLock() @@ -740,60 +777,95 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { return v, nil } -// Write writes data to the endpoint's peer. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { - // Linux completely ignores any address passed to sendto(2) for TCP sockets - // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More - // and opts.EndOfRecord are also ignored. - - e.mu.RLock() - defer e.mu.RUnlock() - +// isEndpointWritableLocked checks if a given endpoint is writable +// and also returns the number of bytes that can be written at this +// moment. If the endpoint is not writable then it returns an error +// indicating the reason why it's not writable. +// Caller must hold e.mu and e.sndBufMu +func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { // The endpoint cannot be written to if it's not connected. if !e.state.connected() { switch e.state { case StateError: - return 0, nil, e.hardError + return 0, e.hardError default: - return 0, nil, tcpip.ErrClosedForSend + return 0, tcpip.ErrClosedForSend } } - // Nothing to do if the buffer is empty. - if p.Size() == 0 { - return 0, nil, nil - } - - e.sndBufMu.Lock() - // Check if the connection has already been closed for sends. if e.sndClosed { - e.sndBufMu.Unlock() - return 0, nil, tcpip.ErrClosedForSend + return 0, tcpip.ErrClosedForSend } - // Check against the limit. avail := e.sndBufSize - e.sndBufUsed if avail <= 0 { + return 0, tcpip.ErrWouldBlock + } + return avail, nil +} + +// Write writes data to the endpoint's peer. +func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + // Linux completely ignores any address passed to sendto(2) for TCP sockets + // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More + // and opts.EndOfRecord are also ignored. + + e.mu.RLock() + e.sndBufMu.Lock() + + avail, err := e.isEndpointWritableLocked() + if err != nil { e.sndBufMu.Unlock() - return 0, nil, tcpip.ErrWouldBlock + e.mu.RUnlock() + return 0, nil, err } + e.sndBufMu.Unlock() + e.mu.RUnlock() + + // Nothing to do if the buffer is empty. + if p.Size() == 0 { + return 0, nil, nil + } + + // Copy in memory without holding sndBufMu so that worker goroutine can + // make progress independent of this operation. v, perr := p.Get(avail) if perr != nil { - e.sndBufMu.Unlock() return 0, nil, perr } - l := len(v) - s := newSegmentFromView(&e.route, e.id, v) + e.mu.RLock() + e.sndBufMu.Lock() + + // Because we released the lock before copying, check state again + // to make sure the endpoint is still in a valid state for a + // write. + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.mu.RUnlock() + return 0, nil, err + } + + // Discard any excess data copied in due to avail being reduced due to a + // simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } // Add data to the send queue. + l := len(v) + s := newSegmentFromView(&e.route, e.id, v) e.sndBufUsed += l e.sndBufInQueue += seqnum.Size(l) e.sndQueue.PushBack(s) e.sndBufMu.Unlock() + // Release the endpoint lock to prevent deadlocks due to lock + // order inversion when acquiring workMu. + e.mu.RUnlock() if e.workMu.TryLock() { // Do the work inline. @@ -803,13 +875,13 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c // Let the protocol goroutine do the work. e.sndWaker.Assert() } - return uintptr(l), nil, nil + return int64(l), nil, nil } // Peek reads data without consuming it from the endpoint. // // This method does not block if there is no data pending. -func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { +func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() @@ -835,8 +907,7 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er // Make a copy of vec so we can modify the slide headers. vec = append([][]byte(nil), vec...) - var num uintptr - + var num int64 for s := e.rcvList.Front(); s != nil; s = s.Next() { views := s.data.Views() @@ -855,7 +926,7 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er n := copy(vec[0], v) v = v[n:] vec[0] = vec[0][n:] - num += uintptr(n) + num += int64(n) } } } @@ -1277,7 +1348,7 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol netProto = header.IPv4ProtocolNumber addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] - if addr.Addr == "\x00\x00\x00\x00" { + if addr.Addr == header.IPv4Any { addr.Addr = "" } } @@ -1291,13 +1362,13 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol return netProto, nil } +// Disconnect implements tcpip.Endpoint.Disconnect. +func (*endpoint) Disconnect() *tcpip.Error { + return tcpip.ErrNotSupported +} + // Connect connects the endpoint to its peer. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - if addr.Addr == "" && addr.Port == 0 { - // AF_UNSPEC isn't supported. - return tcpip.ErrAddressFamilyNotSupported - } - return e.connect(addr, true, true) } |