summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/tcp/endpoint.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport/tcp/endpoint.go')
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go161
1 files changed, 116 insertions, 45 deletions
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index cc49c8272..ac927569a 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -27,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tmutex"
@@ -361,6 +362,12 @@ type endpoint struct {
// without hearing a response, the connection is closed.
keepalive keepalive
+ // pendingAccepted is a synchronization primitive used to track number
+ // of connections that are queued up to be delivered to the accepted
+ // channel. We use this to ensure that all goroutines blocked on writing
+ // to the acceptedChan below terminate before we close acceptedChan.
+ pendingAccepted sync.WaitGroup `state:"nosave"`
+
// acceptedChan is used by a listening endpoint protocol goroutine to
// send newly accepted connections to the endpoint so that they can be
// read by Accept() calls.
@@ -374,7 +381,11 @@ type endpoint struct {
// The goroutine drain completion notification channel.
drainDone chan struct{} `state:"nosave"`
- // The goroutine undrain notification channel.
+ // The goroutine undrain notification channel. This is currently used as
+ // a way to block the worker goroutines. Today nothing closes/writes
+ // this channel and this causes any goroutines waiting on this to just
+ // block. This is used during save/restore to prevent worker goroutines
+ // from mutating state as it's being saved.
undrain chan struct{} `state:"nosave"`
// probe if not nil is invoked on every received segment. It is passed
@@ -574,6 +585,34 @@ func (e *endpoint) Close() {
e.mu.Unlock()
}
+// closePendingAcceptableConnections closes all connections that have completed
+// handshake but not yet been delivered to the application.
+func (e *endpoint) closePendingAcceptableConnectionsLocked() {
+ done := make(chan struct{})
+ // Spin a goroutine up as ranging on e.acceptedChan will just block when
+ // there are no more connections in the channel. Using a non-blocking
+ // select does not work as it can potentially select the default case
+ // even when there are pending writes but that are not yet written to
+ // the channel.
+ go func() {
+ defer close(done)
+ for n := range e.acceptedChan {
+ n.mu.Lock()
+ n.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ n.mu.Unlock()
+ n.Close()
+ }
+ }()
+ // pendingAccepted(see endpoint.deliverAccepted) tracks the number of
+ // endpoints which have completed handshake but are not yet written to
+ // the e.acceptedChan. We wait here till the goroutine above can drain
+ // all such connections from e.acceptedChan.
+ e.pendingAccepted.Wait()
+ close(e.acceptedChan)
+ <-done
+ e.acceptedChan = nil
+}
+
// cleanupLocked frees all resources associated with the endpoint. It is called
// after Close() is called and the worker goroutine (if any) is done with its
// work.
@@ -581,14 +620,7 @@ func (e *endpoint) cleanupLocked() {
// Close all endpoints that might have been accepted by TCP but not by
// the client.
if e.acceptedChan != nil {
- close(e.acceptedChan)
- for n := range e.acceptedChan {
- n.mu.Lock()
- n.resetConnectionLocked(tcpip.ErrConnectionAborted)
- n.mu.Unlock()
- n.Close()
- }
- e.acceptedChan = nil
+ e.closePendingAcceptableConnectionsLocked()
}
e.workerCleanup = false
@@ -683,6 +715,11 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
e.rcvListMu.Unlock()
}
+// IPTables implements tcpip.Endpoint.IPTables.
+func (e *endpoint) IPTables() (iptables.IPTables, error) {
+ return e.stack.IPTables(), nil
+}
+
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.mu.RLock()
@@ -740,60 +777,95 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
return v, nil
}
-// Write writes data to the endpoint's peer.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
- // Linux completely ignores any address passed to sendto(2) for TCP sockets
- // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
- // and opts.EndOfRecord are also ignored.
-
- e.mu.RLock()
- defer e.mu.RUnlock()
-
+// isEndpointWritableLocked checks if a given endpoint is writable
+// and also returns the number of bytes that can be written at this
+// moment. If the endpoint is not writable then it returns an error
+// indicating the reason why it's not writable.
+// Caller must hold e.mu and e.sndBufMu
+func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
// The endpoint cannot be written to if it's not connected.
if !e.state.connected() {
switch e.state {
case StateError:
- return 0, nil, e.hardError
+ return 0, e.hardError
default:
- return 0, nil, tcpip.ErrClosedForSend
+ return 0, tcpip.ErrClosedForSend
}
}
- // Nothing to do if the buffer is empty.
- if p.Size() == 0 {
- return 0, nil, nil
- }
-
- e.sndBufMu.Lock()
-
// Check if the connection has already been closed for sends.
if e.sndClosed {
- e.sndBufMu.Unlock()
- return 0, nil, tcpip.ErrClosedForSend
+ return 0, tcpip.ErrClosedForSend
}
- // Check against the limit.
avail := e.sndBufSize - e.sndBufUsed
if avail <= 0 {
+ return 0, tcpip.ErrWouldBlock
+ }
+ return avail, nil
+}
+
+// Write writes data to the endpoint's peer.
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ // Linux completely ignores any address passed to sendto(2) for TCP sockets
+ // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
+ // and opts.EndOfRecord are also ignored.
+
+ e.mu.RLock()
+ e.sndBufMu.Lock()
+
+ avail, err := e.isEndpointWritableLocked()
+ if err != nil {
e.sndBufMu.Unlock()
- return 0, nil, tcpip.ErrWouldBlock
+ e.mu.RUnlock()
+ return 0, nil, err
}
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+
+ // Nothing to do if the buffer is empty.
+ if p.Size() == 0 {
+ return 0, nil, nil
+ }
+
+ // Copy in memory without holding sndBufMu so that worker goroutine can
+ // make progress independent of this operation.
v, perr := p.Get(avail)
if perr != nil {
- e.sndBufMu.Unlock()
return 0, nil, perr
}
- l := len(v)
- s := newSegmentFromView(&e.route, e.id, v)
+ e.mu.RLock()
+ e.sndBufMu.Lock()
+
+ // Because we released the lock before copying, check state again
+ // to make sure the endpoint is still in a valid state for a
+ // write.
+ avail, err = e.isEndpointWritableLocked()
+ if err != nil {
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+ return 0, nil, err
+ }
+
+ // Discard any excess data copied in due to avail being reduced due to a
+ // simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
+ }
// Add data to the send queue.
+ l := len(v)
+ s := newSegmentFromView(&e.route, e.id, v)
e.sndBufUsed += l
e.sndBufInQueue += seqnum.Size(l)
e.sndQueue.PushBack(s)
e.sndBufMu.Unlock()
+ // Release the endpoint lock to prevent deadlocks due to lock
+ // order inversion when acquiring workMu.
+ e.mu.RUnlock()
if e.workMu.TryLock() {
// Do the work inline.
@@ -803,13 +875,13 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
// Let the protocol goroutine do the work.
e.sndWaker.Assert()
}
- return uintptr(l), nil, nil
+ return int64(l), nil, nil
}
// Peek reads data without consuming it from the endpoint.
//
// This method does not block if there is no data pending.
-func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
@@ -835,8 +907,7 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er
// Make a copy of vec so we can modify the slide headers.
vec = append([][]byte(nil), vec...)
- var num uintptr
-
+ var num int64
for s := e.rcvList.Front(); s != nil; s = s.Next() {
views := s.data.Views()
@@ -855,7 +926,7 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er
n := copy(vec[0], v)
v = v[n:]
vec[0] = vec[0][n:]
- num += uintptr(n)
+ num += int64(n)
}
}
}
@@ -1277,7 +1348,7 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol
netProto = header.IPv4ProtocolNumber
addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
- if addr.Addr == "\x00\x00\x00\x00" {
+ if addr.Addr == header.IPv4Any {
addr.Addr = ""
}
}
@@ -1291,13 +1362,13 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol
return netProto, nil
}
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*endpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
// Connect connects the endpoint to its peer.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
- if addr.Addr == "" && addr.Port == 0 {
- // AF_UNSPEC isn't supported.
- return tcpip.ErrAddressFamilyNotSupported
- }
-
return e.connect(addr, true, true)
}