diff options
author | Bhasker Hariharan <bhaskerh@google.com> | 2019-05-30 10:47:11 -0700 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2019-05-30 12:08:41 -0700 |
commit | ae26b2c425d53aa8720238183d1c156a45904311 (patch) | |
tree | e7eff51fb10a29493586d4193acb3b1047ce1d57 /pkg | |
parent | 8d25cd0b40694d1911724816d72b34d0717878d6 (diff) |
Fixes to TCP listen behavior.
Netstack listen loop can get stuck if cookies are in-use and the app is slow to
accept incoming connections. Further we continue to complete handshake for a
connection even if the backlog is full. This creates a problem when a lots of
connections come in rapidly and we end up with lots of completed connections
just hanging around to be delivered.
These fixes change netstack behaviour to mirror what linux does as described
here in the following article
http://veithen.io/2014/01/01/how-tcp-backlog-works-in-linux.html
Now when cookies are not in-use Netstack will silently drop the ACK to a SYN-ACK
and not complete the handshake if the backlog is full. This will result in the
connection staying in a half-complete state. Eventually the sender will
retransmit the ACK and if backlog has space we will transition to a connected
state and deliver the endpoint.
Similarly when cookies are in use we do not try and create an endpoint unless
there is space in the accept queue to accept the newly created endpoint. If
there is no space then we again silently drop the ACK as we can just recreate it
when the ACK is retransmitted by the peer.
We also now use the backlog to cap the size of the SYN-RCVD queue for a given
endpoint. So at any time there can be N connections in the backlog and N in a
SYN-RCVD state if the application is not accepting connections. Any new SYNs
will be dropped.
This CL also fixes another small bug where we mark a new endpoint which has not
completed handshake as connected. We should wait till handshake successfully
completes before marking it connected.
Updates #236
PiperOrigin-RevId: 250717817
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 35 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 19 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 161 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 27 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/forwarder.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 499 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/testing/context/context.go | 6 |
8 files changed, 644 insertions, 111 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 31a449cf2..de4b963da 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -140,21 +140,26 @@ var Metrics = tcpip.Stats{ OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."), }, TCP: tcpip.TCPStats{ - ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."), - PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."), - FailedConnectionAttempts: mustCreateMetric("/netstack/tcp/failed_connection_attempts", "Number of calls to Connect or Listen (active and passive openings, respectively) that end in an error."), - ValidSegmentsReceived: mustCreateMetric("/netstack/tcp/valid_segments_received", "Number of TCP segments received that the transport layer successfully parsed."), - InvalidSegmentsReceived: mustCreateMetric("/netstack/tcp/invalid_segments_received", "Number of TCP segments received that the transport layer could not parse."), - SegmentsSent: mustCreateMetric("/netstack/tcp/segments_sent", "Number of TCP segments sent."), - ResetsSent: mustCreateMetric("/netstack/tcp/resets_sent", "Number of TCP resets sent."), - ResetsReceived: mustCreateMetric("/netstack/tcp/resets_received", "Number of TCP resets received."), - Retransmits: mustCreateMetric("/netstack/tcp/retransmits", "Number of TCP segments retransmitted."), - FastRecovery: mustCreateMetric("/netstack/tcp/fast_recovery", "Number of times fast recovery was used to recover from packet loss."), - SACKRecovery: mustCreateMetric("/netstack/tcp/sack_recovery", "Number of times SACK recovery was used to recover from packet loss."), - SlowStartRetransmits: mustCreateMetric("/netstack/tcp/slow_start_retransmits", "Number of segments retransmitted in slow start mode."), - FastRetransmit: mustCreateMetric("/netstack/tcp/fast_retransmit", "Number of TCP segments which were fast retransmitted."), - Timeouts: mustCreateMetric("/netstack/tcp/timeouts", "Number of times RTO expired."), - ChecksumErrors: mustCreateMetric("/netstack/tcp/checksum_errors", "Number of segments dropped due to bad checksums."), + ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."), + PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."), + ListenOverflowSynDrop: mustCreateMetric("/netstack/tcp/listen_overflow_syn_drop", "Number of times the listen queue overflowed and a SYN was dropped."), + ListenOverflowAckDrop: mustCreateMetric("/netstack/tcp/listen_overflow_ack_drop", "Number of times the listen queue overflowed and the final ACK in the handshake was dropped."), + ListenOverflowSynCookieSent: mustCreateMetric("/netstack/tcp/listen_overflow_syn_cookie_sent", "Number of times a SYN cookie was sent."), + ListenOverflowSynCookieRcvd: mustCreateMetric("/netstack/tcp/listen_overflow_syn_cookie_rcvd", "Number of times a SYN cookie was received."), + ListenOverflowInvalidSynCookieRcvd: mustCreateMetric("/netstack/tcp/listen_overflow_invalid_syn_cookie_rcvd", "Number of times an invalid SYN cookie was received."), + FailedConnectionAttempts: mustCreateMetric("/netstack/tcp/failed_connection_attempts", "Number of calls to Connect or Listen (active and passive openings, respectively) that end in an error."), + ValidSegmentsReceived: mustCreateMetric("/netstack/tcp/valid_segments_received", "Number of TCP segments received that the transport layer successfully parsed."), + InvalidSegmentsReceived: mustCreateMetric("/netstack/tcp/invalid_segments_received", "Number of TCP segments received that the transport layer could not parse."), + SegmentsSent: mustCreateMetric("/netstack/tcp/segments_sent", "Number of TCP segments sent."), + ResetsSent: mustCreateMetric("/netstack/tcp/resets_sent", "Number of TCP resets sent."), + ResetsReceived: mustCreateMetric("/netstack/tcp/resets_received", "Number of TCP resets received."), + Retransmits: mustCreateMetric("/netstack/tcp/retransmits", "Number of TCP segments retransmitted."), + FastRecovery: mustCreateMetric("/netstack/tcp/fast_recovery", "Number of times fast recovery was used to recover from packet loss."), + SACKRecovery: mustCreateMetric("/netstack/tcp/sack_recovery", "Number of times SACK recovery was used to recover from packet loss."), + SlowStartRetransmits: mustCreateMetric("/netstack/tcp/slow_start_retransmits", "Number of segments retransmitted in slow start mode."), + FastRetransmit: mustCreateMetric("/netstack/tcp/fast_retransmit", "Number of TCP segments which were fast retransmitted."), + Timeouts: mustCreateMetric("/netstack/tcp/timeouts", "Number of times RTO expired."), + ChecksumErrors: mustCreateMetric("/netstack/tcp/checksum_errors", "Number of segments dropped due to bad checksums."), }, UDP: tcpip.UDPStats{ PacketsReceived: mustCreateMetric("/netstack/udp/packets_received", "Number of UDP datagrams received via HandlePacket."), diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index c8164c0f0..f9886c6e4 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -760,6 +760,25 @@ type TCPStats struct { // successfully via Listen. PassiveConnectionOpenings *StatCounter + // ListenOverflowSynDrop is the number of times the listen queue overflowed + // and a SYN was dropped. + ListenOverflowSynDrop *StatCounter + + // ListenOverflowAckDrop is the number of times the final ACK + // in the handshake was dropped due to overflow. + ListenOverflowAckDrop *StatCounter + + // ListenOverflowCookieSent is the number of times a SYN cookie was sent. + ListenOverflowSynCookieSent *StatCounter + + // ListenOverflowSynCookieRcvd is the number of times a valid SYN + // cookie was received. + ListenOverflowSynCookieRcvd *StatCounter + + // ListenOverflowInvalidSynCookieRcvd is the number of times an invalid SYN cookie + // was received. + ListenOverflowInvalidSynCookieRcvd *StatCounter + // FailedConnectionAttempts is the number of calls to Connect or Listen // (active and passive openings, respectively) that end in an error. FailedConnectionAttempts *StatCounter diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index e506d7133..d4b860975 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -19,6 +19,7 @@ import ( "encoding/binary" "hash" "io" + "log" "sync" "time" @@ -87,9 +88,10 @@ var synRcvdCount struct { // and must not be accessed or have its methods called concurrently as they // may mutate the stored objects. type listenContext struct { - stack *stack.Stack - rcvWnd seqnum.Size - nonce [2][sha1.BlockSize]byte + stack *stack.Stack + rcvWnd seqnum.Size + nonce [2][sha1.BlockSize]byte + listenEP *endpoint hasherMu sync.Mutex hasher hash.Hash @@ -107,15 +109,16 @@ func timeStamp() uint32 { // threshold, and fails otherwise. func incSynRcvdCount() bool { synRcvdCount.Lock() - defer synRcvdCount.Unlock() if synRcvdCount.value >= SynRcvdCountThreshold { + synRcvdCount.Unlock() return false } synRcvdCount.pending.Add(1) synRcvdCount.value++ + synRcvdCount.Unlock() return true } @@ -124,20 +127,21 @@ func incSynRcvdCount() bool { // succeeded. func decSynRcvdCount() { synRcvdCount.Lock() - defer synRcvdCount.Unlock() synRcvdCount.value-- synRcvdCount.pending.Done() + synRcvdCount.Unlock() } // newListenContext creates a new listen context. -func newListenContext(stack *stack.Stack, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { +func newListenContext(stack *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { l := &listenContext{ stack: stack, rcvWnd: rcvWnd, hasher: sha1.New(), v6only: v6only, netProto: netProto, + listenEP: listenEP, } rand.Read(l.nonce[0][:]) @@ -195,9 +199,9 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true } -// createConnectedEndpoint creates a new connected endpoint, with the connection -// parameters given by the arguments. -func (l *listenContext) createConnectedEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { +// createConnectingEndpoint creates a new endpoint in a connecting state, with +// the connection parameters given by the arguments. +func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { // Create a new endpoint. netProto := l.netProto if netProto == 0 { @@ -223,7 +227,7 @@ func (l *listenContext) createConnectedEndpoint(s *segment, iss seqnum.Value, ir } n.isRegistered = true - n.state = stateConnected + n.state = stateConnecting // Create sender and receiver. // @@ -241,7 +245,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // Create new endpoint. irs := s.sequenceNumber cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS)) - ep, err := l.createConnectedEndpoint(s, cookie, irs, opts) + ep, err := l.createConnectingEndpoint(s, cookie, irs, opts) if err != nil { return nil, err } @@ -249,12 +253,15 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // Perform the 3-way handshake. h := newHandshake(ep, l.rcvWnd) - h.resetToSynRcvd(cookie, irs, opts) + h.resetToSynRcvd(cookie, irs, opts, l.listenEP) if err := h.execute(); err != nil { + ep.stack.Stats().TCP.FailedConnectionAttempts.Increment() ep.Close() return nil, err } + ep.state = stateConnected + // Update the receive window scaling. We can't do it before the // handshake because it's possible that the peer doesn't support window // scaling. @@ -268,13 +275,14 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // instead. func (e *endpoint) deliverAccepted(n *endpoint) { e.mu.RLock() - if e.state == stateListen { + state := e.state + e.mu.RUnlock() + if state == stateListen { e.acceptedChan <- n e.waiterQueue.Notify(waiter.EventIn) } else { n.Close() } - e.mu.RUnlock() } // handleSynSegment is called in its own goroutine once the listening endpoint @@ -285,16 +293,36 @@ func (e *endpoint) deliverAccepted(n *endpoint) { // cookies to accept connections. func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { defer decSynRcvdCount() + defer e.decSynRcvdCount() defer s.decRef() n, err := ctx.createEndpointAndPerformHandshake(s, opts) if err != nil { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() return } e.deliverAccepted(n) } +func (e *endpoint) incSynRcvdCount() bool { + e.mu.Lock() + log.Printf("l: %d, c: %d, e.synRcvdCount: %d", len(e.acceptedChan), cap(e.acceptedChan), e.synRcvdCount) + if l, c := len(e.acceptedChan), cap(e.acceptedChan); l == c && e.synRcvdCount >= c { + e.mu.Unlock() + return false + } + e.synRcvdCount++ + e.mu.Unlock() + return true +} + +func (e *endpoint) decSynRcvdCount() { + e.mu.Lock() + e.synRcvdCount-- + e.mu.Unlock() +} + // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { @@ -302,9 +330,20 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { case header.TCPFlagSyn: opts := parseSynSegmentOptions(s) if incSynRcvdCount() { - s.incRef() - go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier. + // Drop the SYN if the listen endpoint's accept queue is + // overflowing. + if e.incSynRcvdCount() { + log.Printf("processing syn packet") + s.incRef() + go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier. + return + } + log.Printf("dropping syn packet") + e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() + e.stack.Stats().DroppedPackets.Increment() + return } else { + // TODO(bhaskerh): Increment syncookie sent stat. cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) // Send SYN with window scaling because we currently // dont't encode this information in the cookie. @@ -318,36 +357,72 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { TSEcr: opts.TSVal, } sendSynTCP(&s.route, s.id, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) + e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() } case header.TCPFlagAck: - if data, ok := ctx.isCookieValid(s.id, s.ackNumber-1, s.sequenceNumber-1); ok && int(data) < len(mssTable) { - // Create newly accepted endpoint and deliver it. - rcvdSynOptions := &header.TCPSynOptions{ - MSS: mssTable[data], - // Disable Window scaling as original SYN is - // lost. - WS: -1, - } - // When syn cookies are in use we enable timestamp only - // if the ack specifies the timestamp option assuming - // that the other end did in fact negotiate the - // timestamp option in the original SYN. - if s.parsedOptions.TS { - rcvdSynOptions.TS = true - rcvdSynOptions.TSVal = s.parsedOptions.TSVal - rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr - } - n, err := ctx.createConnectedEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions) - if err == nil { - // clear the tsOffset for the newly created - // endpoint as the Timestamp was already - // randomly offset when the original SYN-ACK was - // sent above. - n.tsOffset = 0 - e.deliverAccepted(n) - } + if len(e.acceptedChan) == cap(e.acceptedChan) { + // Silently drop the ack as the application can't accept + // the connection at this point. The ack will be + // retransmitted by the sender anyway and we can + // complete the connection at the time of retransmit if + // the backlog has space. + e.stack.Stats().TCP.ListenOverflowAckDrop.Increment() + e.stack.Stats().DroppedPackets.Increment() + return + } + + // Validate the cookie. + data, ok := ctx.isCookieValid(s.id, s.ackNumber-1, s.sequenceNumber-1) + if !ok || int(data) >= len(mssTable) { + e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment() + e.stack.Stats().DroppedPackets.Increment() + return } + e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment() + // Create newly accepted endpoint and deliver it. + rcvdSynOptions := &header.TCPSynOptions{ + MSS: mssTable[data], + // Disable Window scaling as original SYN is + // lost. + WS: -1, + } + + // When syn cookies are in use we enable timestamp only + // if the ack specifies the timestamp option assuming + // that the other end did in fact negotiate the + // timestamp option in the original SYN. + if s.parsedOptions.TS { + rcvdSynOptions.TS = true + rcvdSynOptions.TSVal = s.parsedOptions.TSVal + rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr + } + + n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions) + if err != nil { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + return + } + + // clear the tsOffset for the newly created + // endpoint as the Timestamp was already + // randomly offset when the original SYN-ACK was + // sent above. + n.tsOffset = 0 + + // Switch state to connected. + n.state = stateConnected + + // Do the delivery in a separate goroutine so + // that we don't block the listen loop in case + // the application is slow to accept or stops + // accepting. + // + // NOTE: This won't result in an unbounded + // number of goroutines as we do check before + // entering here that there was at least some + // space available in the backlog. + go e.deliverAccepted(n) } } @@ -377,7 +452,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { v6only := e.v6only e.mu.Unlock() - ctx := newListenContext(e.stack, rcvWnd, v6only, e.netProto) + ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto) s := sleep.Sleeper{} s.AddWaker(&e.notificationWaker, wakerForNotification) diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 3b927d82e..2aed6f286 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -60,11 +60,12 @@ const ( // handshake holds the state used during a TCP 3-way handshake. type handshake struct { - ep *endpoint - state handshakeState - active bool - flags uint8 - ackNum seqnum.Value + ep *endpoint + listenEP *endpoint // only non nil when doing passive connects. + state handshakeState + active bool + flags uint8 + ackNum seqnum.Value // iss is the initial send sequence number, as defined in RFC 793. iss seqnum.Value @@ -141,7 +142,7 @@ func (h *handshake) effectiveRcvWndScale() uint8 { // resetToSynRcvd resets the state of the handshake object to the SYN-RCVD // state. -func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions) { +func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions, listenEP *endpoint) { h.active = false h.state = handshakeSynRcvd h.flags = header.TCPFlagSyn | header.TCPFlagAck @@ -149,6 +150,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea h.ackNum = irs + 1 h.mss = opts.MSS h.sndWndScale = opts.WS + h.listenEP = listenEP } // checkAck checks if the ACK number, if present, of a segment received during @@ -279,7 +281,18 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { // We have previously received (and acknowledged) the peer's SYN. If the // peer acknowledges our SYN, the handshake is completed. if s.flagIsSet(header.TCPFlagAck) { - + // listenContext is also used by a tcp.Forwarder and in that + // context we do not have a listening endpoint to check the + // backlog. So skip this check if listenEP is nil. + if h.listenEP != nil && len(h.listenEP.acceptedChan) == cap(h.listenEP.acceptedChan) { + // If there is no space in the accept queue to accept + // this endpoint then silently drop this ACK. The peer + // will anyway resend the ack and we can complete the + // connection the next time it's retransmitted. + h.ep.stack.Stats().TCP.ListenOverflowAckDrop.Increment() + h.ep.stack.Stats().DroppedPackets.Increment() + return nil + } // If the timestamp option is negotiated and the segment does // not carry a timestamp option then the segment must be dropped // as per https://tools.ietf.org/html/rfc7323#section-3.2. diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 00962a63e..b66610ee2 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -198,6 +198,10 @@ type endpoint struct { // and dropped when it is. segmentQueue segmentQueue `state:"wait"` + // synRcvdCount is the number of connections for this endpoint that are + // in SYN-RCVD state. + synRcvdCount int + // The following fields are used to manage the send buffer. When // segments are ready to be sent, they are added to sndQueue and the // protocol goroutine is signaled via sndWaker. @@ -1302,7 +1306,6 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { } e.workerRunning = true - e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() go e.protocolListenLoop( // S/R-SAFE: drained on save. seqnum.Size(e.receiveBufferAvailable())) @@ -1339,6 +1342,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { // Start the protocol goroutine. wq := &waiter.Queue{} n.startAcceptedLoop(wq) + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() return n, wq, nil } diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index e088e24cb..c30b45c2c 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -53,7 +53,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward maxInFlight: maxInFlight, handler: handler, inFlight: make(map[stack.TransportEndpointID]struct{}), - listen: newListenContext(s, seqnum.Size(rcvWnd), true, 0), + listen: newListenContext(s, nil /* listenEP */, seqnum.Size(rcvWnd), true, 0), } } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index e341bb4aa..fe037602b 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -124,50 +124,6 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { } } -func TestPassiveConnectionAttemptIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - want := stats.TCP.PassiveConnectionOpenings.Value() + 1 - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - if err := ep.Listen(1); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want { - t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %v, want = %v", got, want) - } -} - -func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - c.EP = ep - want := stats.TCP.FailedConnectionAttempts.Value() + 1 - - if err := ep.Listen(1); err != tcpip.ErrInvalidEndpointState { - t.Errorf("got ep.Listen(1) = %v, want = %v", err, tcpip.ErrInvalidEndpointState) - } - - if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { - t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want) - } -} - func TestTCPSegmentsSentIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -3900,3 +3856,458 @@ func TestKeepalive(t *testing.T) { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset) } } + +func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { + // Send a SYN request. + irs = seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: srcPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcp := header.TCP(header.IPv4(b).Payload()) + iss = seqnum.Value(tcp.SequenceNumber()) + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(srcPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.AckNum(uint32(irs) + 1), + } + + if synCookieInUse { + // When cookies are in use window scaling is disabled. + tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{ + WS: -1, + MSS: c.MSSWithoutOptions(), + })) + } + + checker.IPv4(t, b, checker.TCP(tcpCheckers...)) + + // Send ACK. + c.SendPacket(nil, &context.Headers{ + SrcPort: srcPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) + return irs, iss +} + +// TestListenBacklogFull tests that netstack does not complete handshakes if the +// listen backlog for the endpoint is full. +func TestListenBacklogFull(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + // Start listening. + listenBacklog := 2 + if err := c.EP.Listen(listenBacklog); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + for i := 0; i < listenBacklog; i++ { + executeHandshake(t, c, context.TestPort+uint16(i), false /*synCookieInUse */) + } + + time.Sleep(50 * time.Millisecond) + + // Now execute one more handshake. This should not be completed and + // delivered on an Accept() call as the backlog is full at this point. + irs, iss := executeHandshake(t, c, context.TestPort+uint16(listenBacklog), false /* synCookieInUse */) + + time.Sleep(50 * time.Millisecond) + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + for i := 0; i < listenBacklog; i++ { + _, _, err = c.EP.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + _, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + } + + // Now verify that there are no more connections that can be accepted. + _, _, err = c.EP.Accept() + if err != tcpip.ErrWouldBlock { + select { + case <-ch: + t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) + case <-time.After(1 * time.Second): + } + } + + // Now craft the ACK again and verify that the connection is now ready + // to be accepted. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort + uint16(listenBacklog), + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) + + newEP, _, err := c.EP.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + newEP, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + // Now verify that the TCP socket is usable and in a connected state. + data := "Don't panic" + newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) + b := c.GetPacket() + tcp := header.TCP(header.IPv4(b).Payload()) + if string(tcp.Payload()) != data { + t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) + } +} + +func TestListenBacklogFullSynCookieInUse(t *testing.T) { + saved := tcp.SynRcvdCountThreshold + defer func() { + tcp.SynRcvdCountThreshold = saved + }() + tcp.SynRcvdCountThreshold = 1 + + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + // Start listening. + listenBacklog := 1 + portOffset := uint16(0) + if err := c.EP.Listen(listenBacklog); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + executeHandshake(t, c, context.TestPort+portOffset, false) + portOffset++ + // Wait for this to be delivered to the accept queue. + time.Sleep(50 * time.Millisecond) + + nonCookieIRS, nonCookieISS := executeHandshake(t, c, context.TestPort+portOffset, false) + + // Since the backlog is full at this point this connection will not + // transition out of handshake and ignore the ACK. + // + // At this point there should be 1 completed connection in the backlog + // and one incomplete one pending for a final ACK and hence not ready to be + // delivered to the endpoint. + // + // Now execute one more handshake. This should not be completed and + // delivered on an Accept() call as the backlog is full at this point + // and there is already 1 pending endpoint. + // + // This one should use a SYN cookie as the synRcvdCount is equal to the + // SynRcvdCountThreshold. + time.Sleep(50 * time.Millisecond) + portOffset++ + irs, iss := executeHandshake(t, c, context.TestPort+portOffset, true) + + time.Sleep(50 * time.Millisecond) + + // Verify that there is only one acceptable connection at this point. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + _, _, err = c.EP.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + _, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Now verify that there are no more connections that can be accepted. + _, _, err = c.EP.Accept() + if err != tcpip.ErrWouldBlock { + select { + case <-ch: + t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) + case <-time.After(1 * time.Second): + } + } + + // Now send an ACK for the half completed connection + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort + portOffset - 1, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: nonCookieIRS + 1, + AckNum: nonCookieISS + 1, + RcvWnd: 30000, + }) + + // Verify that the connection is now delivered to the backlog. + _, _, err = c.EP.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + _, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Finally send an ACK for the connection that used a cookie and verify that + // it's also completed and delivered. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort + portOffset, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs, + AckNum: iss, + RcvWnd: 30000, + }) + + time.Sleep(50 * time.Millisecond) + newEP, _, err := c.EP.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + newEP, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Now verify that the TCP socket is usable and in a connected state. + data := "Don't panic" + newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) + b := c.GetPacket() + tcp := header.TCP(header.IPv4(b).Payload()) + if string(tcp.Payload()) != data { + t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) + } +} + +func TestPassiveConnectionAttemptIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + c.EP = ep + if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + if err := c.EP.Listen(1); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + stats := c.Stack().Stats() + want := stats.TCP.PassiveConnectionOpenings.Value() + 1 + + srcPort := uint16(context.TestPort) + executeHandshake(t, c, srcPort+1, false) + + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + // Verify that there is only one acceptable connection at this point. + _, _, err = c.EP.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + _, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want { + t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %v, want = %v", got, want) + } +} + +func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + c.EP = ep + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + if err := c.EP.Listen(1); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + srcPort := uint16(context.TestPort) + // Now attempt 3 handshakes, the first two will fill up the accept and the SYN-RCVD + // queue for the endpoint. + executeHandshake(t, c, srcPort, false) + + // Give time for the final ACK to be processed as otherwise the next handshake could + // get accepted before the previous one based on goroutine scheduling. + time.Sleep(50 * time.Millisecond) + irs, iss := executeHandshake(t, c, srcPort+1, false) + + // Wait for a short while for the accepted connection to be delivered to + // the channel before trying to send the 3rd SYN. + time.Sleep(40 * time.Millisecond) + + want := stats.TCP.ListenOverflowSynDrop.Value() + 1 + + // Now we will send one more SYN and this one should get dropped + // Send a SYN request. + c.SendPacket(nil, &context.Headers{ + SrcPort: srcPort + 2, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: seqnum.Value(789), + RcvWnd: 30000, + }) + + time.Sleep(50 * time.Millisecond) + if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want { + t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %v, want = %v", got, want) + } + + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + // Now check that there is one acceptable connections. + _, _, err = c.EP.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + _, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Now complete the next connection in SYN-RCVD state as it should + // have dropped the final ACK to the handshake due to accept queue + // being full. + c.SendPacket(nil, &context.Headers{ + SrcPort: srcPort + 1, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) + + // Now check that there is one more acceptable connections. + _, _, err = c.EP.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + _, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Try and accept a 3rd one this should fail. + _, _, err = c.EP.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + ep, _, err = c.EP.Accept() + if err == nil { + t.Fatalf("Accept succeeded when it should have failed got: %+v", ep) + } + + case <-time.After(1 * time.Second): + } + } +} diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index e08eb6533..6e12413c6 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -989,3 +989,9 @@ func (c *Context) SACKEnabled() bool { func (c *Context) SetGSOEnabled(enable bool) { c.linkEP.GSO = enable } + +// MSSWithoutOptions returns the value for the MSS used by the stack when no +// options are in use. +func (c *Context) MSSWithoutOptions() uint16 { + return uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize) +} |