diff options
author | Ian Lewis <ianmlewis@gmail.com> | 2020-08-17 21:44:31 -0400 |
---|---|---|
committer | Ian Lewis <ianmlewis@gmail.com> | 2020-08-17 21:44:31 -0400 |
commit | ac324f646ee3cb7955b0b45a7453aeb9671cbdf1 (patch) | |
tree | 0cbc5018e8807421d701d190dc20525726c7ca76 /pkg/tcpip/transport/tcp/endpoint.go | |
parent | 352ae1022ce19de28fc72e034cc469872ad79d06 (diff) | |
parent | 6d0c5803d557d453f15ac6f683697eeb46dab680 (diff) |
Merge branch 'master' into ip-forwarding
- Merges aleksej-paschenko's with HEAD
- Adds vfs2 support for ip_forward
Diffstat (limited to 'pkg/tcpip/transport/tcp/endpoint.go')
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 1828 |
1 files changed, 1205 insertions, 623 deletions
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 6ca0d73a9..1ccedebcc 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -18,21 +18,21 @@ import ( "encoding/binary" "fmt" "math" + "runtime" "strings" - "sync" "sync/atomic" "time" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tmutex" "gvisor.dev/gvisor/pkg/waiter" ) @@ -63,7 +63,8 @@ const ( StateClosing ) -// connected is the set of states where an endpoint is connected to a peer. +// connected returns true when s is one of the states representing an +// endpoint connected to a peer. func (s EndpointState) connected() bool { switch s { case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: @@ -73,6 +74,40 @@ func (s EndpointState) connected() bool { } } +// connecting returns true when s is one of the states representing a +// connection in progress, but not yet fully established. +func (s EndpointState) connecting() bool { + switch s { + case StateConnecting, StateSynSent, StateSynRecv: + return true + default: + return false + } +} + +// handshake returns true when s is one of the states representing an endpoint +// in the middle of a TCP handshake. +func (s EndpointState) handshake() bool { + switch s { + case StateSynSent, StateSynRecv: + return true + default: + return false + } +} + +// closed returns true when s is one of the states an endpoint transitions to +// when closed or when it encounters an error. This is distinct from a newly +// initialized endpoint that was never connected. +func (s EndpointState) closed() bool { + switch s { + case StateClose, StateError: + return true + default: + return false + } +} + // String implements fmt.Stringer.String. func (s EndpointState) String() string { switch s { @@ -119,8 +154,17 @@ const ( notifyMTUChanged notifyDrain notifyReset + notifyResetByPeer + // notifyAbort is a request for an expedited teardown. + notifyAbort notifyKeepaliveChanged notifyMSSChanged + // notifyTickleWorker is used to tickle the protocol main loop during a + // restore after we update the endpoint state to the correct one. This + // ensures the loop terminates if the final state of the endpoint is + // say TIME_WAIT. + notifyTickleWorker + notifyError ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -273,20 +317,59 @@ func (*EndpointInfo) IsEndpointInfo() {} // synchronized. The protocol implementation, however, runs in a single // goroutine. // +// Each endpoint has a few mutexes: +// +// e.mu -> Primary mutex for an endpoint must be held for all operations except +// in e.Readiness where acquiring it will result in a deadlock in epoll +// implementation. +// +// The following three mutexes can be acquired independent of e.mu but if +// acquired with e.mu then e.mu must be acquired first. +// +// e.acceptMu -> protects acceptedChan. +// e.rcvListMu -> Protects the rcvList and associated fields. +// e.sndBufMu -> Protects the sndQueue and associated fields. +// e.lastErrorMu -> Protects the lastError field. +// +// LOCKING/UNLOCKING of the endpoint. The locking of an endpoint is different +// based on the context in which the lock is acquired. In the syscall context +// e.LockUser/e.UnlockUser should be used and when doing background processing +// e.mu.Lock/e.mu.Unlock should be used. The distinction is described below +// in brief. +// +// The reason for this locking behaviour is to avoid wakeups to handle packets. +// In cases where the endpoint is already locked the background processor can +// queue the packet up and go its merry way and the lock owner will eventually +// process the backlog when releasing the lock. Similarly when acquiring the +// lock from say a syscall goroutine we can implement a bit of spinning if we +// know that the lock is not held by another syscall goroutine. Background +// processors should never hold the lock for long and we can avoid an expensive +// sleep/wakeup by spinning for a shortwhile. +// +// For more details please see the detailed documentation on +// e.LockUser/e.UnlockUser methods. +// // +stateify savable type endpoint struct { EndpointInfo - // workMu is used to arbitrate which goroutine may perform protocol - // work. Only the main protocol goroutine is expected to call Lock() on - // it, but other goroutines (e.g., send) may call TryLock() to eagerly - // perform work without having to wait for the main one to wake up. - workMu tmutex.Mutex `state:"nosave"` + // endpointEntry is used to queue endpoints for processing to the + // a given tcp processor goroutine. + // + // Precondition: epQueue.mu must be held to read/write this field.. + endpointEntry `state:"nosave"` + + // pendingProcessing is true if this endpoint is queued for processing + // to a TCP processor. + // + // Precondition: epQueue.mu must be held to read/write this field.. + pendingProcessing bool `state:"nosave"` // The following fields are initialized at creation time and do not // change throughout the lifetime of the endpoint. stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue `state:"wait"` + uniqueID uint64 // lastError represents the last error that the endpoint reported; // access to it is protected by the following mutex. @@ -307,21 +390,24 @@ type endpoint struct { rcvBufSize int rcvBufUsed int rcvAutoParams rcvBufAutoTuneParams - // zeroWindow indicates that the window was closed due to receive buffer - // space being filled up. This is set by the worker goroutine before - // moving a segment to the rcvList. This setting is cleared by the - // endpoint when a Read() call reads enough data for the new window to - // be non-zero. - zeroWindow bool - // The following fields are protected by the mutex. - mu sync.RWMutex `state:"nosave"` + // mu protects all endpoint fields unless documented otherwise. mu must + // be acquired before interacting with the endpoint fields. + mu sync.Mutex `state:"nosave"` + ownedByUser uint32 + // state must be read/set using the EndpointState()/setEndpointState() + // methods. state EndpointState `state:".(EndpointState)"` + // origEndpointState is only used during a restore phase to save the + // endpoint state at restore time as the socket is moved to it's correct + // state. + origEndpointState EndpointState `state:"nosave"` + isPortReserved bool `state:"manual"` - isRegistered bool - boundNICID tcpip.NICID `state:"manual"` + isRegistered bool `state:"manual"` + boundNICID tcpip.NICID route stack.Route `state:"manual"` ttl uint8 v6only bool @@ -330,19 +416,28 @@ type endpoint struct { // disabling SO_BROADCAST, albeit as a NOOP. broadcast bool + // portFlags stores the current values of port related flags. + portFlags ports.Flags + + // Values used to reserve a port or register a transport endpoint + // (which ever happens first). + boundBindToDevice tcpip.NICID + boundPortFlags ports.Flags + boundDest tcpip.FullAddress + // effectiveNetProtos contains the network protocols actually in use. In // most cases it will only contain "netProto", but in cases like IPv6 // endpoints with v6only set to false, this could include multiple // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g., // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped // address). - effectiveNetProtos []tcpip.NetworkProtocolNumber `state:"manual"` + effectiveNetProtos []tcpip.NetworkProtocolNumber // workerRunning specifies if a worker goroutine is running. workerRunning bool // workerCleanup specifies if the worker goroutine must perform cleanup - // before exitting. This can only be set to true when workerRunning is + // before exiting. This can only be set to true when workerRunning is // also true, and they're both protected by the mutex. workerCleanup bool @@ -356,6 +451,9 @@ type endpoint struct { // updated if required when a new segment is received by this endpoint. recentTS uint32 + // recentTSTime is the unix time when we updated recentTS last. + recentTSTime time.Time `state:".(unixTime)"` + // tsOffset is a randomized offset added to the value of the // TSVal field in the timestamp option. tsOffset uint32 @@ -370,9 +468,6 @@ type endpoint struct { // sack holds TCP SACK related information for this endpoint. sack SACKInfo - // reusePort is set to true if SO_REUSEPORT is enabled. - reusePort bool - // bindToDevice is set to the NIC on which to bind or disabled if 0. bindToDevice tcpip.NICID @@ -392,7 +487,6 @@ type endpoint struct { // The options below aren't implemented, but we remember the user // settings because applications expect to be able to set/query these // options. - reuseAddr bool // slowAck holds the negated state of quick ack. It is stubbed out and // does nothing. @@ -411,7 +505,18 @@ type endpoint struct { // userMSS if non-zero is the MSS value explicitly set by the user // for this endpoint using the TCP_MAXSEG setsockopt. - userMSS int + userMSS uint16 + + // maxSynRetries is the maximum number of SYN retransmits that TCP should + // send before aborting the attempt to connect. It cannot exceed 255. + // + // NOTE: This is currently a no-op and does not change the SYN + // retransmissions. + maxSynRetries uint8 + + // windowClamp is used to bound the size of the advertised window to + // this value. + windowClamp uint32 // The following fields are used to manage the send buffer. When // segments are ready to be sent, they are added to sndQueue and the @@ -458,12 +563,42 @@ type endpoint struct { // without hearing a response, the connection is closed. keepalive keepalive + // userTimeout if non-zero specifies a user specified timeout for + // a connection w/ pending data to send. A connection that has pending + // unacked data will be forcibily aborted if the timeout is reached + // without any data being acked. + userTimeout time.Duration + + // deferAccept if non-zero specifies a user specified time during + // which the final ACK of a handshake will be dropped provided the + // ACK is a bare ACK and carries no data. If the timeout is crossed then + // the bare ACK is accepted and the connection is delivered to the + // listener. + deferAccept time.Duration + // pendingAccepted is a synchronization primitive used to track number // of connections that are queued up to be delivered to the accepted // channel. We use this to ensure that all goroutines blocked on writing // to the acceptedChan below terminate before we close acceptedChan. pendingAccepted sync.WaitGroup `state:"nosave"` + // acceptMu protects acceptedChan. + acceptMu sync.Mutex `state:"nosave"` + + // acceptCond is a condition variable that can be used to block on when + // acceptedChan is full and an endpoint is ready to be delivered. + // + // This condition variable is required because just blocking on sending + // to acceptedChan does not work in cases where endpoint.Listen is + // called twice with different backlog values. In such cases the channel + // is closed and a new one created. Any pending goroutines blocking on + // the write to the channel will panic. + // + // We use this condition variable to block/unblock goroutines which + // tried to deliver an endpoint but couldn't because accept backlog was + // full ( See: endpoint.deliverAccepted ). + acceptCond *sync.Cond `state:"nosave"` + // acceptedChan is used by a listening endpoint protocol goroutine to // send newly accepted connections to the endpoint so that they can be // read by Accept() calls. @@ -502,16 +637,175 @@ type endpoint struct { // TODO(b/142022063): Add ability to save and restore per endpoint stats. stats Stats `state:"nosave"` + + // tcpLingerTimeout is the maximum amount of a time a socket + // a socket stays in TIME_WAIT state before being marked + // closed. + tcpLingerTimeout time.Duration + + // closed indicates that the user has called closed on the + // endpoint and at this point the endpoint is only around + // to complete the TCP shutdown. + closed bool + + // txHash is the transport layer hash to be set on outbound packets + // emitted by this endpoint. + txHash uint32 + + // owner is used to get uid and gid of the packet. + owner tcpip.PacketOwner +} + +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + +// calculateAdvertisedMSS calculates the MSS to advertise. +// +// If userMSS is non-zero and is not greater than the maximum possible MSS for +// r, it will be used; otherwise, the maximum possible MSS will be used. +func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 { + // The maximum possible MSS is dependent on the route. + // TODO(b/143359391): Respect TCP Min and Max size. + maxMSS := uint16(r.MTU() - header.TCPMinimumSize) + + if userMSS != 0 && userMSS < maxMSS { + return userMSS + } + + return maxMSS +} + +// LockUser tries to lock e.mu and if it fails it will check if the lock is held +// by another syscall goroutine. If yes, then it will goto sleep waiting for the +// lock to be released, if not then it will spin till it acquires the lock or +// another syscall goroutine acquires it in which case it will goto sleep as +// described above. +// +// The assumption behind spinning here being that background packet processing +// should not be holding the lock for long and spinning reduces latency as we +// avoid an expensive sleep/wakeup of of the syscall goroutine). +func (e *endpoint) LockUser() { + for { + // Try first if the sock is locked then check if it's owned + // by another user goroutine if not then we spin, otherwise + // we just goto sleep on the Lock() and wait. + if !e.mu.TryLock() { + // If socket is owned by the user then just goto sleep + // as the lock could be held for a reasonably long time. + if atomic.LoadUint32(&e.ownedByUser) == 1 { + e.mu.Lock() + atomic.StoreUint32(&e.ownedByUser, 1) + return + } + // Spin but yield the processor since the lower half + // should yield the lock soon. + runtime.Gosched() + continue + } + atomic.StoreUint32(&e.ownedByUser, 1) + return + } +} + +// UnlockUser will check if there are any segments already queued for processing +// and process any such segments before unlocking e.mu. This is required because +// we when packets arrive and endpoint lock is already held then such packets +// are queued up to be processed. If the lock is held by the endpoint goroutine +// then it will process these packets but if the lock is instead held by the +// syscall goroutine then we can have the syscall goroutine process the backlog +// before unlocking. +// +// This avoids an unnecessary wakeup of the endpoint protocol goroutine for the +// endpoint. It's also required eventually when we get rid of the endpoint +// protocol goroutine altogether. +// +// Precondition: e.LockUser() must have been called before calling e.UnlockUser() +func (e *endpoint) UnlockUser() { + // Lock segment queue before checking so that we avoid a race where + // segments can be queued between the time we check if queue is empty + // and actually unlock the endpoint mutex. + for { + e.segmentQueue.mu.Lock() + if e.segmentQueue.emptyLocked() { + if atomic.SwapUint32(&e.ownedByUser, 0) != 1 { + panic("e.UnlockUser() called without calling e.LockUser()") + } + e.mu.Unlock() + e.segmentQueue.mu.Unlock() + return + } + e.segmentQueue.mu.Unlock() + + switch e.EndpointState() { + case StateEstablished: + if err := e.handleSegments(true /* fastPath */); err != nil { + e.notifyProtocolGoroutine(notifyTickleWorker) + } + default: + // Since we are waking the endpoint goroutine here just unlock + // and let it process the queued segments. + e.newSegmentWaker.Assert() + if atomic.SwapUint32(&e.ownedByUser, 0) != 1 { + panic("e.UnlockUser() called without calling e.LockUser()") + } + e.mu.Unlock() + return + } + } } // StopWork halts packet processing. Only to be used in tests. func (e *endpoint) StopWork() { - e.workMu.Lock() + e.mu.Lock() } // ResumeWork resumes packet processing. Only to be used in tests. func (e *endpoint) ResumeWork() { - e.workMu.Unlock() + e.mu.Unlock() +} + +// setEndpointState updates the state of the endpoint to state atomically. This +// method is unexported as the only place we should update the state is in this +// package but we allow the state to be read freely without holding e.mu. +// +// Precondition: e.mu must be held to call this method. +func (e *endpoint) setEndpointState(state EndpointState) { + oldstate := EndpointState(atomic.LoadUint32((*uint32)(&e.state))) + switch state { + case StateEstablished: + e.stack.Stats().TCP.CurrentEstablished.Increment() + e.stack.Stats().TCP.CurrentConnected.Increment() + case StateError: + fallthrough + case StateClose: + if oldstate == StateCloseWait || oldstate == StateEstablished { + e.stack.Stats().TCP.EstablishedResets.Increment() + } + fallthrough + default: + if oldstate == StateEstablished { + e.stack.Stats().TCP.CurrentEstablished.Decrement() + } + } + atomic.StoreUint32((*uint32)(&e.state), uint32(state)) +} + +// EndpointState returns the current state of the endpoint. +func (e *endpoint) EndpointState() EndpointState { + return EndpointState(atomic.LoadUint32((*uint32)(&e.state))) +} + +// setRecentTimestamp sets the recentTS field to the provided value. +func (e *endpoint) setRecentTimestamp(recentTS uint32) { + e.recentTS = recentTS + e.recentTSTime = time.Now() +} + +// recentTimestamp returns the value of the recentTS field. +func (e *endpoint) recentTimestamp() uint32 { + return e.recentTS } // keepalive is a synchronization wrapper used to appease stateify. See the @@ -543,13 +837,16 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue rcvBufSize: DefaultReceiveBufferSize, sndBufSize: DefaultSendBufferSize, sndMTU: int(math.MaxInt32), - reuseAddr: true, keepalive: keepalive{ // Linux defaults. idle: 2 * time.Hour, interval: 75 * time.Second, count: 9, }, + uniqueID: s.UniqueID(), + txHash: s.Rand().Uint32(), + windowClamp: DefaultReceiveBufferSize, + maxSynRetries: DefaultSynRetries, } var ss SendBufferSizeOption @@ -572,14 +869,28 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.rcvAutoParams.disabled = !bool(mrb) } + var de DelayEnabled + if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de { + e.SetSockOptBool(tcpip.DelayOption, true) + } + + var tcpLT tcpip.TCPLingerTimeoutOption + if err := s.TransportProtocolOption(ProtocolNumber, &tcpLT); err == nil { + e.tcpLingerTimeout = time.Duration(tcpLT) + } + + var synRetries tcpip.TCPSynRetriesOption + if err := s.TransportProtocolOption(ProtocolNumber, &synRetries); err == nil { + e.maxSynRetries = uint8(synRetries) + } + if p := s.GetTCPProbe(); p != nil { e.probe = p } e.segmentQueue.setLimit(MaxUnprocessedSegments) - e.workMu.Init() - e.workMu.Lock() e.tsOffset = timeStampOffset() + e.acceptCond = sync.NewCond(&e.acceptMu) return e } @@ -589,26 +900,25 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { result := waiter.EventMask(0) - e.mu.RLock() - defer e.mu.RUnlock() - - switch e.state { + switch e.EndpointState() { case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv: // Ready for nothing. - case StateClose, StateError: + case StateClose, StateError, StateTimeWait: // Ready for anything. result = mask case StateListen: // Check if there's anything in the accepted channel. if (mask & waiter.EventIn) != 0 { + e.acceptMu.Lock() if len(e.acceptedChan) > 0 { result |= waiter.EventIn } + e.acceptMu.Unlock() } } - if e.state.connected() { + if e.EndpointState().connected() { // Determine if the endpoint is writable if requested. if (mask & waiter.EventOut) != 0 { e.sndBufMu.Lock() @@ -655,69 +965,117 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) { } } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + // The abort notification is not processed synchronously, so no + // synchronization is needed. + // + // If the endpoint becomes connected after this check, we still close + // the endpoint. This worst case results in a slower abort. + // + // If the endpoint disconnected after the check, nothing needs to be + // done, so sending a notification which will potentially be ignored is + // fine. + // + // If the endpoint connecting finishes after the check, the endpoint + // is either in a connected state (where we would notifyAbort anyway), + // SYN-RECV (where we would also notifyAbort anyway), or in an error + // state where nothing is required and the notification can be safely + // ignored. + // + // Endpoints where a Close during connecting or SYN-RECV state would be + // problematic are set to state connecting before being registered (and + // thus possible to be Aborted). They are never available in initial + // state. + // + // Endpoints transitioning from initial to connecting state may be + // safely either closed or sent notifyAbort. + if s := e.EndpointState(); s == StateConnecting || s == StateSynRecv || s.connected() { + e.notifyProtocolGoroutine(notifyAbort) + return + } + e.Close() +} + // Close puts the endpoint in a closed state and frees all resources associated // with it. It must be called only once and with no other concurrent calls to // the endpoint. func (e *endpoint) Close() { + e.LockUser() + defer e.UnlockUser() + if e.closed { + return + } + // Issue a shutdown so that the peer knows we won't send any more data // if we're connected, or stop accepting if we're listening. - e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) - - e.mu.Lock() + e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead) + e.closeNoShutdownLocked() +} +// closeNoShutdown closes the endpoint without doing a full shutdown. This is +// used when a connection needs to be aborted with a RST and we want to skip +// a full 4 way TCP shutdown. +func (e *endpoint) closeNoShutdownLocked() { // For listening sockets, we always release ports inline so that they // are immediately available for reuse after Close() is called. If also // registered, we unregister as well otherwise the next user would fail // in Listen() when trying to register. - if e.state == StateListen && e.isPortReserved { + if e.EndpointState() == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false } - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest) e.isPortReserved = false + e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} + e.boundDest = tcpip.FullAddress{} + } + + // Mark endpoint as closed. + e.closed = true + + switch e.EndpointState() { + case StateClose, StateError: + return } // Either perform the local cleanup or kick the worker to make sure it // knows it needs to cleanup. - tcpip.AddDanglingEndpoint(e) - if !e.workerRunning { - e.cleanupLocked() - } else { + if e.workerRunning { e.workerCleanup = true + tcpip.AddDanglingEndpoint(e) + // Worker will remove the dangling endpoint when the endpoint + // goroutine terminates. e.notifyProtocolGoroutine(notifyClose) + } else { + e.transitionToStateCloseLocked() } - - e.mu.Unlock() } // closePendingAcceptableConnections closes all connections that have completed // handshake but not yet been delivered to the application. func (e *endpoint) closePendingAcceptableConnectionsLocked() { - done := make(chan struct{}) - // Spin a goroutine up as ranging on e.acceptedChan will just block when - // there are no more connections in the channel. Using a non-blocking - // select does not work as it can potentially select the default case - // even when there are pending writes but that are not yet written to - // the channel. - go func() { - defer close(done) - for n := range e.acceptedChan { - n.mu.Lock() - n.resetConnectionLocked(tcpip.ErrConnectionAborted) - n.mu.Unlock() - n.Close() - } - }() - // pendingAccepted(see endpoint.deliverAccepted) tracks the number of - // endpoints which have completed handshake but are not yet written to - // the e.acceptedChan. We wait here till the goroutine above can drain - // all such connections from e.acceptedChan. - e.pendingAccepted.Wait() + e.acceptMu.Lock() + if e.acceptedChan == nil { + e.acceptMu.Unlock() + return + } close(e.acceptedChan) - <-done + ch := e.acceptedChan e.acceptedChan = nil + e.acceptCond.Broadcast() + e.acceptMu.Unlock() + + // Reset all connections that are waiting to be accepted. + for n := range ch { + n.notifyProtocolGoroutine(notifyReset) + } + // Wait for reset of all endpoints that are still waiting to be delivered to + // the now closed acceptedChan. + e.pendingAccepted.Wait() } // cleanupLocked frees all resources associated with the endpoint. It is called @@ -726,22 +1084,25 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { func (e *endpoint) cleanupLocked() { // Close all endpoints that might have been accepted by TCP but not by // the client. - if e.acceptedChan != nil { - e.closePendingAcceptableConnectionsLocked() - } + e.closePendingAcceptableConnectionsLocked() + e.workerCleanup = false if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false } if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest) e.isPortReserved = false } + e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} + e.boundDest = tcpip.FullAddress{} e.route.Release() + e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } @@ -752,16 +1113,34 @@ func (e *endpoint) initialReceiveWindow() int { if rcvWnd > math.MaxUint16 { rcvWnd = math.MaxUint16 } - routeWnd := InitialCwnd * int(mssForRoute(&e.route)) * 2 + + // Use the user supplied MSS, if available. + routeWnd := InitialCwnd * int(calculateAdvertisedMSS(e.userMSS, e.route)) * 2 if rcvWnd > routeWnd { rcvWnd = routeWnd } + rcvWndScale := e.rcvWndScaleForHandshake() + + // Round-down the rcvWnd to a multiple of wndScale. This ensures that the + // window offered in SYN won't be reduced due to the loss of precision if + // window scaling is enabled after the handshake. + rcvWnd = (rcvWnd >> uint8(rcvWndScale)) << uint8(rcvWndScale) + + // Ensure we can always accept at least 1 byte if the scale specified + // was too high for the provided rcvWnd. + if rcvWnd == 0 { + rcvWnd = 1 + } + return rcvWnd } // ModerateRecvBuf adjusts the receive buffer and the advertised window -// based on the number of bytes copied to user space. +// based on the number of bytes copied to userspace. func (e *endpoint) ModerateRecvBuf(copied int) { + e.LockUser() + defer e.UnlockUser() + e.rcvListMu.Lock() if e.rcvAutoParams.disabled { e.rcvListMu.Unlock() @@ -807,8 +1186,14 @@ func (e *endpoint) ModerateRecvBuf(copied int) { // reject valid data that might already be in flight as the // acceptable window will shrink. if rcvWnd > e.rcvBufSize { + availBefore := e.receiveBufferAvailableLocked() e.rcvBufSize = rcvWnd - e.notifyProtocolGoroutine(notifyReceiveWindowChanged) + availAfter := e.receiveBufferAvailableLocked() + mask := uint32(notifyReceiveWindowChanged) + if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { + mask |= notifyNonZeroReceiveWindow + } + e.notifyProtocolGoroutine(mask) } // We only update prevCopied when we grow the buffer because in cases @@ -822,36 +1207,50 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvListMu.Unlock() } -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (iptables.IPTables, error) { - return e.stack.IPTables(), nil +func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { + e.owner = owner +} + +func (e *endpoint) takeLastError() *tcpip.Error { + e.lastErrorMu.Lock() + defer e.lastErrorMu.Unlock() + err := e.lastError + e.lastError = nil + return err } // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - e.mu.RLock() + e.LockUser() + defer e.UnlockUser() + + // When in SYN-SENT state, let the caller block on the receive. + // An application can initiate a non-blocking connect and then block + // on a receive. It can expect to read any data after the handshake + // is complete. RFC793, section 3.9, p58. + if e.EndpointState() == StateSynSent { + return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrWouldBlock + } + // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. Also note that a RST being received // would cause the state to become StateError so we should allow the // reads to proceed before returning a ECONNRESET. e.rcvListMu.Lock() bufUsed := e.rcvBufUsed - if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 { + if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() he := e.HardError - e.mu.RUnlock() if s == StateError { return buffer.View{}, tcpip.ControlMessages{}, he } - e.stats.ReadErrors.InvalidEndpointState.Increment() - return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState + e.stats.ReadErrors.NotConnected.Increment() + return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected } v, err := e.readLocked() e.rcvListMu.Unlock() - e.mu.RUnlock() - if err == tcpip.ErrClosedForReceive { e.stats.ReadErrors.ReadClosed.Increment() } @@ -860,7 +1259,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { if e.rcvBufUsed == 0 { - if e.rcvClosed || !e.state.connected() { + if e.rcvClosed || !e.EndpointState().connected() { return buffer.View{}, tcpip.ErrClosedForReceive } return buffer.View{}, tcpip.ErrWouldBlock @@ -877,11 +1276,12 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { } e.rcvBufUsed -= len(v) - // If the window was zero before this read and if the read freed up - // enough buffer space for the scaled window to be non-zero then notify - // the protocol goroutine to send a window update. - if e.zeroWindow && !e.zeroReceiveWindow(e.rcv.rcvWndScale) { - e.zeroWindow = false + + // If the window was small before this read and if the read freed up + // enough buffer space, to either fit an aMSS or half a receive buffer + // (whichever smaller), then notify the protocol goroutine to send a + // window update. + if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above { e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } @@ -895,8 +1295,8 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { // Caller must hold e.mu and e.sndBufMu func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { // The endpoint cannot be written to if it's not connected. - if !e.state.connected() { - switch e.state { + if !e.EndpointState().connected() { + switch e.EndpointState() { case StateError: return 0, e.HardError default: @@ -922,13 +1322,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More // and opts.EndOfRecord are also ignored. - e.mu.RLock() + e.LockUser() e.sndBufMu.Lock() avail, err := e.isEndpointWritableLocked() if err != nil { e.sndBufMu.Unlock() - e.mu.RUnlock() + e.UnlockUser() e.stats.WriteErrors.WriteClosed.Increment() return 0, nil, err } @@ -940,73 +1340,72 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // are copying data in. if !opts.Atomic { e.sndBufMu.Unlock() - e.mu.RUnlock() + e.UnlockUser() } // Fetch data. v, perr := p.Payload(avail) if perr != nil || len(v) == 0 { - if opts.Atomic { // See above. + // Note that perr may be nil if len(v) == 0. + if opts.Atomic { e.sndBufMu.Unlock() - e.mu.RUnlock() + e.UnlockUser() } - // Note that perr may be nil if len(v) == 0. return 0, nil, perr } - if !opts.Atomic { // See above. - e.mu.RLock() - e.sndBufMu.Lock() + queueAndSend := func() (int64, <-chan struct{}, *tcpip.Error) { + // Add data to the send queue. + s := newSegmentFromView(&e.route, e.ID, v) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) + e.sndQueue.PushBack(s) + e.sndBufMu.Unlock() - // Because we released the lock before copying, check state again - // to make sure the endpoint is still in a valid state for a write. - avail, err = e.isEndpointWritableLocked() - if err != nil { - e.sndBufMu.Unlock() - e.mu.RUnlock() - e.stats.WriteErrors.WriteClosed.Increment() - return 0, nil, err - } + // Do the work inline. + e.handleWrite() + e.UnlockUser() + return int64(len(v)), nil, nil + } - // Discard any excess data copied in due to avail being reduced due - // to a simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] - } + if opts.Atomic { + // Locks released in queueAndSend() + return queueAndSend() } - // Add data to the send queue. - s := newSegmentFromView(&e.route, e.ID, v) - e.sndBufUsed += len(v) - e.sndBufInQueue += seqnum.Size(len(v)) - e.sndQueue.PushBack(s) - e.sndBufMu.Unlock() - // Release the endpoint lock to prevent deadlocks due to lock - // order inversion when acquiring workMu. - e.mu.RUnlock() + // Since we released locks in between it's possible that the + // endpoint transitioned to a CLOSED/ERROR states so make + // sure endpoint is still writable before trying to write. + e.LockUser() + e.sndBufMu.Lock() + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.UnlockUser() + e.stats.WriteErrors.WriteClosed.Increment() + return 0, nil, err + } - if e.workMu.TryLock() { - // Do the work inline. - e.handleWrite() - e.workMu.Unlock() - } else { - // Let the protocol goroutine do the work. - e.sndWaker.Assert() + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] } - return int64(len(v)), nil, nil + // Locks released in queueAndSend() + return queueAndSend() } // Peek reads data without consuming it from the endpoint. // // This method does not block if there is no data pending. func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() + e.LockUser() + defer e.UnlockUser() // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. - if s := e.state; !s.connected() && s != StateClose { + if s := e.EndpointState(); !s.connected() && s != StateClose { if s == StateError { return 0, tcpip.ControlMessages{}, e.HardError } @@ -1018,7 +1417,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro defer e.rcvListMu.Unlock() if e.rcvBufUsed == 0 { - if e.rcvClosed || !e.state.connected() { + if e.rcvClosed || !e.EndpointState().connected() { e.stats.ReadErrors.ReadClosed.Increment() return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive } @@ -1055,37 +1454,174 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro return num, tcpip.ControlMessages{}, nil } -// zeroReceiveWindow checks if the receive window to be announced now would be -// zero, based on the amount of available buffer and the receive window scaling. +// windowCrossedACKThresholdLocked checks if the receive window to be announced +// now would be under aMSS or under half receive buffer, whichever smaller. This +// is useful as a receive side silly window syndrome prevention mechanism. If +// window grows to reasonable value, we should send ACK to the sender to inform +// the rx space is now large. We also want ensure a series of small read()'s +// won't trigger a flood of spurious tiny ACK's. // -// It must be called with rcvListMu held. -func (e *endpoint) zeroReceiveWindow(scale uint8) bool { - if e.rcvBufUsed >= e.rcvBufSize { - return true +// For large receive buffers, the threshold is aMSS - once reader reads more +// than aMSS we'll send ACK. For tiny receive buffers, the threshold is half of +// receive buffer size. This is chosen arbitrairly. +// crossed will be true if the window size crossed the ACK threshold. +// above will be true if the new window is >= ACK threshold and false +// otherwise. +// +// Precondition: e.mu and e.rcvListMu must be held. +func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) { + newAvail := e.receiveBufferAvailableLocked() + oldAvail := newAvail - deltaBefore + if oldAvail < 0 { + oldAvail = 0 + } + + threshold := int(e.amss) + if threshold > e.rcvBufSize/2 { + threshold = e.rcvBufSize / 2 + } + + switch { + case oldAvail < threshold && newAvail >= threshold: + return true, true + case oldAvail >= threshold && newAvail < threshold: + return true, false + } + return false, false +} + +// SetSockOptBool sets a socket option. +func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { + + case tcpip.BroadcastOption: + e.LockUser() + e.broadcast = v + e.UnlockUser() + + case tcpip.CorkOption: + e.LockUser() + if !v { + atomic.StoreUint32(&e.cork, 0) + + // Handle the corked data. + e.sndWaker.Assert() + } else { + atomic.StoreUint32(&e.cork, 1) + } + e.UnlockUser() + + case tcpip.DelayOption: + if v { + atomic.StoreUint32(&e.delay, 1) + } else { + atomic.StoreUint32(&e.delay, 0) + + // Handle delayed data. + e.sndWaker.Assert() + } + + case tcpip.KeepaliveEnabledOption: + e.keepalive.Lock() + e.keepalive.enabled = v + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + + case tcpip.QuickAckOption: + o := uint32(1) + if v { + o = 0 + } + atomic.StoreUint32(&e.slowAck, o) + + case tcpip.ReuseAddressOption: + e.LockUser() + e.portFlags.TupleOnly = v + e.UnlockUser() + + case tcpip.ReusePortOption: + e.LockUser() + e.portFlags.LoadBalanced = v + e.UnlockUser() + + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return tcpip.ErrInvalidEndpointState + } + + // We only allow this to be set when we're in the initial state. + if e.EndpointState() != StateInitial { + return tcpip.ErrInvalidEndpointState + } + + e.LockUser() + e.v6only = v + e.UnlockUser() } - return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0 + return nil } // SetSockOptInt sets a socket option. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 + const inetECNMask = 3 + switch opt { + case tcpip.KeepaliveCountOption: + e.keepalive.Lock() + e.keepalive.count = v + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + + case tcpip.IPv4TOSOption: + e.LockUser() + // TODO(gvisor.dev/issue/995): ECN is not currently supported, + // ignore the bits for now. + e.sendTOS = uint8(v) & ^uint8(inetECNMask) + e.UnlockUser() + + case tcpip.IPv6TrafficClassOption: + e.LockUser() + // TODO(gvisor.dev/issue/995): ECN is not currently supported, + // ignore the bits for now. + e.sendTOS = uint8(v) & ^uint8(inetECNMask) + e.UnlockUser() + + case tcpip.MaxSegOption: + userMSS := v + if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { + return tcpip.ErrInvalidOptionValue + } + e.LockUser() + e.userMSS = uint16(userMSS) + e.UnlockUser() + e.notifyProtocolGoroutine(notifyMSSChanged) + + case tcpip.MTUDiscoverOption: + // Return not supported if attempting to set this option to + // anything other than path MTU discovery disabled. + if v != tcpip.PMTUDiscoveryDont { + return tcpip.ErrNotSupported + } + case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. var rs ReceiveBufferSizeOption - size := int(v) if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { - if size < rs.Min { - size = rs.Min + if v < rs.Min { + v = rs.Min } - if size > rs.Max { - size = rs.Max + if v > rs.Max { + v = rs.Max } } mask := uint32(notifyReceiveWindowChanged) + e.LockUser() e.rcvListMu.Lock() // Make sure the receive buffer size allows us to send a @@ -1094,179 +1630,119 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { if e.rcv != nil { scale = e.rcv.rcvWndScale } - if size>>scale == 0 { - size = 1 << scale + if v>>scale == 0 { + v = 1 << scale } // Make sure 2*size doesn't overflow. - if size > math.MaxInt32/2 { - size = math.MaxInt32 / 2 + if v > math.MaxInt32/2 { + v = math.MaxInt32 / 2 } - e.rcvBufSize = size + availBefore := e.receiveBufferAvailableLocked() + e.rcvBufSize = v + availAfter := e.receiveBufferAvailableLocked() + e.rcvAutoParams.disabled = true - if e.zeroWindow && !e.zeroReceiveWindow(scale) { - e.zeroWindow = false + + // Immediately send an ACK to uncork the sender silly window + // syndrome prevetion, when our available space grows above aMSS + // or half receive buffer, whichever smaller. + if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { mask |= notifyNonZeroReceiveWindow } - e.rcvListMu.Unlock() + e.rcvListMu.Unlock() + e.UnlockUser() e.notifyProtocolGoroutine(mask) - return nil case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max // allowed. - size := int(v) var ss SendBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { - if size < ss.Min { - size = ss.Min + if v < ss.Min { + v = ss.Min } - if size > ss.Max { - size = ss.Max + if v > ss.Max { + v = ss.Max } } e.sndBufMu.Lock() - e.sndBufSize = size + e.sndBufSize = v e.sndBufMu.Unlock() - return nil - case tcpip.DelayOption: - if v == 0 { - atomic.StoreUint32(&e.delay, 0) + case tcpip.TTLOption: + e.LockUser() + e.ttl = uint8(v) + e.UnlockUser() - // Handle delayed data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.delay, 1) + case tcpip.TCPSynCountOption: + if v < 1 || v > 255 { + return tcpip.ErrInvalidOptionValue } - return nil + e.LockUser() + e.maxSynRetries = uint8(v) + e.UnlockUser() - default: - return nil + case tcpip.TCPWindowClampOption: + if v == 0 { + e.LockUser() + switch e.EndpointState() { + case StateClose, StateInitial: + e.windowClamp = 0 + e.UnlockUser() + return nil + default: + e.UnlockUser() + return tcpip.ErrInvalidOptionValue + } + } + var rs ReceiveBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + if v < rs.Min/2 { + v = rs.Min / 2 + } + } + e.LockUser() + e.windowClamp = uint32(v) + e.UnlockUser() } + return nil } // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 - const inetECNMask = 3 switch v := opt.(type) { - case tcpip.CorkOption: - if v == 0 { - atomic.StoreUint32(&e.cork, 0) - - // Handle the corked data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.cork, 1) - } - return nil - - case tcpip.ReuseAddressOption: - e.mu.Lock() - e.reuseAddr = v != 0 - e.mu.Unlock() - return nil - - case tcpip.ReusePortOption: - e.mu.Lock() - e.reusePort = v != 0 - e.mu.Unlock() - return nil - case tcpip.BindToDeviceOption: - e.mu.Lock() - defer e.mu.Unlock() - if v == "" { - e.bindToDevice = 0 - return nil + id := tcpip.NICID(v) + if id != 0 && !e.stack.HasNIC(id) { + return tcpip.ErrUnknownDevice } - for nicid, nic := range e.stack.NICInfo() { - if nic.Name == string(v) { - e.bindToDevice = nicid - return nil - } - } - return tcpip.ErrUnknownDevice - - case tcpip.QuickAckOption: - if v == 0 { - atomic.StoreUint32(&e.slowAck, 1) - } else { - atomic.StoreUint32(&e.slowAck, 0) - } - return nil - - case tcpip.MaxSegOption: - userMSS := v - if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { - return tcpip.ErrInvalidOptionValue - } - e.mu.Lock() - e.userMSS = int(userMSS) - e.mu.Unlock() - e.notifyProtocolGoroutine(notifyMSSChanged) - return nil - - case tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrInvalidEndpointState - } - - e.mu.Lock() - defer e.mu.Unlock() - - // We only allow this to be set when we're in the initial state. - if e.state != StateInitial { - return tcpip.ErrInvalidEndpointState - } - - e.v6only = v != 0 - return nil - - case tcpip.TTLOption: - e.mu.Lock() - e.ttl = uint8(v) - e.mu.Unlock() - return nil - - case tcpip.KeepaliveEnabledOption: - e.keepalive.Lock() - e.keepalive.enabled = v != 0 - e.keepalive.Unlock() - e.notifyProtocolGoroutine(notifyKeepaliveChanged) - return nil + e.LockUser() + e.bindToDevice = id + e.UnlockUser() case tcpip.KeepaliveIdleOption: e.keepalive.Lock() e.keepalive.idle = time.Duration(v) e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) - return nil case tcpip.KeepaliveIntervalOption: e.keepalive.Lock() e.keepalive.interval = time.Duration(v) e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) - return nil - case tcpip.KeepaliveCountOption: - e.keepalive.Lock() - e.keepalive.count = int(v) - e.keepalive.Unlock() - e.notifyProtocolGoroutine(notifyKeepaliveChanged) - return nil + case tcpip.OutOfBandInlineOption: + // We don't currently support disabling this option. - case tcpip.BroadcastOption: - e.mu.Lock() - e.broadcast = v != 0 - e.mu.Unlock() - return nil + case tcpip.TCPUserTimeoutOption: + e.LockUser() + e.userTimeout = time.Duration(v) + e.UnlockUser() case tcpip.CongestionControlOption: // Query the available cc algorithms in the stack and @@ -1279,22 +1755,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { availCC := strings.Split(string(avail), " ") for _, cc := range availCC { if v == tcpip.CongestionControlOption(cc) { - // Acquire the work mutex as we may need to - // reinitialize the congestion control state. - e.mu.Lock() - state := e.state + e.LockUser() + state := e.EndpointState() e.cc = v - e.mu.Unlock() switch state { case StateEstablished: - e.workMu.Lock() - e.mu.Lock() - if e.state == state { + if e.EndpointState() == state { e.snd.cc = e.snd.initCongestionControl(e.cc) } - e.mu.Unlock() - e.workMu.Unlock() } + e.UnlockUser() return nil } } @@ -1303,34 +1773,44 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // control algorithm is specified. return tcpip.ErrNoSuchFile - case tcpip.IPv4TOSOption: - e.mu.Lock() - // TODO(gvisor.dev/issue/995): ECN is not currently supported, - // ignore the bits for now. - e.sendTOS = uint8(v) & ^uint8(inetECNMask) - e.mu.Unlock() - return nil + case tcpip.TCPLingerTimeoutOption: + e.LockUser() + if v < 0 { + // Same as effectively disabling TCPLinger timeout. + v = 0 + } + // Cap it to MaxTCPLingerTimeout. + stkTCPLingerTimeout := tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout) + if v > stkTCPLingerTimeout { + v = stkTCPLingerTimeout + } + e.tcpLingerTimeout = time.Duration(v) + e.UnlockUser() - case tcpip.IPv6TrafficClassOption: - e.mu.Lock() - // TODO(gvisor.dev/issue/995): ECN is not currently supported, - // ignore the bits for now. - e.sendTOS = uint8(v) & ^uint8(inetECNMask) - e.mu.Unlock() + case tcpip.TCPDeferAcceptOption: + e.LockUser() + if time.Duration(v) > MaxRTO { + v = tcpip.TCPDeferAcceptOption(MaxRTO) + } + e.deferAccept = time.Duration(v) + e.UnlockUser() + + case tcpip.SocketDetachFilterOption: return nil default: return nil } + return nil } // readyReceiveSize returns the number of bytes ready to be received. func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() + e.LockUser() + defer e.UnlockUser() // The endpoint cannot be in listen state. - if e.state == StateListen { + if e.EndpointState() == StateListen { return 0, tcpip.ErrInvalidEndpointState } @@ -1340,9 +1820,100 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { return e.rcvBufUsed, nil } +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + switch opt { + case tcpip.BroadcastOption: + e.LockUser() + v := e.broadcast + e.UnlockUser() + return v, nil + + case tcpip.CorkOption: + return atomic.LoadUint32(&e.cork) != 0, nil + + case tcpip.DelayOption: + return atomic.LoadUint32(&e.delay) != 0, nil + + case tcpip.KeepaliveEnabledOption: + e.keepalive.Lock() + v := e.keepalive.enabled + e.keepalive.Unlock() + + return v, nil + + case tcpip.QuickAckOption: + v := atomic.LoadUint32(&e.slowAck) == 0 + return v, nil + + case tcpip.ReuseAddressOption: + e.LockUser() + v := e.portFlags.TupleOnly + e.UnlockUser() + + return v, nil + + case tcpip.ReusePortOption: + e.LockUser() + v := e.portFlags.LoadBalanced + e.UnlockUser() + + return v, nil + + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return false, tcpip.ErrUnknownProtocolOption + } + + e.LockUser() + v := e.v6only + e.UnlockUser() + + return v, nil + + case tcpip.MulticastLoopOption: + return true, nil + + default: + return false, tcpip.ErrUnknownProtocolOption + } +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { + case tcpip.KeepaliveCountOption: + e.keepalive.Lock() + v := e.keepalive.count + e.keepalive.Unlock() + return v, nil + + case tcpip.IPv4TOSOption: + e.LockUser() + v := int(e.sendTOS) + e.UnlockUser() + return v, nil + + case tcpip.IPv6TrafficClassOption: + e.LockUser() + v := int(e.sendTOS) + e.UnlockUser() + return v, nil + + case tcpip.MaxSegOption: + // This is just stubbed out. Linux never returns the user_mss + // value as it either returns the defaultMSS or returns the + // actual current MSS. Netstack just returns the defaultMSS + // always for now. + v := header.TCPDefaultMSS + return v, nil + + case tcpip.MTUDiscoverOption: + // Always return the path MTU discovery disabled setting since + // it's the only one supported. + return tcpip.PMTUDiscoveryDont, nil + case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() @@ -1358,12 +1929,26 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { e.rcvListMu.Unlock() return v, nil - case tcpip.DelayOption: - var o int - if v := atomic.LoadUint32(&e.delay); v != 0 { - o = 1 - } - return o, nil + case tcpip.TTLOption: + e.LockUser() + v := int(e.ttl) + e.UnlockUser() + return v, nil + + case tcpip.TCPSynCountOption: + e.LockUser() + v := int(e.maxSynRetries) + e.UnlockUser() + return v, nil + + case tcpip.TCPWindowClampOption: + e.LockUser() + v := int(e.windowClamp) + e.UnlockUser() + return v, nil + + case tcpip.MulticastTTLOption: + return 1, nil default: return -1, tcpip.ErrUnknownProtocolOption @@ -1374,191 +1959,84 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { case tcpip.ErrorOption: - e.lastErrorMu.Lock() - err := e.lastError - e.lastError = nil - e.lastErrorMu.Unlock() - return err - - case *tcpip.MaxSegOption: - // This is just stubbed out. Linux never returns the user_mss - // value as it either returns the defaultMSS or returns the - // actual current MSS. Netstack just returns the defaultMSS - // always for now. - *o = header.TCPDefaultMSS - return nil - - case *tcpip.CorkOption: - *o = 0 - if v := atomic.LoadUint32(&e.cork); v != 0 { - *o = 1 - } - return nil - - case *tcpip.ReuseAddressOption: - e.mu.RLock() - v := e.reuseAddr - e.mu.RUnlock() - - *o = 0 - if v { - *o = 1 - } - return nil - - case *tcpip.ReusePortOption: - e.mu.RLock() - v := e.reusePort - e.mu.RUnlock() - - *o = 0 - if v { - *o = 1 - } - return nil + return e.takeLastError() case *tcpip.BindToDeviceOption: - e.mu.RLock() - defer e.mu.RUnlock() - if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok { - *o = tcpip.BindToDeviceOption(nic.Name) - return nil - } - *o = "" - return nil - - case *tcpip.QuickAckOption: - *o = 1 - if v := atomic.LoadUint32(&e.slowAck); v != 0 { - *o = 0 - } - return nil - - case *tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrUnknownProtocolOption - } - - e.mu.Lock() - v := e.v6only - e.mu.Unlock() - - *o = 0 - if v { - *o = 1 - } - return nil - - case *tcpip.TTLOption: - e.mu.Lock() - *o = tcpip.TTLOption(e.ttl) - e.mu.Unlock() - return nil + e.LockUser() + *o = tcpip.BindToDeviceOption(e.bindToDevice) + e.UnlockUser() case *tcpip.TCPInfoOption: *o = tcpip.TCPInfoOption{} - e.mu.RLock() + e.LockUser() snd := e.snd - e.mu.RUnlock() + e.UnlockUser() if snd != nil { snd.rtt.Lock() o.RTT = snd.rtt.srtt o.RTTVar = snd.rtt.rttvar snd.rtt.Unlock() } - return nil - - case *tcpip.KeepaliveEnabledOption: - e.keepalive.Lock() - v := e.keepalive.enabled - e.keepalive.Unlock() - - *o = 0 - if v { - *o = 1 - } - return nil case *tcpip.KeepaliveIdleOption: e.keepalive.Lock() *o = tcpip.KeepaliveIdleOption(e.keepalive.idle) e.keepalive.Unlock() - return nil case *tcpip.KeepaliveIntervalOption: e.keepalive.Lock() *o = tcpip.KeepaliveIntervalOption(e.keepalive.interval) e.keepalive.Unlock() - return nil - case *tcpip.KeepaliveCountOption: - e.keepalive.Lock() - *o = tcpip.KeepaliveCountOption(e.keepalive.count) - e.keepalive.Unlock() - return nil + case *tcpip.TCPUserTimeoutOption: + e.LockUser() + *o = tcpip.TCPUserTimeoutOption(e.userTimeout) + e.UnlockUser() case *tcpip.OutOfBandInlineOption: // We don't currently support disabling this option. *o = 1 - return nil - - case *tcpip.BroadcastOption: - e.mu.Lock() - v := e.broadcast - e.mu.Unlock() - - *o = 0 - if v { - *o = 1 - } - return nil case *tcpip.CongestionControlOption: - e.mu.Lock() + e.LockUser() *o = e.cc - e.mu.Unlock() - return nil + e.UnlockUser() - case *tcpip.IPv4TOSOption: - e.mu.RLock() - *o = tcpip.IPv4TOSOption(e.sendTOS) - e.mu.RUnlock() - return nil + case *tcpip.TCPLingerTimeoutOption: + e.LockUser() + *o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout) + e.UnlockUser() - case *tcpip.IPv6TrafficClassOption: - e.mu.RLock() - *o = tcpip.IPv6TrafficClassOption(e.sendTOS) - e.mu.RUnlock() - return nil + case *tcpip.TCPDeferAcceptOption: + e.LockUser() + *o = tcpip.TCPDeferAcceptOption(e.deferAccept) + e.UnlockUser() + + case *tcpip.OriginalDestinationOption: + ipt := e.stack.IPTables() + addr, port, err := ipt.OriginalDst(e.ID) + if err != nil { + return err + } + *o = tcpip.OriginalDestinationOption{ + Addr: addr, + Port: port, + } default: return tcpip.ErrUnknownProtocolOption } + return nil } -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.NetProto - if header.IsV4MappedAddress(addr.Addr) { - // Fail if using a v4 mapped address on a v6only endpoint. - if e.v6only { - return 0, tcpip.ErrNoRoute - } - - netProto = header.IPv4ProtocolNumber - addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] - if addr.Addr == header.IPv4Any { - addr.Addr = "" - } - } - - // Fail if we're bound to an address length different from the one we're - // checking. - if l := len(e.ID.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) { - return 0, tcpip.ErrInvalidEndpointState +// checkV4MappedLocked determines the effective network protocol and converts +// addr to its canonical form. +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only) + if err != nil { + return tcpip.FullAddress{}, 0, err } - - return netProto, nil + return unwrapped, netProto, nil } // Disconnect implements tcpip.Endpoint.Disconnect. @@ -1583,17 +2061,17 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // yet accepted by the app, they are restored without running the main goroutine // here. func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() + e.LockUser() + defer e.UnlockUser() connectingAddr := addr.Addr - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } - if e.state.connected() { + if e.EndpointState().connected() { // The endpoint is already connected. If caller hasn't been // notified yet, return success. if !e.isConnectNotified { @@ -1604,8 +2082,8 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc return tcpip.ErrAlreadyConnected } - nicid := addr.NIC - switch e.state { + nicID := addr.NIC + switch e.EndpointState() { case StateBound: // If we're already bound to a NIC but the caller is requesting // that we use a different one now, we cannot proceed. @@ -1613,11 +2091,11 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc break } - if nicid != 0 && nicid != e.boundNICID { + if nicID != 0 && nicID != e.boundNICID { return tcpip.ErrNoRoute } - nicid = e.boundNICID + nicID = e.boundNICID case StateInitial: // Nothing to do. We'll eventually fill-in the gaps in the ID (if any) @@ -1636,14 +2114,12 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicID, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) if err != nil { return err } defer r.Release() - origID := e.ID - netProtos := []tcpip.NetworkProtocolNumber{netProto} e.ID.LocalAddress = r.LocalAddress e.ID.RemoteAddress = r.RemoteAddress @@ -1651,7 +2127,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc if e.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) if err != nil { return err } @@ -1666,7 +2142,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc // src IP to ensure that for a given tuple (srcIP, destIP, // destPort) the offset used as a starting point is the same to // ensure that we can cycle through the port space effectively. - h := jenkins.Sum32(e.stack.PortSeed()) + h := jenkins.Sum32(e.stack.Seed()) h.Write([]byte(e.ID.LocalAddress)) h.Write([]byte(e.ID.RemoteAddress)) portBuf := make([]byte, 2) @@ -1674,44 +2150,95 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc h.Write(portBuf) portOffset := h.Sum32() + var twReuse tcpip.TCPTimeWaitReuseOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &twReuse); err != nil { + panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &twReuse, err)) + } + + reuse := twReuse == tcpip.TCPTimeWaitReuseGlobal + if twReuse == tcpip.TCPTimeWaitReuseLoopbackOnly { + switch netProto { + case header.IPv4ProtocolNumber: + reuse = header.IsV4LoopbackAddress(e.ID.LocalAddress) && header.IsV4LoopbackAddress(e.ID.RemoteAddress) + case header.IPv6ProtocolNumber: + reuse = e.ID.LocalAddress == header.IPv6Loopback && e.ID.RemoteAddress == header.IPv6Loopback + } + } + if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) { if sameAddr && p == e.ID.RemotePort { return false, nil } - // reusePort is false below because connect cannot reuse a port even if - // reusePort was set. - if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, false /* reusePort */, e.bindToDevice) { - return false, nil + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil { + if err != tcpip.ErrPortInUse || !reuse { + return false, nil + } + transEPID := e.ID + transEPID.LocalPort = p + // Check if an endpoint is registered with demuxer in TIME-WAIT and if + // we can reuse it. If we can't find a transport endpoint then we just + // skip using this port as it's possible that either an endpoint has + // bound the port but not registered with demuxer yet (no listen/connect + // done yet) or the reservation was freed between the check above and + // the FindTransportEndpoint below. But rather than retry the same port + // we just skip it and move on. + transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, &r) + if transEP == nil { + // ReservePort failed but there is no registered endpoint with + // demuxer. Which indicates there is at least some endpoint that has + // bound the port. + return false, nil + } + + tcpEP := transEP.(*endpoint) + tcpEP.LockUser() + // If the endpoint is not in TIME-WAIT or if it is in TIME-WAIT but + // less than 1 second has elapsed since its recentTS was updated then + // we cannot reuse the port. + if tcpEP.EndpointState() != StateTimeWait || time.Since(tcpEP.recentTSTime) < 1*time.Second { + tcpEP.UnlockUser() + return false, nil + } + // Since the endpoint is in TIME-WAIT it should be safe to acquire its + // Lock while holding the lock for this endpoint as endpoints in + // TIME-WAIT do not acquire locks on other endpoints. + tcpEP.workerCleanup = false + tcpEP.cleanupLocked() + tcpEP.notifyProtocolGoroutine(notifyAbort) + tcpEP.UnlockUser() + // Now try and Reserve again if it fails then we skip. + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil { + return false, nil + } } id := e.ID id.LocalPort = p - switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) { - case nil: - e.ID = id - return true, nil - case tcpip.ErrPortInUse: - return false, nil - default: + if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, e.bindToDevice); err != nil { + e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr) + if err == tcpip.ErrPortInUse { + return false, nil + } return false, err } + + // Port picking successful. Save the details of + // the selected port. + e.ID = id + e.isPortReserved = true + e.boundBindToDevice = e.bindToDevice + e.boundPortFlags = e.portFlags + e.boundDest = addr + return true, nil }); err != nil { return err } } - // Remove the port reservation. This can happen when Bind is called - // before Connect: in such a case we don't want to hold on to - // reservations anymore. - if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.bindToDevice) - e.isPortReserved = false - } - e.isRegistered = true - e.state = StateConnecting + e.setEndpointState(StateConnecting) e.route = r.Clone() - e.boundNICID = nicid + e.boundNICID = nicID e.effectiveNetProtos = netProtos e.connectingAddress = connectingAddr @@ -1730,14 +2257,13 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } e.segmentQueue.mu.Unlock() e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0) - e.state = StateEstablished - e.stack.Stats().TCP.CurrentEstablished.Increment() + e.setEndpointState(StateEstablished) } if run { e.workerRunning = true e.stack.Stats().TCP.ActiveConnectionOpenings.Increment() - go e.protocolMainLoop(handshake) // S/R-SAFE: will be drained before save. + go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save. } return tcpip.ErrConnectStarted @@ -1751,14 +2277,17 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { // Shutdown closes the read and/or write end of the endpoint connection to its // peer. func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() - e.shutdownFlags |= flags + e.LockUser() + defer e.UnlockUser() + return e.shutdownLocked(flags) +} +func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { + e.shutdownFlags |= flags switch { - case e.state.connected(): + case e.EndpointState().connected(): // Close for read. - if (e.shutdownFlags & tcpip.ShutdownRead) != 0 { + if e.shutdownFlags&tcpip.ShutdownRead != 0 { // Mark read side as closed. e.rcvListMu.Lock() e.rcvClosed = true @@ -1767,47 +2296,56 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // If we're fully closed and we have unread data we need to abort // the connection with a RST. - if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 { - e.notifyProtocolGoroutine(notifyReset) + if e.shutdownFlags&tcpip.ShutdownWrite != 0 && rcvBufUsed > 0 { + e.resetConnectionLocked(tcpip.ErrConnectionAborted) + // Wake up worker to terminate loop. + e.notifyProtocolGoroutine(notifyTickleWorker) return nil } } // Close for write. - if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 { + if e.shutdownFlags&tcpip.ShutdownWrite != 0 { e.sndBufMu.Lock() - if e.sndClosed { // Already closed. e.sndBufMu.Unlock() - break + if e.EndpointState() == StateTimeWait { + return tcpip.ErrNotConnected + } + return nil } // Queue fin segment. s := newSegmentFromView(&e.route, e.ID, nil) e.sndQueue.PushBack(s) e.sndBufInQueue++ - // Mark endpoint as closed. e.sndClosed = true - e.sndBufMu.Unlock() - - // Tell protocol goroutine to close. - e.sndCloseWaker.Assert() + e.handleClose() } - case e.state == StateListen: - // Tell protocolListenLoop to stop. - if flags&tcpip.ShutdownRead != 0 { - e.notifyProtocolGoroutine(notifyClose) + return nil + case e.EndpointState() == StateListen: + if e.shutdownFlags&tcpip.ShutdownRead != 0 { + // Reset all connections from the accept queue and keep the + // worker running so that it can continue handling incoming + // segments by replying with RST. + // + // By not removing this endpoint from the demuxer mapping, we + // ensure that any other bind to the same port fails, as on Linux. + e.rcvListMu.Lock() + e.rcvClosed = true + e.rcvListMu.Unlock() + e.closePendingAcceptableConnectionsLocked() + // Notify waiters that the endpoint is shutdown. + e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr) } - + return nil default: return tcpip.ErrNotConnected } - - return nil } // Listen puts the endpoint in "listen" mode, which allows it to accept @@ -1822,104 +2360,136 @@ func (e *endpoint) Listen(backlog int) *tcpip.Error { } func (e *endpoint) listen(backlog int) *tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() - - // Allow the backlog to be adjusted if the endpoint is not shutting down. - // When the endpoint shuts down, it sets workerCleanup to true, and from - // that point onward, acceptedChan is the responsibility of the cleanup() - // method (and should not be touched anywhere else, including here). - if e.state == StateListen && !e.workerCleanup { - // Adjust the size of the channel iff we can fix existing - // pending connections into the new one. - if len(e.acceptedChan) > backlog { - return tcpip.ErrInvalidEndpointState - } - if cap(e.acceptedChan) == backlog { - return nil - } - origChan := e.acceptedChan - e.acceptedChan = make(chan *endpoint, backlog) - close(origChan) - for ep := range origChan { - e.acceptedChan <- ep + e.LockUser() + defer e.UnlockUser() + + if e.EndpointState() == StateListen && !e.closed { + e.acceptMu.Lock() + defer e.acceptMu.Unlock() + if e.acceptedChan == nil { + // listen is called after shutdown. + e.acceptedChan = make(chan *endpoint, backlog) + e.shutdownFlags = 0 + e.rcvListMu.Lock() + e.rcvClosed = false + e.rcvListMu.Unlock() + } else { + // Adjust the size of the channel iff we can fix + // existing pending connections into the new one. + if len(e.acceptedChan) > backlog { + return tcpip.ErrInvalidEndpointState + } + if cap(e.acceptedChan) == backlog { + return nil + } + origChan := e.acceptedChan + e.acceptedChan = make(chan *endpoint, backlog) + close(origChan) + for ep := range origChan { + e.acceptedChan <- ep + } } + + // Notify any blocked goroutines that they can attempt to + // deliver endpoints again. + e.acceptCond.Broadcast() + return nil } + if e.EndpointState() == StateInitial { + // The listen is called on an unbound socket, the socket is + // automatically bound to a random free port with the local + // address set to INADDR_ANY. + if err := e.bindLocked(tcpip.FullAddress{}); err != nil { + return err + } + } + // Endpoint must be bound before it can transition to listen mode. - if e.state != StateBound { + if e.EndpointState() != StateBound { e.stats.ReadErrors.InvalidEndpointState.Increment() return tcpip.ErrInvalidEndpointState } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil { return err } e.isRegistered = true - e.state = StateListen + e.setEndpointState(StateListen) + + // The channel may be non-nil when we're restoring the endpoint, and it + // may be pre-populated with some previously accepted (but not Accepted) + // endpoints. + e.acceptMu.Lock() if e.acceptedChan == nil { e.acceptedChan = make(chan *endpoint, backlog) } - e.workerRunning = true + e.acceptMu.Unlock() + e.workerRunning = true go e.protocolListenLoop( // S/R-SAFE: drained on save. seqnum.Size(e.receiveBufferAvailable())) - return nil } // startAcceptedLoop sets up required state and starts a goroutine with the // main loop for accepted connections. -func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) { - e.waiterQueue = waiterQueue +func (e *endpoint) startAcceptedLoop() { e.workerRunning = true - go e.protocolMainLoop(false) // S/R-SAFE: drained on save. + e.mu.Unlock() + wakerInitDone := make(chan struct{}) + go e.protocolMainLoop(false, wakerInitDone) // S/R-SAFE: drained on save. + <-wakerInitDone } // Accept returns a new endpoint if a peer has established a connection // to an endpoint previously set to listen mode. func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() + e.LockUser() + defer e.UnlockUser() + e.rcvListMu.Lock() + rcvClosed := e.rcvClosed + e.rcvListMu.Unlock() // Endpoint must be in listen state before it can accept connections. - if e.state != StateListen { + if rcvClosed || e.EndpointState() != StateListen { return nil, nil, tcpip.ErrInvalidEndpointState } // Get the new accepted endpoint. + e.acceptMu.Lock() + defer e.acceptMu.Unlock() var n *endpoint select { case n = <-e.acceptedChan: + e.acceptCond.Signal() default: return nil, nil, tcpip.ErrWouldBlock } - - // Start the protocol goroutine. - wq := &waiter.Queue{} - n.startAcceptedLoop(wq) - e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - - return n, wq, nil + return n, n.waiterQueue, nil } // Bind binds the endpoint to a specific local port and optionally address. func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { - e.mu.Lock() - defer e.mu.Unlock() + e.LockUser() + defer e.UnlockUser() + + return e.bindLocked(addr) +} +func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // Don't allow binding once endpoint is not in the initial state // anymore. This is because once the endpoint goes into a connected or // listen state, it is already bound. - if e.state != StateInitial { + if e.EndpointState() != StateInitial { return tcpip.ErrAlreadyBound } e.BindAddr = addr.Addr - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } @@ -1935,26 +2505,30 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { } } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort, e.bindToDevice) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}) if err != nil { return err } + e.boundBindToDevice = e.bindToDevice + e.boundPortFlags = e.portFlags e.isPortReserved = true e.effectiveNetProtos = netProtos e.ID.LocalPort = port // Any failures beyond this point must remove the port registration. - defer func(bindToDevice tcpip.NICID) { + defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) { if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, bindToDevice) + e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice, tcpip.FullAddress{}) e.isPortReserved = false e.effectiveNetProtos = nil e.ID.LocalPort = 0 e.ID.LocalAddress = "" e.boundNICID = 0 + e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} } - }(e.bindToDevice) + }(e.boundPortFlags, e.boundBindToDevice) // If an address is specified, we must ensure that it's one of our // local addresses. @@ -1968,16 +2542,20 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { e.ID.LocalAddress = addr.Addr } + if err := e.stack.CheckRegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e.boundPortFlags, e.boundBindToDevice); err != nil { + return err + } + // Mark endpoint as bound. - e.state = StateBound + e.setEndpointState(StateBound) return nil } // GetLocalAddress returns the address to which the endpoint is bound. func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() + e.LockUser() + defer e.UnlockUser() return tcpip.FullAddress{ Addr: e.ID.LocalAddress, @@ -1988,10 +2566,10 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { // GetRemoteAddress returns the address to which the endpoint is connected. func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() + e.LockUser() + defer e.UnlockUser() - if !e.state.connected() { + if !e.EndpointState().connected() { return tcpip.FullAddress{}, tcpip.ErrNotConnected } @@ -2002,45 +2580,26 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { }, nil } -// HandlePacket is called by the stack when new packets arrive to this transport -// endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { - s := newSegment(r, id, vv) - if !s.parse() { - e.stack.Stats().MalformedRcvdPackets.Increment() - e.stack.Stats().TCP.InvalidSegmentsReceived.Increment() - e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() - s.decRef() - return - } - - if !s.csumValid { - e.stack.Stats().MalformedRcvdPackets.Increment() - e.stack.Stats().TCP.ChecksumErrors.Increment() - e.stats.ReceiveErrors.ChecksumErrors.Increment() - s.decRef() - return - } - - e.stack.Stats().TCP.ValidSegmentsReceived.Increment() - e.stats.SegmentsReceived.Increment() - if (s.flags & header.TCPFlagRst) != 0 { - e.stack.Stats().TCP.ResetsReceived.Increment() - } +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { + // TCP HandlePacket is not required anymore as inbound packets first + // land at the Dispatcher which then can either delivery using the + // worker go routine or directly do the invoke the tcp processing inline + // based on the state of the endpoint. +} +func (e *endpoint) enqueueSegment(s *segment) bool { // Send packet to worker goroutine. - if e.segmentQueue.enqueue(s) { - e.newSegmentWaker.Assert() - } else { + if !e.segmentQueue.enqueue(s) { // The queue is full, so we drop the segment. e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.SegmentQueueDropped.Increment() - s.decRef() + return false } + return true } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { switch typ { case stack.ControlPacketTooBig: e.sndBufMu.Lock() @@ -2051,6 +2610,18 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C e.sndBufMu.Unlock() e.notifyProtocolGoroutine(notifyMTUChanged) + + case stack.ControlNoRoute: + e.lastErrorMu.Lock() + e.lastError = tcpip.ErrNoRoute + e.lastErrorMu.Unlock() + e.notifyProtocolGoroutine(notifyError) + + case stack.ControlNetworkUnreachable: + e.lastErrorMu.Lock() + e.lastError = tcpip.ErrNetworkUnreachable + e.lastErrorMu.Unlock() + e.notifyProtocolGoroutine(notifyError) } } @@ -2079,20 +2650,16 @@ func (e *endpoint) readyToRead(s *segment) { if s != nil { s.incRef() e.rcvBufUsed += s.data.Size() - // Check if the receive window is now closed. If so make sure - // we set the zero window before we deliver the segment to ensure - // that a subsequent read of the segment will correctly trigger - // a non-zero notification. - if avail := e.receiveBufferAvailableLocked(); avail>>e.rcv.rcvWndScale == 0 { + // Increase counter if the receive window falls down below MSS + // or half receive buffer size, whichever smaller. + if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above { e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() - e.zeroWindow = true } e.rcvList.PushBack(s) } else { e.rcvClosed = true } e.rcvListMu.Unlock() - e.waiterQueue.Notify(waiter.EventIn) } @@ -2156,8 +2723,8 @@ func (e *endpoint) rcvWndScaleForHandshake() int { // updateRecentTimestamp updates the recent timestamp using the algorithm // described in https://tools.ietf.org/html/rfc7323#section-4.3 func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) { - if e.sendTSOk && seqnum.Value(e.recentTS).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) { - e.recentTS = tsVal + if e.sendTSOk && seqnum.Value(e.recentTimestamp()).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) { + e.setRecentTimestamp(tsVal) } } @@ -2167,22 +2734,21 @@ func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { if synOpts.TS { e.sendTSOk = true - e.recentTS = synOpts.TSVal + e.setRecentTimestamp(synOpts.TSVal) } } // timestamp returns the timestamp value to be used in the TSVal field of the // timestamp option for outgoing TCP segments for a given endpoint. func (e *endpoint) timestamp() uint32 { - return tcpTimeStamp(e.tsOffset) + return tcpTimeStamp(time.Now(), e.tsOffset) } // tcpTimeStamp returns a timestamp offset by the provided offset. This is // not inlined above as it's used when SYN cookies are in use and endpoint // is not created at the time when the SYN cookie is sent. -func tcpTimeStamp(offset uint32) uint32 { - now := time.Now() - return uint32(now.Unix()*1000+int64(now.Nanosecond()/1e6)) + offset +func tcpTimeStamp(curTime time.Time, offset uint32) uint32 { + return uint32(curTime.Unix()*1000+int64(curTime.Nanosecond()/1e6)) + offset } // timeStampOffset returns a randomized timestamp offset to be used when sending @@ -2236,9 +2802,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { s.SegTime = time.Now() // Copy EndpointID. - e.mu.Lock() s.ID = stack.TCPEndpointID(e.ID) - e.mu.Unlock() // Copy endpoint rcv state. e.rcvListMu.Lock() @@ -2256,7 +2820,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { // Endpoint TCP Option state. s.SendTSOk = e.sendTSOk - s.RecentTS = e.recentTS + s.RecentTS = e.recentTimestamp() s.TSOffset = e.tsOffset s.SACKPermitted = e.sackPermitted s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks) @@ -2327,6 +2891,14 @@ func (e *endpoint) completeState() stack.TCPEndpointState { WEst: cubic.wEst, } } + + rc := e.snd.rc + s.Sender.RACKState = stack.TCPRACKState{ + XmitTime: rc.xmitTime, + EndSequence: rc.endSequence, + FACK: rc.fack, + RTT: rc.rtt, + } return s } @@ -2363,17 +2935,15 @@ func (e *endpoint) initGSO() { // State implements tcpip.Endpoint.State. It exports the endpoint's protocol // state for diagnostics. func (e *endpoint) State() uint32 { - e.mu.Lock() - defer e.mu.Unlock() - return uint32(e.state) + return uint32(e.EndpointState()) } // Info returns a copy of the endpoint info. func (e *endpoint) Info() tcpip.EndpointInfo { - e.mu.RLock() + e.LockUser() // Make a copy of the endpoint info. ret := e.EndpointInfo - e.mu.RUnlock() + e.UnlockUser() return &ret } @@ -2382,6 +2952,18 @@ func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } -func mssForRoute(r *stack.Route) uint16 { - return uint16(r.MTU() - header.TCPMinimumSize) +// Wait implements stack.TransportEndpoint.Wait. +func (e *endpoint) Wait() { + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + e.waiterQueue.EventRegister(&waitEntry, waiter.EventHUp) + defer e.waiterQueue.EventUnregister(&waitEntry) + for { + e.LockUser() + running := e.workerRunning + e.UnlockUser() + if !running { + break + } + <-notifyCh + } } |