diff options
Diffstat (limited to 'pkg/tcpip/transport/tcp')
-rw-r--r-- | pkg/tcpip/transport/tcp/BUILD | 9 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 54 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 458 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/dual_stack_test.go | 56 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 593 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint_state.go | 59 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/forwarder.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/protocol.go | 67 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/rcv.go | 167 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/segment.go | 15 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/snd.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_noracedetector_test.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_sack_test.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 1158 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/testing/context/context.go | 42 |
15 files changed, 2404 insertions, 299 deletions
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 39a839ab7..3f47b328d 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -1,10 +1,9 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( name = "tcp_segment_list", out = "tcp_segment_list.go", @@ -45,10 +44,12 @@ go_library( imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ + "//pkg/log", "//pkg/rand", "//pkg/sleep", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", "//pkg/tcpip/iptables", "//pkg/tcpip/seqnum", @@ -70,7 +71,7 @@ filegroup( go_test( name = "tcp_test", - size = "small", + size = "medium", srcs = [ "dual_stack_test.go", "sack_scoreboard_test.go", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 0802e984e..f24b51b91 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -229,7 +230,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i } n := newEndpoint(l.stack, netProto, nil) n.v6only = l.v6only - n.id = s.id + n.ID = s.id n.boundNICID = s.route.NICID() n.route = s.route.Clone() n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} @@ -242,7 +243,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.initGSO() // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort); err != nil { + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.bindToDevice); err != nil { n.Close() return nil, err } @@ -268,8 +269,8 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber - cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS)) - ep, err := l.createConnectingEndpoint(s, cookie, irs, opts) + isn := generateSecureISN(s.id, l.stack.Seed()) + ep, err := l.createConnectingEndpoint(s, isn, irs, opts) if err != nil { return nil, err } @@ -288,9 +289,8 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // Perform the 3-way handshake. h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow())) - h.resetToSynRcvd(cookie, irs, opts) + h.resetToSynRcvd(isn, irs, opts) if err := h.execute(); err != nil { - ep.stack.Stats().TCP.FailedConnectionAttempts.Increment() ep.Close() if l.listenEP != nil { l.removePendingEndpoint(ep) @@ -298,7 +298,9 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return nil, err } ep.mu.Lock() + ep.stack.Stats().TCP.CurrentEstablished.Increment() ep.state = StateEstablished + ep.isConnectNotified = true ep.mu.Unlock() // Update the receive window scaling. We can't do it before the @@ -311,14 +313,14 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head func (l *listenContext) addPendingEndpoint(n *endpoint) { l.pendingMu.Lock() - l.pendingEndpoints[n.id] = n + l.pendingEndpoints[n.ID] = n l.pending.Add(1) l.pendingMu.Unlock() } func (l *listenContext) removePendingEndpoint(n *endpoint) { l.pendingMu.Lock() - delete(l.pendingEndpoints, n.id) + delete(l.pendingEndpoints, n.ID) l.pending.Done() l.pendingMu.Unlock() } @@ -360,12 +362,19 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header defer decSynRcvdCount() defer e.decSynRcvdCount() defer s.decRef() + n, err := ctx.createEndpointAndPerformHandshake(s, opts) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() return } ctx.removePendingEndpoint(n) + // Start the protocol goroutine. + wq := &waiter.Queue{} + n.startAcceptedLoop(wq) + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() + e.deliverAccepted(n) } @@ -399,6 +408,17 @@ func (e *endpoint) acceptQueueIsFull() bool { // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { + if s.flagsAreSet(header.TCPFlagSyn | header.TCPFlagAck) { + // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST + // must be sent in response to a SYN-ACK while in the listen + // state to prevent completing a handshake from an old SYN. + e.sendTCP(&s.route, s.id, buffer.VectorisedView{}, e.ttl, e.sendTOS, header.TCPFlagRst, s.ackNumber, 0, 0, nil, nil) + return + } + + // TODO(b/143300739): Use the userMSS of the listening socket + // for accepted sockets. + switch s.flags { case header.TCPFlagSyn: opts := parseSynSegmentOptions(s) @@ -414,6 +434,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { } decSynRcvdCount() e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() + e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() return } else { @@ -421,6 +442,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // is full then drop the syn. if e.acceptQueueIsFull() { e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() + e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() return } @@ -431,15 +453,14 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // // Enable Timestamp option if the original syn did have // the timestamp option specified. - mss := mssForRoute(&s.route) synOpts := header.TCPSynOptions{ WS: -1, TS: opts.TS, TSVal: tcpTimeStamp(timeStampOffset()), TSEcr: opts.TSVal, - MSS: uint16(mss), + MSS: mssForRoute(&s.route), } - sendSynTCP(&s.route, s.id, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) + e.sendSynTCP(&s.route, s.id, e.ttl, e.sendTOS, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() } @@ -451,6 +472,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // complete the connection at the time of retransmit if // the backlog has space. e.stack.Stats().TCP.ListenOverflowAckDrop.Increment() + e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment() e.stack.Stats().DroppedPackets.Increment() return } @@ -505,6 +527,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() return } @@ -515,7 +538,9 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n.tsOffset = 0 // Switch state to connected. + n.stack.Stats().TCP.CurrentEstablished.Increment() n.state = StateEstablished + n.isConnectNotified = true // Do the delivery in a separate goroutine so // that we don't block the listen loop in case @@ -526,6 +551,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // number of goroutines as we do check before // entering here that there was at least some // space available in the backlog. + + // Start the protocol goroutine. + wq := &waiter.Queue{} + n.startAcceptedLoop(wq) + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() go e.deliverAccepted(n) } } @@ -536,7 +566,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.mu.Lock() v6only := e.v6only e.mu.Unlock() - ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto) + ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.NetProto) defer func() { // Mark endpoint as closed. This will prevent goroutines running diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 21038a65a..a114c06c1 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -15,6 +15,7 @@ package tcp import ( + "encoding/binary" "sync" "time" @@ -22,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -78,9 +80,6 @@ type handshake struct { // mss is the maximum segment size received from the peer. mss uint16 - // amss is the maximum segment size advertised by us to the peer. - amss uint16 - // sndWndScale is the send window scale, as defined in RFC 1323. A // negative value means no scaling is supported by the peer. sndWndScale int @@ -142,7 +141,32 @@ func (h *handshake) resetState() { h.flags = header.TCPFlagSyn h.ackNum = 0 h.mss = 0 - h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24) + h.iss = generateSecureISN(h.ep.ID, h.ep.stack.Seed()) +} + +// generateSecureISN generates a secure Initial Sequence number based on the +// recommendation here https://tools.ietf.org/html/rfc6528#page-3. +func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value { + isnHasher := jenkins.Sum32(seed) + isnHasher.Write([]byte(id.LocalAddress)) + isnHasher.Write([]byte(id.RemoteAddress)) + portBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(portBuf, id.LocalPort) + isnHasher.Write(portBuf) + binary.LittleEndian.PutUint16(portBuf, id.RemotePort) + isnHasher.Write(portBuf) + // The time period here is 64ns. This is similar to what linux uses + // generate a sequence number that overlaps less than one + // time per MSL (2 minutes). + // + // A 64ns clock ticks 10^9/64 = 15625000) times in a second. + // To wrap the whole 32 bit space would require + // 2^32/1562500 ~ 274 seconds. + // + // Which sort of guarantees that we won't reuse the ISN for a new + // connection for the same tuple for at least 274s. + isn := isnHasher.Sum32() + uint32(time.Now().UnixNano()>>6) + return seqnum.Value(isn) } // effectiveRcvWndScale returns the effective receive window scale to be used. @@ -238,6 +262,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { h.state = handshakeSynRcvd h.ep.mu.Lock() h.ep.state = StateSynRecv + ttl := h.ep.ttl h.ep.mu.Unlock() synOpts := header.TCPSynOptions{ WS: int(h.effectiveRcvWndScale()), @@ -251,8 +276,10 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { SACKPermitted: rcvSynOpts.SACKPermitted, MSS: h.ep.amss, } - sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) - + if ttl == 0 { + ttl = s.route.DefaultTTL() + } + h.ep.sendSynTCP(&s.route, h.ep.ID, ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) return nil } @@ -296,7 +323,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { SACKPermitted: h.ep.sackPermitted, MSS: h.ep.amss, } - sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) return nil } @@ -383,6 +410,11 @@ func (h *handshake) resolveRoute() *tcpip.Error { switch index { case wakerForResolution: if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock { + if err == tcpip.ErrNoLinkAddress { + h.ep.stats.SendErrors.NoLinkAddr.Increment() + } else if err != nil { + h.ep.stats.SendErrors.NoRoute.Increment() + } // Either success (err == nil) or failure. return err } @@ -437,7 +469,7 @@ func (h *handshake) execute() *tcpip.Error { // Send the initial SYN segment and loop until the handshake is // completed. - h.ep.amss = mssForRoute(&h.ep.route) + h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, @@ -460,7 +492,8 @@ func (h *handshake) execute() *tcpip.Error { synOpts.WS = -1 } } - sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + for h.state != handshakeCompleted { switch index, _ := s.Fetch(true); index { case wakerForResend: @@ -469,7 +502,7 @@ func (h *handshake) execute() *tcpip.Error { return tcpip.ErrTimeout } rt.Reset(timeOut) - sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) case wakerForNotification: n := h.ep.fetchNotifications() @@ -579,24 +612,30 @@ func makeSynOptions(opts header.TCPSynOptions) []byte { return options[:offset] } -func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error { +func (e *endpoint) sendSynTCP(r *stack.Route, id stack.TransportEndpointID, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error { options := makeSynOptions(opts) - err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options, nil) + // We ignore SYN send errors and let the callers re-attempt send. + if err := e.sendTCP(r, id, buffer.VectorisedView{}, ttl, tos, flags, seq, ack, rcvWnd, options, nil); err != nil { + e.stats.SendErrors.SynSendToNetworkFailed.Increment() + } putOptions(options) - return err + return nil } -// sendTCP sends a TCP segment with the provided options via the provided -// network endpoint and under the provided identity. -func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { - optLen := len(opts) - // Allocate a buffer for the TCP header. - hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen) - - if rcvWnd > 0xffff { - rcvWnd = 0xffff +func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { + if err := sendTCP(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso); err != nil { + e.stats.SendErrors.SegmentSendToNetworkFailed.Increment() + return err } + e.stats.SegmentsSent.Increment() + return nil +} +func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, d *stack.PacketDescriptor, data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) { + optLen := len(opts) + hdr := &d.Hdr + packetSize := d.Size + off := d.Off // Initialize the header. tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen)) tcp.Encode(&header.TCPFields{ @@ -610,7 +649,7 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise }) copy(tcp[header.TCPMinimumSize:], opts) - length := uint16(hdr.UsedLength() + data.Size()) + length := uint16(hdr.UsedLength() + packetSize) xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) // Only calculate the checksum if offloading isn't supported. if gso != nil && gso.NeedsCsum { @@ -620,16 +659,79 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise // header and data and get the right sum of the TCP packet. tcp.SetChecksum(xsum) } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 { - xsum = header.ChecksumVV(data, xsum) + xsum = header.ChecksumVVWithOffset(data, xsum, off, packetSize) tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) } +} + +func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { + optLen := len(opts) + if rcvWnd > 0xffff { + rcvWnd = 0xffff + } + + mss := int(gso.MSS) + n := (data.Size() + mss - 1) / mss + + hdrs := stack.NewPacketDescriptors(n, header.TCPMinimumSize+int(r.MaxHeaderLength())+optLen) + + size := data.Size() + off := 0 + for i := 0; i < n; i++ { + packetSize := mss + if packetSize > size { + packetSize = size + } + size -= packetSize + hdrs[i].Off = off + hdrs[i].Size = packetSize + buildTCPHdr(r, id, &hdrs[i], data, flags, seq, ack, rcvWnd, opts, gso) + off += packetSize + seq = seq.Add(seqnum.Size(packetSize)) + } + if ttl == 0 { + ttl = r.DefaultTTL() + } + sent, err := r.WritePackets(gso, hdrs, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}) + if err != nil { + r.Stats().TCP.SegmentSendErrors.IncrementBy(uint64(n - sent)) + } + r.Stats().TCP.SegmentsSent.IncrementBy(uint64(sent)) + return err +} + +// sendTCP sends a TCP segment with the provided options via the provided +// network endpoint and under the provided identity. +func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { + optLen := len(opts) + if rcvWnd > 0xffff { + rcvWnd = 0xffff + } + + if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() { + return sendTCPBatch(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso) + } + + d := &stack.PacketDescriptor{ + Hdr: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), + Off: 0, + Size: data.Size(), + } + buildTCPHdr(r, id, d, data, flags, seq, ack, rcvWnd, opts, gso) + + if ttl == 0 { + ttl = r.DefaultTTL() + } + if err := r.WritePacket(gso, d.Hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil { + r.Stats().TCP.SegmentSendErrors.Increment() + return err + } r.Stats().TCP.SegmentsSent.Increment() if (flags & header.TCPFlagRst) != 0 { r.Stats().TCP.ResetsSent.Increment() } - - return r.WritePacket(gso, hdr, data, ProtocolNumber, ttl) + return nil } // makeOptions makes an options slice. @@ -678,7 +780,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) - err := sendTCP(&e.route, e.id, data, e.route.DefaultTTL(), flags, seq, ack, rcvWnd, options, e.gso) + err := e.sendTCP(&e.route, e.ID, data, e.ttl, e.sendTOS, flags, seq, ack, rcvWnd, options, e.gso) putOptions(options) return err } @@ -727,10 +829,26 @@ func (e *endpoint) handleClose() *tcpip.Error { func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { // Only send a reset if the connection is being aborted for a reason // other than receiving a reset. + if e.state == StateEstablished || e.state == StateCloseWait { + e.stack.Stats().TCP.EstablishedResets.Increment() + e.stack.Stats().TCP.CurrentEstablished.Decrement() + } e.state = StateError - e.hardError = err + e.HardError = err if err != tcpip.ErrConnectionReset { - e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0) + // The exact sequence number to be used for the RST is the same as the + // one used by Linux. We need to handle the case of window being shrunk + // which can cause sndNxt to be outside the acceptable window on the + // receiver. + // + // See: https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for more + // information. + sndWndEnd := e.snd.sndUna.Add(e.snd.sndWnd) + resetSeqNum := sndWndEnd + if !sndWndEnd.LessThan(e.snd.sndNxt) || e.snd.sndNxt.Size(sndWndEnd) < (1<<e.snd.sndWndScale) { + resetSeqNum = e.snd.sndNxt + } + e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, resetSeqNum, e.rcv.rcvNxt, 0) } } @@ -744,6 +862,51 @@ func (e *endpoint) completeWorkerLocked() { } } +func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { + if e.rcv.acceptable(s.sequenceNumber, 0) { + // RFC 793, page 37 states that "in all states + // except SYN-SENT, all reset (RST) segments are + // validated by checking their SEQ-fields." So + // we only process it if it's acceptable. + s.decRef() + e.mu.Lock() + switch e.state { + // In case of a RST in CLOSE-WAIT linux moves + // the socket to closed state with an error set + // to indicate EPIPE. + // + // Technically this seems to be at odds w/ RFC. + // As per https://tools.ietf.org/html/rfc793#section-2.7 + // page 69 the behavior for a segment arriving + // w/ RST bit set in CLOSE-WAIT is inlined below. + // + // ESTABLISHED + // FIN-WAIT-1 + // FIN-WAIT-2 + // CLOSE-WAIT + + // If the RST bit is set then, any outstanding RECEIVEs and + // SEND should receive "reset" responses. All segment queues + // should be flushed. Users should also receive an unsolicited + // general "connection reset" signal. Enter the CLOSED state, + // delete the TCB, and return. + case StateCloseWait: + e.state = StateClose + e.HardError = tcpip.ErrAborted + // We need to set this explicitly here because otherwise + // the port registrations will not be released till the + // endpoint is actively closed by the application. + e.workerCleanup = true + e.mu.Unlock() + return false, nil + default: + e.mu.Unlock() + return false, tcpip.ErrConnectionReset + } + } + return true, nil +} + // handleSegments pulls segments from the queue and processes them. It returns // no error if the protocol loop should continue, an error otherwise. func (e *endpoint) handleSegments() *tcpip.Error { @@ -761,14 +924,34 @@ func (e *endpoint) handleSegments() *tcpip.Error { } if s.flagIsSet(header.TCPFlagRst) { - if e.rcv.acceptable(s.sequenceNumber, 0) { - // RFC 793, page 37 states that "in all states - // except SYN-SENT, all reset (RST) segments are - // validated by checking their SEQ-fields." So - // we only process it if it's acceptable. - s.decRef() - return tcpip.ErrConnectionReset + if ok, err := e.handleReset(s); !ok { + return err } + } else if s.flagIsSet(header.TCPFlagSyn) { + // See: https://tools.ietf.org/html/rfc5961#section-4.1 + // 1) If the SYN bit is set, irrespective of the sequence number, TCP + // MUST send an ACK (also referred to as challenge ACK) to the remote + // peer: + // + // <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK> + // + // After sending the acknowledgment, TCP MUST drop the unacceptable + // segment and stop processing further. + // + // By sending an ACK, the remote peer is challenged to confirm the loss + // of the previous connection and the request to start a new connection. + // A legitimate peer, after restart, would not have a TCB in the + // synchronized state. Thus, when the ACK arrives, the peer should send + // a RST segment back with the sequence number derived from the ACK + // field that caused the RST. + + // This RST will confirm that the remote peer has indeed closed the + // previous connection. Upon receipt of a valid RST, the local TCP + // endpoint MUST terminate its connection. The local TCP endpoint + // should then rely on SYN retransmission from the remote end to + // re-establish the connection. + + e.snd.sendAck() } else if s.flagIsSet(header.TCPFlagAck) { // Patch the window size in the segment according to the // send window scale. @@ -777,7 +960,15 @@ func (e *endpoint) handleSegments() *tcpip.Error { // RFC 793, page 41 states that "once in the ESTABLISHED // state all segments must carry current acknowledgment // information." - e.rcv.handleRcvdSegment(s) + drop, err := e.rcv.handleRcvdSegment(s) + if err != nil { + s.decRef() + return err + } + if drop { + s.decRef() + continue + } e.snd.handleRcvdSegment(s) } s.decRef() @@ -876,7 +1067,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } e.mu.Unlock() - // When the protocol loop exits we should wake up our waiters. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -897,8 +1087,10 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.lastErrorMu.Unlock() e.mu.Lock() + e.stack.Stats().TCP.EstablishedResets.Increment() + e.stack.Stats().TCP.CurrentEstablished.Decrement() e.state = StateError - e.hardError = err + e.HardError = err // Lock released below. epilogue() @@ -920,6 +1112,10 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // RTT itself. e.rcvAutoParams.prevCopied = initialRcvWnd e.rcvListMu.Unlock() + e.stack.Stats().TCP.CurrentEstablished.Increment() + e.mu.Lock() + e.state = StateEstablished + e.mu.Unlock() } e.keepalive.timer.init(&e.keepalive.waker) @@ -927,7 +1123,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Tell waiters that the endpoint is connected and writable. e.mu.Lock() - e.state = StateEstablished drained := e.drainDone != nil e.mu.Unlock() if drained { @@ -958,7 +1153,13 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { { w: &closeWaker, f: func() *tcpip.Error { - return tcpip.ErrConnectionAborted + // This means the socket is being closed due + // to the TCP_FIN_WAIT2 timeout was hit. Just + // mark the socket as closed. + e.mu.Lock() + e.state = StateClose + e.mu.Unlock() + return nil }, }, { @@ -1001,17 +1202,18 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.resetConnectionLocked(tcpip.ErrConnectionAborted) e.mu.Unlock() } + if n¬ifyClose != 0 && closeTimer == nil { - // Reset the connection 3 seconds after - // the endpoint has been closed. - // - // The timer could fire in background - // when the endpoint is drained. That's - // OK as the loop here will not honor - // the firing until the undrain arrives. - closeTimer = time.AfterFunc(3*time.Second, func() { - closeWaker.Assert() - }) + e.mu.Lock() + if e.state == StateFinWait2 && e.closed { + // The socket has been closed and we are in FIN_WAIT2 + // so start the FIN_WAIT2 timer. + closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() { + closeWaker.Assert() + }) + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + } + e.mu.Unlock() } if n¬ifyKeepaliveChanged != 0 { @@ -1033,6 +1235,12 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } } + if n¬ifyTickleWorker != 0 { + // Just a tickle notification. No need to do + // anything. + return nil + } + return nil }, }, @@ -1059,15 +1267,16 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } e.rcvListMu.Unlock() - e.mu.RLock() + e.mu.Lock() if e.workerCleanup { e.notifyProtocolGoroutine(notifyClose) } - e.mu.RUnlock() // Main loop. Handle segments until both send and receive ends of the // connection have completed. - for !e.rcv.closed || !e.snd.closed || e.snd.sndUna != e.snd.sndNxtList { + + for e.state != StateTimeWait && e.state != StateClose && e.state != StateError { + e.mu.Unlock() e.workMu.Unlock() v, _ := s.Fetch(true) e.workMu.Lock() @@ -1083,15 +1292,156 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return nil } + e.mu.Lock() + } + + state := e.state + e.mu.Unlock() + var reuseTW func() + if state == StateTimeWait { + // Disable close timer as we now entering real TIME_WAIT. + if closeTimer != nil { + closeTimer.Stop() + } + // Mark the current sleeper done so as to free all associated + // wakers. + s.Done() + // Wake up any waiters before we enter TIME_WAIT. + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + reuseTW = e.doTimeWait() } // Mark endpoint as closed. e.mu.Lock() if e.state != StateError { + e.stack.Stats().TCP.EstablishedResets.Increment() + e.stack.Stats().TCP.CurrentEstablished.Decrement() e.state = StateClose } + // Lock released below. epilogue() + // A new SYN was received during TIME_WAIT and we need to abort + // the timewait and redirect the segment to the listener queue + if reuseTW != nil { + reuseTW() + } + return nil } + +// handleTimeWaitSegments processes segments received during TIME_WAIT +// state. +func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()) { + checkRequeue := true + for i := 0; i < maxSegmentsPerWake; i++ { + s := e.segmentQueue.dequeue() + if s == nil { + checkRequeue = false + break + } + extTW, newSyn := e.rcv.handleTimeWaitSegment(s) + if newSyn { + info := e.EndpointInfo.TransportEndpointInfo + newID := info.ID + newID.RemoteAddress = "" + newID.RemotePort = 0 + netProtos := []tcpip.NetworkProtocolNumber{info.NetProto} + // If the local address is an IPv4 address then also + // look for IPv6 dual stack endpoints that might be + // listening on the local address. + if newID.LocalAddress.To4() != "" { + netProtos = []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber} + } + for _, netProto := range netProtos { + if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, &s.route); listenEP != nil { + tcpEP := listenEP.(*endpoint) + if EndpointState(tcpEP.State()) == StateListen { + reuseTW = func() { + tcpEP.enqueueSegment(s) + } + // We explicitly do not decRef + // the segment as it's still + // valid and being reflected to + // a listening endpoint. + return false, reuseTW + } + } + } + } + if extTW { + extendTimeWait = true + } + s.decRef() + } + if checkRequeue && !e.segmentQueue.empty() { + e.newSegmentWaker.Assert() + } + return extendTimeWait, nil +} + +// doTimeWait is responsible for handling the TCP behaviour once a socket +// enters the TIME_WAIT state. Optionally it can return a closure that +// should be executed after releasing the endpoint registrations. This is +// done in cases where a new SYN is received during TIME_WAIT that carries +// a sequence number larger than one see on the connection. +func (e *endpoint) doTimeWait() (twReuse func()) { + // Trigger a 2 * MSL time wait state. During this period + // we will drop all incoming segments. + // NOTE: On Linux this is not configurable and is fixed at 60 seconds. + timeWaitDuration := DefaultTCPTimeWaitTimeout + + // Get the stack wide configuration. + var tcpTW tcpip.TCPTimeWaitTimeoutOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &tcpTW); err == nil { + timeWaitDuration = time.Duration(tcpTW) + } + + const newSegment = 1 + const notification = 2 + const timeWaitDone = 3 + + s := sleep.Sleeper{} + s.AddWaker(&e.newSegmentWaker, newSegment) + s.AddWaker(&e.notificationWaker, notification) + + var timeWaitWaker sleep.Waker + s.AddWaker(&timeWaitWaker, timeWaitDone) + timeWaitTimer := time.AfterFunc(timeWaitDuration, timeWaitWaker.Assert) + defer timeWaitTimer.Stop() + + for { + e.workMu.Unlock() + v, _ := s.Fetch(true) + e.workMu.Lock() + switch v { + case newSegment: + extendTimeWait, reuseTW := e.handleTimeWaitSegments() + if reuseTW != nil { + return reuseTW + } + if extendTimeWait { + timeWaitTimer.Reset(timeWaitDuration) + } + case notification: + n := e.fetchNotifications() + if n¬ifyClose != 0 { + return nil + } + if n¬ifyDrain != 0 { + for !e.segmentQueue.empty() { + // Ignore extending TIME_WAIT during a + // save. For sockets in TIME_WAIT we just + // terminate the TIME_WAIT early. + e.handleTimeWaitSegments() + } + close(e.drainDone) + <-e.undrain + return nil + } + case timeWaitDone: + return nil + } + } +} diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index c54610a87..dfaa4a559 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -42,7 +42,7 @@ func TestV4MappedConnectOnV6Only(t *testing.T) { } } -func testV4Connect(t *testing.T, c *context.Context) { +func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) { // Start connection attempt. we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventOut) @@ -55,12 +55,11 @@ func testV4Connect(t *testing.T, c *context.Context) { // Receive SYN packet. b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) + synCheckers := append(checkers, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + )) + checker.IPv4(t, b, synCheckers...) tcp := header.TCP(header.IPv4(b).Payload()) c.IRS = seqnum.Value(tcp.SequenceNumber()) @@ -76,14 +75,13 @@ func testV4Connect(t *testing.T, c *context.Context) { }) // Receive ACK packet. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), - ), - ) + ackCheckers := append(checkers, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+1), + )) + checker.IPv4(t, c.GetPacket(), ackCheckers...) // Wait for connection to be established. select { @@ -152,7 +150,7 @@ func TestV4ConnectWhenBoundToV4Mapped(t *testing.T) { testV4Connect(t, c) } -func testV6Connect(t *testing.T, c *context.Context) { +func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) { // Start connection attempt to IPv6 address. we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventOut) @@ -165,12 +163,11 @@ func testV6Connect(t *testing.T, c *context.Context) { // Receive SYN packet. b := c.GetV6Packet() - checker.IPv6(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) + synCheckers := append(checkers, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + )) + checker.IPv6(t, b, synCheckers...) tcp := header.TCP(header.IPv6(b).Payload()) c.IRS = seqnum.Value(tcp.SequenceNumber()) @@ -186,14 +183,13 @@ func testV6Connect(t *testing.T, c *context.Context) { }) // Receive ACK packet. - checker.IPv6(t, c.GetV6Packet(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), - ), - ) + ackCheckers := append(checkers, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+1), + )) + checker.IPv6(t, c.GetV6Packet(), ackCheckers...) // Wait for connection to be established. select { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 35b489c68..04c92c04c 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -15,6 +15,7 @@ package tcp import ( + "encoding/binary" "fmt" "math" "strings" @@ -26,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/seqnum" @@ -119,6 +121,11 @@ const ( notifyReset notifyKeepaliveChanged notifyMSSChanged + // notifyTickleWorker is used to tickle the protocol main loop during a + // restore after we update the endpoint state to the correct one. This + // ensures the loop terminates if the final state of the endpoint is + // say TIME_WAIT. + notifyTickleWorker ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -170,6 +177,101 @@ type rcvBufAutoTuneParams struct { disabled bool } +// ReceiveErrors collect segment receive errors within transport layer. +type ReceiveErrors struct { + tcpip.ReceiveErrors + + // SegmentQueueDropped is the number of segments dropped due to + // a full segment queue. + SegmentQueueDropped tcpip.StatCounter + + // ChecksumErrors is the number of segments dropped due to bad checksums. + ChecksumErrors tcpip.StatCounter + + // ListenOverflowSynDrop is the number of times the listen queue overflowed + // and a SYN was dropped. + ListenOverflowSynDrop tcpip.StatCounter + + // ListenOverflowAckDrop is the number of times the final ACK + // in the handshake was dropped due to overflow. + ListenOverflowAckDrop tcpip.StatCounter + + // ZeroRcvWindowState is the number of times we advertised + // a zero receive window when rcvList is full. + ZeroRcvWindowState tcpip.StatCounter +} + +// SendErrors collect segment send errors within the transport layer. +type SendErrors struct { + tcpip.SendErrors + + // SegmentSendToNetworkFailed is the number of TCP segments failed to be sent + // to the network endpoint. + SegmentSendToNetworkFailed tcpip.StatCounter + + // SynSendToNetworkFailed is the number of TCP SYNs failed to be sent + // to the network endpoint. + SynSendToNetworkFailed tcpip.StatCounter + + // Retransmits is the number of TCP segments retransmitted. + Retransmits tcpip.StatCounter + + // FastRetransmit is the number of segments retransmitted in fast + // recovery. + FastRetransmit tcpip.StatCounter + + // Timeouts is the number of times the RTO expired. + Timeouts tcpip.StatCounter +} + +// Stats holds statistics about the endpoint. +type Stats struct { + // SegmentsReceived is the number of TCP segments received that + // the transport layer successfully parsed. + SegmentsReceived tcpip.StatCounter + + // SegmentsSent is the number of TCP segments sent. + SegmentsSent tcpip.StatCounter + + // FailedConnectionAttempts is the number of times we saw Connect and + // Accept errors. + FailedConnectionAttempts tcpip.StatCounter + + // ReceiveErrors collects segment receive errors within the + // transport layer. + ReceiveErrors ReceiveErrors + + // ReadErrors collects segment read errors from an endpoint read call. + ReadErrors tcpip.ReadErrors + + // SendErrors collects segment send errors within the transport layer. + SendErrors SendErrors + + // WriteErrors collects segment write errors from an endpoint write call. + WriteErrors tcpip.WriteErrors +} + +// IsEndpointStats is an empty method to implement the tcpip.EndpointStats +// marker interface. +func (*Stats) IsEndpointStats() {} + +// EndpointInfo holds useful information about a transport endpoint which +// can be queried by monitoring tools. +// +// +stateify savable +type EndpointInfo struct { + stack.TransportEndpointInfo + + // HardError is meaningful only when state is stateError. It stores the + // error to be returned when read/write syscalls are called and the + // endpoint is in this state. HardError is protected by endpoint mu. + HardError *tcpip.Error `state:".(string)"` +} + +// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo +// marker interface. +func (*EndpointInfo) IsEndpointInfo() {} + // endpoint represents a TCP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -178,6 +280,8 @@ type rcvBufAutoTuneParams struct { // // +stateify savable type endpoint struct { + EndpointInfo + // workMu is used to arbitrate which goroutine may perform protocol // work. Only the main protocol goroutine is expected to call Lock() on // it, but other goroutines (e.g., send) may call TryLock() to eagerly @@ -186,9 +290,9 @@ type endpoint struct { // The following fields are initialized at creation time and do not // change throughout the lifetime of the endpoint. - stack *stack.Stack `state:"manual"` - netProto tcpip.NetworkProtocolNumber + stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue `state:"wait"` + uniqueID uint64 // lastError represents the last error that the endpoint reported; // access to it is protected by the following mutex. @@ -218,14 +322,19 @@ type endpoint struct { // The following fields are protected by the mutex. mu sync.RWMutex `state:"nosave"` - id stack.TransportEndpointID state EndpointState `state:".(EndpointState)"` + // origEndpointState is only used during a restore phase to save the + // endpoint state at restore time as the socket is moved to it's correct + // state. + origEndpointState EndpointState `state:"nosave"` + isPortReserved bool `state:"manual"` isRegistered bool boundNICID tcpip.NICID `state:"manual"` route stack.Route `state:"manual"` + ttl uint8 v6only bool isConnectNotified bool // TCP should never broadcast but Linux nevertheless supports enabling/ @@ -240,11 +349,6 @@ type endpoint struct { // address). effectiveNetProtos []tcpip.NetworkProtocolNumber `state:"manual"` - // hardError is meaningful only when state is stateError, it stores the - // error to be returned when read/write syscalls are called and the - // endpoint is in this state. hardError is protected by mu. - hardError *tcpip.Error `state:".(string)"` - // workerRunning specifies if a worker goroutine is running. workerRunning bool @@ -280,6 +384,9 @@ type endpoint struct { // reusePort is set to true if SO_REUSEPORT is enabled. reusePort bool + // bindToDevice is set to the NIC on which to bind or disabled if 0. + bindToDevice tcpip.NICID + // delay enables Nagle's algorithm. // // delay is a boolean (0 is false) and must be accessed atomically. @@ -315,7 +422,7 @@ type endpoint struct { // userMSS if non-zero is the MSS value explicitly set by the user // for this endpoint using the TCP_MAXSEG setsockopt. - userMSS int + userMSS uint16 // The following fields are used to manage the send buffer. When // segments are ready to be sent, they are added to sndQueue and the @@ -393,13 +500,49 @@ type endpoint struct { probe stack.TCPProbeFunc `state:"nosave"` // The following are only used to assist the restore run to re-connect. - bindAddress tcpip.Address connectingAddress tcpip.Address // amss is the advertised MSS to the peer by this endpoint. amss uint16 + // sendTOS represents IPv4 TOS or IPv6 TrafficClass, + // applied while sending packets. Defaults to 0 as on Linux. + sendTOS uint8 + gso *stack.GSO + + // TODO(b/142022063): Add ability to save and restore per endpoint stats. + stats Stats `state:"nosave"` + + // tcpLingerTimeout is the maximum amount of a time a socket + // a socket stays in TIME_WAIT state before being marked + // closed. + tcpLingerTimeout time.Duration + + // closed indicates that the user has called closed on the + // endpoint and at this point the endpoint is only around + // to complete the TCP shutdown. + closed bool +} + +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + +// calculateAdvertisedMSS calculates the MSS to advertise. +// +// If userMSS is non-zero and is not greater than the maximum possible MSS for +// r, it will be used; otherwise, the maximum possible MSS will be used. +func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 { + // The maximum possible MSS is dependent on the route. + maxMSS := mssForRoute(&r) + + if userMSS != 0 && userMSS < maxMSS { + return userMSS + } + + return maxMSS } // StopWork halts packet processing. Only to be used in tests. @@ -427,10 +570,15 @@ type keepalive struct { waker sleep.Waker `state:"nosave"` } -func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { e := &endpoint{ - stack: stack, - netProto: netProto, + stack: s, + EndpointInfo: EndpointInfo{ + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: header.TCPProtocolNumber, + }, + }, waiterQueue: waiterQueue, state: StateInitial, rcvBufSize: DefaultReceiveBufferSize, @@ -443,29 +591,40 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite interval: 75 * time.Second, count: 9, }, + uniqueID: s.UniqueID(), } var ss SendBufferSizeOption - if err := stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { e.sndBufSize = ss.Default } var rs ReceiveBufferSizeOption - if err := stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { e.rcvBufSize = rs.Default } var cs tcpip.CongestionControlOption - if err := stack.TransportProtocolOption(ProtocolNumber, &cs); err == nil { + if err := s.TransportProtocolOption(ProtocolNumber, &cs); err == nil { e.cc = cs } var mrb tcpip.ModerateReceiveBufferOption - if err := stack.TransportProtocolOption(ProtocolNumber, &mrb); err == nil { + if err := s.TransportProtocolOption(ProtocolNumber, &mrb); err == nil { e.rcvAutoParams.disabled = !bool(mrb) } - if p := stack.GetTCPProbe(); p != nil { + var de DelayEnabled + if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de { + e.SetSockOptInt(tcpip.DelayOption, 1) + } + + var tcpLT tcpip.TCPLingerTimeoutOption + if err := s.TransportProtocolOption(ProtocolNumber, &tcpLT); err == nil { + e.tcpLingerTimeout = time.Duration(tcpLT) + } + + if p := s.GetTCPProbe(); p != nil { e.probe = p } @@ -552,6 +711,13 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) { // with it. It must be called only once and with no other concurrent calls to // the endpoint. func (e *endpoint) Close() { + e.mu.Lock() + closed := e.closed + e.mu.Unlock() + if closed { + return + } + // Issue a shutdown so that the peer knows we won't send any more data // if we're connected, or stop accepting if we're listening. e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) @@ -564,14 +730,16 @@ func (e *endpoint) Close() { // in Listen() when trying to register. if e.state == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) e.isRegistered = false } - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) e.isPortReserved = false } + // Mark endpoint as closed. + e.closed = true // Either perform the local cleanup or kick the worker to make sure it // knows it needs to cleanup. tcpip.AddDanglingEndpoint(e) @@ -597,9 +765,7 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { go func() { defer close(done) for n := range e.acceptedChan { - n.mu.Lock() - n.resetConnectionLocked(tcpip.ErrConnectionAborted) - n.mu.Unlock() + n.notifyProtocolGoroutine(notifyReset) n.Close() } }() @@ -625,16 +791,17 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) e.isRegistered = false } if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) e.isPortReserved = false } e.route.Release() + e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } @@ -645,7 +812,9 @@ func (e *endpoint) initialReceiveWindow() int { if rcvWnd > math.MaxUint16 { rcvWnd = math.MaxUint16 } - routeWnd := InitialCwnd * int(mssForRoute(&e.route)) * 2 + + // Use the user supplied MSS, if available. + routeWnd := InitialCwnd * int(calculateAdvertisedMSS(e.userMSS, e.route)) * 2 if rcvWnd > routeWnd { rcvWnd = routeWnd } @@ -731,11 +900,12 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, bufUsed := e.rcvBufUsed if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() - he := e.hardError + he := e.HardError e.mu.RUnlock() if s == StateError { return buffer.View{}, tcpip.ControlMessages{}, he } + e.stats.ReadErrors.InvalidEndpointState.Increment() return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState } @@ -744,6 +914,9 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, e.mu.RUnlock() + if err == tcpip.ErrClosedForReceive { + e.stats.ReadErrors.ReadClosed.Increment() + } return v, tcpip.ControlMessages{}, err } @@ -787,7 +960,7 @@ func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { if !e.state.connected() { switch e.state { case StateError: - return 0, e.hardError + return 0, e.HardError default: return 0, tcpip.ErrClosedForSend } @@ -818,6 +991,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c if err != nil { e.sndBufMu.Unlock() e.mu.RUnlock() + e.stats.WriteErrors.WriteClosed.Increment() return 0, nil, err } @@ -852,6 +1026,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c if err != nil { e.sndBufMu.Unlock() e.mu.RUnlock() + e.stats.WriteErrors.WriteClosed.Increment() return 0, nil, err } @@ -863,7 +1038,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } // Add data to the send queue. - s := newSegmentFromView(&e.route, e.id, v) + s := newSegmentFromView(&e.route, e.ID, v) e.sndBufUsed += len(v) e.sndBufInQueue += seqnum.Size(len(v)) e.sndQueue.PushBack(s) @@ -895,8 +1070,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // but has some pending unread data. if s := e.state; !s.connected() && s != StateClose { if s == StateError { - return 0, tcpip.ControlMessages{}, e.hardError + return 0, tcpip.ControlMessages{}, e.HardError } + e.stats.ReadErrors.InvalidEndpointState.Increment() return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState } @@ -905,6 +1081,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro if e.rcvBufUsed == 0 { if e.rcvClosed || !e.state.connected() { + e.stats.ReadErrors.ReadClosed.Increment() return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive } return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock @@ -1018,14 +1195,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { e.sndBufMu.Unlock() return nil - default: - return nil - } -} - -// SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch v := opt.(type) { case tcpip.DelayOption: if v == 0 { atomic.StoreUint32(&e.delay, 0) @@ -1037,6 +1206,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } return nil + default: + return nil + } +} + +// SetSockOpt sets a socket option. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 + const inetECNMask = 3 + switch v := opt.(type) { case tcpip.CorkOption: if v == 0 { atomic.StoreUint32(&e.cork, 0) @@ -1060,6 +1239,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case tcpip.BindToDeviceOption: + e.mu.Lock() + defer e.mu.Unlock() + if v == "" { + e.bindToDevice = 0 + return nil + } + for nicID, nic := range e.stack.NICInfo() { + if nic.Name == string(v) { + e.bindToDevice = nicID + return nil + } + } + return tcpip.ErrUnknownDevice + case tcpip.QuickAckOption: if v == 0 { atomic.StoreUint32(&e.slowAck, 1) @@ -1074,14 +1268,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrInvalidOptionValue } e.mu.Lock() - e.userMSS = int(userMSS) + e.userMSS = uint16(userMSS) e.mu.Unlock() e.notifyProtocolGoroutine(notifyMSSChanged) return nil case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. - if e.netProto != header.IPv6ProtocolNumber { + if e.NetProto != header.IPv6ProtocolNumber { return tcpip.ErrInvalidEndpointState } @@ -1096,6 +1290,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.v6only = v != 0 return nil + case tcpip.TTLOption: + e.mu.Lock() + e.ttl = uint8(v) + e.mu.Unlock() + return nil + case tcpip.KeepaliveEnabledOption: e.keepalive.Lock() e.keepalive.enabled = v != 0 @@ -1164,6 +1364,45 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // Linux returns ENOENT when an invalid congestion // control algorithm is specified. return tcpip.ErrNoSuchFile + + case tcpip.IPv4TOSOption: + e.mu.Lock() + // TODO(gvisor.dev/issue/995): ECN is not currently supported, + // ignore the bits for now. + e.sendTOS = uint8(v) & ^uint8(inetECNMask) + e.mu.Unlock() + return nil + + case tcpip.IPv6TrafficClassOption: + e.mu.Lock() + // TODO(gvisor.dev/issue/995): ECN is not currently supported, + // ignore the bits for now. + e.sendTOS = uint8(v) & ^uint8(inetECNMask) + e.mu.Unlock() + return nil + + case tcpip.TCPLingerTimeoutOption: + e.mu.Lock() + if v < 0 { + // Same as effectively disabling TCPLinger timeout. + v = 0 + } + var stkTCPLingerTimeout tcpip.TCPLingerTimeoutOption + if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &stkTCPLingerTimeout); err != nil { + // We were unable to retrieve a stack config, just use + // the DefaultTCPLingerTimeout. + if v > tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) { + stkTCPLingerTimeout = tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) + } + } + // Cap it to the stack wide TCPLinger timeout. + if v > stkTCPLingerTimeout { + v = stkTCPLingerTimeout + } + e.tcpLingerTimeout = time.Duration(v) + e.mu.Unlock() + return nil + default: return nil } @@ -1190,6 +1429,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() + case tcpip.SendBufferSizeOption: e.sndBufMu.Lock() v := e.sndBufSize @@ -1202,8 +1442,16 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { e.rcvListMu.Unlock() return v, nil + case tcpip.DelayOption: + var o int + if v := atomic.LoadUint32(&e.delay); v != 0 { + o = 1 + } + return o, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption } - return -1, tcpip.ErrUnknownProtocolOption } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. @@ -1224,13 +1472,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = header.TCPDefaultMSS return nil - case *tcpip.DelayOption: - *o = 0 - if v := atomic.LoadUint32(&e.delay); v != 0 { - *o = 1 - } - return nil - case *tcpip.CorkOption: *o = 0 if v := atomic.LoadUint32(&e.cork); v != 0 { @@ -1260,6 +1501,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.BindToDeviceOption: + e.mu.RLock() + defer e.mu.RUnlock() + if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok { + *o = tcpip.BindToDeviceOption(nic.Name) + return nil + } + *o = "" + return nil + case *tcpip.QuickAckOption: *o = 1 if v := atomic.LoadUint32(&e.slowAck); v != 0 { @@ -1269,7 +1520,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case *tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. - if e.netProto != header.IPv6ProtocolNumber { + if e.NetProto != header.IPv6ProtocolNumber { return tcpip.ErrUnknownProtocolOption } @@ -1283,6 +1534,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.TTLOption: + e.mu.Lock() + *o = tcpip.TTLOption(e.ttl) + e.mu.Unlock() + return nil + case *tcpip.TCPInfoOption: *o = tcpip.TCPInfoOption{} e.mu.RLock() @@ -1347,13 +1604,31 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case *tcpip.IPv4TOSOption: + e.mu.RLock() + *o = tcpip.IPv4TOSOption(e.sendTOS) + e.mu.RUnlock() + return nil + + case *tcpip.IPv6TrafficClassOption: + e.mu.RLock() + *o = tcpip.IPv6TrafficClassOption(e.sendTOS) + e.mu.RUnlock() + return nil + + case *tcpip.TCPLingerTimeoutOption: + e.mu.Lock() + *o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout) + e.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.netProto + netProto := e.NetProto if header.IsV4MappedAddress(addr.Addr) { // Fail if using a v4 mapped address on a v6only endpoint. if e.v6only { @@ -1369,7 +1644,7 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol // Fail if we're bound to an address length different from the one we're // checking. - if l := len(e.id.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) { + if l := len(e.ID.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) { return 0, tcpip.ErrInvalidEndpointState } @@ -1383,7 +1658,12 @@ func (*endpoint) Disconnect() *tcpip.Error { // Connect connects the endpoint to its peer. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - return e.connect(addr, true, true) + err := e.connect(addr, true, true) + if err != nil && !err.IgnoreStats() { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + } + return err } // connect connects the endpoint to its peer. In the normal non-S/R case, the @@ -1392,14 +1672,9 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // created (so no new handshaking is done); for stack-accepted connections not // yet accepted by the app, they are restored without running the main goroutine // here. -func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (err *tcpip.Error) { +func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - defer func() { - if err != nil && !err.IgnoreStats() { - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - } - }() connectingAddr := addr.Addr @@ -1419,7 +1694,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er return tcpip.ErrAlreadyConnected } - nicid := addr.NIC + nicID := addr.NIC switch e.state { case StateBound: // If we're already bound to a NIC but the caller is requesting @@ -1428,11 +1703,11 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er break } - if nicid != 0 && nicid != e.boundNICID { + if nicID != 0 && nicID != e.boundNICID { return tcpip.ErrNoRoute } - nicid = e.boundNICID + nicID = e.boundNICID case StateInitial: // Nothing to do. We'll eventually fill-in the gaps in the ID (if any) @@ -1444,29 +1719,29 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er return tcpip.ErrAlreadyConnecting case StateError: - return e.hardError + return e.HardError default: return tcpip.ErrInvalidEndpointState } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicID, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) if err != nil { return err } defer r.Release() - origID := e.id + origID := e.ID netProtos := []tcpip.NetworkProtocolNumber{netProto} - e.id.LocalAddress = r.LocalAddress - e.id.RemoteAddress = r.RemoteAddress - e.id.RemotePort = addr.Port + e.ID.LocalAddress = r.LocalAddress + e.ID.RemoteAddress = r.RemoteAddress + e.ID.RemotePort = addr.Port - if e.id.LocalPort != 0 { + if e.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice) if err != nil { return err } @@ -1475,20 +1750,35 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er // one. Make sure that it isn't one that will result in the same // address/port for both local and remote (otherwise this // endpoint would be trying to connect to itself). - sameAddr := e.id.LocalAddress == e.id.RemoteAddress - if _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { - if sameAddr && p == e.id.RemotePort { + sameAddr := e.ID.LocalAddress == e.ID.RemoteAddress + + // Calculate a port offset based on the destination IP/port and + // src IP to ensure that for a given tuple (srcIP, destIP, + // destPort) the offset used as a starting point is the same to + // ensure that we can cycle through the port space effectively. + h := jenkins.Sum32(e.stack.Seed()) + h.Write([]byte(e.ID.LocalAddress)) + h.Write([]byte(e.ID.RemoteAddress)) + portBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort) + h.Write(portBuf) + portOffset := h.Sum32() + + if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) { + if sameAddr && p == e.ID.RemotePort { return false, nil } - if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false) { + // reusePort is false below because connect cannot reuse a port even if + // reusePort was set. + if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, false /* reusePort */, e.bindToDevice) { return false, nil } - id := e.id + id := e.ID id.LocalPort = p - switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) { + switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) { case nil: - e.id = id + e.ID = id return true, nil case tcpip.ErrPortInUse: return false, nil @@ -1504,14 +1794,14 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er // before Connect: in such a case we don't want to hold on to // reservations anymore. if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.bindToDevice) e.isPortReserved = false } e.isRegistered = true e.state = StateConnecting e.route = r.Clone() - e.boundNICID = nicid + e.boundNICID = nicID e.effectiveNetProtos = netProtos e.connectingAddress = connectingAddr @@ -1523,7 +1813,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er e.segmentQueue.mu.Lock() for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} { for s := l.Front(); s != nil; s = s.Next() { - s.id = e.id + s.id = e.ID s.route = r.Clone() e.sndWaker.Assert() } @@ -1531,6 +1821,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er e.segmentQueue.mu.Unlock() e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0) e.state = StateEstablished + e.stack.Stats().TCP.CurrentEstablished.Increment() } if run { @@ -1551,9 +1842,8 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { // peer. func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { e.mu.Lock() - defer e.mu.Unlock() e.shutdownFlags |= flags - + finQueued := false switch { case e.state.connected(): // Close for read. @@ -1568,6 +1858,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // the connection with a RST. if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 { e.notifyProtocolGoroutine(notifyReset) + e.mu.Unlock() return nil } } @@ -1583,17 +1874,14 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { } // Queue fin segment. - s := newSegmentFromView(&e.route, e.id, nil) + s := newSegmentFromView(&e.route, e.ID, nil) e.sndQueue.PushBack(s) e.sndBufInQueue++ - + finQueued = true // Mark endpoint as closed. e.sndClosed = true e.sndBufMu.Unlock() - - // Tell protocol goroutine to close. - e.sndCloseWaker.Assert() } case e.state == StateListen: @@ -1601,24 +1889,37 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { if flags&tcpip.ShutdownRead != 0 { e.notifyProtocolGoroutine(notifyClose) } - default: + e.mu.Unlock() return tcpip.ErrNotConnected } - + e.mu.Unlock() + if finQueued { + if e.workMu.TryLock() { + e.handleClose() + e.workMu.Unlock() + } else { + // Tell protocol goroutine to close. + e.sndCloseWaker.Assert() + } + } return nil } // Listen puts the endpoint in "listen" mode, which allows it to accept // new connections. -func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { +func (e *endpoint) Listen(backlog int) *tcpip.Error { + err := e.listen(backlog) + if err != nil && !err.IgnoreStats() { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + } + return err +} + +func (e *endpoint) listen(backlog int) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - defer func() { - if err != nil && !err.IgnoreStats() { - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - } - }() // Allow the backlog to be adjusted if the endpoint is not shutting down. // When the endpoint shuts down, it sets workerCleanup to true, and from @@ -1644,11 +1945,12 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { // Endpoint must be bound before it can transition to listen mode. if e.state != StateBound { + e.stats.ReadErrors.InvalidEndpointState.Increment() return tcpip.ErrInvalidEndpointState } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice); err != nil { return err } @@ -1692,12 +1994,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrWouldBlock } - // Start the protocol goroutine. - wq := &waiter.Queue{} - n.startAcceptedLoop(wq) - e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - - return n, wq, nil + return n, n.waiterQueue, nil } // Bind binds the endpoint to a specific local port and optionally address. @@ -1712,7 +2009,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { return tcpip.ErrAlreadyBound } - e.bindAddress = addr.Addr + e.BindAddr = addr.Addr netProto, err := e.checkV4Mapped(&addr) if err != nil { return err @@ -1729,26 +2026,26 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { } } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort, e.bindToDevice) if err != nil { return err } e.isPortReserved = true e.effectiveNetProtos = netProtos - e.id.LocalPort = port + e.ID.LocalPort = port // Any failures beyond this point must remove the port registration. - defer func() { + defer func(bindToDevice tcpip.NICID) { if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port) + e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, bindToDevice) e.isPortReserved = false e.effectiveNetProtos = nil - e.id.LocalPort = 0 - e.id.LocalAddress = "" + e.ID.LocalPort = 0 + e.ID.LocalAddress = "" e.boundNICID = 0 } - }() + }(e.bindToDevice) // If an address is specified, we must ensure that it's one of our // local addresses. @@ -1759,7 +2056,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { } e.boundNICID = nic - e.id.LocalAddress = addr.Addr + e.ID.LocalAddress = addr.Addr } // Mark endpoint as bound. @@ -1774,8 +2071,8 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { defer e.mu.RUnlock() return tcpip.FullAddress{ - Addr: e.id.LocalAddress, - Port: e.id.LocalPort, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, NIC: e.boundNICID, }, nil } @@ -1790,19 +2087,20 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { } return tcpip.FullAddress{ - Addr: e.id.RemoteAddress, - Port: e.id.RemotePort, + Addr: e.ID.RemoteAddress, + Port: e.ID.RemotePort, NIC: e.boundNICID, }, nil } // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { - s := newSegment(r, id, vv) +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { + s := newSegment(r, id, pkt) if !s.parse() { e.stack.Stats().MalformedRcvdPackets.Increment() e.stack.Stats().TCP.InvalidSegmentsReceived.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() s.decRef() return } @@ -1810,27 +2108,34 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv if !s.csumValid { e.stack.Stats().MalformedRcvdPackets.Increment() e.stack.Stats().TCP.ChecksumErrors.Increment() + e.stats.ReceiveErrors.ChecksumErrors.Increment() s.decRef() return } e.stack.Stats().TCP.ValidSegmentsReceived.Increment() + e.stats.SegmentsReceived.Increment() if (s.flags & header.TCPFlagRst) != 0 { e.stack.Stats().TCP.ResetsReceived.Increment() } + e.enqueueSegment(s) +} + +func (e *endpoint) enqueueSegment(s *segment) { // Send packet to worker goroutine. if e.segmentQueue.enqueue(s) { e.newSegmentWaker.Assert() } else { // The queue is full, so we drop the segment. e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.SegmentQueueDropped.Increment() s.decRef() } } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { switch typ { case stack.ControlPacketTooBig: e.sndBufMu.Lock() @@ -1874,6 +2179,7 @@ func (e *endpoint) readyToRead(s *segment) { // that a subsequent read of the segment will correctly trigger // a non-zero notification. if avail := e.receiveBufferAvailableLocked(); avail>>e.rcv.rcvWndScale == 0 { + e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() e.zeroWindow = true } e.rcvList.PushBack(s) @@ -2026,7 +2332,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { // Copy EndpointID. e.mu.Lock() - s.ID = stack.TCPEndpointID(e.id) + s.ID = stack.TCPEndpointID(e.ID) e.mu.Unlock() // Copy endpoint rcv state. @@ -2119,11 +2425,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { return s } -func (e *endpoint) initGSO() { - if e.route.Capabilities()&stack.CapabilityGSO == 0 { - return - } - +func (e *endpoint) initHardwareGSO() { gso := &stack.GSO{} switch e.route.NetProto { case header.IPv4ProtocolNumber: @@ -2133,7 +2435,7 @@ func (e *endpoint) initGSO() { gso.Type = stack.GSOTCPv6 gso.L3HdrLen = header.IPv6MinimumSize default: - panic(fmt.Sprintf("Unknown netProto: %v", e.netProto)) + panic(fmt.Sprintf("Unknown netProto: %v", e.NetProto)) } gso.NeedsCsum = true gso.CsumOffset = header.TCPChecksumOffset @@ -2141,6 +2443,18 @@ func (e *endpoint) initGSO() { e.gso = gso } +func (e *endpoint) initGSO() { + if e.route.Capabilities()&stack.CapabilityHardwareGSO != 0 { + e.initHardwareGSO() + } else if e.route.Capabilities()&stack.CapabilitySoftwareGSO != 0 { + e.gso = &stack.GSO{ + MaxSize: e.route.GSOMaxSize(), + Type: stack.GSOSW, + NeedsCsum: false, + } + } +} + // State implements tcpip.Endpoint.State. It exports the endpoint's protocol // state for diagnostics. func (e *endpoint) State() uint32 { @@ -2149,6 +2463,37 @@ func (e *endpoint) State() uint32 { return uint32(e.state) } +// Info returns a copy of the endpoint info. +func (e *endpoint) Info() tcpip.EndpointInfo { + e.mu.RLock() + // Make a copy of the endpoint info. + ret := e.EndpointInfo + e.mu.RUnlock() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (e *endpoint) Stats() tcpip.EndpointStats { + return &e.stats +} + +// Wait implements stack.TransportEndpoint.Wait. +func (e *endpoint) Wait() { + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + e.waiterQueue.EventRegister(&waitEntry, waiter.EventHUp) + defer e.waiterQueue.EventUnregister(&waitEntry) + for { + e.mu.Lock() + running := e.workerRunning + e.mu.Unlock() + if !running { + break + } + <-notifyCh + } +} + func mssForRoute(r *stack.Route) uint16 { + // TODO(b/143359391): Respect TCP Min and Max size. return uint16(r.MTU() - header.TCPMinimumSize) } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 831389ec7..7aa4c3f0e 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -55,7 +55,7 @@ func (e *endpoint) beforeSave() { case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 { if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 { - panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)}) + panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)}) } e.resetConnectionLocked(tcpip.ErrConnectionAborted) e.mu.Unlock() @@ -78,7 +78,7 @@ func (e *endpoint) beforeSave() { } fallthrough case StateError, StateClose: - for e.state == StateError && e.workerRunning { + for (e.state == StateError || e.state == StateClose) && e.workerRunning { e.mu.Unlock() time.Sleep(100 * time.Millisecond) e.mu.Lock() @@ -165,6 +165,12 @@ func (e *endpoint) loadState(state EndpointState) { // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { + // Freeze segment queue before registering to prevent any segments + // from being delivered while it is being restored. + e.origEndpointState = e.state + // Restore the endpoint to InitialState as it will be moved to + // its origEndpointState during Resume. + e.state = StateInitial stack.StackFromEnv.RegisterRestoredEndpoint(e) } @@ -173,8 +179,8 @@ func (e *endpoint) Resume(s *stack.Stack) { e.stack = s e.segmentQueue.setLimit(MaxUnprocessedSegments) e.workMu.Init() + state := e.origEndpointState - state := e.state switch state { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss SendBufferSizeOption @@ -189,12 +195,13 @@ func (e *endpoint) Resume(s *stack.Stack) { } bind := func() { - e.state = StateInitial - if len(e.bindAddress) == 0 { - e.bindAddress = e.id.LocalAddress + if len(e.BindAddr) == 0 { + e.BindAddr = e.ID.LocalAddress } - if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}); err != nil { - panic("endpoint binding failed: " + err.String()) + addr := e.BindAddr + port := e.ID.LocalPort + if err := e.Bind(tcpip.FullAddress{Addr: addr, Port: port}); err != nil { + panic(fmt.Sprintf("endpoint binding [%v]:%d failed: %v", addr, port, err)) } } @@ -202,21 +209,31 @@ func (e *endpoint) Resume(s *stack.Stack) { case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: bind() if len(e.connectingAddress) == 0 { - e.connectingAddress = e.id.RemoteAddress + e.connectingAddress = e.ID.RemoteAddress // This endpoint is accepted by netstack but not yet by // the app. If the endpoint is IPv6 but the remote // address is IPv4, we need to connect as IPv6 so that // dual-stack mode can be properly activated. - if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize { - e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress + if e.NetProto == header.IPv6ProtocolNumber && len(e.ID.RemoteAddress) != header.IPv6AddressSize { + e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.ID.RemoteAddress } } // Reset the scoreboard to reinitialize the sack information as // we do not restore SACK information. e.scoreboard.Reset() - if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { + if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { panic("endpoint connecting failed: " + err.String()) } + e.mu.Lock() + e.state = e.origEndpointState + closed := e.closed + e.mu.Unlock() + e.notifyProtocolGoroutine(notifyTickleWorker) + if state == StateFinWait2 && closed { + // If the endpoint has been closed then make sure we notify so + // that the FIN_WAIT2 timer is started after a restore. + e.notifyProtocolGoroutine(notifyClose) + } connectedLoading.Done() case StateListen: tcpip.AsyncLoading.Add(1) @@ -236,7 +253,7 @@ func (e *endpoint) Resume(s *stack.Stack) { connectedLoading.Wait() listenLoading.Wait() bind() - if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted { + if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}); err != tcpip.ErrConnectStarted { panic("endpoint connecting failed: " + err.String()) } connectingLoading.Done() @@ -263,8 +280,12 @@ func (e *endpoint) Resume(s *stack.Stack) { tcpip.AsyncLoading.Done() }() } - fallthrough + e.state = StateClose + e.stack.CompleteTransportEndpointCleanup(e) + tcpip.DeleteDanglingEndpoint(e) case StateError: + e.state = StateError + e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } } @@ -288,21 +309,21 @@ func (e *endpoint) loadLastError(s string) { } // saveHardError is invoked by stateify. -func (e *endpoint) saveHardError() string { - if e.hardError == nil { +func (e *EndpointInfo) saveHardError() string { + if e.HardError == nil { return "" } - return e.hardError.String() + return e.HardError.String() } // loadHardError is invoked by stateify. -func (e *endpoint) loadHardError(s string) { +func (e *EndpointInfo) loadHardError(s string) { if s == "" { return } - e.hardError = loadError(s) + e.HardError = loadError(s) } var messageToError map[string]*tcpip.Error diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 63666f0b3..4983bca81 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -18,7 +18,6 @@ import ( "sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -63,8 +62,8 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { - s := newSegment(r, id, vv) +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { + s := newSegment(r, id, pkt) defer s.decRef() // We only care about well-formed SYN packets. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index d5d8ab96a..89b965c23 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -23,6 +23,7 @@ package tcp import ( "strings" "sync" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -54,12 +55,23 @@ const ( // MaxUnprocessedSegments is the maximum number of unprocessed segments // that can be queued for a given endpoint. MaxUnprocessedSegments = 300 + + // DefaultTCPLingerTimeout is the amount of time that sockets linger in + // FIN_WAIT_2 state before being marked closed. + DefaultTCPLingerTimeout = 60 * time.Second + + // DefaultTCPTimeWaitTimeout is the amount of time that sockets linger + // in TIME_WAIT state before being marked closed. + DefaultTCPTimeWaitTimeout = 60 * time.Second ) // SACKEnabled option can be used to enable SACK support in the TCP // protocol. See: https://tools.ietf.org/html/rfc2018. type SACKEnabled bool +// DelayEnabled option can be used to enable Nagle's algorithm in the TCP protocol. +type DelayEnabled bool + // SendBufferSizeOption allows the default, min and max send buffer sizes for // TCP endpoints to be queried or configured. type SendBufferSizeOption struct { @@ -84,11 +96,14 @@ const ( type protocol struct { mu sync.Mutex sackEnabled bool + delayEnabled bool sendBufferSize SendBufferSizeOption recvBufferSize ReceiveBufferSizeOption congestionControl string availableCongestionControl []string moderateReceiveBuffer bool + tcpLingerTimeout time.Duration + tcpTimeWaitTimeout time.Duration } // Number returns the tcp protocol number. @@ -97,7 +112,7 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { } // NewEndpoint creates a new tcp endpoint. -func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return newEndpoint(stack, netProto, waiterQueue), nil } @@ -126,8 +141,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { - s := newSegment(r, id, vv) +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { + s := newSegment(r, id, pkt) defer s.decRef() if !s.parse() || !s.csumValid { @@ -153,7 +168,7 @@ func replyWithReset(s *segment) { ack := s.sequenceNumber.Add(s.logicalLen()) - sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0, nil /* options */, nil /* gso */) + sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) } // SetOption implements TransportProtocol.SetOption. @@ -165,6 +180,12 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case DelayEnabled: + p.mu.Lock() + p.delayEnabled = bool(v) + p.mu.Unlock() + return nil + case SendBufferSizeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue @@ -202,6 +223,24 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case tcpip.TCPLingerTimeoutOption: + if v < 0 { + v = 0 + } + p.mu.Lock() + p.tcpLingerTimeout = time.Duration(v) + p.mu.Unlock() + return nil + + case tcpip.TCPTimeWaitTimeoutOption: + if v < 0 { + v = 0 + } + p.mu.Lock() + p.tcpTimeWaitTimeout = time.Duration(v) + p.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -216,6 +255,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case *DelayEnabled: + p.mu.Lock() + *v = DelayEnabled(p.delayEnabled) + p.mu.Unlock() + return nil + case *SendBufferSizeOption: p.mu.Lock() *v = p.sendBufferSize @@ -246,6 +291,18 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case *tcpip.TCPLingerTimeoutOption: + p.mu.Lock() + *v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout) + p.mu.Unlock() + return nil + + case *tcpip.TCPTimeWaitTimeoutOption: + p.mu.Lock() + *v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout) + p.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -258,5 +315,7 @@ func NewProtocol() stack.TransportProtocol { recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize}, congestionControl: ccReno, availableCongestionControl: []string{ccReno, ccCubic}, + tcpLingerTimeout: DefaultTCPLingerTimeout, + tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, } } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index e90f9a7d9..068b90fb6 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -18,6 +18,7 @@ import ( "container/heap" "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" ) @@ -209,6 +210,11 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum switch r.ep.state { case StateFinWait1: r.ep.state = StateFinWait2 + // Notify protocol goroutine that we have received an + // ACK to our FIN so that it can start the FIN_WAIT2 + // timer to abort connection if the other side does + // not close within 2MSL. + r.ep.notifyProtocolGoroutine(notifyClose) case StateClosing: r.ep.state = StateTimeWait case StateLastAck: @@ -253,23 +259,105 @@ func (r *receiver) updateRTT() { r.ep.rcvListMu.Unlock() } -// handleRcvdSegment handles TCP segments directed at the connection managed by -// r as they arrive. It is called by the protocol main loop. -func (r *receiver) handleRcvdSegment(s *segment) { +func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err *tcpip.Error) { + r.ep.rcvListMu.Lock() + rcvClosed := r.ep.rcvClosed || r.closed + r.ep.rcvListMu.Unlock() + + // If we are in one of the shutdown states then we need to do + // additional checks before we try and process the segment. + switch state { + case StateCloseWait, StateClosing, StateLastAck: + if !s.sequenceNumber.LessThanEq(r.rcvNxt) { + s.decRef() + // Just drop the segment as we have + // already received a FIN and this + // segment is after the sequence number + // for the FIN. + return true, nil + } + fallthrough + case StateFinWait1: + fallthrough + case StateFinWait2: + // If we are closed for reads (either due to an + // incoming FIN or the user calling shutdown(.., + // SHUT_RD) then any data past the rcvNxt should + // trigger a RST. + endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) + if rcvClosed && r.rcvNxt.LessThan(endDataSeq) { + s.decRef() + return true, tcpip.ErrConnectionAborted + } + if state == StateFinWait1 { + break + } + + // If it's a retransmission of an old data segment + // or a pure ACK then allow it. + if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.rcvNxt) || + s.logicalLen() == 0 { + break + } + + // In FIN-WAIT2 if the socket is fully + // closed(not owned by application on our end + // then the only acceptable segment is a + // FIN. Since FIN can technically also carry + // data we verify that the segment carrying a + // FIN ends at exactly e.rcvNxt+1. + // + // From RFC793 page 25. + // + // For sequence number purposes, the SYN is + // considered to occur before the first actual + // data octet of the segment in which it occurs, + // while the FIN is considered to occur after + // the last actual data octet in a segment in + // which it occurs. + if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) { + s.decRef() + return true, tcpip.ErrConnectionAborted + } + } + // We don't care about receive processing anymore if the receive side // is closed. - if r.closed { - return + // + // NOTE: We still want to permit a FIN as it's possible only our + // end has closed and the peer is yet to send a FIN. Hence we + // compare only the payload. + segEnd := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) + if rcvClosed && !segEnd.LessThanEq(r.rcvNxt) { + return true, nil + } + return false, nil +} + +// handleRcvdSegment handles TCP segments directed at the connection managed by +// r as they arrive. It is called by the protocol main loop. +func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { + r.ep.mu.RLock() + state := r.ep.state + closed := r.ep.closed + r.ep.mu.RUnlock() + + if state != StateEstablished { + drop, err := r.handleRcvdSegmentClosing(s, state, closed) + if drop || err != nil { + return drop, err + } } segLen := seqnum.Size(s.data.Size()) segSeq := s.sequenceNumber // If the sequence number range is outside the acceptable range, just - // send an ACK. This is according to RFC 793, page 37. + // send an ACK and stop further processing of the segment. + // This is according to RFC 793, page 68. if !r.acceptable(segSeq, segLen) { r.ep.snd.sendAck() - return + return true, nil } // Defer segment processing if it can't be consumed now. @@ -288,7 +376,7 @@ func (r *receiver) handleRcvdSegment(s *segment) { // have to retransmit. r.ep.snd.sendAck() } - return + return false, nil } // Since we consumed a segment update the receiver's RTT estimate @@ -315,4 +403,67 @@ func (r *receiver) handleRcvdSegment(s *segment) { r.pendingBufUsed -= s.logicalLen() s.decRef() } + return false, nil +} + +// handleTimeWaitSegment handles inbound segments received when the endpoint +// has entered the TIME_WAIT state. +func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn bool) { + segSeq := s.sequenceNumber + segLen := seqnum.Size(s.data.Size()) + + // Just silently drop any RST packets in TIME_WAIT. We do not support + // TIME_WAIT assasination as a result we confirm w/ fix 1 as described + // in https://tools.ietf.org/html/rfc1337#section-3. + if s.flagIsSet(header.TCPFlagRst) { + return false, false + } + + // If it's a SYN and the sequence number is higher than any seen before + // for this connection then try and redirect it to a listening endpoint + // if available. + // + // RFC 1122: + // "When a connection is [...] on TIME-WAIT state [...] + // [a TCP] MAY accept a new SYN from the remote TCP to + // reopen the connection directly, if it: + + // (1) assigns its initial sequence number for the new + // connection to be larger than the largest sequence + // number it used on the previous connection incarnation, + // and + + // (2) returns to TIME-WAIT state if the SYN turns out + // to be an old duplicate". + if s.flagIsSet(header.TCPFlagSyn) && r.rcvNxt.LessThan(segSeq) { + + return false, true + } + + // Drop the segment if it does not contain an ACK. + if !s.flagIsSet(header.TCPFlagAck) { + return false, false + } + + // Update Timestamp if required. See RFC7323, section-4.3. + if r.ep.sendTSOk && s.parsedOptions.TS { + r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.maxSentAck, segSeq) + } + + if segSeq.Add(1) == r.rcvNxt && s.flagIsSet(header.TCPFlagFin) { + // If it's a FIN-ACK then resetTimeWait and send an ACK, as it + // indicates our final ACK could have been lost. + r.ep.snd.sendAck() + return true, false + } + + // If the sequence number range is outside the acceptable range or + // carries data then just send an ACK. This is according to RFC 793, + // page 37. + // + // NOTE: In TIME_WAIT the only acceptable sequence number is rcvNxt. + if segSeq != r.rcvNxt || segLen != 0 { + r.ep.snd.sendAck() + } + return false, false } diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index ea725d513..1c10da5ca 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" @@ -60,13 +61,13 @@ type segment struct { xmitTime time.Time `state:".(unixTime)"` } -func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment { +func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) *segment { s := &segment{ refCnt: 1, id: id, route: r.Clone(), } - s.data = vv.Clone(s.views[:]) + s.data = pkt.Data.Clone(s.views[:]) s.rcvdTime = time.Now() return s } @@ -99,8 +100,14 @@ func (s *segment) clone() *segment { return t } -func (s *segment) flagIsSet(flag uint8) bool { - return (s.flags & flag) != 0 +// flagIsSet checks if at least one flag in flags is set in s.flags. +func (s *segment) flagIsSet(flags uint8) bool { + return s.flags&flags != 0 +} + +// flagsAreSet checks if all flags in flags are set in s.flags. +func (s *segment) flagsAreSet(flags uint8) bool { + return s.flags&flags == flags } func (s *segment) decRef() { diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 735edfe55..d3f7c9125 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -417,6 +417,7 @@ func (s *sender) resendSegment() { s.fr.rescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 s.sendSegment(seg) s.ep.stack.Stats().TCP.FastRetransmit.Increment() + s.ep.stats.SendErrors.FastRetransmit.Increment() // Run SetPipe() as per RFC 6675 section 5 Step 4.4 s.SetPipe() @@ -435,6 +436,7 @@ func (s *sender) retransmitTimerExpired() bool { } s.ep.stack.Stats().TCP.Timeouts.Increment() + s.ep.stats.SendErrors.Timeouts.Increment() // Give up if we've waited more than a minute since the last resend. if s.rto >= 60*time.Second { @@ -672,6 +674,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se default: s.ep.state = StateFinWait1 } + s.ep.stack.Stats().TCP.CurrentEstablished.Decrement() s.ep.mu.Unlock() } else { // We're sending a non-FIN segment. @@ -1188,6 +1191,7 @@ func (s *sender) handleRcvdSegment(seg *segment) { func (s *sender) sendSegment(seg *segment) *tcpip.Error { if !seg.xmitTime.IsZero() { s.ep.stack.Stats().TCP.Retransmits.Increment() + s.ep.stats.SendErrors.Retransmits.Increment() if s.sndCwnd < s.sndSsthresh { s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment() } diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go index 9fa97528b..782d7b42c 100644 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go @@ -500,6 +500,14 @@ func TestRetransmit(t *testing.T) { t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) } + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want { + t.Errorf("got EP SendErrors.Timeouts.Value = %v, want = %v", got, want) + } + + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want { + t.Errorf("got EP stats SendErrors.Retransmits.Value = %v, want = %v", got, want) + } + if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want { t.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want) } diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 4e7f1a740..afea124ec 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -520,10 +520,18 @@ func TestSACKRecovery(t *testing.T) { t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) } + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want { + t.Errorf("got EP stats SendErrors.FastRetransmit = %v, want = %v", got, want) + } + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want { t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) } + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want { + t.Errorf("got EP stats Stats.SendErrors.Retransmits = %v, want = %v", got, want) + } + c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond) // Acknowledge all pending data to recover point. diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 2be094876..84579ce52 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -99,7 +99,10 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { - t.Errorf("got stats.TCP.FailedConnectionOpenings.Value() = %v, want = %v", got, want) + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want) + } + if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { + t.Errorf("got EP stats.FailedConnectionAttempts = %v, want = %v", got, want) } } @@ -122,6 +125,9 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want) } + if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { + t.Errorf("got EP stats FailedConnectionAttempts = %v, want = %v", got, want) + } } func TestTCPSegmentsSentIncrement(t *testing.T) { @@ -136,6 +142,9 @@ func TestTCPSegmentsSentIncrement(t *testing.T) { if got := stats.TCP.SegmentsSent.Value(); got != want { t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want) } + if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want { + t.Errorf("got EP stats SegmentsSent.Value() = %v, want = %v", got, want) + } } func TestTCPResetsSentIncrement(t *testing.T) { @@ -197,17 +206,18 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + // Set TCPLingerTimeout to 5 seconds so that sockets are marked closed wq := &waiter.Queue{} ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } // Send a SYN request. @@ -247,7 +257,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { case <-ch: c.EP, _, err = ep.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -255,6 +265,13 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { } } + // Lower stackwide TIME_WAIT timeout so that the reservations + // are released instantly on Close. + tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil { + t.Fatalf("e.stack.SetTransportProtocolOption(%d, %s) = %s", tcp.ProtocolNumber, tcpTW, err) + } + c.EP.Close() checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), @@ -276,6 +293,11 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { // Get the ACK to the FIN we just sent. c.GetPacket() + // Since an active close was done we need to wait for a little more than + // tcpLingerTimeout for the port reservations to be released and the + // socket to move to a CLOSED state. + time.Sleep(20 * time.Millisecond) + // Now resend the same ACK, this ACK should generate a RST as there // should be no endpoint in SYN-RCVD state and we are not using // syn-cookies yet. The reason we send the same ACK is we need a valid @@ -367,6 +389,13 @@ func TestConnectResetAfterClose(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + // Set TCPLinger to 3 seconds so that sockets are marked closed + // after 3 second in FIN_WAIT2 state. + tcpLingerTimeout := 3 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpLingerTimeout, err) + } + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil @@ -387,12 +416,24 @@ func TestConnectResetAfterClose(t *testing.T) { DstPort: c.Port, Flags: header.TCPFlagAck, SeqNum: 790, - AckNum: c.IRS.Add(1), + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Wait for the ep to give up waiting for a FIN. + time.Sleep(tcpLingerTimeout + 1*time.Second) + + // Now send an ACK and it should trigger a RST as the endpoint should + // not exist anymore. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(2), RcvWnd: 30000, }) - // Wait for the ep to give up waiting for a FIN, and send a RST. - time.Sleep(3 * time.Second) for { b := c.GetPacket() tcpHdr := header.TCP(header.IPv4(b).Payload()) @@ -404,7 +445,7 @@ func TestConnectResetAfterClose(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), + checker.SeqNum(uint32(c.IRS)+2), checker.AckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), ), @@ -465,6 +506,347 @@ func TestSimpleReceive(t *testing.T) { ) } +// TestUserSuppliedMSSOnConnectV4 tests that the user supplied MSS is used when +// creating a new active IPv4 TCP socket. It should be present in the sent TCP +// SYN segment. +func TestUserSuppliedMSSOnConnectV4(t *testing.T) { + const mtu = 5000 + const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize + tests := []struct { + name string + setMSS uint16 + expMSS uint16 + }{ + { + "EqualToMaxMSS", + maxMSS, + maxMSS, + }, + { + "LessThanMTU", + maxMSS - 1, + maxMSS - 1, + }, + { + "GreaterThanMTU", + maxMSS + 1, + maxMSS, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() + + c.Create(-1) + + // Set the MSS socket option. + opt := tcpip.MaxSegOption(test.setMSS) + if err := c.EP.SetSockOpt(opt); err != nil { + t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err) + } + + // Get expected window size. + rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { + t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err) + } + ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) + + // Start connection attempt to IPv4 address. + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet with our user supplied MSS. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws}))) + }) + } +} + +// TestUserSuppliedMSSOnConnectV6 tests that the user supplied MSS is used when +// creating a new active IPv6 TCP socket. It should be present in the sent TCP +// SYN segment. +func TestUserSuppliedMSSOnConnectV6(t *testing.T) { + const mtu = 5000 + const maxMSS = mtu - header.IPv6MinimumSize - header.TCPMinimumSize + tests := []struct { + name string + setMSS uint16 + expMSS uint16 + }{ + { + "EqualToMaxMSS", + maxMSS, + maxMSS, + }, + { + "LessThanMTU", + maxMSS - 1, + maxMSS - 1, + }, + { + "GreaterThanMTU", + maxMSS + 1, + maxMSS, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + // Set the MSS socket option. + opt := tcpip.MaxSegOption(test.setMSS) + if err := c.EP.SetSockOpt(opt); err != nil { + t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err) + } + + // Get expected window size. + rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { + t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err) + } + ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) + + // Start connection attempt to IPv6 address. + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet with our user supplied MSS. + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws}))) + }) + } +} + +func TestSendRstOnListenerRxSynAckV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.SeqNum(200))) +} + +func TestSendRstOnListenerRxSynAckV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.SeqNum(200))) +} + +func TestTOSV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + c.EP = ep + + const tos = 0xC0 + if err := c.EP.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil { + t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err) + } + + var v tcpip.IPv4TOSOption + if err := c.EP.GetSockOpt(&v); err != nil { + t.Errorf("GetSockopt failed: %s", err) + } + + if want := tcpip.IPv4TOSOption(tos); v != want { + t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + } + + testV4Connect(t, c, checker.TOS(tos, 0)) + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that data is received. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), // Acknum is initial sequence number + 1 + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + checker.TOS(tos, 0), + ) + + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Errorf("got data = %x, want = %x", p, data) + } +} + +func TestTrafficClassV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + const tos = 0xC0 + if err := c.EP.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil { + t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv6TrafficClassOption(tos), err) + } + + var v tcpip.IPv6TrafficClassOption + if err := c.EP.GetSockOpt(&v); err != nil { + t.Fatalf("GetSockopt failed: %s", err) + } + + if want := tcpip.IPv6TrafficClassOption(tos); v != want { + t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + } + + // Test the connection request. + testV6Connect(t, c, checker.TOS(tos, 0)) + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that data is received. + b := c.GetV6Packet() + checker.IPv6(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + checker.TOS(tos, 0), + ) + + if p := b[header.IPv6MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Errorf("got data = %x, want = %x", p, data) + } +} + +func TestConnectBindToDevice(t *testing.T) { + for _, test := range []struct { + name string + device string + want tcp.EndpointState + }{ + {"RightDevice", "nic1", tcp.StateEstablished}, + {"WrongDevice", "nic2", tcp.StateSynSent}, + {"AnyDevice", "", tcp.StateEstablished}, + } { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + bindToDevice := tcpip.BindToDeviceOption(test.device) + c.EP.SetSockOpt(bindToDevice) + // Start connection attempt. + waitEntry, _ := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventOut) + defer c.WQ.EventUnregister(&waitEntry) + + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet. + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + iss := seqnum.Value(789) + rcvWnd := seqnum.Size(30000) + c.SendPacket(nil, &context.Headers{ + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + TCPOpts: nil, + }) + + c.GetPacket() + if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want { + t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } + }) + } +} + func TestOutOfOrderReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -760,8 +1142,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - // We shouldn't consume a sequence number on RST. - checker.SeqNum(uint32(c.IRS)+1), + checker.SeqNum(uint32(c.IRS)+2), )) // The RST puts the endpoint into an error state. if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { @@ -797,6 +1178,10 @@ func TestShutdownRead(t *testing.T) { if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive) } + var want uint64 = 1 + if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want { + t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %v want %v", got, want) + } } func TestFullWindowReceive(t *testing.T) { @@ -853,6 +1238,11 @@ func TestFullWindowReceive(t *testing.T) { t.Fatalf("got data = %v, want = %v", v, data) } + var want uint64 = 1 + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want { + t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %v want %v", got, want) + } + // Check that we get an ACK for the newly non-zero window. checker.IPv4(t, c.GetPacket(), checker.TCP( @@ -1444,7 +1834,7 @@ func TestDelay(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOpt(tcpip.DelayOption(1)) + c.EP.SetSockOptInt(tcpip.DelayOption, 1) var allData []byte for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { @@ -1492,7 +1882,7 @@ func TestUndelay(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOpt(tcpip.DelayOption(1)) + c.EP.SetSockOptInt(tcpip.DelayOption, 1) allData := [][]byte{{0}, {1, 2, 3}} for i, data := range allData { @@ -1525,7 +1915,7 @@ func TestUndelay(t *testing.T) { // Check that we don't get the second packet yet. c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond) - c.EP.SetSockOpt(tcpip.DelayOption(0)) + c.EP.SetSockOptInt(tcpip.DelayOption, 0) // Check that data is received. second := c.GetPacket() @@ -1562,7 +1952,7 @@ func TestMSSNotDelayed(t *testing.T) { fn func(tcpip.Endpoint) }{ {"no-op", func(tcpip.Endpoint) {}}, - {"delay", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.DelayOption(1)) }}, + {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptInt(tcpip.DelayOption, 1) }}, {"cork", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.CorkOption(1)) }}, } @@ -1692,6 +2082,34 @@ func TestSendGreaterThanMTU(t *testing.T) { testBrokenUpWrite(t, c, maxPayload) } +func TestSetTTL(t *testing.T) { + for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { + t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { + c := context.New(t, 65535) + defer c.Cleanup() + + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + if err := c.EP.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil { + t.Fatalf("SetSockOpt failed: %v", err) + } + + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet. + b := c.GetPacket() + + checker.IPv4(t, b, checker.TTL(wantTTL)) + }) + } +} + func TestActiveSendMSSLessThanMTU(t *testing.T) { const maxPayload = 100 c := context.New(t, 65535) @@ -1997,6 +2415,13 @@ loop: t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset) } } + // Expect the state to be StateError and subsequent Reads to fail with HardError. + if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset) + } + if tcp.EndpointState(c.EP.State()) != tcp.StateError { + t.Fatalf("got EP state is not StateError") + } } func TestSendOnResetConnection(t *testing.T) { @@ -2568,6 +2993,17 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) { if got := stats.TCP.ValidSegmentsReceived.Value(); got != want { t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %v, want = %v", got, want) } + if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want { + t.Errorf("got EP stats Stats.SegmentsReceived = %v, want = %v", got, want) + } + // Ensure there were no errors during handshake. If these stats have + // incremented, then the connection should not have been established. + if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 { + t.Errorf("got EP stats Stats.SendErrors.NoRoute = %v, want = %v", got, 0) + } + if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoLinkAddr.Value(); got != 0 { + t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %v, want = %v", got, 0) + } } func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { @@ -2592,6 +3028,9 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want { t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %v, want = %v", got, want) } + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want) + } } func TestReceivedIncorrectChecksumIncrement(t *testing.T) { @@ -2618,6 +3057,9 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) { if got := stats.TCP.ChecksumErrors.Value(); got != want { t.Errorf("got stats.TCP.ChecksumErrors.Value() = %d, want = %d", got, want) } + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP stats Stats.ReceiveErrors.ChecksumErrors = %d, want = %d", got, want) + } } func TestReceivedSegmentQueuing(t *testing.T) { @@ -2674,6 +3116,13 @@ func TestReadAfterClosedState(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed + // after 1 second in TIME_WAIT state. + tcpTimeWaitTimeout := 1 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) + } + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) @@ -2681,12 +3130,12 @@ func TestReadAfterClosedState(t *testing.T) { defer c.WQ.EventUnregister(&we) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrWouldBlock) } // Shutdown immediately for write, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } checker.IPv4(t, c.GetPacket(), @@ -2724,10 +3173,9 @@ func TestReadAfterClosedState(t *testing.T) { ), ) - // Give the stack the chance to transition to closed state. Note that since - // both the sender and receiver are now closed, we effectively skip the - // TIME-WAIT state. - time.Sleep(1 * time.Second) + // Give the stack the chance to transition to closed state from + // TIME_WAIT. + time.Sleep(tcpTimeWaitTimeout * 2) if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) @@ -2744,7 +3192,7 @@ func TestReadAfterClosedState(t *testing.T) { peekBuf := make([]byte, 10) n, _, err := c.EP.Peek([][]byte{peekBuf}) if err != nil { - t.Fatalf("Peek failed: %v", err) + t.Fatalf("Peek failed: %s", err) } peekBuf = peekBuf[:n] @@ -2755,7 +3203,7 @@ func TestReadAfterClosedState(t *testing.T) { // Receive data. v, _, err := c.EP.Read(nil) if err != nil { - t.Fatalf("Read failed: %v", err) + t.Fatalf("Read failed: %s", err) } if !bytes.Equal(data, v) { @@ -2765,11 +3213,11 @@ func TestReadAfterClosedState(t *testing.T) { // Now that we drained the queue, check that functions fail with the // right error code. if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive) + t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrClosedForReceive) } if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Peek(...) = %v, want = %v", err, tcpip.ErrClosedForReceive) + t.Fatalf("got c.EP.Peek(...) = %v, want = %s", err, tcpip.ErrClosedForReceive) } } @@ -2970,6 +3418,62 @@ func TestMinMaxBufferSizes(t *testing.T) { checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30) } +func TestBindToDeviceOption(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}}) + + ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %v", err) + } + defer ep.Close() + + if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil { + t.Errorf("CreateNamedNIC failed: %v", err) + } + + // Make an nameless NIC. + if err := s.CreateNIC(54321, loopback.New()); err != nil { + t.Errorf("CreateNIC failed: %v", err) + } + + // strPtr is used instead of taking the address of string literals, which is + // a compiler error. + strPtr := func(s string) *string { + return &s + } + + testActions := []struct { + name string + setBindToDevice *string + setBindToDeviceError *tcpip.Error + getBindToDevice tcpip.BindToDeviceOption + }{ + {"GetDefaultValue", nil, nil, ""}, + {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, + {"BindToExistent", strPtr("my_device"), nil, "my_device"}, + {"UnbindToDevice", strPtr(""), nil, ""}, + } + for _, testAction := range testActions { + t.Run(testAction.name, func(t *testing.T) { + if testAction.setBindToDevice != nil { + bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) + if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { + t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) + } + } + bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") + if ep.GetSockOpt(&bindToDevice) != nil { + t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) + } + if got, want := bindToDevice, testAction.getBindToDevice; got != want { + t.Errorf("bindToDevice got %q, want %q", got, want) + } + }) + } +} + func makeStack() (*stack.Stack, *tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ @@ -3880,7 +4384,8 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { // Send a SYN request. irs := seqnum.Value(789) c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, + // pick a different src port for new SYN. + SrcPort: context.TestPort + 1, DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: irs, @@ -4014,6 +4519,9 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want { t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %v, want = %v", got, want) } + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want { + t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %v, want = %v", got, want) + } we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -4052,6 +4560,14 @@ func TestEndpointBindListenAcceptState(t *testing.T) { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) } + // Expect InvalidEndpointState errors on a read at this point. + if _, _, err := ep.Read(nil); err != tcpip.ErrInvalidEndpointState { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrInvalidEndpointState) + } + if got := ep.Stats().(*tcp.Stats).ReadErrors.InvalidEndpointState.Value(); got != 1 { + t.Fatalf("got EP stats Stats.ReadErrors.InvalidEndpointState got %v want %v", got, 1) + } + if err := ep.Listen(10); err != nil { t.Fatalf("Listen failed: %v", err) } @@ -4083,6 +4599,9 @@ func TestEndpointBindListenAcceptState(t *testing.T) { if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) } + if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected { + t.Errorf("Unexpected error attempting to call connect on an established endpoint, got: %v, want: %v", err, tcpip.ErrAlreadyConnected) + } // Listening endpoint remains in listen state. if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) @@ -4370,3 +4889,590 @@ func TestReceiveBufferAutoTuning(t *testing.T) { payloadSize *= 2 } } + +func TestDelayEnabled(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + checkDelayOption(t, c, false, 0) // Delay is disabled by default. + + for _, v := range []struct { + delayEnabled tcp.DelayEnabled + wantDelayOption int + }{ + {delayEnabled: false, wantDelayOption: 0}, + {delayEnabled: true, wantDelayOption: 1}, + } { + c := context.New(t, defaultMTU) + defer c.Cleanup() + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil { + t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %v", v.delayEnabled, err) + } + checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption) + } +} + +func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption int) { + t.Helper() + + var gotDelayEnabled tcp.DelayEnabled + if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil { + t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %v", err) + } + if gotDelayEnabled != wantDelayEnabled { + t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled) + } + + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue)) + if err != nil { + t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %v", err) + } + gotDelayOption, err := ep.GetSockOptInt(tcpip.DelayOption) + if err != nil { + t.Fatalf("ep.GetSockOptInt(tcpip.DelayOption) failed: %v", err) + } + if gotDelayOption != wantDelayOption { + t.Errorf("ep.GetSockOptInt(tcpip.DelayOption) got: %d, want: %d", gotDelayOption, wantDelayOption) + } +} + +func TestTCPLingerTimeout(t *testing.T) { + c := context.New(t, 1500 /* mtu */) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + testCases := []struct { + name string + tcpLingerTimeout time.Duration + want time.Duration + }{ + {"NegativeLingerTimeout", -123123, 0}, + {"ZeroLingerTimeout", 0, 0}, + {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second}, + // Values > stack's TCPLingerTimeout are capped to the stack's + // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds) + {"AboveMaxLingerTimeout", 65 * time.Second, 60 * time.Second}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if err := c.EP.SetSockOpt(tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)); err != nil { + t.Fatalf("SetSockOpt(%s) = %s", tc.tcpLingerTimeout, err) + } + var v tcpip.TCPLingerTimeoutOption + if err := c.EP.GetSockOpt(&v); err != nil { + t.Fatalf("GetSockOpt(tcpip.TCPLingerTimeoutOption) = %s", err) + } + if got, want := time.Duration(v), tc.want; got != want { + t.Fatalf("unexpected linger timeout got: %s, want: %s", got, want) + } + }) + } +} + +func TestTCPTimeWaitRSTIgnored(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Now send a RST and this should be ignored and not + // generate an ACK. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagRst, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + }) + + c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second) + + // Out of order ACK should generate an immediate ACK in + // TIME_WAIT. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 3, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) +} + +func TestTCPTimeWaitOutOfOrder(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Out of order ACK should generate an immediate ACK in + // TIME_WAIT. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 3, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) +} + +func TestTCPTimeWaitNewSyn(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Send a SYN request w/ sequence number lower than + // the highest sequence number sent. We just reuse + // the same number. + iss = seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second) + + // Send a SYN request w/ sequence number higher than + // the highest sequence number sent. + iss = seqnum.Value(792) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b = c.GetPacket() + tcpHdr = header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } +} + +func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed + // after 5 seconds in TIME_WAIT state. + tcpTimeWaitTimeout := 5 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + } + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + time.Sleep(2 * time.Second) + + // Now send a duplicate FIN. This should cause the TIME_WAIT to extend + // by another 5 seconds and also send us a duplicate ACK as it should + // indicate that the final ACK was potentially lost. + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Sleep for 4 seconds so at this point we are 1 second past the + // original tcpLingerTimeout of 5 seconds. + time.Sleep(4 * time.Second) + + // Send an ACK and it should not generate any packet as the socket + // should still be in TIME_WAIT for another another 5 seconds due + // to the duplicate FIN we sent earlier. + *ackHeaders = *finHeaders + ackHeaders.SeqNum = ackHeaders.SeqNum + 1 + ackHeaders.Flags = header.TCPFlagAck + c.SendPacket(nil, ackHeaders) + + c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second) + // Now sleep for another 2 seconds so that we are past the + // extended TIME_WAIT of 7 seconds (2 + 5). + time.Sleep(2 * time.Second) + + // Resend the same ACK. + c.SendPacket(nil, ackHeaders) + + // Receive the RST that should be generated as there is no valid + // endpoint. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(ackHeaders.AckNum)), + checker.AckNum(uint32(ackHeaders.SeqNum)), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) +} diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index d3f1d2cdf..4854e719d 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -158,7 +158,14 @@ func New(t *testing.T, mtu uint32) *Context { if testing.Verbose() { wep = sniffer.New(ep) } - if err := s.CreateNIC(1, wep); err != nil { + if err := s.CreateNamedNIC(1, "nic1", wep); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + wep2 := stack.LinkEndpoint(channel.New(1000, mtu, "")) + if testing.Verbose() { + wep2 = sniffer.New(channel.New(1000, mtu, "")) + } + if err := s.CreateNamedNIC(2, "nic2", wep2); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -295,7 +302,9 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt copy(icmp[header.ICMPv4PayloadOffset:], p2) // Inject packet. - c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) } // BuildSegment builds a TCP segment based on the given Headers and payload. @@ -343,13 +352,17 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView // SendSegment sends a TCP segment that has already been built and written to a // buffer.VectorisedView. func (c *Context) SendSegment(s buffer.VectorisedView) { - c.linkEP.Inject(ipv4.ProtocolNumber, s) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: s, + }) } // SendPacket builds and sends a TCP segment(with the provided payload & TCP // headers) in an IPv4 packet via the link layer endpoint. func (c *Context) SendPacket(payload []byte, h *Headers) { - c.linkEP.Inject(ipv4.ProtocolNumber, c.BuildSegment(payload, h)) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: c.BuildSegment(payload, h), + }) } // SendAck sends an ACK packet. @@ -511,7 +524,9 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) { t.SetChecksum(^t.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) } // CreateConnected creates a connected TCP endpoint. @@ -588,12 +603,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) c.Port = tcpHdr.SourcePort() } -// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends -// the specified option bytes as the Option field in the initial SYN packet. -// -// It also sets the receive buffer for the endpoint to the specified -// value in epRcvBuf. -func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) { +// Create creates a TCP endpoint. +func (c *Context) Create(epRcvBuf int) { // Create TCP endpoint. var err *tcpip.Error c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) @@ -606,6 +617,15 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. c.t.Fatalf("SetSockOpt failed failed: %v", err) } } +} + +// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends +// the specified option bytes as the Option field in the initial SYN packet. +// +// It also sets the receive buffer for the endpoint to the specified +// value in epRcvBuf. +func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) { + c.Create(epRcvBuf) c.Connect(iss, rcvWnd, options) } |