summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/tcp
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport/tcp')
-rw-r--r--pkg/tcpip/transport/tcp/BUILD9
-rw-r--r--pkg/tcpip/transport/tcp/accept.go54
-rw-r--r--pkg/tcpip/transport/tcp/connect.go458
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go56
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go593
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go59
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go5
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go67
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go167
-rw-r--r--pkg/tcpip/transport/tcp/segment.go15
-rw-r--r--pkg/tcpip/transport/tcp/snd.go4
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go8
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go8
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go1158
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go42
15 files changed, 2404 insertions, 299 deletions
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 39a839ab7..3f47b328d 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "tcp_segment_list",
out = "tcp_segment_list.go",
@@ -45,10 +44,12 @@ go_library(
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/log",
"//pkg/rand",
"//pkg/sleep",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
"//pkg/tcpip/iptables",
"//pkg/tcpip/seqnum",
@@ -70,7 +71,7 @@ filegroup(
go_test(
name = "tcp_test",
- size = "small",
+ size = "medium",
srcs = [
"dual_stack_test.go",
"sack_scoreboard_test.go",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 0802e984e..f24b51b91 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -229,7 +230,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
}
n := newEndpoint(l.stack, netProto, nil)
n.v6only = l.v6only
- n.id = s.id
+ n.ID = s.id
n.boundNICID = s.route.NICID()
n.route = s.route.Clone()
n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto}
@@ -242,7 +243,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n.initGSO()
// Register new endpoint so that packets are routed to it.
- if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort); err != nil {
+ if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.bindToDevice); err != nil {
n.Close()
return nil, err
}
@@ -268,8 +269,8 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
- cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS))
- ep, err := l.createConnectingEndpoint(s, cookie, irs, opts)
+ isn := generateSecureISN(s.id, l.stack.Seed())
+ ep, err := l.createConnectingEndpoint(s, isn, irs, opts)
if err != nil {
return nil, err
}
@@ -288,9 +289,8 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// Perform the 3-way handshake.
h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow()))
- h.resetToSynRcvd(cookie, irs, opts)
+ h.resetToSynRcvd(isn, irs, opts)
if err := h.execute(); err != nil {
- ep.stack.Stats().TCP.FailedConnectionAttempts.Increment()
ep.Close()
if l.listenEP != nil {
l.removePendingEndpoint(ep)
@@ -298,7 +298,9 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
return nil, err
}
ep.mu.Lock()
+ ep.stack.Stats().TCP.CurrentEstablished.Increment()
ep.state = StateEstablished
+ ep.isConnectNotified = true
ep.mu.Unlock()
// Update the receive window scaling. We can't do it before the
@@ -311,14 +313,14 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
func (l *listenContext) addPendingEndpoint(n *endpoint) {
l.pendingMu.Lock()
- l.pendingEndpoints[n.id] = n
+ l.pendingEndpoints[n.ID] = n
l.pending.Add(1)
l.pendingMu.Unlock()
}
func (l *listenContext) removePendingEndpoint(n *endpoint) {
l.pendingMu.Lock()
- delete(l.pendingEndpoints, n.id)
+ delete(l.pendingEndpoints, n.ID)
l.pending.Done()
l.pendingMu.Unlock()
}
@@ -360,12 +362,19 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
defer decSynRcvdCount()
defer e.decSynRcvdCount()
defer s.decRef()
+
n, err := ctx.createEndpointAndPerformHandshake(s, opts)
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
return
}
ctx.removePendingEndpoint(n)
+ // Start the protocol goroutine.
+ wq := &waiter.Queue{}
+ n.startAcceptedLoop(wq)
+ e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
+
e.deliverAccepted(n)
}
@@ -399,6 +408,17 @@ func (e *endpoint) acceptQueueIsFull() bool {
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
+ if s.flagsAreSet(header.TCPFlagSyn | header.TCPFlagAck) {
+ // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST
+ // must be sent in response to a SYN-ACK while in the listen
+ // state to prevent completing a handshake from an old SYN.
+ e.sendTCP(&s.route, s.id, buffer.VectorisedView{}, e.ttl, e.sendTOS, header.TCPFlagRst, s.ackNumber, 0, 0, nil, nil)
+ return
+ }
+
+ // TODO(b/143300739): Use the userMSS of the listening socket
+ // for accepted sockets.
+
switch s.flags {
case header.TCPFlagSyn:
opts := parseSynSegmentOptions(s)
@@ -414,6 +434,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
}
decSynRcvdCount()
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
+ e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
return
} else {
@@ -421,6 +442,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// is full then drop the syn.
if e.acceptQueueIsFull() {
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
+ e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
return
}
@@ -431,15 +453,14 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
//
// Enable Timestamp option if the original syn did have
// the timestamp option specified.
- mss := mssForRoute(&s.route)
synOpts := header.TCPSynOptions{
WS: -1,
TS: opts.TS,
TSVal: tcpTimeStamp(timeStampOffset()),
TSEcr: opts.TSVal,
- MSS: uint16(mss),
+ MSS: mssForRoute(&s.route),
}
- sendSynTCP(&s.route, s.id, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts)
+ e.sendSynTCP(&s.route, s.id, e.ttl, e.sendTOS, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts)
e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
}
@@ -451,6 +472,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// complete the connection at the time of retransmit if
// the backlog has space.
e.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
+ e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
return
}
@@ -505,6 +527,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions)
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
return
}
@@ -515,7 +538,9 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
n.tsOffset = 0
// Switch state to connected.
+ n.stack.Stats().TCP.CurrentEstablished.Increment()
n.state = StateEstablished
+ n.isConnectNotified = true
// Do the delivery in a separate goroutine so
// that we don't block the listen loop in case
@@ -526,6 +551,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// number of goroutines as we do check before
// entering here that there was at least some
// space available in the backlog.
+
+ // Start the protocol goroutine.
+ wq := &waiter.Queue{}
+ n.startAcceptedLoop(wq)
+ e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
go e.deliverAccepted(n)
}
}
@@ -536,7 +566,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.mu.Lock()
v6only := e.v6only
e.mu.Unlock()
- ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto)
+ ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.NetProto)
defer func() {
// Mark endpoint as closed. This will prevent goroutines running
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 21038a65a..a114c06c1 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -15,6 +15,7 @@
package tcp
import (
+ "encoding/binary"
"sync"
"time"
@@ -22,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -78,9 +80,6 @@ type handshake struct {
// mss is the maximum segment size received from the peer.
mss uint16
- // amss is the maximum segment size advertised by us to the peer.
- amss uint16
-
// sndWndScale is the send window scale, as defined in RFC 1323. A
// negative value means no scaling is supported by the peer.
sndWndScale int
@@ -142,7 +141,32 @@ func (h *handshake) resetState() {
h.flags = header.TCPFlagSyn
h.ackNum = 0
h.mss = 0
- h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24)
+ h.iss = generateSecureISN(h.ep.ID, h.ep.stack.Seed())
+}
+
+// generateSecureISN generates a secure Initial Sequence number based on the
+// recommendation here https://tools.ietf.org/html/rfc6528#page-3.
+func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value {
+ isnHasher := jenkins.Sum32(seed)
+ isnHasher.Write([]byte(id.LocalAddress))
+ isnHasher.Write([]byte(id.RemoteAddress))
+ portBuf := make([]byte, 2)
+ binary.LittleEndian.PutUint16(portBuf, id.LocalPort)
+ isnHasher.Write(portBuf)
+ binary.LittleEndian.PutUint16(portBuf, id.RemotePort)
+ isnHasher.Write(portBuf)
+ // The time period here is 64ns. This is similar to what linux uses
+ // generate a sequence number that overlaps less than one
+ // time per MSL (2 minutes).
+ //
+ // A 64ns clock ticks 10^9/64 = 15625000) times in a second.
+ // To wrap the whole 32 bit space would require
+ // 2^32/1562500 ~ 274 seconds.
+ //
+ // Which sort of guarantees that we won't reuse the ISN for a new
+ // connection for the same tuple for at least 274s.
+ isn := isnHasher.Sum32() + uint32(time.Now().UnixNano()>>6)
+ return seqnum.Value(isn)
}
// effectiveRcvWndScale returns the effective receive window scale to be used.
@@ -238,6 +262,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
h.state = handshakeSynRcvd
h.ep.mu.Lock()
h.ep.state = StateSynRecv
+ ttl := h.ep.ttl
h.ep.mu.Unlock()
synOpts := header.TCPSynOptions{
WS: int(h.effectiveRcvWndScale()),
@@ -251,8 +276,10 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
SACKPermitted: rcvSynOpts.SACKPermitted,
MSS: h.ep.amss,
}
- sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
-
+ if ttl == 0 {
+ ttl = s.route.DefaultTTL()
+ }
+ h.ep.sendSynTCP(&s.route, h.ep.ID, ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
return nil
}
@@ -296,7 +323,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
SACKPermitted: h.ep.sackPermitted,
MSS: h.ep.amss,
}
- sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
return nil
}
@@ -383,6 +410,11 @@ func (h *handshake) resolveRoute() *tcpip.Error {
switch index {
case wakerForResolution:
if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock {
+ if err == tcpip.ErrNoLinkAddress {
+ h.ep.stats.SendErrors.NoLinkAddr.Increment()
+ } else if err != nil {
+ h.ep.stats.SendErrors.NoRoute.Increment()
+ }
// Either success (err == nil) or failure.
return err
}
@@ -437,7 +469,7 @@ func (h *handshake) execute() *tcpip.Error {
// Send the initial SYN segment and loop until the handshake is
// completed.
- h.ep.amss = mssForRoute(&h.ep.route)
+ h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
@@ -460,7 +492,8 @@ func (h *handshake) execute() *tcpip.Error {
synOpts.WS = -1
}
}
- sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+
for h.state != handshakeCompleted {
switch index, _ := s.Fetch(true); index {
case wakerForResend:
@@ -469,7 +502,7 @@ func (h *handshake) execute() *tcpip.Error {
return tcpip.ErrTimeout
}
rt.Reset(timeOut)
- sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
case wakerForNotification:
n := h.ep.fetchNotifications()
@@ -579,24 +612,30 @@ func makeSynOptions(opts header.TCPSynOptions) []byte {
return options[:offset]
}
-func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error {
+func (e *endpoint) sendSynTCP(r *stack.Route, id stack.TransportEndpointID, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error {
options := makeSynOptions(opts)
- err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options, nil)
+ // We ignore SYN send errors and let the callers re-attempt send.
+ if err := e.sendTCP(r, id, buffer.VectorisedView{}, ttl, tos, flags, seq, ack, rcvWnd, options, nil); err != nil {
+ e.stats.SendErrors.SynSendToNetworkFailed.Increment()
+ }
putOptions(options)
- return err
+ return nil
}
-// sendTCP sends a TCP segment with the provided options via the provided
-// network endpoint and under the provided identity.
-func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
- optLen := len(opts)
- // Allocate a buffer for the TCP header.
- hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen)
-
- if rcvWnd > 0xffff {
- rcvWnd = 0xffff
+func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
+ if err := sendTCP(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso); err != nil {
+ e.stats.SendErrors.SegmentSendToNetworkFailed.Increment()
+ return err
}
+ e.stats.SegmentsSent.Increment()
+ return nil
+}
+func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, d *stack.PacketDescriptor, data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) {
+ optLen := len(opts)
+ hdr := &d.Hdr
+ packetSize := d.Size
+ off := d.Off
// Initialize the header.
tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen))
tcp.Encode(&header.TCPFields{
@@ -610,7 +649,7 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise
})
copy(tcp[header.TCPMinimumSize:], opts)
- length := uint16(hdr.UsedLength() + data.Size())
+ length := uint16(hdr.UsedLength() + packetSize)
xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
// Only calculate the checksum if offloading isn't supported.
if gso != nil && gso.NeedsCsum {
@@ -620,16 +659,79 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise
// header and data and get the right sum of the TCP packet.
tcp.SetChecksum(xsum)
} else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
- xsum = header.ChecksumVV(data, xsum)
+ xsum = header.ChecksumVVWithOffset(data, xsum, off, packetSize)
tcp.SetChecksum(^tcp.CalculateChecksum(xsum))
}
+}
+
+func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
+ optLen := len(opts)
+ if rcvWnd > 0xffff {
+ rcvWnd = 0xffff
+ }
+
+ mss := int(gso.MSS)
+ n := (data.Size() + mss - 1) / mss
+
+ hdrs := stack.NewPacketDescriptors(n, header.TCPMinimumSize+int(r.MaxHeaderLength())+optLen)
+
+ size := data.Size()
+ off := 0
+ for i := 0; i < n; i++ {
+ packetSize := mss
+ if packetSize > size {
+ packetSize = size
+ }
+ size -= packetSize
+ hdrs[i].Off = off
+ hdrs[i].Size = packetSize
+ buildTCPHdr(r, id, &hdrs[i], data, flags, seq, ack, rcvWnd, opts, gso)
+ off += packetSize
+ seq = seq.Add(seqnum.Size(packetSize))
+ }
+ if ttl == 0 {
+ ttl = r.DefaultTTL()
+ }
+ sent, err := r.WritePackets(gso, hdrs, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos})
+ if err != nil {
+ r.Stats().TCP.SegmentSendErrors.IncrementBy(uint64(n - sent))
+ }
+ r.Stats().TCP.SegmentsSent.IncrementBy(uint64(sent))
+ return err
+}
+
+// sendTCP sends a TCP segment with the provided options via the provided
+// network endpoint and under the provided identity.
+func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
+ optLen := len(opts)
+ if rcvWnd > 0xffff {
+ rcvWnd = 0xffff
+ }
+
+ if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() {
+ return sendTCPBatch(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso)
+ }
+
+ d := &stack.PacketDescriptor{
+ Hdr: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen),
+ Off: 0,
+ Size: data.Size(),
+ }
+ buildTCPHdr(r, id, d, data, flags, seq, ack, rcvWnd, opts, gso)
+
+ if ttl == 0 {
+ ttl = r.DefaultTTL()
+ }
+ if err := r.WritePacket(gso, d.Hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil {
+ r.Stats().TCP.SegmentSendErrors.Increment()
+ return err
+ }
r.Stats().TCP.SegmentsSent.Increment()
if (flags & header.TCPFlagRst) != 0 {
r.Stats().TCP.ResetsSent.Increment()
}
-
- return r.WritePacket(gso, hdr, data, ProtocolNumber, ttl)
+ return nil
}
// makeOptions makes an options slice.
@@ -678,7 +780,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
}
options := e.makeOptions(sackBlocks)
- err := sendTCP(&e.route, e.id, data, e.route.DefaultTTL(), flags, seq, ack, rcvWnd, options, e.gso)
+ err := e.sendTCP(&e.route, e.ID, data, e.ttl, e.sendTOS, flags, seq, ack, rcvWnd, options, e.gso)
putOptions(options)
return err
}
@@ -727,10 +829,26 @@ func (e *endpoint) handleClose() *tcpip.Error {
func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
// Only send a reset if the connection is being aborted for a reason
// other than receiving a reset.
+ if e.state == StateEstablished || e.state == StateCloseWait {
+ e.stack.Stats().TCP.EstablishedResets.Increment()
+ e.stack.Stats().TCP.CurrentEstablished.Decrement()
+ }
e.state = StateError
- e.hardError = err
+ e.HardError = err
if err != tcpip.ErrConnectionReset {
- e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0)
+ // The exact sequence number to be used for the RST is the same as the
+ // one used by Linux. We need to handle the case of window being shrunk
+ // which can cause sndNxt to be outside the acceptable window on the
+ // receiver.
+ //
+ // See: https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for more
+ // information.
+ sndWndEnd := e.snd.sndUna.Add(e.snd.sndWnd)
+ resetSeqNum := sndWndEnd
+ if !sndWndEnd.LessThan(e.snd.sndNxt) || e.snd.sndNxt.Size(sndWndEnd) < (1<<e.snd.sndWndScale) {
+ resetSeqNum = e.snd.sndNxt
+ }
+ e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, resetSeqNum, e.rcv.rcvNxt, 0)
}
}
@@ -744,6 +862,51 @@ func (e *endpoint) completeWorkerLocked() {
}
}
+func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
+ if e.rcv.acceptable(s.sequenceNumber, 0) {
+ // RFC 793, page 37 states that "in all states
+ // except SYN-SENT, all reset (RST) segments are
+ // validated by checking their SEQ-fields." So
+ // we only process it if it's acceptable.
+ s.decRef()
+ e.mu.Lock()
+ switch e.state {
+ // In case of a RST in CLOSE-WAIT linux moves
+ // the socket to closed state with an error set
+ // to indicate EPIPE.
+ //
+ // Technically this seems to be at odds w/ RFC.
+ // As per https://tools.ietf.org/html/rfc793#section-2.7
+ // page 69 the behavior for a segment arriving
+ // w/ RST bit set in CLOSE-WAIT is inlined below.
+ //
+ // ESTABLISHED
+ // FIN-WAIT-1
+ // FIN-WAIT-2
+ // CLOSE-WAIT
+
+ // If the RST bit is set then, any outstanding RECEIVEs and
+ // SEND should receive "reset" responses. All segment queues
+ // should be flushed. Users should also receive an unsolicited
+ // general "connection reset" signal. Enter the CLOSED state,
+ // delete the TCB, and return.
+ case StateCloseWait:
+ e.state = StateClose
+ e.HardError = tcpip.ErrAborted
+ // We need to set this explicitly here because otherwise
+ // the port registrations will not be released till the
+ // endpoint is actively closed by the application.
+ e.workerCleanup = true
+ e.mu.Unlock()
+ return false, nil
+ default:
+ e.mu.Unlock()
+ return false, tcpip.ErrConnectionReset
+ }
+ }
+ return true, nil
+}
+
// handleSegments pulls segments from the queue and processes them. It returns
// no error if the protocol loop should continue, an error otherwise.
func (e *endpoint) handleSegments() *tcpip.Error {
@@ -761,14 +924,34 @@ func (e *endpoint) handleSegments() *tcpip.Error {
}
if s.flagIsSet(header.TCPFlagRst) {
- if e.rcv.acceptable(s.sequenceNumber, 0) {
- // RFC 793, page 37 states that "in all states
- // except SYN-SENT, all reset (RST) segments are
- // validated by checking their SEQ-fields." So
- // we only process it if it's acceptable.
- s.decRef()
- return tcpip.ErrConnectionReset
+ if ok, err := e.handleReset(s); !ok {
+ return err
}
+ } else if s.flagIsSet(header.TCPFlagSyn) {
+ // See: https://tools.ietf.org/html/rfc5961#section-4.1
+ // 1) If the SYN bit is set, irrespective of the sequence number, TCP
+ // MUST send an ACK (also referred to as challenge ACK) to the remote
+ // peer:
+ //
+ // <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK>
+ //
+ // After sending the acknowledgment, TCP MUST drop the unacceptable
+ // segment and stop processing further.
+ //
+ // By sending an ACK, the remote peer is challenged to confirm the loss
+ // of the previous connection and the request to start a new connection.
+ // A legitimate peer, after restart, would not have a TCB in the
+ // synchronized state. Thus, when the ACK arrives, the peer should send
+ // a RST segment back with the sequence number derived from the ACK
+ // field that caused the RST.
+
+ // This RST will confirm that the remote peer has indeed closed the
+ // previous connection. Upon receipt of a valid RST, the local TCP
+ // endpoint MUST terminate its connection. The local TCP endpoint
+ // should then rely on SYN retransmission from the remote end to
+ // re-establish the connection.
+
+ e.snd.sendAck()
} else if s.flagIsSet(header.TCPFlagAck) {
// Patch the window size in the segment according to the
// send window scale.
@@ -777,7 +960,15 @@ func (e *endpoint) handleSegments() *tcpip.Error {
// RFC 793, page 41 states that "once in the ESTABLISHED
// state all segments must carry current acknowledgment
// information."
- e.rcv.handleRcvdSegment(s)
+ drop, err := e.rcv.handleRcvdSegment(s)
+ if err != nil {
+ s.decRef()
+ return err
+ }
+ if drop {
+ s.decRef()
+ continue
+ }
e.snd.handleRcvdSegment(s)
}
s.decRef()
@@ -876,7 +1067,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
}
e.mu.Unlock()
-
// When the protocol loop exits we should wake up our waiters.
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
@@ -897,8 +1087,10 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
e.lastErrorMu.Unlock()
e.mu.Lock()
+ e.stack.Stats().TCP.EstablishedResets.Increment()
+ e.stack.Stats().TCP.CurrentEstablished.Decrement()
e.state = StateError
- e.hardError = err
+ e.HardError = err
// Lock released below.
epilogue()
@@ -920,6 +1112,10 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
// RTT itself.
e.rcvAutoParams.prevCopied = initialRcvWnd
e.rcvListMu.Unlock()
+ e.stack.Stats().TCP.CurrentEstablished.Increment()
+ e.mu.Lock()
+ e.state = StateEstablished
+ e.mu.Unlock()
}
e.keepalive.timer.init(&e.keepalive.waker)
@@ -927,7 +1123,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
// Tell waiters that the endpoint is connected and writable.
e.mu.Lock()
- e.state = StateEstablished
drained := e.drainDone != nil
e.mu.Unlock()
if drained {
@@ -958,7 +1153,13 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
{
w: &closeWaker,
f: func() *tcpip.Error {
- return tcpip.ErrConnectionAborted
+ // This means the socket is being closed due
+ // to the TCP_FIN_WAIT2 timeout was hit. Just
+ // mark the socket as closed.
+ e.mu.Lock()
+ e.state = StateClose
+ e.mu.Unlock()
+ return nil
},
},
{
@@ -1001,17 +1202,18 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
e.resetConnectionLocked(tcpip.ErrConnectionAborted)
e.mu.Unlock()
}
+
if n&notifyClose != 0 && closeTimer == nil {
- // Reset the connection 3 seconds after
- // the endpoint has been closed.
- //
- // The timer could fire in background
- // when the endpoint is drained. That's
- // OK as the loop here will not honor
- // the firing until the undrain arrives.
- closeTimer = time.AfterFunc(3*time.Second, func() {
- closeWaker.Assert()
- })
+ e.mu.Lock()
+ if e.state == StateFinWait2 && e.closed {
+ // The socket has been closed and we are in FIN_WAIT2
+ // so start the FIN_WAIT2 timer.
+ closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() {
+ closeWaker.Assert()
+ })
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ }
+ e.mu.Unlock()
}
if n&notifyKeepaliveChanged != 0 {
@@ -1033,6 +1235,12 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
}
}
+ if n&notifyTickleWorker != 0 {
+ // Just a tickle notification. No need to do
+ // anything.
+ return nil
+ }
+
return nil
},
},
@@ -1059,15 +1267,16 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
}
e.rcvListMu.Unlock()
- e.mu.RLock()
+ e.mu.Lock()
if e.workerCleanup {
e.notifyProtocolGoroutine(notifyClose)
}
- e.mu.RUnlock()
// Main loop. Handle segments until both send and receive ends of the
// connection have completed.
- for !e.rcv.closed || !e.snd.closed || e.snd.sndUna != e.snd.sndNxtList {
+
+ for e.state != StateTimeWait && e.state != StateClose && e.state != StateError {
+ e.mu.Unlock()
e.workMu.Unlock()
v, _ := s.Fetch(true)
e.workMu.Lock()
@@ -1083,15 +1292,156 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
return nil
}
+ e.mu.Lock()
+ }
+
+ state := e.state
+ e.mu.Unlock()
+ var reuseTW func()
+ if state == StateTimeWait {
+ // Disable close timer as we now entering real TIME_WAIT.
+ if closeTimer != nil {
+ closeTimer.Stop()
+ }
+ // Mark the current sleeper done so as to free all associated
+ // wakers.
+ s.Done()
+ // Wake up any waiters before we enter TIME_WAIT.
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ reuseTW = e.doTimeWait()
}
// Mark endpoint as closed.
e.mu.Lock()
if e.state != StateError {
+ e.stack.Stats().TCP.EstablishedResets.Increment()
+ e.stack.Stats().TCP.CurrentEstablished.Decrement()
e.state = StateClose
}
+
// Lock released below.
epilogue()
+ // A new SYN was received during TIME_WAIT and we need to abort
+ // the timewait and redirect the segment to the listener queue
+ if reuseTW != nil {
+ reuseTW()
+ }
+
return nil
}
+
+// handleTimeWaitSegments processes segments received during TIME_WAIT
+// state.
+func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()) {
+ checkRequeue := true
+ for i := 0; i < maxSegmentsPerWake; i++ {
+ s := e.segmentQueue.dequeue()
+ if s == nil {
+ checkRequeue = false
+ break
+ }
+ extTW, newSyn := e.rcv.handleTimeWaitSegment(s)
+ if newSyn {
+ info := e.EndpointInfo.TransportEndpointInfo
+ newID := info.ID
+ newID.RemoteAddress = ""
+ newID.RemotePort = 0
+ netProtos := []tcpip.NetworkProtocolNumber{info.NetProto}
+ // If the local address is an IPv4 address then also
+ // look for IPv6 dual stack endpoints that might be
+ // listening on the local address.
+ if newID.LocalAddress.To4() != "" {
+ netProtos = []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber}
+ }
+ for _, netProto := range netProtos {
+ if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, &s.route); listenEP != nil {
+ tcpEP := listenEP.(*endpoint)
+ if EndpointState(tcpEP.State()) == StateListen {
+ reuseTW = func() {
+ tcpEP.enqueueSegment(s)
+ }
+ // We explicitly do not decRef
+ // the segment as it's still
+ // valid and being reflected to
+ // a listening endpoint.
+ return false, reuseTW
+ }
+ }
+ }
+ }
+ if extTW {
+ extendTimeWait = true
+ }
+ s.decRef()
+ }
+ if checkRequeue && !e.segmentQueue.empty() {
+ e.newSegmentWaker.Assert()
+ }
+ return extendTimeWait, nil
+}
+
+// doTimeWait is responsible for handling the TCP behaviour once a socket
+// enters the TIME_WAIT state. Optionally it can return a closure that
+// should be executed after releasing the endpoint registrations. This is
+// done in cases where a new SYN is received during TIME_WAIT that carries
+// a sequence number larger than one see on the connection.
+func (e *endpoint) doTimeWait() (twReuse func()) {
+ // Trigger a 2 * MSL time wait state. During this period
+ // we will drop all incoming segments.
+ // NOTE: On Linux this is not configurable and is fixed at 60 seconds.
+ timeWaitDuration := DefaultTCPTimeWaitTimeout
+
+ // Get the stack wide configuration.
+ var tcpTW tcpip.TCPTimeWaitTimeoutOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &tcpTW); err == nil {
+ timeWaitDuration = time.Duration(tcpTW)
+ }
+
+ const newSegment = 1
+ const notification = 2
+ const timeWaitDone = 3
+
+ s := sleep.Sleeper{}
+ s.AddWaker(&e.newSegmentWaker, newSegment)
+ s.AddWaker(&e.notificationWaker, notification)
+
+ var timeWaitWaker sleep.Waker
+ s.AddWaker(&timeWaitWaker, timeWaitDone)
+ timeWaitTimer := time.AfterFunc(timeWaitDuration, timeWaitWaker.Assert)
+ defer timeWaitTimer.Stop()
+
+ for {
+ e.workMu.Unlock()
+ v, _ := s.Fetch(true)
+ e.workMu.Lock()
+ switch v {
+ case newSegment:
+ extendTimeWait, reuseTW := e.handleTimeWaitSegments()
+ if reuseTW != nil {
+ return reuseTW
+ }
+ if extendTimeWait {
+ timeWaitTimer.Reset(timeWaitDuration)
+ }
+ case notification:
+ n := e.fetchNotifications()
+ if n&notifyClose != 0 {
+ return nil
+ }
+ if n&notifyDrain != 0 {
+ for !e.segmentQueue.empty() {
+ // Ignore extending TIME_WAIT during a
+ // save. For sockets in TIME_WAIT we just
+ // terminate the TIME_WAIT early.
+ e.handleTimeWaitSegments()
+ }
+ close(e.drainDone)
+ <-e.undrain
+ return nil
+ }
+ case timeWaitDone:
+ return nil
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index c54610a87..dfaa4a559 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -42,7 +42,7 @@ func TestV4MappedConnectOnV6Only(t *testing.T) {
}
}
-func testV4Connect(t *testing.T, c *context.Context) {
+func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) {
// Start connection attempt.
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventOut)
@@ -55,12 +55,11 @@ func testV4Connect(t *testing.T, c *context.Context) {
// Receive SYN packet.
b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ),
- )
+ synCheckers := append(checkers, checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ))
+ checker.IPv4(t, b, synCheckers...)
tcp := header.TCP(header.IPv4(b).Payload())
c.IRS = seqnum.Value(tcp.SequenceNumber())
@@ -76,14 +75,13 @@ func testV4Connect(t *testing.T, c *context.Context) {
})
// Receive ACK packet.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+1),
- ),
- )
+ ackCheckers := append(checkers, checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(iss)+1),
+ ))
+ checker.IPv4(t, c.GetPacket(), ackCheckers...)
// Wait for connection to be established.
select {
@@ -152,7 +150,7 @@ func TestV4ConnectWhenBoundToV4Mapped(t *testing.T) {
testV4Connect(t, c)
}
-func testV6Connect(t *testing.T, c *context.Context) {
+func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) {
// Start connection attempt to IPv6 address.
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventOut)
@@ -165,12 +163,11 @@ func testV6Connect(t *testing.T, c *context.Context) {
// Receive SYN packet.
b := c.GetV6Packet()
- checker.IPv6(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ),
- )
+ synCheckers := append(checkers, checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ))
+ checker.IPv6(t, b, synCheckers...)
tcp := header.TCP(header.IPv6(b).Payload())
c.IRS = seqnum.Value(tcp.SequenceNumber())
@@ -186,14 +183,13 @@ func testV6Connect(t *testing.T, c *context.Context) {
})
// Receive ACK packet.
- checker.IPv6(t, c.GetV6Packet(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+1),
- ),
- )
+ ackCheckers := append(checkers, checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(iss)+1),
+ ))
+ checker.IPv6(t, c.GetV6Packet(), ackCheckers...)
// Wait for connection to be established.
select {
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 35b489c68..04c92c04c 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -15,6 +15,7 @@
package tcp
import (
+ "encoding/binary"
"fmt"
"math"
"strings"
@@ -26,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
@@ -119,6 +121,11 @@ const (
notifyReset
notifyKeepaliveChanged
notifyMSSChanged
+ // notifyTickleWorker is used to tickle the protocol main loop during a
+ // restore after we update the endpoint state to the correct one. This
+ // ensures the loop terminates if the final state of the endpoint is
+ // say TIME_WAIT.
+ notifyTickleWorker
)
// SACKInfo holds TCP SACK related information for a given endpoint.
@@ -170,6 +177,101 @@ type rcvBufAutoTuneParams struct {
disabled bool
}
+// ReceiveErrors collect segment receive errors within transport layer.
+type ReceiveErrors struct {
+ tcpip.ReceiveErrors
+
+ // SegmentQueueDropped is the number of segments dropped due to
+ // a full segment queue.
+ SegmentQueueDropped tcpip.StatCounter
+
+ // ChecksumErrors is the number of segments dropped due to bad checksums.
+ ChecksumErrors tcpip.StatCounter
+
+ // ListenOverflowSynDrop is the number of times the listen queue overflowed
+ // and a SYN was dropped.
+ ListenOverflowSynDrop tcpip.StatCounter
+
+ // ListenOverflowAckDrop is the number of times the final ACK
+ // in the handshake was dropped due to overflow.
+ ListenOverflowAckDrop tcpip.StatCounter
+
+ // ZeroRcvWindowState is the number of times we advertised
+ // a zero receive window when rcvList is full.
+ ZeroRcvWindowState tcpip.StatCounter
+}
+
+// SendErrors collect segment send errors within the transport layer.
+type SendErrors struct {
+ tcpip.SendErrors
+
+ // SegmentSendToNetworkFailed is the number of TCP segments failed to be sent
+ // to the network endpoint.
+ SegmentSendToNetworkFailed tcpip.StatCounter
+
+ // SynSendToNetworkFailed is the number of TCP SYNs failed to be sent
+ // to the network endpoint.
+ SynSendToNetworkFailed tcpip.StatCounter
+
+ // Retransmits is the number of TCP segments retransmitted.
+ Retransmits tcpip.StatCounter
+
+ // FastRetransmit is the number of segments retransmitted in fast
+ // recovery.
+ FastRetransmit tcpip.StatCounter
+
+ // Timeouts is the number of times the RTO expired.
+ Timeouts tcpip.StatCounter
+}
+
+// Stats holds statistics about the endpoint.
+type Stats struct {
+ // SegmentsReceived is the number of TCP segments received that
+ // the transport layer successfully parsed.
+ SegmentsReceived tcpip.StatCounter
+
+ // SegmentsSent is the number of TCP segments sent.
+ SegmentsSent tcpip.StatCounter
+
+ // FailedConnectionAttempts is the number of times we saw Connect and
+ // Accept errors.
+ FailedConnectionAttempts tcpip.StatCounter
+
+ // ReceiveErrors collects segment receive errors within the
+ // transport layer.
+ ReceiveErrors ReceiveErrors
+
+ // ReadErrors collects segment read errors from an endpoint read call.
+ ReadErrors tcpip.ReadErrors
+
+ // SendErrors collects segment send errors within the transport layer.
+ SendErrors SendErrors
+
+ // WriteErrors collects segment write errors from an endpoint write call.
+ WriteErrors tcpip.WriteErrors
+}
+
+// IsEndpointStats is an empty method to implement the tcpip.EndpointStats
+// marker interface.
+func (*Stats) IsEndpointStats() {}
+
+// EndpointInfo holds useful information about a transport endpoint which
+// can be queried by monitoring tools.
+//
+// +stateify savable
+type EndpointInfo struct {
+ stack.TransportEndpointInfo
+
+ // HardError is meaningful only when state is stateError. It stores the
+ // error to be returned when read/write syscalls are called and the
+ // endpoint is in this state. HardError is protected by endpoint mu.
+ HardError *tcpip.Error `state:".(string)"`
+}
+
+// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
+// marker interface.
+func (*EndpointInfo) IsEndpointInfo() {}
+
// endpoint represents a TCP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
@@ -178,6 +280,8 @@ type rcvBufAutoTuneParams struct {
//
// +stateify savable
type endpoint struct {
+ EndpointInfo
+
// workMu is used to arbitrate which goroutine may perform protocol
// work. Only the main protocol goroutine is expected to call Lock() on
// it, but other goroutines (e.g., send) may call TryLock() to eagerly
@@ -186,9 +290,9 @@ type endpoint struct {
// The following fields are initialized at creation time and do not
// change throughout the lifetime of the endpoint.
- stack *stack.Stack `state:"manual"`
- netProto tcpip.NetworkProtocolNumber
+ stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue `state:"wait"`
+ uniqueID uint64
// lastError represents the last error that the endpoint reported;
// access to it is protected by the following mutex.
@@ -218,14 +322,19 @@ type endpoint struct {
// The following fields are protected by the mutex.
mu sync.RWMutex `state:"nosave"`
- id stack.TransportEndpointID
state EndpointState `state:".(EndpointState)"`
+ // origEndpointState is only used during a restore phase to save the
+ // endpoint state at restore time as the socket is moved to it's correct
+ // state.
+ origEndpointState EndpointState `state:"nosave"`
+
isPortReserved bool `state:"manual"`
isRegistered bool
boundNICID tcpip.NICID `state:"manual"`
route stack.Route `state:"manual"`
+ ttl uint8
v6only bool
isConnectNotified bool
// TCP should never broadcast but Linux nevertheless supports enabling/
@@ -240,11 +349,6 @@ type endpoint struct {
// address).
effectiveNetProtos []tcpip.NetworkProtocolNumber `state:"manual"`
- // hardError is meaningful only when state is stateError, it stores the
- // error to be returned when read/write syscalls are called and the
- // endpoint is in this state. hardError is protected by mu.
- hardError *tcpip.Error `state:".(string)"`
-
// workerRunning specifies if a worker goroutine is running.
workerRunning bool
@@ -280,6 +384,9 @@ type endpoint struct {
// reusePort is set to true if SO_REUSEPORT is enabled.
reusePort bool
+ // bindToDevice is set to the NIC on which to bind or disabled if 0.
+ bindToDevice tcpip.NICID
+
// delay enables Nagle's algorithm.
//
// delay is a boolean (0 is false) and must be accessed atomically.
@@ -315,7 +422,7 @@ type endpoint struct {
// userMSS if non-zero is the MSS value explicitly set by the user
// for this endpoint using the TCP_MAXSEG setsockopt.
- userMSS int
+ userMSS uint16
// The following fields are used to manage the send buffer. When
// segments are ready to be sent, they are added to sndQueue and the
@@ -393,13 +500,49 @@ type endpoint struct {
probe stack.TCPProbeFunc `state:"nosave"`
// The following are only used to assist the restore run to re-connect.
- bindAddress tcpip.Address
connectingAddress tcpip.Address
// amss is the advertised MSS to the peer by this endpoint.
amss uint16
+ // sendTOS represents IPv4 TOS or IPv6 TrafficClass,
+ // applied while sending packets. Defaults to 0 as on Linux.
+ sendTOS uint8
+
gso *stack.GSO
+
+ // TODO(b/142022063): Add ability to save and restore per endpoint stats.
+ stats Stats `state:"nosave"`
+
+ // tcpLingerTimeout is the maximum amount of a time a socket
+ // a socket stays in TIME_WAIT state before being marked
+ // closed.
+ tcpLingerTimeout time.Duration
+
+ // closed indicates that the user has called closed on the
+ // endpoint and at this point the endpoint is only around
+ // to complete the TCP shutdown.
+ closed bool
+}
+
+// UniqueID implements stack.TransportEndpoint.UniqueID.
+func (e *endpoint) UniqueID() uint64 {
+ return e.uniqueID
+}
+
+// calculateAdvertisedMSS calculates the MSS to advertise.
+//
+// If userMSS is non-zero and is not greater than the maximum possible MSS for
+// r, it will be used; otherwise, the maximum possible MSS will be used.
+func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 {
+ // The maximum possible MSS is dependent on the route.
+ maxMSS := mssForRoute(&r)
+
+ if userMSS != 0 && userMSS < maxMSS {
+ return userMSS
+ }
+
+ return maxMSS
}
// StopWork halts packet processing. Only to be used in tests.
@@ -427,10 +570,15 @@ type keepalive struct {
waker sleep.Waker `state:"nosave"`
}
-func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
+func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
e := &endpoint{
- stack: stack,
- netProto: netProto,
+ stack: s,
+ EndpointInfo: EndpointInfo{
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ TransProto: header.TCPProtocolNumber,
+ },
+ },
waiterQueue: waiterQueue,
state: StateInitial,
rcvBufSize: DefaultReceiveBufferSize,
@@ -443,29 +591,40 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite
interval: 75 * time.Second,
count: 9,
},
+ uniqueID: s.UniqueID(),
}
var ss SendBufferSizeOption
- if err := stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
e.sndBufSize = ss.Default
}
var rs ReceiveBufferSizeOption
- if err := stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
e.rcvBufSize = rs.Default
}
var cs tcpip.CongestionControlOption
- if err := stack.TransportProtocolOption(ProtocolNumber, &cs); err == nil {
+ if err := s.TransportProtocolOption(ProtocolNumber, &cs); err == nil {
e.cc = cs
}
var mrb tcpip.ModerateReceiveBufferOption
- if err := stack.TransportProtocolOption(ProtocolNumber, &mrb); err == nil {
+ if err := s.TransportProtocolOption(ProtocolNumber, &mrb); err == nil {
e.rcvAutoParams.disabled = !bool(mrb)
}
- if p := stack.GetTCPProbe(); p != nil {
+ var de DelayEnabled
+ if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de {
+ e.SetSockOptInt(tcpip.DelayOption, 1)
+ }
+
+ var tcpLT tcpip.TCPLingerTimeoutOption
+ if err := s.TransportProtocolOption(ProtocolNumber, &tcpLT); err == nil {
+ e.tcpLingerTimeout = time.Duration(tcpLT)
+ }
+
+ if p := s.GetTCPProbe(); p != nil {
e.probe = p
}
@@ -552,6 +711,13 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) {
// with it. It must be called only once and with no other concurrent calls to
// the endpoint.
func (e *endpoint) Close() {
+ e.mu.Lock()
+ closed := e.closed
+ e.mu.Unlock()
+ if closed {
+ return
+ }
+
// Issue a shutdown so that the peer knows we won't send any more data
// if we're connected, or stop accepting if we're listening.
e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
@@ -564,14 +730,16 @@ func (e *endpoint) Close() {
// in Listen() when trying to register.
if e.state == StateListen && e.isPortReserved {
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
e.isRegistered = false
}
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
e.isPortReserved = false
}
+ // Mark endpoint as closed.
+ e.closed = true
// Either perform the local cleanup or kick the worker to make sure it
// knows it needs to cleanup.
tcpip.AddDanglingEndpoint(e)
@@ -597,9 +765,7 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() {
go func() {
defer close(done)
for n := range e.acceptedChan {
- n.mu.Lock()
- n.resetConnectionLocked(tcpip.ErrConnectionAborted)
- n.mu.Unlock()
+ n.notifyProtocolGoroutine(notifyReset)
n.Close()
}
}()
@@ -625,16 +791,17 @@ func (e *endpoint) cleanupLocked() {
e.workerCleanup = false
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
e.isRegistered = false
}
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
e.isPortReserved = false
}
e.route.Release()
+ e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
@@ -645,7 +812,9 @@ func (e *endpoint) initialReceiveWindow() int {
if rcvWnd > math.MaxUint16 {
rcvWnd = math.MaxUint16
}
- routeWnd := InitialCwnd * int(mssForRoute(&e.route)) * 2
+
+ // Use the user supplied MSS, if available.
+ routeWnd := InitialCwnd * int(calculateAdvertisedMSS(e.userMSS, e.route)) * 2
if rcvWnd > routeWnd {
rcvWnd = routeWnd
}
@@ -731,11 +900,12 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
bufUsed := e.rcvBufUsed
if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 {
e.rcvListMu.Unlock()
- he := e.hardError
+ he := e.HardError
e.mu.RUnlock()
if s == StateError {
return buffer.View{}, tcpip.ControlMessages{}, he
}
+ e.stats.ReadErrors.InvalidEndpointState.Increment()
return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
}
@@ -744,6 +914,9 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
e.mu.RUnlock()
+ if err == tcpip.ErrClosedForReceive {
+ e.stats.ReadErrors.ReadClosed.Increment()
+ }
return v, tcpip.ControlMessages{}, err
}
@@ -787,7 +960,7 @@ func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
if !e.state.connected() {
switch e.state {
case StateError:
- return 0, e.hardError
+ return 0, e.HardError
default:
return 0, tcpip.ErrClosedForSend
}
@@ -818,6 +991,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
if err != nil {
e.sndBufMu.Unlock()
e.mu.RUnlock()
+ e.stats.WriteErrors.WriteClosed.Increment()
return 0, nil, err
}
@@ -852,6 +1026,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
if err != nil {
e.sndBufMu.Unlock()
e.mu.RUnlock()
+ e.stats.WriteErrors.WriteClosed.Increment()
return 0, nil, err
}
@@ -863,7 +1038,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
// Add data to the send queue.
- s := newSegmentFromView(&e.route, e.id, v)
+ s := newSegmentFromView(&e.route, e.ID, v)
e.sndBufUsed += len(v)
e.sndBufInQueue += seqnum.Size(len(v))
e.sndQueue.PushBack(s)
@@ -895,8 +1070,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
// but has some pending unread data.
if s := e.state; !s.connected() && s != StateClose {
if s == StateError {
- return 0, tcpip.ControlMessages{}, e.hardError
+ return 0, tcpip.ControlMessages{}, e.HardError
}
+ e.stats.ReadErrors.InvalidEndpointState.Increment()
return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
}
@@ -905,6 +1081,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
if e.rcvBufUsed == 0 {
if e.rcvClosed || !e.state.connected() {
+ e.stats.ReadErrors.ReadClosed.Increment()
return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
}
return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
@@ -1018,14 +1195,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
e.sndBufMu.Unlock()
return nil
- default:
- return nil
- }
-}
-
-// SetSockOpt sets a socket option.
-func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- switch v := opt.(type) {
case tcpip.DelayOption:
if v == 0 {
atomic.StoreUint32(&e.delay, 0)
@@ -1037,6 +1206,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ default:
+ return nil
+ }
+}
+
+// SetSockOpt sets a socket option.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ // Lower 2 bits represents ECN bits. RFC 3168, section 23.1
+ const inetECNMask = 3
+ switch v := opt.(type) {
case tcpip.CorkOption:
if v == 0 {
atomic.StoreUint32(&e.cork, 0)
@@ -1060,6 +1239,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
return nil
+ case tcpip.BindToDeviceOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if v == "" {
+ e.bindToDevice = 0
+ return nil
+ }
+ for nicID, nic := range e.stack.NICInfo() {
+ if nic.Name == string(v) {
+ e.bindToDevice = nicID
+ return nil
+ }
+ }
+ return tcpip.ErrUnknownDevice
+
case tcpip.QuickAckOption:
if v == 0 {
atomic.StoreUint32(&e.slowAck, 1)
@@ -1074,14 +1268,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrInvalidOptionValue
}
e.mu.Lock()
- e.userMSS = int(userMSS)
+ e.userMSS = uint16(userMSS)
e.mu.Unlock()
e.notifyProtocolGoroutine(notifyMSSChanged)
return nil
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
- if e.netProto != header.IPv6ProtocolNumber {
+ if e.NetProto != header.IPv6ProtocolNumber {
return tcpip.ErrInvalidEndpointState
}
@@ -1096,6 +1290,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.v6only = v != 0
return nil
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ e.ttl = uint8(v)
+ e.mu.Unlock()
+ return nil
+
case tcpip.KeepaliveEnabledOption:
e.keepalive.Lock()
e.keepalive.enabled = v != 0
@@ -1164,6 +1364,45 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
// Linux returns ENOENT when an invalid congestion
// control algorithm is specified.
return tcpip.ErrNoSuchFile
+
+ case tcpip.IPv4TOSOption:
+ e.mu.Lock()
+ // TODO(gvisor.dev/issue/995): ECN is not currently supported,
+ // ignore the bits for now.
+ e.sendTOS = uint8(v) & ^uint8(inetECNMask)
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.IPv6TrafficClassOption:
+ e.mu.Lock()
+ // TODO(gvisor.dev/issue/995): ECN is not currently supported,
+ // ignore the bits for now.
+ e.sendTOS = uint8(v) & ^uint8(inetECNMask)
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.TCPLingerTimeoutOption:
+ e.mu.Lock()
+ if v < 0 {
+ // Same as effectively disabling TCPLinger timeout.
+ v = 0
+ }
+ var stkTCPLingerTimeout tcpip.TCPLingerTimeoutOption
+ if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &stkTCPLingerTimeout); err != nil {
+ // We were unable to retrieve a stack config, just use
+ // the DefaultTCPLingerTimeout.
+ if v > tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) {
+ stkTCPLingerTimeout = tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout)
+ }
+ }
+ // Cap it to the stack wide TCPLinger timeout.
+ if v > stkTCPLingerTimeout {
+ v = stkTCPLingerTimeout
+ }
+ e.tcpLingerTimeout = time.Duration(v)
+ e.mu.Unlock()
+ return nil
+
default:
return nil
}
@@ -1190,6 +1429,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
+
case tcpip.SendBufferSizeOption:
e.sndBufMu.Lock()
v := e.sndBufSize
@@ -1202,8 +1442,16 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
e.rcvListMu.Unlock()
return v, nil
+ case tcpip.DelayOption:
+ var o int
+ if v := atomic.LoadUint32(&e.delay); v != 0 {
+ o = 1
+ }
+ return o, nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
}
- return -1, tcpip.ErrUnknownProtocolOption
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
@@ -1224,13 +1472,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = header.TCPDefaultMSS
return nil
- case *tcpip.DelayOption:
- *o = 0
- if v := atomic.LoadUint32(&e.delay); v != 0 {
- *o = 1
- }
- return nil
-
case *tcpip.CorkOption:
*o = 0
if v := atomic.LoadUint32(&e.cork); v != 0 {
@@ -1260,6 +1501,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.BindToDeviceOption:
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
+ *o = tcpip.BindToDeviceOption(nic.Name)
+ return nil
+ }
+ *o = ""
+ return nil
+
case *tcpip.QuickAckOption:
*o = 1
if v := atomic.LoadUint32(&e.slowAck); v != 0 {
@@ -1269,7 +1520,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case *tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
- if e.netProto != header.IPv6ProtocolNumber {
+ if e.NetProto != header.IPv6ProtocolNumber {
return tcpip.ErrUnknownProtocolOption
}
@@ -1283,6 +1534,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.TTLOption:
+ e.mu.Lock()
+ *o = tcpip.TTLOption(e.ttl)
+ e.mu.Unlock()
+ return nil
+
case *tcpip.TCPInfoOption:
*o = tcpip.TCPInfoOption{}
e.mu.RLock()
@@ -1347,13 +1604,31 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
return nil
+ case *tcpip.IPv4TOSOption:
+ e.mu.RLock()
+ *o = tcpip.IPv4TOSOption(e.sendTOS)
+ e.mu.RUnlock()
+ return nil
+
+ case *tcpip.IPv6TrafficClassOption:
+ e.mu.RLock()
+ *o = tcpip.IPv6TrafficClassOption(e.sendTOS)
+ e.mu.RUnlock()
+ return nil
+
+ case *tcpip.TCPLingerTimeoutOption:
+ e.mu.Lock()
+ *o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout)
+ e.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
}
func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto := e.netProto
+ netProto := e.NetProto
if header.IsV4MappedAddress(addr.Addr) {
// Fail if using a v4 mapped address on a v6only endpoint.
if e.v6only {
@@ -1369,7 +1644,7 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol
// Fail if we're bound to an address length different from the one we're
// checking.
- if l := len(e.id.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) {
+ if l := len(e.ID.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) {
return 0, tcpip.ErrInvalidEndpointState
}
@@ -1383,7 +1658,12 @@ func (*endpoint) Disconnect() *tcpip.Error {
// Connect connects the endpoint to its peer.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
- return e.connect(addr, true, true)
+ err := e.connect(addr, true, true)
+ if err != nil && !err.IgnoreStats() {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ }
+ return err
}
// connect connects the endpoint to its peer. In the normal non-S/R case, the
@@ -1392,14 +1672,9 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// created (so no new handshaking is done); for stack-accepted connections not
// yet accepted by the app, they are restored without running the main goroutine
// here.
-func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (err *tcpip.Error) {
+func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- defer func() {
- if err != nil && !err.IgnoreStats() {
- e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
- }
- }()
connectingAddr := addr.Addr
@@ -1419,7 +1694,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
return tcpip.ErrAlreadyConnected
}
- nicid := addr.NIC
+ nicID := addr.NIC
switch e.state {
case StateBound:
// If we're already bound to a NIC but the caller is requesting
@@ -1428,11 +1703,11 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
break
}
- if nicid != 0 && nicid != e.boundNICID {
+ if nicID != 0 && nicID != e.boundNICID {
return tcpip.ErrNoRoute
}
- nicid = e.boundNICID
+ nicID = e.boundNICID
case StateInitial:
// Nothing to do. We'll eventually fill-in the gaps in the ID (if any)
@@ -1444,29 +1719,29 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
return tcpip.ErrAlreadyConnecting
case StateError:
- return e.hardError
+ return e.HardError
default:
return tcpip.ErrInvalidEndpointState
}
// Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto, false /* multicastLoop */)
+ r, err := e.stack.FindRoute(nicID, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */)
if err != nil {
return err
}
defer r.Release()
- origID := e.id
+ origID := e.ID
netProtos := []tcpip.NetworkProtocolNumber{netProto}
- e.id.LocalAddress = r.LocalAddress
- e.id.RemoteAddress = r.RemoteAddress
- e.id.RemotePort = addr.Port
+ e.ID.LocalAddress = r.LocalAddress
+ e.ID.RemoteAddress = r.RemoteAddress
+ e.ID.RemotePort = addr.Port
- if e.id.LocalPort != 0 {
+ if e.ID.LocalPort != 0 {
// The endpoint is bound to a port, attempt to register it.
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice)
if err != nil {
return err
}
@@ -1475,20 +1750,35 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
// one. Make sure that it isn't one that will result in the same
// address/port for both local and remote (otherwise this
// endpoint would be trying to connect to itself).
- sameAddr := e.id.LocalAddress == e.id.RemoteAddress
- if _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
- if sameAddr && p == e.id.RemotePort {
+ sameAddr := e.ID.LocalAddress == e.ID.RemoteAddress
+
+ // Calculate a port offset based on the destination IP/port and
+ // src IP to ensure that for a given tuple (srcIP, destIP,
+ // destPort) the offset used as a starting point is the same to
+ // ensure that we can cycle through the port space effectively.
+ h := jenkins.Sum32(e.stack.Seed())
+ h.Write([]byte(e.ID.LocalAddress))
+ h.Write([]byte(e.ID.RemoteAddress))
+ portBuf := make([]byte, 2)
+ binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort)
+ h.Write(portBuf)
+ portOffset := h.Sum32()
+
+ if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) {
+ if sameAddr && p == e.ID.RemotePort {
return false, nil
}
- if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false) {
+ // reusePort is false below because connect cannot reuse a port even if
+ // reusePort was set.
+ if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, false /* reusePort */, e.bindToDevice) {
return false, nil
}
- id := e.id
+ id := e.ID
id.LocalPort = p
- switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) {
+ switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) {
case nil:
- e.id = id
+ e.ID = id
return true, nil
case tcpip.ErrPortInUse:
return false, nil
@@ -1504,14 +1794,14 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
// before Connect: in such a case we don't want to hold on to
// reservations anymore.
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.bindToDevice)
e.isPortReserved = false
}
e.isRegistered = true
e.state = StateConnecting
e.route = r.Clone()
- e.boundNICID = nicid
+ e.boundNICID = nicID
e.effectiveNetProtos = netProtos
e.connectingAddress = connectingAddr
@@ -1523,7 +1813,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
e.segmentQueue.mu.Lock()
for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} {
for s := l.Front(); s != nil; s = s.Next() {
- s.id = e.id
+ s.id = e.ID
s.route = r.Clone()
e.sndWaker.Assert()
}
@@ -1531,6 +1821,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
e.segmentQueue.mu.Unlock()
e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0)
e.state = StateEstablished
+ e.stack.Stats().TCP.CurrentEstablished.Increment()
}
if run {
@@ -1551,9 +1842,8 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
// peer.
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
e.mu.Lock()
- defer e.mu.Unlock()
e.shutdownFlags |= flags
-
+ finQueued := false
switch {
case e.state.connected():
// Close for read.
@@ -1568,6 +1858,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// the connection with a RST.
if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 {
e.notifyProtocolGoroutine(notifyReset)
+ e.mu.Unlock()
return nil
}
}
@@ -1583,17 +1874,14 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
}
// Queue fin segment.
- s := newSegmentFromView(&e.route, e.id, nil)
+ s := newSegmentFromView(&e.route, e.ID, nil)
e.sndQueue.PushBack(s)
e.sndBufInQueue++
-
+ finQueued = true
// Mark endpoint as closed.
e.sndClosed = true
e.sndBufMu.Unlock()
-
- // Tell protocol goroutine to close.
- e.sndCloseWaker.Assert()
}
case e.state == StateListen:
@@ -1601,24 +1889,37 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
if flags&tcpip.ShutdownRead != 0 {
e.notifyProtocolGoroutine(notifyClose)
}
-
default:
+ e.mu.Unlock()
return tcpip.ErrNotConnected
}
-
+ e.mu.Unlock()
+ if finQueued {
+ if e.workMu.TryLock() {
+ e.handleClose()
+ e.workMu.Unlock()
+ } else {
+ // Tell protocol goroutine to close.
+ e.sndCloseWaker.Assert()
+ }
+ }
return nil
}
// Listen puts the endpoint in "listen" mode, which allows it to accept
// new connections.
-func (e *endpoint) Listen(backlog int) (err *tcpip.Error) {
+func (e *endpoint) Listen(backlog int) *tcpip.Error {
+ err := e.listen(backlog)
+ if err != nil && !err.IgnoreStats() {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ }
+ return err
+}
+
+func (e *endpoint) listen(backlog int) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- defer func() {
- if err != nil && !err.IgnoreStats() {
- e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
- }
- }()
// Allow the backlog to be adjusted if the endpoint is not shutting down.
// When the endpoint shuts down, it sets workerCleanup to true, and from
@@ -1644,11 +1945,12 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) {
// Endpoint must be bound before it can transition to listen mode.
if e.state != StateBound {
+ e.stats.ReadErrors.InvalidEndpointState.Increment()
return tcpip.ErrInvalidEndpointState
}
// Register the endpoint.
- if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort); err != nil {
+ if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice); err != nil {
return err
}
@@ -1692,12 +1994,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrWouldBlock
}
- // Start the protocol goroutine.
- wq := &waiter.Queue{}
- n.startAcceptedLoop(wq)
- e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
-
- return n, wq, nil
+ return n, n.waiterQueue, nil
}
// Bind binds the endpoint to a specific local port and optionally address.
@@ -1712,7 +2009,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
return tcpip.ErrAlreadyBound
}
- e.bindAddress = addr.Addr
+ e.BindAddr = addr.Addr
netProto, err := e.checkV4Mapped(&addr)
if err != nil {
return err
@@ -1729,26 +2026,26 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
}
}
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort, e.bindToDevice)
if err != nil {
return err
}
e.isPortReserved = true
e.effectiveNetProtos = netProtos
- e.id.LocalPort = port
+ e.ID.LocalPort = port
// Any failures beyond this point must remove the port registration.
- defer func() {
+ defer func(bindToDevice tcpip.NICID) {
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, bindToDevice)
e.isPortReserved = false
e.effectiveNetProtos = nil
- e.id.LocalPort = 0
- e.id.LocalAddress = ""
+ e.ID.LocalPort = 0
+ e.ID.LocalAddress = ""
e.boundNICID = 0
}
- }()
+ }(e.bindToDevice)
// If an address is specified, we must ensure that it's one of our
// local addresses.
@@ -1759,7 +2056,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
}
e.boundNICID = nic
- e.id.LocalAddress = addr.Addr
+ e.ID.LocalAddress = addr.Addr
}
// Mark endpoint as bound.
@@ -1774,8 +2071,8 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
defer e.mu.RUnlock()
return tcpip.FullAddress{
- Addr: e.id.LocalAddress,
- Port: e.id.LocalPort,
+ Addr: e.ID.LocalAddress,
+ Port: e.ID.LocalPort,
NIC: e.boundNICID,
}, nil
}
@@ -1790,19 +2087,20 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
}
return tcpip.FullAddress{
- Addr: e.id.RemoteAddress,
- Port: e.id.RemotePort,
+ Addr: e.ID.RemoteAddress,
+ Port: e.ID.RemotePort,
NIC: e.boundNICID,
}, nil
}
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
- s := newSegment(r, id, vv)
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
+ s := newSegment(r, id, pkt)
if !s.parse() {
e.stack.Stats().MalformedRcvdPackets.Increment()
e.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
+ e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
s.decRef()
return
}
@@ -1810,27 +2108,34 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
if !s.csumValid {
e.stack.Stats().MalformedRcvdPackets.Increment()
e.stack.Stats().TCP.ChecksumErrors.Increment()
+ e.stats.ReceiveErrors.ChecksumErrors.Increment()
s.decRef()
return
}
e.stack.Stats().TCP.ValidSegmentsReceived.Increment()
+ e.stats.SegmentsReceived.Increment()
if (s.flags & header.TCPFlagRst) != 0 {
e.stack.Stats().TCP.ResetsReceived.Increment()
}
+ e.enqueueSegment(s)
+}
+
+func (e *endpoint) enqueueSegment(s *segment) {
// Send packet to worker goroutine.
if e.segmentQueue.enqueue(s) {
e.newSegmentWaker.Assert()
} else {
// The queue is full, so we drop the segment.
e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.SegmentQueueDropped.Increment()
s.decRef()
}
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
switch typ {
case stack.ControlPacketTooBig:
e.sndBufMu.Lock()
@@ -1874,6 +2179,7 @@ func (e *endpoint) readyToRead(s *segment) {
// that a subsequent read of the segment will correctly trigger
// a non-zero notification.
if avail := e.receiveBufferAvailableLocked(); avail>>e.rcv.rcvWndScale == 0 {
+ e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
e.zeroWindow = true
}
e.rcvList.PushBack(s)
@@ -2026,7 +2332,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
// Copy EndpointID.
e.mu.Lock()
- s.ID = stack.TCPEndpointID(e.id)
+ s.ID = stack.TCPEndpointID(e.ID)
e.mu.Unlock()
// Copy endpoint rcv state.
@@ -2119,11 +2425,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
return s
}
-func (e *endpoint) initGSO() {
- if e.route.Capabilities()&stack.CapabilityGSO == 0 {
- return
- }
-
+func (e *endpoint) initHardwareGSO() {
gso := &stack.GSO{}
switch e.route.NetProto {
case header.IPv4ProtocolNumber:
@@ -2133,7 +2435,7 @@ func (e *endpoint) initGSO() {
gso.Type = stack.GSOTCPv6
gso.L3HdrLen = header.IPv6MinimumSize
default:
- panic(fmt.Sprintf("Unknown netProto: %v", e.netProto))
+ panic(fmt.Sprintf("Unknown netProto: %v", e.NetProto))
}
gso.NeedsCsum = true
gso.CsumOffset = header.TCPChecksumOffset
@@ -2141,6 +2443,18 @@ func (e *endpoint) initGSO() {
e.gso = gso
}
+func (e *endpoint) initGSO() {
+ if e.route.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ e.initHardwareGSO()
+ } else if e.route.Capabilities()&stack.CapabilitySoftwareGSO != 0 {
+ e.gso = &stack.GSO{
+ MaxSize: e.route.GSOMaxSize(),
+ Type: stack.GSOSW,
+ NeedsCsum: false,
+ }
+ }
+}
+
// State implements tcpip.Endpoint.State. It exports the endpoint's protocol
// state for diagnostics.
func (e *endpoint) State() uint32 {
@@ -2149,6 +2463,37 @@ func (e *endpoint) State() uint32 {
return uint32(e.state)
}
+// Info returns a copy of the endpoint info.
+func (e *endpoint) Info() tcpip.EndpointInfo {
+ e.mu.RLock()
+ // Make a copy of the endpoint info.
+ ret := e.EndpointInfo
+ e.mu.RUnlock()
+ return &ret
+}
+
+// Stats returns a pointer to the endpoint stats.
+func (e *endpoint) Stats() tcpip.EndpointStats {
+ return &e.stats
+}
+
+// Wait implements stack.TransportEndpoint.Wait.
+func (e *endpoint) Wait() {
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ e.waiterQueue.EventRegister(&waitEntry, waiter.EventHUp)
+ defer e.waiterQueue.EventUnregister(&waitEntry)
+ for {
+ e.mu.Lock()
+ running := e.workerRunning
+ e.mu.Unlock()
+ if !running {
+ break
+ }
+ <-notifyCh
+ }
+}
+
func mssForRoute(r *stack.Route) uint16 {
+ // TODO(b/143359391): Respect TCP Min and Max size.
return uint16(r.MTU() - header.TCPMinimumSize)
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 831389ec7..7aa4c3f0e 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -55,7 +55,7 @@ func (e *endpoint) beforeSave() {
case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 {
if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 {
- panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)})
+ panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)})
}
e.resetConnectionLocked(tcpip.ErrConnectionAborted)
e.mu.Unlock()
@@ -78,7 +78,7 @@ func (e *endpoint) beforeSave() {
}
fallthrough
case StateError, StateClose:
- for e.state == StateError && e.workerRunning {
+ for (e.state == StateError || e.state == StateClose) && e.workerRunning {
e.mu.Unlock()
time.Sleep(100 * time.Millisecond)
e.mu.Lock()
@@ -165,6 +165,12 @@ func (e *endpoint) loadState(state EndpointState) {
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
+ // Freeze segment queue before registering to prevent any segments
+ // from being delivered while it is being restored.
+ e.origEndpointState = e.state
+ // Restore the endpoint to InitialState as it will be moved to
+ // its origEndpointState during Resume.
+ e.state = StateInitial
stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
@@ -173,8 +179,8 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
e.segmentQueue.setLimit(MaxUnprocessedSegments)
e.workMu.Init()
+ state := e.origEndpointState
- state := e.state
switch state {
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
var ss SendBufferSizeOption
@@ -189,12 +195,13 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
bind := func() {
- e.state = StateInitial
- if len(e.bindAddress) == 0 {
- e.bindAddress = e.id.LocalAddress
+ if len(e.BindAddr) == 0 {
+ e.BindAddr = e.ID.LocalAddress
}
- if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}); err != nil {
- panic("endpoint binding failed: " + err.String())
+ addr := e.BindAddr
+ port := e.ID.LocalPort
+ if err := e.Bind(tcpip.FullAddress{Addr: addr, Port: port}); err != nil {
+ panic(fmt.Sprintf("endpoint binding [%v]:%d failed: %v", addr, port, err))
}
}
@@ -202,21 +209,31 @@ func (e *endpoint) Resume(s *stack.Stack) {
case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
bind()
if len(e.connectingAddress) == 0 {
- e.connectingAddress = e.id.RemoteAddress
+ e.connectingAddress = e.ID.RemoteAddress
// This endpoint is accepted by netstack but not yet by
// the app. If the endpoint is IPv6 but the remote
// address is IPv4, we need to connect as IPv6 so that
// dual-stack mode can be properly activated.
- if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize {
- e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress
+ if e.NetProto == header.IPv6ProtocolNumber && len(e.ID.RemoteAddress) != header.IPv6AddressSize {
+ e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.ID.RemoteAddress
}
}
// Reset the scoreboard to reinitialize the sack information as
// we do not restore SACK information.
e.scoreboard.Reset()
- if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted {
+ if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted {
panic("endpoint connecting failed: " + err.String())
}
+ e.mu.Lock()
+ e.state = e.origEndpointState
+ closed := e.closed
+ e.mu.Unlock()
+ e.notifyProtocolGoroutine(notifyTickleWorker)
+ if state == StateFinWait2 && closed {
+ // If the endpoint has been closed then make sure we notify so
+ // that the FIN_WAIT2 timer is started after a restore.
+ e.notifyProtocolGoroutine(notifyClose)
+ }
connectedLoading.Done()
case StateListen:
tcpip.AsyncLoading.Add(1)
@@ -236,7 +253,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
connectedLoading.Wait()
listenLoading.Wait()
bind()
- if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted {
+ if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}); err != tcpip.ErrConnectStarted {
panic("endpoint connecting failed: " + err.String())
}
connectingLoading.Done()
@@ -263,8 +280,12 @@ func (e *endpoint) Resume(s *stack.Stack) {
tcpip.AsyncLoading.Done()
}()
}
- fallthrough
+ e.state = StateClose
+ e.stack.CompleteTransportEndpointCleanup(e)
+ tcpip.DeleteDanglingEndpoint(e)
case StateError:
+ e.state = StateError
+ e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
}
@@ -288,21 +309,21 @@ func (e *endpoint) loadLastError(s string) {
}
// saveHardError is invoked by stateify.
-func (e *endpoint) saveHardError() string {
- if e.hardError == nil {
+func (e *EndpointInfo) saveHardError() string {
+ if e.HardError == nil {
return ""
}
- return e.hardError.String()
+ return e.HardError.String()
}
// loadHardError is invoked by stateify.
-func (e *endpoint) loadHardError(s string) {
+func (e *EndpointInfo) loadHardError(s string) {
if s == "" {
return
}
- e.hardError = loadError(s)
+ e.HardError = loadError(s)
}
var messageToError map[string]*tcpip.Error
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 63666f0b3..4983bca81 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -18,7 +18,6 @@ import (
"sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -63,8 +62,8 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
- s := newSegment(r, id, vv)
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool {
+ s := newSegment(r, id, pkt)
defer s.decRef()
// We only care about well-formed SYN packets.
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index d5d8ab96a..89b965c23 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -23,6 +23,7 @@ package tcp
import (
"strings"
"sync"
+ "time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -54,12 +55,23 @@ const (
// MaxUnprocessedSegments is the maximum number of unprocessed segments
// that can be queued for a given endpoint.
MaxUnprocessedSegments = 300
+
+ // DefaultTCPLingerTimeout is the amount of time that sockets linger in
+ // FIN_WAIT_2 state before being marked closed.
+ DefaultTCPLingerTimeout = 60 * time.Second
+
+ // DefaultTCPTimeWaitTimeout is the amount of time that sockets linger
+ // in TIME_WAIT state before being marked closed.
+ DefaultTCPTimeWaitTimeout = 60 * time.Second
)
// SACKEnabled option can be used to enable SACK support in the TCP
// protocol. See: https://tools.ietf.org/html/rfc2018.
type SACKEnabled bool
+// DelayEnabled option can be used to enable Nagle's algorithm in the TCP protocol.
+type DelayEnabled bool
+
// SendBufferSizeOption allows the default, min and max send buffer sizes for
// TCP endpoints to be queried or configured.
type SendBufferSizeOption struct {
@@ -84,11 +96,14 @@ const (
type protocol struct {
mu sync.Mutex
sackEnabled bool
+ delayEnabled bool
sendBufferSize SendBufferSizeOption
recvBufferSize ReceiveBufferSizeOption
congestionControl string
availableCongestionControl []string
moderateReceiveBuffer bool
+ tcpLingerTimeout time.Duration
+ tcpTimeWaitTimeout time.Duration
}
// Number returns the tcp protocol number.
@@ -97,7 +112,7 @@ func (*protocol) Number() tcpip.TransportProtocolNumber {
}
// NewEndpoint creates a new tcp endpoint.
-func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newEndpoint(stack, netProto, waiterQueue), nil
}
@@ -126,8 +141,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// a reset is sent in response to any incoming segment except another reset. In
// particular, SYNs addressed to a non-existent connection are rejected by this
// means."
-func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
- s := newSegment(r, id, vv)
+func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool {
+ s := newSegment(r, id, pkt)
defer s.decRef()
if !s.parse() || !s.csumValid {
@@ -153,7 +168,7 @@ func replyWithReset(s *segment) {
ack := s.sequenceNumber.Add(s.logicalLen())
- sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0, nil /* options */, nil /* gso */)
+ sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */)
}
// SetOption implements TransportProtocol.SetOption.
@@ -165,6 +180,12 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
+ case DelayEnabled:
+ p.mu.Lock()
+ p.delayEnabled = bool(v)
+ p.mu.Unlock()
+ return nil
+
case SendBufferSizeOption:
if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
return tcpip.ErrInvalidOptionValue
@@ -202,6 +223,24 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
+ case tcpip.TCPLingerTimeoutOption:
+ if v < 0 {
+ v = 0
+ }
+ p.mu.Lock()
+ p.tcpLingerTimeout = time.Duration(v)
+ p.mu.Unlock()
+ return nil
+
+ case tcpip.TCPTimeWaitTimeoutOption:
+ if v < 0 {
+ v = 0
+ }
+ p.mu.Lock()
+ p.tcpTimeWaitTimeout = time.Duration(v)
+ p.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -216,6 +255,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
+ case *DelayEnabled:
+ p.mu.Lock()
+ *v = DelayEnabled(p.delayEnabled)
+ p.mu.Unlock()
+ return nil
+
case *SendBufferSizeOption:
p.mu.Lock()
*v = p.sendBufferSize
@@ -246,6 +291,18 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
+ case *tcpip.TCPLingerTimeoutOption:
+ p.mu.Lock()
+ *v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout)
+ p.mu.Unlock()
+ return nil
+
+ case *tcpip.TCPTimeWaitTimeoutOption:
+ p.mu.Lock()
+ *v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout)
+ p.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -258,5 +315,7 @@ func NewProtocol() stack.TransportProtocol {
recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize},
congestionControl: ccReno,
availableCongestionControl: []string{ccReno, ccCubic},
+ tcpLingerTimeout: DefaultTCPLingerTimeout,
+ tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout,
}
}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index e90f9a7d9..068b90fb6 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -18,6 +18,7 @@ import (
"container/heap"
"time"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
)
@@ -209,6 +210,11 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
switch r.ep.state {
case StateFinWait1:
r.ep.state = StateFinWait2
+ // Notify protocol goroutine that we have received an
+ // ACK to our FIN so that it can start the FIN_WAIT2
+ // timer to abort connection if the other side does
+ // not close within 2MSL.
+ r.ep.notifyProtocolGoroutine(notifyClose)
case StateClosing:
r.ep.state = StateTimeWait
case StateLastAck:
@@ -253,23 +259,105 @@ func (r *receiver) updateRTT() {
r.ep.rcvListMu.Unlock()
}
-// handleRcvdSegment handles TCP segments directed at the connection managed by
-// r as they arrive. It is called by the protocol main loop.
-func (r *receiver) handleRcvdSegment(s *segment) {
+func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err *tcpip.Error) {
+ r.ep.rcvListMu.Lock()
+ rcvClosed := r.ep.rcvClosed || r.closed
+ r.ep.rcvListMu.Unlock()
+
+ // If we are in one of the shutdown states then we need to do
+ // additional checks before we try and process the segment.
+ switch state {
+ case StateCloseWait, StateClosing, StateLastAck:
+ if !s.sequenceNumber.LessThanEq(r.rcvNxt) {
+ s.decRef()
+ // Just drop the segment as we have
+ // already received a FIN and this
+ // segment is after the sequence number
+ // for the FIN.
+ return true, nil
+ }
+ fallthrough
+ case StateFinWait1:
+ fallthrough
+ case StateFinWait2:
+ // If we are closed for reads (either due to an
+ // incoming FIN or the user calling shutdown(..,
+ // SHUT_RD) then any data past the rcvNxt should
+ // trigger a RST.
+ endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size()))
+ if rcvClosed && r.rcvNxt.LessThan(endDataSeq) {
+ s.decRef()
+ return true, tcpip.ErrConnectionAborted
+ }
+ if state == StateFinWait1 {
+ break
+ }
+
+ // If it's a retransmission of an old data segment
+ // or a pure ACK then allow it.
+ if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.rcvNxt) ||
+ s.logicalLen() == 0 {
+ break
+ }
+
+ // In FIN-WAIT2 if the socket is fully
+ // closed(not owned by application on our end
+ // then the only acceptable segment is a
+ // FIN. Since FIN can technically also carry
+ // data we verify that the segment carrying a
+ // FIN ends at exactly e.rcvNxt+1.
+ //
+ // From RFC793 page 25.
+ //
+ // For sequence number purposes, the SYN is
+ // considered to occur before the first actual
+ // data octet of the segment in which it occurs,
+ // while the FIN is considered to occur after
+ // the last actual data octet in a segment in
+ // which it occurs.
+ if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) {
+ s.decRef()
+ return true, tcpip.ErrConnectionAborted
+ }
+ }
+
// We don't care about receive processing anymore if the receive side
// is closed.
- if r.closed {
- return
+ //
+ // NOTE: We still want to permit a FIN as it's possible only our
+ // end has closed and the peer is yet to send a FIN. Hence we
+ // compare only the payload.
+ segEnd := s.sequenceNumber.Add(seqnum.Size(s.data.Size()))
+ if rcvClosed && !segEnd.LessThanEq(r.rcvNxt) {
+ return true, nil
+ }
+ return false, nil
+}
+
+// handleRcvdSegment handles TCP segments directed at the connection managed by
+// r as they arrive. It is called by the protocol main loop.
+func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
+ r.ep.mu.RLock()
+ state := r.ep.state
+ closed := r.ep.closed
+ r.ep.mu.RUnlock()
+
+ if state != StateEstablished {
+ drop, err := r.handleRcvdSegmentClosing(s, state, closed)
+ if drop || err != nil {
+ return drop, err
+ }
}
segLen := seqnum.Size(s.data.Size())
segSeq := s.sequenceNumber
// If the sequence number range is outside the acceptable range, just
- // send an ACK. This is according to RFC 793, page 37.
+ // send an ACK and stop further processing of the segment.
+ // This is according to RFC 793, page 68.
if !r.acceptable(segSeq, segLen) {
r.ep.snd.sendAck()
- return
+ return true, nil
}
// Defer segment processing if it can't be consumed now.
@@ -288,7 +376,7 @@ func (r *receiver) handleRcvdSegment(s *segment) {
// have to retransmit.
r.ep.snd.sendAck()
}
- return
+ return false, nil
}
// Since we consumed a segment update the receiver's RTT estimate
@@ -315,4 +403,67 @@ func (r *receiver) handleRcvdSegment(s *segment) {
r.pendingBufUsed -= s.logicalLen()
s.decRef()
}
+ return false, nil
+}
+
+// handleTimeWaitSegment handles inbound segments received when the endpoint
+// has entered the TIME_WAIT state.
+func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn bool) {
+ segSeq := s.sequenceNumber
+ segLen := seqnum.Size(s.data.Size())
+
+ // Just silently drop any RST packets in TIME_WAIT. We do not support
+ // TIME_WAIT assasination as a result we confirm w/ fix 1 as described
+ // in https://tools.ietf.org/html/rfc1337#section-3.
+ if s.flagIsSet(header.TCPFlagRst) {
+ return false, false
+ }
+
+ // If it's a SYN and the sequence number is higher than any seen before
+ // for this connection then try and redirect it to a listening endpoint
+ // if available.
+ //
+ // RFC 1122:
+ // "When a connection is [...] on TIME-WAIT state [...]
+ // [a TCP] MAY accept a new SYN from the remote TCP to
+ // reopen the connection directly, if it:
+
+ // (1) assigns its initial sequence number for the new
+ // connection to be larger than the largest sequence
+ // number it used on the previous connection incarnation,
+ // and
+
+ // (2) returns to TIME-WAIT state if the SYN turns out
+ // to be an old duplicate".
+ if s.flagIsSet(header.TCPFlagSyn) && r.rcvNxt.LessThan(segSeq) {
+
+ return false, true
+ }
+
+ // Drop the segment if it does not contain an ACK.
+ if !s.flagIsSet(header.TCPFlagAck) {
+ return false, false
+ }
+
+ // Update Timestamp if required. See RFC7323, section-4.3.
+ if r.ep.sendTSOk && s.parsedOptions.TS {
+ r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.maxSentAck, segSeq)
+ }
+
+ if segSeq.Add(1) == r.rcvNxt && s.flagIsSet(header.TCPFlagFin) {
+ // If it's a FIN-ACK then resetTimeWait and send an ACK, as it
+ // indicates our final ACK could have been lost.
+ r.ep.snd.sendAck()
+ return true, false
+ }
+
+ // If the sequence number range is outside the acceptable range or
+ // carries data then just send an ACK. This is according to RFC 793,
+ // page 37.
+ //
+ // NOTE: In TIME_WAIT the only acceptable sequence number is rcvNxt.
+ if segSeq != r.rcvNxt || segLen != 0 {
+ r.ep.snd.sendAck()
+ }
+ return false, false
}
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index ea725d513..1c10da5ca 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -18,6 +18,7 @@ import (
"sync/atomic"
"time"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
@@ -60,13 +61,13 @@ type segment struct {
xmitTime time.Time `state:".(unixTime)"`
}
-func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment {
+func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) *segment {
s := &segment{
refCnt: 1,
id: id,
route: r.Clone(),
}
- s.data = vv.Clone(s.views[:])
+ s.data = pkt.Data.Clone(s.views[:])
s.rcvdTime = time.Now()
return s
}
@@ -99,8 +100,14 @@ func (s *segment) clone() *segment {
return t
}
-func (s *segment) flagIsSet(flag uint8) bool {
- return (s.flags & flag) != 0
+// flagIsSet checks if at least one flag in flags is set in s.flags.
+func (s *segment) flagIsSet(flags uint8) bool {
+ return s.flags&flags != 0
+}
+
+// flagsAreSet checks if all flags in flags are set in s.flags.
+func (s *segment) flagsAreSet(flags uint8) bool {
+ return s.flags&flags == flags
}
func (s *segment) decRef() {
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 735edfe55..d3f7c9125 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -417,6 +417,7 @@ func (s *sender) resendSegment() {
s.fr.rescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1
s.sendSegment(seg)
s.ep.stack.Stats().TCP.FastRetransmit.Increment()
+ s.ep.stats.SendErrors.FastRetransmit.Increment()
// Run SetPipe() as per RFC 6675 section 5 Step 4.4
s.SetPipe()
@@ -435,6 +436,7 @@ func (s *sender) retransmitTimerExpired() bool {
}
s.ep.stack.Stats().TCP.Timeouts.Increment()
+ s.ep.stats.SendErrors.Timeouts.Increment()
// Give up if we've waited more than a minute since the last resend.
if s.rto >= 60*time.Second {
@@ -672,6 +674,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
default:
s.ep.state = StateFinWait1
}
+ s.ep.stack.Stats().TCP.CurrentEstablished.Decrement()
s.ep.mu.Unlock()
} else {
// We're sending a non-FIN segment.
@@ -1188,6 +1191,7 @@ func (s *sender) handleRcvdSegment(seg *segment) {
func (s *sender) sendSegment(seg *segment) *tcpip.Error {
if !seg.xmitTime.IsZero() {
s.ep.stack.Stats().TCP.Retransmits.Increment()
+ s.ep.stats.SendErrors.Retransmits.Increment()
if s.sndCwnd < s.sndSsthresh {
s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment()
}
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
index 9fa97528b..782d7b42c 100644
--- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
@@ -500,6 +500,14 @@ func TestRetransmit(t *testing.T) {
t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
}
+ if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want {
+ t.Errorf("got EP SendErrors.Timeouts.Value = %v, want = %v", got, want)
+ }
+
+ if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want {
+ t.Errorf("got EP stats SendErrors.Retransmits.Value = %v, want = %v", got, want)
+ }
+
if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want {
t.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want)
}
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index 4e7f1a740..afea124ec 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -520,10 +520,18 @@ func TestSACKRecovery(t *testing.T) {
t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
}
+ if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want {
+ t.Errorf("got EP stats SendErrors.FastRetransmit = %v, want = %v", got, want)
+ }
+
if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want {
t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
}
+ if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want {
+ t.Errorf("got EP stats Stats.SendErrors.Retransmits = %v, want = %v", got, want)
+ }
+
c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond)
// Acknowledge all pending data to recover point.
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 2be094876..84579ce52 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -99,7 +99,10 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got stats.TCP.FailedConnectionOpenings.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want)
+ }
+ if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want {
+ t.Errorf("got EP stats.FailedConnectionAttempts = %v, want = %v", got, want)
}
}
@@ -122,6 +125,9 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) {
if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want)
}
+ if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want {
+ t.Errorf("got EP stats FailedConnectionAttempts = %v, want = %v", got, want)
+ }
}
func TestTCPSegmentsSentIncrement(t *testing.T) {
@@ -136,6 +142,9 @@ func TestTCPSegmentsSentIncrement(t *testing.T) {
if got := stats.TCP.SegmentsSent.Value(); got != want {
t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want)
}
+ if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want {
+ t.Errorf("got EP stats SegmentsSent.Value() = %v, want = %v", got, want)
+ }
}
func TestTCPResetsSentIncrement(t *testing.T) {
@@ -197,17 +206,18 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ // Set TCPLingerTimeout to 5 seconds so that sockets are marked closed
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Send a SYN request.
@@ -247,7 +257,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
case <-ch:
c.EP, _, err = ep.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -255,6 +265,13 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
}
}
+ // Lower stackwide TIME_WAIT timeout so that the reservations
+ // are released instantly on Close.
+ tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil {
+ t.Fatalf("e.stack.SetTransportProtocolOption(%d, %s) = %s", tcp.ProtocolNumber, tcpTW, err)
+ }
+
c.EP.Close()
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
@@ -276,6 +293,11 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
// Get the ACK to the FIN we just sent.
c.GetPacket()
+ // Since an active close was done we need to wait for a little more than
+ // tcpLingerTimeout for the port reservations to be released and the
+ // socket to move to a CLOSED state.
+ time.Sleep(20 * time.Millisecond)
+
// Now resend the same ACK, this ACK should generate a RST as there
// should be no endpoint in SYN-RCVD state and we are not using
// syn-cookies yet. The reason we send the same ACK is we need a valid
@@ -367,6 +389,13 @@ func TestConnectResetAfterClose(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ // Set TCPLinger to 3 seconds so that sockets are marked closed
+ // after 3 second in FIN_WAIT2 state.
+ tcpLingerTimeout := 3 * time.Second
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil {
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpLingerTimeout, err)
+ }
+
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
ep := c.EP
c.EP = nil
@@ -387,12 +416,24 @@ func TestConnectResetAfterClose(t *testing.T) {
DstPort: c.Port,
Flags: header.TCPFlagAck,
SeqNum: 790,
- AckNum: c.IRS.Add(1),
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ // Wait for the ep to give up waiting for a FIN.
+ time.Sleep(tcpLingerTimeout + 1*time.Second)
+
+ // Now send an ACK and it should trigger a RST as the endpoint should
+ // not exist anymore.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(2),
RcvWnd: 30000,
})
- // Wait for the ep to give up waiting for a FIN, and send a RST.
- time.Sleep(3 * time.Second)
for {
b := c.GetPacket()
tcpHdr := header.TCP(header.IPv4(b).Payload())
@@ -404,7 +445,7 @@ func TestConnectResetAfterClose(t *testing.T) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
+ checker.SeqNum(uint32(c.IRS)+2),
checker.AckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
),
@@ -465,6 +506,347 @@ func TestSimpleReceive(t *testing.T) {
)
}
+// TestUserSuppliedMSSOnConnectV4 tests that the user supplied MSS is used when
+// creating a new active IPv4 TCP socket. It should be present in the sent TCP
+// SYN segment.
+func TestUserSuppliedMSSOnConnectV4(t *testing.T) {
+ const mtu = 5000
+ const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
+ tests := []struct {
+ name string
+ setMSS uint16
+ expMSS uint16
+ }{
+ {
+ "EqualToMaxMSS",
+ maxMSS,
+ maxMSS,
+ },
+ {
+ "LessThanMTU",
+ maxMSS - 1,
+ maxMSS - 1,
+ },
+ {
+ "GreaterThanMTU",
+ maxMSS + 1,
+ maxMSS,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ // Set the MSS socket option.
+ opt := tcpip.MaxSegOption(test.setMSS)
+ if err := c.EP.SetSockOpt(opt); err != nil {
+ t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err)
+ }
+
+ // Get expected window size.
+ rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
+ t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err)
+ }
+ ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
+
+ // Start connection attempt to IPv4 address.
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Receive SYN packet with our user supplied MSS.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws})))
+ })
+ }
+}
+
+// TestUserSuppliedMSSOnConnectV6 tests that the user supplied MSS is used when
+// creating a new active IPv6 TCP socket. It should be present in the sent TCP
+// SYN segment.
+func TestUserSuppliedMSSOnConnectV6(t *testing.T) {
+ const mtu = 5000
+ const maxMSS = mtu - header.IPv6MinimumSize - header.TCPMinimumSize
+ tests := []struct {
+ name string
+ setMSS uint16
+ expMSS uint16
+ }{
+ {
+ "EqualToMaxMSS",
+ maxMSS,
+ maxMSS,
+ },
+ {
+ "LessThanMTU",
+ maxMSS - 1,
+ maxMSS - 1,
+ },
+ {
+ "GreaterThanMTU",
+ maxMSS + 1,
+ maxMSS,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true)
+
+ // Set the MSS socket option.
+ opt := tcpip.MaxSegOption(test.setMSS)
+ if err := c.EP.SetSockOpt(opt); err != nil {
+ t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err)
+ }
+
+ // Get expected window size.
+ rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
+ t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err)
+ }
+ ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
+
+ // Start connection attempt to IPv6 address.
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Receive SYN packet with our user supplied MSS.
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws})))
+ })
+ }
+}
+
+func TestSendRstOnListenerRxSynAckV4(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: 100,
+ AckNum: 200,
+ })
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst),
+ checker.SeqNum(200)))
+}
+
+func TestSendRstOnListenerRxSynAckV6(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ c.SendV6Packet(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: 100,
+ AckNum: 200,
+ })
+
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst),
+ checker.SeqNum(200)))
+}
+
+func TestTOSV4(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ c.EP = ep
+
+ const tos = 0xC0
+ if err := c.EP.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil {
+ t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err)
+ }
+
+ var v tcpip.IPv4TOSOption
+ if err := c.EP.GetSockOpt(&v); err != nil {
+ t.Errorf("GetSockopt failed: %s", err)
+ }
+
+ if want := tcpip.IPv4TOSOption(tos); v != want {
+ t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ }
+
+ testV4Connect(t, c, checker.TOS(tos, 0))
+
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Check that data is received.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790), // Acknum is initial sequence number + 1
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ checker.TOS(tos, 0),
+ )
+
+ if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
+ t.Errorf("got data = %x, want = %x", p, data)
+ }
+}
+
+func TestTrafficClassV6(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ const tos = 0xC0
+ if err := c.EP.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil {
+ t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv6TrafficClassOption(tos), err)
+ }
+
+ var v tcpip.IPv6TrafficClassOption
+ if err := c.EP.GetSockOpt(&v); err != nil {
+ t.Fatalf("GetSockopt failed: %s", err)
+ }
+
+ if want := tcpip.IPv6TrafficClassOption(tos); v != want {
+ t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ }
+
+ // Test the connection request.
+ testV6Connect(t, c, checker.TOS(tos, 0))
+
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Check that data is received.
+ b := c.GetV6Packet()
+ checker.IPv6(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ checker.TOS(tos, 0),
+ )
+
+ if p := b[header.IPv6MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
+ t.Errorf("got data = %x, want = %x", p, data)
+ }
+}
+
+func TestConnectBindToDevice(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ device string
+ want tcp.EndpointState
+ }{
+ {"RightDevice", "nic1", tcp.StateEstablished},
+ {"WrongDevice", "nic2", tcp.StateSynSent},
+ {"AnyDevice", "", tcp.StateEstablished},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+ bindToDevice := tcpip.BindToDeviceOption(test.device)
+ c.EP.SetSockOpt(bindToDevice)
+ // Start connection attempt.
+ waitEntry, _ := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Receive SYN packet.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ iss := seqnum.Value(789)
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: tcpHdr.DestinationPort(),
+ DstPort: tcpHdr.SourcePort(),
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: rcvWnd,
+ TCPOpts: nil,
+ })
+
+ c.GetPacket()
+ if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want {
+ t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+ })
+ }
+}
+
func TestOutOfOrderReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -760,8 +1142,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
- // We shouldn't consume a sequence number on RST.
- checker.SeqNum(uint32(c.IRS)+1),
+ checker.SeqNum(uint32(c.IRS)+2),
))
// The RST puts the endpoint into an error state.
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
@@ -797,6 +1178,10 @@ func TestShutdownRead(t *testing.T) {
if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive)
}
+ var want uint64 = 1
+ if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want {
+ t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %v want %v", got, want)
+ }
}
func TestFullWindowReceive(t *testing.T) {
@@ -853,6 +1238,11 @@ func TestFullWindowReceive(t *testing.T) {
t.Fatalf("got data = %v, want = %v", v, data)
}
+ var want uint64 = 1
+ if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want {
+ t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %v want %v", got, want)
+ }
+
// Check that we get an ACK for the newly non-zero window.
checker.IPv4(t, c.GetPacket(),
checker.TCP(
@@ -1444,7 +1834,7 @@ func TestDelay(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- c.EP.SetSockOpt(tcpip.DelayOption(1))
+ c.EP.SetSockOptInt(tcpip.DelayOption, 1)
var allData []byte
for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
@@ -1492,7 +1882,7 @@ func TestUndelay(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- c.EP.SetSockOpt(tcpip.DelayOption(1))
+ c.EP.SetSockOptInt(tcpip.DelayOption, 1)
allData := [][]byte{{0}, {1, 2, 3}}
for i, data := range allData {
@@ -1525,7 +1915,7 @@ func TestUndelay(t *testing.T) {
// Check that we don't get the second packet yet.
c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond)
- c.EP.SetSockOpt(tcpip.DelayOption(0))
+ c.EP.SetSockOptInt(tcpip.DelayOption, 0)
// Check that data is received.
second := c.GetPacket()
@@ -1562,7 +1952,7 @@ func TestMSSNotDelayed(t *testing.T) {
fn func(tcpip.Endpoint)
}{
{"no-op", func(tcpip.Endpoint) {}},
- {"delay", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.DelayOption(1)) }},
+ {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptInt(tcpip.DelayOption, 1) }},
{"cork", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.CorkOption(1)) }},
}
@@ -1692,6 +2082,34 @@ func TestSendGreaterThanMTU(t *testing.T) {
testBrokenUpWrite(t, c, maxPayload)
}
+func TestSetTTL(t *testing.T) {
+ for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} {
+ t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) {
+ c := context.New(t, 65535)
+ defer c.Cleanup()
+
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ if err := c.EP.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil {
+ t.Fatalf("SetSockOpt failed: %v", err)
+ }
+
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Receive SYN packet.
+ b := c.GetPacket()
+
+ checker.IPv4(t, b, checker.TTL(wantTTL))
+ })
+ }
+}
+
func TestActiveSendMSSLessThanMTU(t *testing.T) {
const maxPayload = 100
c := context.New(t, 65535)
@@ -1997,6 +2415,13 @@ loop:
t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset)
}
}
+ // Expect the state to be StateError and subsequent Reads to fail with HardError.
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset)
+ }
+ if tcp.EndpointState(c.EP.State()) != tcp.StateError {
+ t.Fatalf("got EP state is not StateError")
+ }
}
func TestSendOnResetConnection(t *testing.T) {
@@ -2568,6 +2993,17 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) {
if got := stats.TCP.ValidSegmentsReceived.Value(); got != want {
t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %v, want = %v", got, want)
}
+ if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want {
+ t.Errorf("got EP stats Stats.SegmentsReceived = %v, want = %v", got, want)
+ }
+ // Ensure there were no errors during handshake. If these stats have
+ // incremented, then the connection should not have been established.
+ if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 {
+ t.Errorf("got EP stats Stats.SendErrors.NoRoute = %v, want = %v", got, 0)
+ }
+ if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoLinkAddr.Value(); got != 0 {
+ t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %v, want = %v", got, 0)
+ }
}
func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
@@ -2592,6 +3028,9 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want {
t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %v, want = %v", got, want)
}
+ if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want)
+ }
}
func TestReceivedIncorrectChecksumIncrement(t *testing.T) {
@@ -2618,6 +3057,9 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) {
if got := stats.TCP.ChecksumErrors.Value(); got != want {
t.Errorf("got stats.TCP.ChecksumErrors.Value() = %d, want = %d", got, want)
}
+ if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP stats Stats.ReceiveErrors.ChecksumErrors = %d, want = %d", got, want)
+ }
}
func TestReceivedSegmentQueuing(t *testing.T) {
@@ -2674,6 +3116,13 @@ func TestReadAfterClosedState(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed
+ // after 1 second in TIME_WAIT state.
+ tcpTimeWaitTimeout := 1 * time.Second
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err)
+ }
+
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
@@ -2681,12 +3130,12 @@ func TestReadAfterClosedState(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrWouldBlock)
}
// Shutdown immediately for write, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -2724,10 +3173,9 @@ func TestReadAfterClosedState(t *testing.T) {
),
)
- // Give the stack the chance to transition to closed state. Note that since
- // both the sender and receiver are now closed, we effectively skip the
- // TIME-WAIT state.
- time.Sleep(1 * time.Second)
+ // Give the stack the chance to transition to closed state from
+ // TIME_WAIT.
+ time.Sleep(tcpTimeWaitTimeout * 2)
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want {
t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
@@ -2744,7 +3192,7 @@ func TestReadAfterClosedState(t *testing.T) {
peekBuf := make([]byte, 10)
n, _, err := c.EP.Peek([][]byte{peekBuf})
if err != nil {
- t.Fatalf("Peek failed: %v", err)
+ t.Fatalf("Peek failed: %s", err)
}
peekBuf = peekBuf[:n]
@@ -2755,7 +3203,7 @@ func TestReadAfterClosedState(t *testing.T) {
// Receive data.
v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
if !bytes.Equal(data, v) {
@@ -2765,11 +3213,11 @@ func TestReadAfterClosedState(t *testing.T) {
// Now that we drained the queue, check that functions fail with the
// right error code.
if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive)
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrClosedForReceive)
}
if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Peek(...) = %v, want = %v", err, tcpip.ErrClosedForReceive)
+ t.Fatalf("got c.EP.Peek(...) = %v, want = %s", err, tcpip.ErrClosedForReceive)
}
}
@@ -2970,6 +3418,62 @@ func TestMinMaxBufferSizes(t *testing.T) {
checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30)
}
+func TestBindToDeviceOption(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}})
+
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
+ }
+ defer ep.Close()
+
+ if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil {
+ t.Errorf("CreateNamedNIC failed: %v", err)
+ }
+
+ // Make an nameless NIC.
+ if err := s.CreateNIC(54321, loopback.New()); err != nil {
+ t.Errorf("CreateNIC failed: %v", err)
+ }
+
+ // strPtr is used instead of taking the address of string literals, which is
+ // a compiler error.
+ strPtr := func(s string) *string {
+ return &s
+ }
+
+ testActions := []struct {
+ name string
+ setBindToDevice *string
+ setBindToDeviceError *tcpip.Error
+ getBindToDevice tcpip.BindToDeviceOption
+ }{
+ {"GetDefaultValue", nil, nil, ""},
+ {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
+ {"BindToExistent", strPtr("my_device"), nil, "my_device"},
+ {"UnbindToDevice", strPtr(""), nil, ""},
+ }
+ for _, testAction := range testActions {
+ t.Run(testAction.name, func(t *testing.T) {
+ if testAction.setBindToDevice != nil {
+ bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
+ if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ }
+ }
+ bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
+ if ep.GetSockOpt(&bindToDevice) != nil {
+ t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ }
+ if got, want := bindToDevice, testAction.getBindToDevice; got != want {
+ t.Errorf("bindToDevice got %q, want %q", got, want)
+ }
+ })
+ }
+}
+
func makeStack() (*stack.Stack, *tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{
@@ -3880,7 +4384,8 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
// Send a SYN request.
irs := seqnum.Value(789)
c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
+ // pick a different src port for new SYN.
+ SrcPort: context.TestPort + 1,
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: irs,
@@ -4014,6 +4519,9 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want {
t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %v, want = %v", got, want)
}
+ if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want {
+ t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %v, want = %v", got, want)
+ }
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -4052,6 +4560,14 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
}
+ // Expect InvalidEndpointState errors on a read at this point.
+ if _, _, err := ep.Read(nil); err != tcpip.ErrInvalidEndpointState {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrInvalidEndpointState)
+ }
+ if got := ep.Stats().(*tcp.Stats).ReadErrors.InvalidEndpointState.Value(); got != 1 {
+ t.Fatalf("got EP stats Stats.ReadErrors.InvalidEndpointState got %v want %v", got, 1)
+ }
+
if err := ep.Listen(10); err != nil {
t.Fatalf("Listen failed: %v", err)
}
@@ -4083,6 +4599,9 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
}
+ if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected {
+ t.Errorf("Unexpected error attempting to call connect on an established endpoint, got: %v, want: %v", err, tcpip.ErrAlreadyConnected)
+ }
// Listening endpoint remains in listen state.
if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
@@ -4370,3 +4889,590 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
payloadSize *= 2
}
}
+
+func TestDelayEnabled(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ checkDelayOption(t, c, false, 0) // Delay is disabled by default.
+
+ for _, v := range []struct {
+ delayEnabled tcp.DelayEnabled
+ wantDelayOption int
+ }{
+ {delayEnabled: false, wantDelayOption: 0},
+ {delayEnabled: true, wantDelayOption: 1},
+ } {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil {
+ t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %v", v.delayEnabled, err)
+ }
+ checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption)
+ }
+}
+
+func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption int) {
+ t.Helper()
+
+ var gotDelayEnabled tcp.DelayEnabled
+ if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil {
+ t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %v", err)
+ }
+ if gotDelayEnabled != wantDelayEnabled {
+ t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled)
+ }
+
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue))
+ if err != nil {
+ t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %v", err)
+ }
+ gotDelayOption, err := ep.GetSockOptInt(tcpip.DelayOption)
+ if err != nil {
+ t.Fatalf("ep.GetSockOptInt(tcpip.DelayOption) failed: %v", err)
+ }
+ if gotDelayOption != wantDelayOption {
+ t.Errorf("ep.GetSockOptInt(tcpip.DelayOption) got: %d, want: %d", gotDelayOption, wantDelayOption)
+ }
+}
+
+func TestTCPLingerTimeout(t *testing.T) {
+ c := context.New(t, 1500 /* mtu */)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ testCases := []struct {
+ name string
+ tcpLingerTimeout time.Duration
+ want time.Duration
+ }{
+ {"NegativeLingerTimeout", -123123, 0},
+ {"ZeroLingerTimeout", 0, 0},
+ {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second},
+ // Values > stack's TCPLingerTimeout are capped to the stack's
+ // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds)
+ {"AboveMaxLingerTimeout", 65 * time.Second, 60 * time.Second},
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ if err := c.EP.SetSockOpt(tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)); err != nil {
+ t.Fatalf("SetSockOpt(%s) = %s", tc.tcpLingerTimeout, err)
+ }
+ var v tcpip.TCPLingerTimeoutOption
+ if err := c.EP.GetSockOpt(&v); err != nil {
+ t.Fatalf("GetSockOpt(tcpip.TCPLingerTimeoutOption) = %s", err)
+ }
+ if got, want := time.Duration(v), tc.want; got != want {
+ t.Fatalf("unexpected linger timeout got: %s, want: %s", got, want)
+ }
+ })
+ }
+}
+
+func TestTCPTimeWaitRSTIgnored(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ c.EP.Close()
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ // Now send a RST and this should be ignored and not
+ // generate an ACK.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagRst,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ })
+
+ c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second)
+
+ // Out of order ACK should generate an immediate ACK in
+ // TIME_WAIT.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 3,
+ })
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+}
+
+func TestTCPTimeWaitOutOfOrder(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ c.EP.Close()
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ // Out of order ACK should generate an immediate ACK in
+ // TIME_WAIT.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 3,
+ })
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+}
+
+func TestTCPTimeWaitNewSyn(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ c.EP.Close()
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ // Send a SYN request w/ sequence number lower than
+ // the highest sequence number sent. We just reuse
+ // the same number.
+ iss = seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ })
+
+ c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second)
+
+ // Send a SYN request w/ sequence number higher than
+ // the highest sequence number sent.
+ iss = seqnum.Value(792)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ })
+
+ // Receive the SYN-ACK reply.
+ b = c.GetPacket()
+ tcpHdr = header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders = &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+}
+
+func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
+ // after 5 seconds in TIME_WAIT state.
+ tcpTimeWaitTimeout := 5 * time.Second
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err)
+ }
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ c.EP.Close()
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ time.Sleep(2 * time.Second)
+
+ // Now send a duplicate FIN. This should cause the TIME_WAIT to extend
+ // by another 5 seconds and also send us a duplicate ACK as it should
+ // indicate that the final ACK was potentially lost.
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ // Sleep for 4 seconds so at this point we are 1 second past the
+ // original tcpLingerTimeout of 5 seconds.
+ time.Sleep(4 * time.Second)
+
+ // Send an ACK and it should not generate any packet as the socket
+ // should still be in TIME_WAIT for another another 5 seconds due
+ // to the duplicate FIN we sent earlier.
+ *ackHeaders = *finHeaders
+ ackHeaders.SeqNum = ackHeaders.SeqNum + 1
+ ackHeaders.Flags = header.TCPFlagAck
+ c.SendPacket(nil, ackHeaders)
+
+ c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second)
+ // Now sleep for another 2 seconds so that we are past the
+ // extended TIME_WAIT of 7 seconds (2 + 5).
+ time.Sleep(2 * time.Second)
+
+ // Resend the same ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Receive the RST that should be generated as there is no valid
+ // endpoint.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(ackHeaders.AckNum)),
+ checker.AckNum(uint32(ackHeaders.SeqNum)),
+ checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck)))
+}
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index d3f1d2cdf..4854e719d 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -158,7 +158,14 @@ func New(t *testing.T, mtu uint32) *Context {
if testing.Verbose() {
wep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, wep); err != nil {
+ if err := s.CreateNamedNIC(1, "nic1", wep); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+ wep2 := stack.LinkEndpoint(channel.New(1000, mtu, ""))
+ if testing.Verbose() {
+ wep2 = sniffer.New(channel.New(1000, mtu, ""))
+ }
+ if err := s.CreateNamedNIC(2, "nic2", wep2); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -295,7 +302,9 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt
copy(icmp[header.ICMPv4PayloadOffset:], p2)
// Inject packet.
- c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView())
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
}
// BuildSegment builds a TCP segment based on the given Headers and payload.
@@ -343,13 +352,17 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView
// SendSegment sends a TCP segment that has already been built and written to a
// buffer.VectorisedView.
func (c *Context) SendSegment(s buffer.VectorisedView) {
- c.linkEP.Inject(ipv4.ProtocolNumber, s)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ Data: s,
+ })
}
// SendPacket builds and sends a TCP segment(with the provided payload & TCP
// headers) in an IPv4 packet via the link layer endpoint.
func (c *Context) SendPacket(payload []byte, h *Headers) {
- c.linkEP.Inject(ipv4.ProtocolNumber, c.BuildSegment(payload, h))
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ Data: c.BuildSegment(payload, h),
+ })
}
// SendAck sends an ACK packet.
@@ -511,7 +524,9 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) {
t.SetChecksum(^t.CalculateChecksum(xsum))
// Inject packet.
- c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
}
// CreateConnected creates a connected TCP endpoint.
@@ -588,12 +603,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte)
c.Port = tcpHdr.SourcePort()
}
-// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
-// the specified option bytes as the Option field in the initial SYN packet.
-//
-// It also sets the receive buffer for the endpoint to the specified
-// value in epRcvBuf.
-func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) {
+// Create creates a TCP endpoint.
+func (c *Context) Create(epRcvBuf int) {
// Create TCP endpoint.
var err *tcpip.Error
c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
@@ -606,6 +617,15 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.
c.t.Fatalf("SetSockOpt failed failed: %v", err)
}
}
+}
+
+// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
+// the specified option bytes as the Option field in the initial SYN packet.
+//
+// It also sets the receive buffer for the endpoint to the specified
+// value in epRcvBuf.
+func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) {
+ c.Create(epRcvBuf)
c.Connect(iss, rcvWnd, options)
}