summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/checker/checker.go2
-rw-r--r--pkg/tcpip/header/eth.go6
-rw-r--r--pkg/tcpip/header/eth_test.go4
-rw-r--r--pkg/tcpip/link/ethernet/ethernet.go23
-rw-r--r--pkg/tcpip/link/tun/BUILD1
-rw-r--r--pkg/tcpip/link/tun/device.go5
-rw-r--r--pkg/tcpip/socketops.go11
-rw-r--r--pkg/tcpip/stack/packet_buffer.go14
-rw-r--r--pkg/tcpip/stack/packet_buffer_test.go26
-rw-r--r--pkg/tcpip/stack/stack.go16
-rw-r--r--pkg/tcpip/stack/tcp.go6
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go2
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go30
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/accept.go44
-rw-r--r--pkg/tcpip/transport/tcp/connect.go71
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go92
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go1
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go27
-rw-r--r--pkg/tcpip/transport/tcp/snd.go9
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go9
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go10
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go370
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go8
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go81
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go2
27 files changed, 708 insertions, 165 deletions
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index e0dfe5813..2f34bf8dd 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -729,7 +729,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp
return
}
l := int(opts[i+1])
- if i < 2 || i+l > limit {
+ if l < 2 || i+l > limit {
return
}
i += l
diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go
index 95ade0e5c..1f18213e5 100644
--- a/pkg/tcpip/header/eth.go
+++ b/pkg/tcpip/header/eth.go
@@ -49,9 +49,9 @@ const (
// EthernetAddressSize is the size, in bytes, of an ethernet address.
EthernetAddressSize = 6
- // unspecifiedEthernetAddress is the unspecified ethernet address
+ // UnspecifiedEthernetAddress is the unspecified ethernet address
// (all bits set to 0).
- unspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")
+ UnspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")
// EthernetBroadcastAddress is an ethernet address that addresses every node
// on a local link.
@@ -134,7 +134,7 @@ func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool {
return false
}
- if addr == unspecifiedEthernetAddress {
+ if addr == UnspecifiedEthernetAddress {
return false
}
diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go
index bf9ccbf1a..adc04e855 100644
--- a/pkg/tcpip/header/eth_test.go
+++ b/pkg/tcpip/header/eth_test.go
@@ -44,7 +44,7 @@ func TestIsValidUnicastEthernetAddress(t *testing.T) {
},
{
"Unspecified",
- unspecifiedEthernetAddress,
+ UnspecifiedEthernetAddress,
false,
},
{
@@ -91,7 +91,7 @@ func TestIsMulticastEthernetAddress(t *testing.T) {
},
{
"Unspecified",
- unspecifiedEthernetAddress,
+ UnspecifiedEthernetAddress,
false,
},
{
diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go
index b427c6170..b9db273d0 100644
--- a/pkg/tcpip/link/ethernet/ethernet.go
+++ b/pkg/tcpip/link/ethernet/ethernet.go
@@ -42,6 +42,14 @@ type Endpoint struct {
nested.Endpoint
}
+// LinkAddress implements stack.LinkEndpoint.
+func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
+ if l := e.Endpoint.LinkAddress(); len(l) != 0 {
+ return l
+ }
+ return header.UnspecifiedEthernetAddress
+}
+
// DeliverNetworkPacket implements stack.NetworkDispatcher.
func (e *Endpoint) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
@@ -57,18 +65,22 @@ func (e *Endpoint) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkP
// Capabilities implements stack.LinkEndpoint.
func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return stack.CapabilityResolutionRequired | e.Endpoint.Capabilities()
+ c := e.Endpoint.Capabilities()
+ if c&stack.CapabilityLoopback == 0 {
+ c |= stack.CapabilityResolutionRequired
+ }
+ return c
}
// WritePacket implements stack.LinkEndpoint.
func (e *Endpoint) WritePacket(r stack.RouteInfo, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt)
+ e.AddHeader(e.LinkAddress(), r.RemoteLinkAddress, proto, pkt)
return e.Endpoint.WritePacket(r, proto, pkt)
}
// WritePackets implements stack.LinkEndpoint.
func (e *Endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- linkAddr := e.Endpoint.LinkAddress()
+ linkAddr := e.LinkAddress()
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
e.AddHeader(linkAddr, r.RemoteLinkAddress, proto, pkt)
@@ -83,7 +95,10 @@ func (e *Endpoint) MaxHeaderLength() uint16 {
}
// ARPHardwareType implements stack.LinkEndpoint.
-func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
+func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
+ if a := e.Endpoint.ARPHardwareType(); a != header.ARPHardwareNone {
+ return a
+ }
return header.ARPHardwareEther
}
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
index 4758a99ad..c3e4c3455 100644
--- a/pkg/tcpip/link/tun/BUILD
+++ b/pkg/tcpip/link/tun/BUILD
@@ -31,7 +31,6 @@ go_library(
"//pkg/refs",
"//pkg/refsvfs2",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index d23210503..fa2131c28 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -20,7 +20,6 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -174,7 +173,7 @@ func (d *Device) Write(data []byte) (int64, error) {
return 0, linuxerr.EBADFD
}
if !endpoint.IsAttached() {
- return 0, syserror.EIO
+ return 0, linuxerr.EIO
}
dataLen := int64(len(data))
@@ -249,7 +248,7 @@ func (d *Device) Read() ([]byte, error) {
for {
info, ok := endpoint.Read()
if !ok {
- return nil, syserror.ErrWouldBlock
+ return nil, linuxerr.ErrWouldBlock
}
v, ok := d.encodePkt(&info)
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index 6bce3af04..34ac62444 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -57,6 +57,11 @@ type SocketOptionsHandler interface {
// OnSetReceiveBufferSize is invoked by SO_RCVBUF and SO_RCVBUFFORCE.
OnSetReceiveBufferSize(v, oldSz int64) (newSz int64)
+
+ // WakeupWriters is invoked when the send buffer size for an endpoint is
+ // changed. The handler notifies the writers if the send buffer size is
+ // increased with setsockopt(2) for TCP endpoints.
+ WakeupWriters()
}
// DefaultSocketOptionsHandler is an embeddable type that implements no-op
@@ -98,6 +103,9 @@ func (*DefaultSocketOptionsHandler) OnSetSendBufferSize(v int64) (newSz int64) {
return v
}
+// WakeupWriters implements SocketOptionsHandler.WakeupWriters.
+func (*DefaultSocketOptionsHandler) WakeupWriters() {}
+
// OnSetReceiveBufferSize implements SocketOptionsHandler.OnSetReceiveBufferSize.
func (*DefaultSocketOptionsHandler) OnSetReceiveBufferSize(v, oldSz int64) (newSz int64) {
return v
@@ -626,6 +634,9 @@ func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) {
sendBufferSize = so.handler.OnSetSendBufferSize(sendBufferSize)
}
so.sendBufferSize.Store(sendBufferSize)
+ if notify {
+ so.handler.WakeupWriters()
+ }
}
// GetReceiveBufferSize gets value for SO_RCVBUF option.
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 9192d8433..29c22bfd4 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -282,14 +282,12 @@ func (pk *PacketBuffer) headerView(typ headerType) tcpipbuffer.View {
return v
}
-// Clone makes a shallow copy of pk.
-//
-// Clone should be called in such cases so that no modifications is done to
-// underlying packet payload.
+// Clone makes a semi-deep copy of pk. The underlying packet payload is
+// shared. Hence, no modifications is done to underlying packet payload.
func (pk *PacketBuffer) Clone() *PacketBuffer {
return &PacketBuffer{
PacketBufferEntry: pk.PacketBufferEntry,
- buf: pk.buf,
+ buf: pk.buf.Clone(),
reserved: pk.reserved,
pushed: pk.pushed,
consumed: pk.consumed,
@@ -321,14 +319,14 @@ func (pk *PacketBuffer) Network() header.Network {
}
}
-// CloneToInbound makes a shallow copy of the packet buffer to be used as an
-// inbound packet.
+// CloneToInbound makes a semi-deep copy of the packet buffer (similar to
+// Clone) to be used as an inbound packet.
//
// See PacketBuffer.Data for details about how a packet buffer holds an inbound
// packet.
func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
newPk := &PacketBuffer{
- buf: pk.buf,
+ buf: pk.buf.Clone(),
// Treat unfilled header portion as reserved.
reserved: pk.AvailableHeaderBytes(),
}
diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go
index a8da34992..87b023445 100644
--- a/pkg/tcpip/stack/packet_buffer_test.go
+++ b/pkg/tcpip/stack/packet_buffer_test.go
@@ -123,6 +123,32 @@ func TestPacketHeaderPush(t *testing.T) {
}
}
+func TestPacketBufferClone(t *testing.T) {
+ data := concatViews(makeView(20), makeView(30), makeView(40))
+ pk := NewPacketBuffer(PacketBufferOptions{
+ // Make a copy of data to make sure our truth data won't be taint by
+ // PacketBuffer.
+ Data: buffer.NewViewFromBytes(data).ToVectorisedView(),
+ })
+
+ bytesToDelete := 30
+ originalSize := data.Size()
+
+ clonedPks := []*PacketBuffer{
+ pk.Clone(),
+ pk.CloneToInbound(),
+ }
+ pk.Data().DeleteFront(bytesToDelete)
+ if got, want := pk.Data().Size(), originalSize-bytesToDelete; got != want {
+ t.Errorf("original packet was not changed: size expected = %d, got = %d", want, got)
+ }
+ for _, clonedPk := range clonedPks {
+ if got := clonedPk.Data().Size(); got != originalSize {
+ t.Errorf("cloned packet should not be modified: expected size = %d, got = %d", originalSize, got)
+ }
+ }
+}
+
func TestPacketHeaderConsume(t *testing.T) {
for _, test := range []struct {
name string
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index c73890c4c..8e5c6edbf 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -119,8 +119,7 @@ type Stack struct {
// by the stack.
icmpRateLimiter *ICMPRateLimiter
- // seed is a one-time random value initialized at stack startup
- // and is used to seed the TCP port picking on active connections
+ // seed is a one-time random value initialized at stack startup.
//
// TODO(gvisor.dev/issue/940): S/R this field.
seed uint32
@@ -161,6 +160,10 @@ type Stack struct {
// This is required to prevent potential ACK loops.
// Setting this to 0 will disable all rate limiting.
tcpInvalidRateLimit time.Duration
+
+ // tsOffsetSecret is the secret key for generating timestamp offsets
+ // initialized at stack startup.
+ tsOffsetSecret uint32
}
// UniqueID is an abstract generator of unique identifiers.
@@ -384,6 +387,7 @@ func New(opts Options) *Stack {
Max: DefaultMaxBufferSize,
},
tcpInvalidRateLimit: defaultTCPInvalidRateLimit,
+ tsOffsetSecret: randomGenerator.Uint32(),
}
// Add specified network protocols.
@@ -1819,14 +1823,6 @@ func (s *Stack) SetNUDConfigurations(id tcpip.NICID, proto tcpip.NetworkProtocol
return nic.setNUDConfigs(proto, c)
}
-// Seed returns a 32 bit value that can be used as a seed value for port
-// picking, ISN generation etc.
-//
-// NOTE: The seed is generated once during stack initialization only.
-func (s *Stack) Seed() uint32 {
- return s.seed
-}
-
// Rand returns a reference to a pseudo random generator that can be used
// to generate random numbers as required.
func (s *Stack) Rand() *rand.Rand {
diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go
index 90a8ba6cf..93ea83cdc 100644
--- a/pkg/tcpip/stack/tcp.go
+++ b/pkg/tcpip/stack/tcp.go
@@ -386,6 +386,12 @@ type TCPSndBufState struct {
// SndMTU is the smallest MTU seen in the control packets received.
SndMTU int
+
+ // AutoTuneSndBufDisabled indicates that the auto tuning of send buffer
+ // is disabled.
+ //
+ // Must be accessed using atomic operations.
+ AutoTuneSndBufDisabled uint32
}
// TCPEndpointStateInner contains the members of TCPEndpointState used directly
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index dda57e225..824cf6526 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -479,7 +479,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
if !ok {
epsByNIC = &endpointsByNIC{
endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
- seed: d.stack.Seed(),
+ seed: d.stack.seed,
}
}
if err := epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice); err != nil {
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index b3d8951ff..55854ba59 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -321,28 +321,26 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
}
defer route.Release()
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(route.MaxHeaderLength()),
+ Data: buffer.View(payloadBytes).ToVectorisedView(),
+ })
+ pkt.Owner = owner
+
if e.ops.GetHeaderIncluded() {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buffer.View(payloadBytes).ToVectorisedView(),
- })
if err := route.WriteHeaderIncludedPacket(pkt); err != nil {
return 0, err
}
- } else {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(route.MaxHeaderLength()),
- Data: buffer.View(payloadBytes).ToVectorisedView(),
- })
- pkt.Owner = owner
- if err := route.WritePacket(stack.NetworkHeaderParams{
- Protocol: e.TransProto,
- TTL: route.DefaultTTL(),
- TOS: stack.DefaultTOS,
- }, pkt); err != nil {
- return 0, err
- }
+ return int64(len(payloadBytes)), nil
}
+ if err := route.WritePacket(stack.NetworkHeaderParams{
+ Protocol: e.TransProto,
+ TTL: route.DefaultTTL(),
+ TOS: stack.DefaultTOS,
+ }, pkt); err != nil {
+ return 0, err
+ }
return int64(len(payloadBytes)), nil
}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 8436d2cf0..c3922bbe5 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -96,6 +96,7 @@ go_test(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/link/sniffer",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index aa413ad05..9560ed43c 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -72,7 +72,8 @@ func encodeMSS(mss uint16) uint32 {
// and must not be accessed or have its methods called concurrently as they
// may mutate the stored objects.
type listenContext struct {
- stack *stack.Stack
+ stack *stack.Stack
+ protocol *protocol
// rcvWnd is the receive window that is sent by this listening context
// in the initial SYN-ACK.
@@ -119,9 +120,10 @@ func timeStamp(clock tcpip.Clock) uint32 {
}
// newListenContext creates a new listen context.
-func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
+func newListenContext(stk *stack.Stack, protocol *protocol, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
stack: stk,
+ protocol: protocol,
rcvWnd: rcvWnd,
hasher: sha1.New(),
v6Only: v6Only,
@@ -213,7 +215,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header
return nil, err
}
- n := newEndpoint(l.stack, netProto, queue)
+ n := newEndpoint(l.stack, l.protocol, netProto, queue)
n.ops.SetV6Only(l.v6Only)
n.TransportEndpointInfo.ID = s.id
n.boundNICID = s.nicID
@@ -247,7 +249,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header
func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
- isn := generateSecureISN(s.id, l.stack.Clock(), l.stack.Seed())
+ isn := generateSecureISN(s.id, l.stack.Clock(), l.protocol.seqnumSecret)
ep, err := l.createConnectingEndpoint(s, opts, queue)
if err != nil {
return nil, err
@@ -600,7 +602,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
synOpts := header.TCPSynOptions{
WS: -1,
TS: opts.TS,
- TSVal: tcpTimeStamp(e.stack.Clock().NowMonotonic(), timeStampOffset(e.stack.Rand())),
+ TSVal: tcpTimeStamp(e.stack.Clock().NowMonotonic(), timeStampOffset(e.protocol.tsOffsetSecret, s.dstAddr, s.srcAddr)),
TSEcr: opts.TSVal,
MSS: calculateAdvertisedMSS(e.userMSS, route),
}
@@ -726,24 +728,24 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
n.isRegistered = true
- // clear the tsOffset for the newly created
- // endpoint as the Timestamp was already
- // randomly offset when the original SYN-ACK was
- // sent above.
- n.TSOffset = 0
+ // Reset the tsOffset for the newly created endpoint to the one
+ // that we have used in SYN-ACK in order to calculate RTT.
+ n.TSOffset = timeStampOffset(e.protocol.tsOffsetSecret, s.dstAddr, s.srcAddr)
// Switch state to connected.
n.isConnectNotified = true
- n.transitionToStateEstablishedLocked(&handshake{
- ep: n,
- iss: iss,
- ackNum: irs + 1,
- rcvWnd: seqnum.Size(n.initialReceiveWindow()),
- sndWnd: s.window,
- rcvWndScale: e.rcvWndScaleForHandshake(),
- sndWndScale: rcvdSynOptions.WS,
- mss: rcvdSynOptions.MSS,
- })
+ h := &handshake{
+ ep: n,
+ iss: iss,
+ ackNum: irs + 1,
+ rcvWnd: seqnum.Size(n.initialReceiveWindow()),
+ sndWnd: s.window,
+ rcvWndScale: e.rcvWndScaleForHandshake(),
+ sndWndScale: rcvdSynOptions.WS,
+ mss: rcvdSynOptions.MSS,
+ sampleRTTWithTSOnly: true,
+ }
+ h.transitionToStateEstablishedLocked(s)
// Requeue the segment if the ACK completing the handshake has more info
// to be procesed by the newly established endpoint.
@@ -779,7 +781,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
e.mu.Lock()
v6Only := e.ops.GetV6Only()
- ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto)
+ ctx := newListenContext(e.stack, e.protocol, e, rcvWnd, v6Only, e.NetProto)
defer func() {
// Mark endpoint as closed. This will prevent goroutines running
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 93ed161f9..f85775a48 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -105,6 +105,11 @@ type handshake struct {
// sendSYNOpts is the cached values for the SYN options to be sent.
sendSYNOpts header.TCPSynOptions
+
+ // sampleRTTWithTSOnly is true when the segment was retransmitted or we can't
+ // tell; then RTT can only be sampled when the incoming segment has timestamp
+ // options enabled.
+ sampleRTTWithTSOnly bool
}
func (e *endpoint) newHandshake() *handshake {
@@ -117,6 +122,8 @@ func (e *endpoint) newHandshake() *handshake {
h.resetState()
// Store reference to handshake state in endpoint.
e.h = h
+ // By the time handshake is created, e.ID is already initialized.
+ e.TSOffset = timeStampOffset(e.protocol.tsOffsetSecret, e.ID.LocalAddress, e.ID.RemoteAddress)
return h
}
@@ -150,7 +157,7 @@ func (h *handshake) resetState() {
h.flags = header.TCPFlagSyn
h.ackNum = 0
h.mss = 0
- h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Clock(), h.ep.stack.Seed())
+ h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Clock(), h.ep.protocol.seqnumSecret)
}
// generateSecureISN generates a secure Initial Sequence number based on the
@@ -266,8 +273,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error {
// and the handshake is completed.
if s.flags.Contains(header.TCPFlagAck) {
h.state = handshakeCompleted
-
- h.ep.transitionToStateEstablishedLocked(h)
+ h.transitionToStateEstablishedLocked(s)
h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale())
return nil
@@ -402,9 +408,10 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error {
if h.ep.SendTSOk && s.parsedOptions.TS {
h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber)
}
+
h.state = handshakeCompleted
- h.ep.transitionToStateEstablishedLocked(h)
+ h.transitionToStateEstablishedLocked(s)
// Requeue the segment if the ACK completing the handshake has more info
// to be procesed by the newly established endpoint.
@@ -557,6 +564,10 @@ func (h *handshake) complete() tcpip.Error {
ack: h.ackNum,
rcvWnd: h.rcvWnd,
}, h.sendSYNOpts)
+ // If we have ever retransmitted the SYN-ACK or
+ // SYN segment, we should only measure RTT if
+ // TS option is present.
+ h.sampleRTTWithTSOnly = true
}
case wakerForNotification:
@@ -600,6 +611,38 @@ func (h *handshake) complete() tcpip.Error {
return nil
}
+// transitionToStateEstablisedLocked transitions the endpoint of the handshake
+// to an established state given the last segment received from peer. It also
+// initializes sender/receiver.
+func (h *handshake) transitionToStateEstablishedLocked(s *segment) {
+ // Transfer handshake state to TCP connection. We disable
+ // receive window scaling if the peer doesn't support it
+ // (indicated by a negative send window scale).
+ h.ep.snd = newSender(h.ep, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
+
+ var rtt time.Duration
+ if h.ep.SendTSOk && s.parsedOptions.TSEcr != 0 {
+ rtt = time.Duration(h.ep.timestamp()-s.parsedOptions.TSEcr) * time.Millisecond
+ }
+ if !h.sampleRTTWithTSOnly && rtt == 0 {
+ rtt = h.ep.stack.Clock().NowMonotonic().Sub(h.startTime)
+ }
+
+ if rtt > 0 {
+ h.ep.snd.updateRTO(rtt)
+ }
+
+ h.ep.rcvQueueInfo.rcvQueueMu.Lock()
+ h.ep.rcv = newReceiver(h.ep, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale())
+ // Bootstrap the auto tuning algorithm. Starting at zero will
+ // result in a really large receive window after the first auto
+ // tuning adjustment.
+ h.ep.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = int(h.rcvWnd)
+ h.ep.rcvQueueInfo.rcvQueueMu.Unlock()
+
+ h.ep.setEndpointState(StateEstablished)
+}
+
type backoffTimer struct {
timeout time.Duration
maxTimeout time.Duration
@@ -965,26 +1008,6 @@ func (e *endpoint) completeWorkerLocked() {
}
}
-// transitionToStateEstablisedLocked transitions a given endpoint
-// to an established state using the handshake parameters provided.
-// It also initializes sender/receiver.
-func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
- // Transfer handshake state to TCP connection. We disable
- // receive window scaling if the peer doesn't support it
- // (indicated by a negative send window scale).
- e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
-
- e.rcvQueueInfo.rcvQueueMu.Lock()
- e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale())
- // Bootstrap the auto tuning algorithm. Starting at zero will
- // result in a really large receive window after the first auto
- // tuning adjustment.
- e.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = int(h.rcvWnd)
- e.rcvQueueInfo.rcvQueueMu.Unlock()
-
- e.setEndpointState(StateEstablished)
-}
-
// transitionToStateCloseLocked ensures that the endpoint is
// cleaned up from the transport demuxer, "before" moving to
// StateClose. This will ensure that no packet will be
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 044123185..4937d126f 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -20,7 +20,6 @@ import (
"fmt"
"io"
"math"
- "math/rand"
"runtime"
"strings"
"sync/atomic"
@@ -378,6 +377,7 @@ type endpoint struct {
// The following fields are initialized at creation time and do not
// change throughout the lifetime of the endpoint.
stack *stack.Stack `state:"manual"`
+ protocol *protocol `state:"manual"`
waiterQueue *waiter.Queue `state:"wait"`
uniqueID uint64
@@ -803,9 +803,10 @@ type keepalive struct {
waker sleep.Waker `state:"nosave"`
}
-func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
+func newEndpoint(s *stack.Stack, protocol *protocol, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
e := &endpoint{
- stack: s,
+ stack: s,
+ protocol: protocol,
TransportEndpointInfo: stack.TransportEndpointInfo{
NetProto: netProto,
TransProto: header.TCPProtocolNumber,
@@ -874,7 +875,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
}
e.segmentQueue.ep = e
- e.TSOffset = timeStampOffset(e.stack.Rand())
+
e.acceptCond = sync.NewCond(&e.acceptMu)
e.keepalive.timer.init(e.stack.Clock(), &e.keepalive.waker)
@@ -1717,6 +1718,27 @@ func (e *endpoint) OnSetReceiveBufferSize(rcvBufSz, oldSz int64) (newSz int64) {
return rcvBufSz
}
+// OnSetSendBufferSize implements tcpip.SocketOptionsHandler.OnSetSendBufferSize.
+func (e *endpoint) OnSetSendBufferSize(sz int64) int64 {
+ atomic.StoreUint32(&e.sndQueueInfo.TCPSndBufState.AutoTuneSndBufDisabled, 1)
+ return sz
+}
+
+// WakeupWriters implements tcpip.SocketOptionsHandler.WakeupWriters.
+func (e *endpoint) WakeupWriters() {
+ e.LockUser()
+ defer e.UnlockUser()
+
+ sendBufferSize := e.getSendBufferSize()
+ e.sndQueueInfo.sndQueueMu.Lock()
+ notify := (sendBufferSize - e.sndQueueInfo.SndBufUsed) >= e.sndQueueInfo.SndBufUsed>>1
+ e.sndQueueInfo.sndQueueMu.Unlock()
+
+ if notify {
+ e.waiterQueue.Notify(waiter.WritableEvents)
+ }
+}
+
// SetSockOptInt sets a socket option.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
// Lower 2 bits represents ECN bits. RFC 3168, section 23.1
@@ -2177,7 +2199,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
portBuf := make([]byte, 2)
binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort)
- h := jenkins.Sum32(e.stack.Seed())
+ h := jenkins.Sum32(e.protocol.portOffsetSecret)
for _, s := range [][]byte{
[]byte(e.ID.LocalAddress),
[]byte(e.ID.RemoteAddress),
@@ -2329,6 +2351,9 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
e.segmentQueue.mu.Unlock()
e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0)
e.setEndpointState(StateEstablished)
+ // Set the new auto tuned send buffer size after entering
+ // established state.
+ e.ops.SetSendBufferSize(e.computeTCPSendBufferSize(), false /* notify */)
}
if run {
@@ -2763,13 +2788,20 @@ func (e *endpoint) updateSndBufferUsage(v int) {
e.sndQueueInfo.sndQueueMu.Lock()
notify := e.sndQueueInfo.SndBufUsed >= sendBufferSize>>1
e.sndQueueInfo.SndBufUsed -= v
+
+ // Get the new send buffer size with auto tuning, but do not set it
+ // unless we decide to notify the writers.
+ newSndBufSz := e.computeTCPSendBufferSize()
+
// We only notify when there is half the sendBufferSize available after
// a full buffer event occurs. This ensures that we don't wake up
// writers to queue just 1-2 segments and go back to sleep.
- notify = notify && e.sndQueueInfo.SndBufUsed < sendBufferSize>>1
+ notify = notify && e.sndQueueInfo.SndBufUsed < int(newSndBufSz)>>1
e.sndQueueInfo.sndQueueMu.Unlock()
if notify {
+ // Set the new send buffer size calculated from auto tuning.
+ e.ops.SetSendBufferSize(newSndBufSz, false /* notify */)
e.waiterQueue.Notify(waiter.WritableEvents)
}
}
@@ -2896,17 +2928,22 @@ func tcpTimeStamp(curTime tcpip.MonotonicTime, offset uint32) uint32 {
// timeStampOffset returns a randomized timestamp offset to be used when sending
// timestamp values in a timestamp option for a TCP segment.
-func timeStampOffset(rng *rand.Rand) uint32 {
+func timeStampOffset(secret uint32, src, dst tcpip.Address) uint32 {
// Initialize a random tsOffset that will be added to the recentTS
// everytime the timestamp is sent when the Timestamp option is enabled.
//
// See https://tools.ietf.org/html/rfc7323#section-5.4 for details on
// why this is required.
//
- // NOTE: This is not completely to spec as normally this should be
- // initialized in a manner analogous to how sequence numbers are
- // randomized per connection basis. But for now this is sufficient.
- return rng.Uint32()
+ // TODO(https://gvisor.dev/issues/6473): This is not really secure as
+ // it does not use the recommended algorithm linked above.
+ h := jenkins.Sum32(secret)
+ // Per hash.Hash.Writer:
+ //
+ // It never returns an error.
+ _, _ = h.Write([]byte(src))
+ _, _ = h.Write([]byte(dst))
+ return h.Sum32()
}
// maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint
@@ -3091,3 +3128,36 @@ func GetTCPReceiveBufferLimits(s tcpip.StackHandler) tcpip.ReceiveBufferSizeOpti
Max: ss.Max,
}
}
+
+// computeTCPSendBufferSize implements auto tuning of send buffer size and
+// returns the new send buffer size.
+func (e *endpoint) computeTCPSendBufferSize() int64 {
+ curSndBufSz := int64(e.getSendBufferSize())
+
+ // Auto tuning is disabled when the user explicitly sets the send
+ // buffer size with SO_SNDBUF option.
+ if disabled := atomic.LoadUint32(&e.sndQueueInfo.TCPSndBufState.AutoTuneSndBufDisabled); disabled == 1 {
+ return curSndBufSz
+ }
+
+ const packetOverheadFactor = 2
+ curMSS := e.snd.MaxPayloadSize
+ numSeg := InitialCwnd
+ if numSeg < e.snd.SndCwnd {
+ numSeg = e.snd.SndCwnd
+ }
+
+ // SndCwnd indicates the number of segments that can be sent. This means
+ // that the sender can send upto #SndCwnd segments and the send buffer
+ // size should be set to SndCwnd*MSS to accommodate sending of all the
+ // segments.
+ newSndBufSz := int64(numSeg * curMSS * packetOverheadFactor)
+ if newSndBufSz < curSndBufSz {
+ return curSndBufSz
+ }
+ if ss := GetTCPSendBufferLimits(e.stack); int64(ss.Max) < newSndBufSz {
+ newSndBufSz = int64(ss.Max)
+ }
+
+ return newSndBufSz
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 952ccacdd..f2e8b3840 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -170,6 +170,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
snd.probeTimer.init(s.Clock(), &snd.probeWaker)
}
e.stack = s
+ e.protocol = protocolFromStack(s)
e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits)
e.segmentQueue.thaw()
epState := EndpointState(e.origEndpointState)
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 2e709ed78..78745ea86 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -54,7 +54,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
maxInFlight: maxInFlight,
handler: handler,
inFlight: make(map[stack.TransportEndpointID]struct{}),
- listen: newListenContext(s, nil /* listenEP */, seqnum.Size(rcvWnd), true, 0),
+ listen: newListenContext(s, protocolFromStack(s), nil /* listenEP */, seqnum.Size(rcvWnd), true, 0),
}
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 18b834243..00a083dbe 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -96,6 +96,11 @@ type protocol struct {
maxRetries uint32
synRetries uint8
dispatcher dispatcher
+
+ // The following secrets are initialized once and stay unchanged after.
+ seqnumSecret uint32
+ portOffsetSecret uint32
+ tsOffsetSecret uint32
}
// Number returns the tcp protocol number.
@@ -105,7 +110,7 @@ func (*protocol) Number() tcpip.TransportProtocolNumber {
// NewEndpoint creates a new tcp endpoint.
func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
- return newEndpoint(p.stack, netProto, waiterQueue), nil
+ return newEndpoint(p.stack, p, netProto, waiterQueue), nil
}
// NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently
@@ -292,22 +297,26 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip
case *tcpip.TCPMinRTOOption:
p.mu.Lock()
+ defer p.mu.Unlock()
if *v < 0 {
p.minRTO = MinRTO
+ } else if minRTO := time.Duration(*v); minRTO <= p.maxRTO {
+ p.minRTO = minRTO
} else {
- p.minRTO = time.Duration(*v)
+ return &tcpip.ErrInvalidOptionValue{}
}
- p.mu.Unlock()
return nil
case *tcpip.TCPMaxRTOOption:
p.mu.Lock()
+ defer p.mu.Unlock()
if *v < 0 {
p.maxRTO = MaxRTO
+ } else if maxRTO := time.Duration(*v); maxRTO >= p.minRTO {
+ p.maxRTO = maxRTO
} else {
- p.maxRTO = time.Duration(*v)
+ return &tcpip.ErrInvalidOptionValue{}
}
- p.mu.Unlock()
return nil
case *tcpip.TCPMaxRetriesOption:
@@ -479,7 +488,15 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol {
maxRTO: MaxRTO,
maxRetries: MaxRetries,
recovery: tcpip.TCPRACKLossDetection,
+ seqnumSecret: s.Rand().Uint32(),
+ portOffsetSecret: s.Rand().Uint32(),
+ tsOffsetSecret: s.Rand().Uint32(),
}
p.dispatcher.init(s.Rand(), runtime.GOMAXPROCS(0))
return &p
}
+
+// protocolFromStack retrieves the tcp.protocol instance from stack s.
+func protocolFromStack(s *stack.Stack) *protocol {
+ return s.TransportProtocolInstance(ProtocolNumber).(*protocol)
+}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 92a66f17e..a1f1c4e59 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -382,6 +382,9 @@ func (s *sender) updateRTO(rtt time.Duration) {
if s.RTO < s.minRTO {
s.RTO = s.minRTO
}
+ if s.RTO > s.maxRTO {
+ s.RTO = s.maxRTO
+ }
}
// resendSegment resends the first unacknowledged segment.
@@ -1415,9 +1418,6 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
ackLeft -= datalen
}
- // Update the send buffer usage and notify potential waiters.
- s.ep.updateSndBufferUsage(int(acked))
-
// Clear SACK information for all acked data.
s.ep.scoreboard.Delete(s.SndUna)
@@ -1437,6 +1437,9 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
}
}
+ // Update the send buffer usage and notify potential waiters.
+ s.ep.updateSndBufferUsage(int(acked))
+
// It is possible for s.outstanding to drop below zero if we get
// a retransmit timeout, reset outstanding to zero but later
// get an ack that cover previously sent data.
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
index 89e9fb886..c35db7c95 100644
--- a/pkg/tcpip/transport/tcp/tcp_rack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -33,7 +33,6 @@ const (
tsOptionSize = 12
maxTCPOptionSize = 40
mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload
- latency = 5 * time.Millisecond
)
func setStackTCPRecovery(t *testing.T, c *context.Context, recovery int) {
@@ -163,7 +162,10 @@ func sendAndReceiveWithSACK(t *testing.T, c *context.Context, numPackets int, en
if !enableRACK {
setStackTCPRecovery(t, c, 0)
}
- createConnectedWithSACKAndTS(c)
+ // The delay should be below initial RTO (1s) otherwise retransimission
+ // will start. Choose a relatively large value so that estimated RTT
+ // keeps high even after a few rounds of undelayed RTT samples.
+ c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true}, 800*time.Millisecond /* delay */)
data := make([]byte, numPackets*maxPayload)
for i := range data {
@@ -181,9 +183,6 @@ func sendAndReceiveWithSACK(t *testing.T, c *context.Context, numPackets int, en
for i := 0; i < numPackets; i++ {
c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
bytesRead += maxPayload
- // This delay is added to increase RTT as low RTT can cause TLP
- // before sending ACK.
- time.Sleep(latency)
}
return data
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index 83e0653b9..6255355bb 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -35,13 +35,13 @@ import (
// SACKPermitted option enabled if the stack in the context has the SACK support
// enabled.
func createConnectedWithSACKPermittedOption(c *context.Context) *context.RawEndpoint {
- return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()})
+ return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()})
}
// createConnectedWithSACKAndTS creates and connects c.ep with the SACK & TS
// option enabled if the stack in the context has SACK and TS enabled.
func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint {
- return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true})
+ return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true})
}
func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) {
@@ -108,7 +108,7 @@ func TestSackDisabledConnect(t *testing.T) {
setStackSACKPermitted(t, c, sackEnabled)
setStackTCPRecovery(t, c, 0)
- rep := c.CreateConnectedWithOptions(header.TCPSynOptions{})
+ rep := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{})
data := []byte{1, 2, 3}
@@ -170,7 +170,7 @@ func TestSackPermittedAccept(t *testing.T) {
setStackSACKPermitted(t, c, sackEnabled)
setStackTCPRecovery(t, c, 0)
- rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted})
+ rep := c.AcceptWithOptionsNoDelay(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted})
// Now verify no SACK blocks are
// received when sack is disabled.
data := []byte{1, 2, 3}
@@ -244,7 +244,7 @@ func TestSackDisabledAccept(t *testing.T) {
setStackSACKPermitted(t, c, sackEnabled)
setStackTCPRecovery(t, c, 0)
- rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
+ rep := c.AcceptWithOptionsNoDelay(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
// Now verify no SACK blocks are
// received when sack is disabled.
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 031f01357..bf726e86a 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -28,6 +28,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
@@ -2143,7 +2144,7 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) {
t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
}
- c.AcceptWithOptions(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS})
+ c.AcceptWithOptionsNoDelay(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS})
// Bump up the receive buffer size such that, when the receive window grows,
// the scaled window exceeds maxUint16.
@@ -2535,7 +2536,7 @@ func TestScaledWindowAccept(t *testing.T) {
// Do 3-way handshake.
// wndScale expected is 3 as 65535 * 3 * 2 < 65535 * 2^3 but > 65535 *2 *2
- c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS})
+ c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS}, 0 /* delay */)
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
@@ -3532,6 +3533,12 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
+ // Wait for the connection to timeout after MaxRetries retransmits.
+ initRTO := time.Second
+ minRTOOpt := tcpip.TCPMinRTOOption(initRTO)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
+ }
c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
@@ -3554,8 +3561,6 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
),
)
}
- // Wait for the connection to timeout after MaxRetries retransmits.
- initRTO := 1 * time.Second
select {
case <-notifyCh:
case <-time.After((2 << numRetries) * initRTO):
@@ -3590,9 +3595,13 @@ func TestMaxRTO(t *testing.T) {
defer c.Cleanup()
rto := 1 * time.Second
- opt := tcpip.TCPMaxRTOOption(rto)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
+ minRTOOpt := tcpip.TCPMinRTOOption(rto / 2)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
+ }
+ maxRTOOpt := tcpip.TCPMaxRTOOption(rto)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &maxRTOOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, maxRTOOpt, maxRTOOpt, err)
}
c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
@@ -3618,8 +3627,8 @@ func TestMaxRTO(t *testing.T) {
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
- if time.Since(start).Round(time.Second).Seconds() != rto.Seconds() {
- t.Errorf("Retransmit interval not capped to MaxRTO.\n")
+ if elapsed := time.Since(start); elapsed.Round(time.Second).Seconds() != rto.Seconds() {
+ t.Errorf("Retransmit interval not capped to MaxRTO(%s). %s", rto, elapsed)
}
}
}
@@ -3670,6 +3679,10 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ minRTOOpt := tcpip.TCPMinRTOOption(time.Second)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
+ }
c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
// Disabling PMTU discovery causes all packets sent from this socket to
@@ -6304,7 +6317,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
- c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS})
+ c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS}, 0 /* delay */)
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
@@ -6385,7 +6398,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// maximum buffer size defined above.
c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
- rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
+ rawEP := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, WS: 4})
// NOTE: The timestamp values in the sent packets are meaningless to the
// peer so we just increment the timestamp value by 1 every batch as we
@@ -6515,7 +6528,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
// maximum buffer size used by stack.
c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
- rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
+ rawEP := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, WS: 4})
tsVal := rawEP.TSVal
rawEP.NextSeqNum--
rawEP.SendPacketWithTS(nil, tsVal)
@@ -7430,6 +7443,11 @@ func TestTCPUserTimeout(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ initRTO := 1 * time.Second
+ minRTOOpt := tcpip.TCPMinRTOOption(initRTO)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
+ }
c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
@@ -7440,7 +7458,6 @@ func TestTCPUserTimeout(t *testing.T) {
// Ensure that on the next retransmit timer fire, the user timeout has
// expired.
- initRTO := 1 * time.Second
userTimeout := initRTO / 2
v := tcpip.TCPUserTimeoutOption(userTimeout)
if err := c.EP.SetSockOpt(&v); err != nil {
@@ -7954,6 +7971,151 @@ func TestSetStackTimeWaitReuse(t *testing.T) {
}
}
+func TestHandshakeRTT(t *testing.T) {
+ type testCase struct {
+ connect bool
+ tsEnabled bool
+ useCookie bool
+ retrans bool
+ delay time.Duration
+ wantRTT time.Duration
+ }
+ var testCases []testCase
+ for _, connect := range []bool{false, true} {
+ for _, tsEnabled := range []bool{false, true} {
+ for _, useCookie := range []bool{false, true} {
+ for _, retrans := range []bool{false, true} {
+ if connect && useCookie {
+ continue
+ }
+ delay := 800 * time.Millisecond
+ if retrans {
+ delay = 1200 * time.Millisecond
+ }
+ wantRTT := delay
+ // If syncookie is enabled, sample RTT only when TS option is enabled.
+ if !retrans && useCookie && !tsEnabled {
+ wantRTT = 0
+ }
+ // If retransmitted, sample RTT only when TS option is enabled.
+ if retrans && !tsEnabled {
+ wantRTT = 0
+ }
+ testCases = append(testCases, testCase{connect, tsEnabled, useCookie, retrans, delay, wantRTT})
+ }
+ }
+ }
+ }
+ for _, tt := range testCases {
+ tt := tt
+ t.Run(fmt.Sprintf("connect=%t,TS=%t,cookie=%t,retrans=%t)", tt.connect, tt.tsEnabled, tt.useCookie, tt.retrans), func(t *testing.T) {
+ t.Parallel()
+ c := context.New(t, defaultMTU)
+ if tt.useCookie {
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
+ }
+ }
+ synOpts := header.TCPSynOptions{}
+ if tt.tsEnabled {
+ synOpts.TS = true
+ synOpts.TSVal = 42
+ }
+ if tt.connect {
+ c.CreateConnectedWithOptions(synOpts, tt.delay)
+ } else {
+ synOpts.MSS = defaultIPv4MSS
+ synOpts.WS = -1
+ c.AcceptWithOptions(-1, synOpts, tt.delay)
+ }
+ var info tcpip.TCPInfoOption
+ if err := c.EP.GetSockOpt(&info); err != nil {
+ t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err)
+ }
+ if got := time.Duration(info.RTT).Round(tt.wantRTT); got != tt.wantRTT {
+ t.Fatalf("got info.RTT=%s, expect %s", got, tt.wantRTT)
+ }
+ if info.RTTVar != 0 && tt.wantRTT == 0 {
+ t.Fatalf("got info.RTTVar=%s, expect 0", info.RTTVar)
+ }
+ if info.RTTVar == 0 && tt.wantRTT != 0 {
+ t.Fatalf("got info.RTTVar=0, expect non zero")
+ }
+ })
+ }
+}
+
+func TestSetRTO(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ minRTO, maxRTO := tcpRTOMinMax(t, c)
+ for _, tt := range []struct {
+ name string
+ RTO time.Duration
+ minRTO time.Duration
+ maxRTO time.Duration
+ err tcpip.Error
+ }{
+ {
+ name: "invalid minRTO",
+ minRTO: maxRTO + time.Second,
+ err: &tcpip.ErrInvalidOptionValue{},
+ },
+ {
+ name: "invalid maxRTO",
+ maxRTO: minRTO - time.Millisecond,
+ err: &tcpip.ErrInvalidOptionValue{},
+ },
+ {
+ name: "valid minRTO",
+ minRTO: maxRTO - time.Second,
+ },
+ {
+ name: "valid maxRTO",
+ maxRTO: minRTO + time.Millisecond,
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ var opt tcpip.SettableTransportProtocolOption
+ if tt.minRTO > 0 {
+ min := tcpip.TCPMinRTOOption(tt.minRTO)
+ opt = &min
+ }
+ if tt.maxRTO > 0 {
+ max := tcpip.TCPMaxRTOOption(tt.maxRTO)
+ opt = &max
+ }
+ err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt)
+ if got, want := err, tt.err; got != want {
+ t.Fatalf("c.Stack().SetTransportProtocolOption(TCP, &%T(%v)) = %v, want = %v", opt, opt, got, want)
+ }
+ if tt.err == nil {
+ minRTO, maxRTO := tcpRTOMinMax(t, c)
+ if tt.minRTO > 0 && tt.minRTO != minRTO {
+ t.Fatalf("got minRTO = %s, want %s", minRTO, tt.minRTO)
+ }
+ if tt.maxRTO > 0 && tt.maxRTO != maxRTO {
+ t.Fatalf("got maxRTO = %s, want %s", maxRTO, tt.maxRTO)
+ }
+ }
+ })
+ }
+}
+
+func tcpRTOMinMax(t *testing.T, c *context.Context) (time.Duration, time.Duration) {
+ t.Helper()
+ var minOpt tcpip.TCPMinRTOOption
+ var maxOpt tcpip.TCPMaxRTOOption
+ if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &minOpt); err != nil {
+ t.Fatalf("c.Stack().TransportProtocolOption(TCP, %T): %s", minOpt, err)
+ }
+ if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &maxOpt); err != nil {
+ t.Fatalf("c.Stack().TransportProtocolOption(TCP, %T): %s", maxOpt, err)
+ }
+ return time.Duration(minOpt), time.Duration(maxOpt)
+}
+
// generateRandomPayload generates a random byte slice of the specified length
// causing a fatal test failure if it is unable to do so.
func generateRandomPayload(t *testing.T, n int) []byte {
@@ -7964,3 +8126,185 @@ func generateRandomPayload(t *testing.T, n int) []byte {
}
return buf
}
+
+func TestSendBufferTuning(t *testing.T) {
+ const maxPayload = 536
+ const mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload
+ const packetOverheadFactor = 2
+
+ testCases := []struct {
+ name string
+ autoTuningDisabled bool
+ }{
+ {"autoTuningDisabled", true},
+ {"autoTuningEnabled", false},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ // Set the stack option for send buffer size.
+ const defaultSndBufSz = maxPayload * tcp.InitialCwnd
+ const maxSndBufSz = defaultSndBufSz * 10
+ {
+ opt := tcpip.TCPSendBufferSizeRangeOption{Min: 1, Default: defaultSndBufSz, Max: maxSndBufSz}
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
+ }
+
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
+
+ oldSz := c.EP.SocketOptions().GetSendBufferSize()
+ if oldSz != defaultSndBufSz {
+ t.Fatalf("Wrong send buffer size got %d want %d", oldSz, defaultSndBufSz)
+ }
+
+ if tc.autoTuningDisabled {
+ c.EP.SocketOptions().SetSendBufferSize(defaultSndBufSz, true /* notify */)
+ }
+
+ data := make([]byte, maxPayload)
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ w, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&w, waiter.WritableEvents)
+ defer c.WQ.EventUnregister(&w)
+
+ bytesRead := 0
+ for {
+ // Packets will be sent till the send buffer
+ // size is reached.
+ var r bytes.Reader
+ r.Reset(data[bytesRead : bytesRead+maxPayload])
+ _, err := c.EP.Write(&r, tcpip.WriteOptions{})
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
+ break
+ }
+
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, 0)
+ bytesRead += maxPayload
+ data = append(data, data...)
+ }
+
+ // Send an ACK and wait for connection to become writable again.
+ c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead)
+ select {
+ case <-ch:
+ if err := c.EP.LastError(); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for connection")
+ }
+
+ outSz := int64(defaultSndBufSz)
+ if !tc.autoTuningDisabled {
+ // Calculate the new auto tuned send buffer.
+ var info tcpip.TCPInfoOption
+ if err := c.EP.GetSockOpt(&info); err != nil {
+ t.Fatalf("GetSockOpt failed: %v", err)
+ }
+ outSz = (int64(info.SndCwnd) * packetOverheadFactor * (maxPayload))
+ }
+
+ if newSz := c.EP.SocketOptions().GetSendBufferSize(); newSz != outSz {
+ t.Fatalf("Wrong send buffer size, got %d want %d", newSz, outSz)
+ }
+ })
+ }
+}
+
+func TestTimestampSynCookies(t *testing.T) {
+ clock := faketime.NewManualClock()
+ c := context.NewWithOpts(t, context.Options{
+ EnableV4: true,
+ EnableV6: true,
+ MTU: defaultMTU,
+ Clock: clock,
+ })
+ defer c.Cleanup()
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
+ }
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ defer ep.Close()
+
+ tcpOpts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
+ header.EncodeTSOption(42, 0, tcpOpts[2:])
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ RcvWnd: seqnum.Size(512),
+ SeqNum: iss,
+ TCPOpts: tcpOpts[:],
+ })
+ // Get the TSVal of SYN-ACK.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+ initialTSVal := tcpHdr.ParsedOptions().TSVal
+
+ header.EncodeTSOption(420, initialTSVal, tcpOpts[2:])
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ RcvWnd: seqnum.Size(512),
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ TCPOpts: tcpOpts[:],
+ })
+ c.EP, _, err = ep.Accept(nil)
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.ReadableEvents)
+ defer wq.EventUnregister(&we)
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept(nil)
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ } else if err != nil {
+ t.Fatalf("failed to accept: %s", err)
+ }
+
+ const elapsed = 200 * time.Millisecond
+ clock.Advance(elapsed)
+ data := []byte{1, 2, 3}
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // The endpoint should have a correct TSOffset so that the received TSVal
+ // should match our expectation.
+ if got, want := header.TCP(header.IPv4(c.GetPacket()).Payload()).ParsedOptions().TSVal, initialTSVal+uint32(elapsed.Milliseconds()); got != want {
+ t.Fatalf("got TSVal = %d, want %d", got, want)
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index 1deb1fe4d..65925daa5 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -32,7 +32,7 @@ import (
// createConnectedWithTimestampOption creates and connects c.ep with the
// timestamp option enabled.
func createConnectedWithTimestampOption(c *context.Context) *context.RawEndpoint {
- return c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, TSVal: 1})
+ return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, TSVal: 1})
}
// TestTimeStampEnabledConnect tests that netstack sends the timestamp option on
@@ -131,7 +131,7 @@ func TestTimeStampDisabledConnect(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnectedWithOptions(header.TCPSynOptions{})
+ c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{})
}
func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) {
@@ -147,7 +147,7 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS
t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
tsVal := rand.Uint32()
- c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal})
+ c.AcceptWithOptionsNoDelay(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal})
// Now send some data and validate that timestamp is echoed correctly in the ACK.
data := []byte{1, 2, 3}
@@ -209,7 +209,7 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd
}
t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
- c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
+ c.AcceptWithOptionsNoDelay(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
// Now send some data with the accepted connection endpoint and validate
// that no timestamp option is sent in the TCP segment.
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 96e4849d2..6e55a7a32 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -122,6 +122,9 @@ type Options struct {
// MTU indicates the maximum transmission unit on the link layer.
MTU uint32
+
+ // Clock that is used by Stack.
+ Clock tcpip.Clock
}
// Context provides an initialized Network stack and a link layer endpoint
@@ -182,6 +185,7 @@ func NewWithOpts(t *testing.T, opts Options) *Context {
stackOpts := stack.Options{
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ Clock: opts.Clock,
}
if opts.EnableV4 {
stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv4.NewProtocol)
@@ -879,13 +883,21 @@ func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) {
)
}
+// CreateConnectedWithOptionsNoDelay just calls CreateConnectedWithOptions
+// without delay.
+func (c *Context) CreateConnectedWithOptionsNoDelay(wantOptions header.TCPSynOptions) *RawEndpoint {
+ return c.CreateConnectedWithOptions(wantOptions, 0 /* delay */)
+}
+
// CreateConnectedWithOptions creates and connects c.ep with the specified TCP
// options enabled and returns a RawEndpoint which represents the other end of
-// the connection.
+// the connection. It delays before a SYNACK is sent. This makes c.EP have a
+// higher RTT estimate so that spurious TLPs aren't sent in tests, which helps
+// reduce flakiness.
//
// It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK
// does not carry an option that was not requested.
-func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint {
+func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint {
var err tcpip.Error
c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
@@ -911,18 +923,17 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
// TS value.
mss := uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize)
- checker.IPv4(c.t, b,
- checker.TCP(
- checker.DstPort(TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- checker.TCPSynOptions(header.TCPSynOptions{
- MSS: mss,
- TS: true,
- WS: int(c.WindowScale),
- SACKPermitted: c.SACKEnabled(),
- }),
- ),
+ synChecker := checker.TCP(
+ checker.DstPort(TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{
+ MSS: mss,
+ TS: true,
+ WS: int(c.WindowScale),
+ SACKPermitted: c.SACKEnabled(),
+ }),
)
+ checker.IPv4(c.t, b, synChecker)
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
}
@@ -948,6 +959,10 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
// Build SYN-ACK.
c.IRS = seqnum.Value(tcpSeg.SequenceNumber())
iss := seqnum.Value(TestInitialSequenceNumber)
+ if delay > 0 {
+ // Sleep so that RTT is increased.
+ time.Sleep(delay)
+ }
c.SendPacket(nil, &Headers{
SrcPort: tcpSeg.DestinationPort(),
DstPort: tcpSeg.SourcePort(),
@@ -959,7 +974,17 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
})
// Read ACK.
- ackPacket := c.GetPacket()
+ var ackPacket []byte
+ // Ignore retransimitted SYN packets.
+ for {
+ packet := c.GetPacket()
+ if header.TCP(header.IPv4(packet).Payload()).Flags()&header.TCPFlagSyn != 0 {
+ checker.IPv4(c.t, packet, synChecker)
+ } else {
+ ackPacket = packet
+ break
+ }
+ }
// Verify TCP header fields.
tcpCheckers := []checker.TransportChecker{
@@ -1016,13 +1041,19 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
}
}
-// AcceptWithOptions initializes a listening endpoint and connects to it with the
-// provided options enabled. It also verifies that the SYN-ACK has the expected
-// values for the provided options.
+// AcceptWithOptionsNoDelay delegates call to AcceptWithOptions without delay.
+func (c *Context) AcceptWithOptionsNoDelay(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
+ return c.AcceptWithOptions(wndScale, synOptions, 0 /* delay */)
+}
+
+// AcceptWithOptions initializes a listening endpoint and connects to it with
+// the provided options enabled. It delays before the final ACK of the 3WHS is
+// sent. It also verifies that the SYN-ACK has the expected values for the
+// provided options.
//
// The function returns a RawEndpoint representing the other end of the accepted
// endpoint.
-func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
+func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint {
// Create EP and start listening.
wq := &waiter.Queue{}
ep, err := c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
@@ -1045,7 +1076,7 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption
c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
}
- rep := c.PassiveConnectWithOptions(100, wndScale, synOptions)
+ rep := c.PassiveConnectWithOptions(100, wndScale, synOptions, delay)
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
@@ -1077,13 +1108,14 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption
// PassiveConnectWithOptions.
func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCPSynOptions) {
synOptions.WS = -1
- c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions)
+ c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions, 0 /* delay */)
}
// PassiveConnectWithOptions initiates a new connection (with the specified TCP
// options enabled) to the port on which the Context.ep is listening for new
// connections. It also validates that the SYN-ACK has the expected values for
-// the enabled options.
+// the enabled options. The final ACK of the handshake is delayed by specified
+// duration.
//
// NOTE: MSS is not a negotiated option and it can be asymmetric
// in each direction. This function uses the maxPayload to set the MSS to be
@@ -1093,7 +1125,7 @@ func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCP
// wndScale is the expected window scale in the SYN-ACK and synOptions.WS is the
// value of the window scaling option to be sent in the SYN. If synOptions.WS >
// 0 then we send the WindowScale option.
-func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
+func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint {
c.t.Helper()
opts := make([]byte, header.TCPOptionsMaximumSize)
offset := 0
@@ -1180,7 +1212,10 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions
ackHeaders.TCPOpts = opts[:]
}
- // Send ACK.
+ // Send ACK, delay if needed.
+ if delay > 0 {
+ time.Sleep(delay)
+ }
c.SendPacket(nil, ackHeaders)
c.RcvdWindowScale = uint8(rcvdSynOptions.WS)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 82a3f2287..108580508 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -266,7 +266,7 @@ func (e *endpoint) Close() {
for mem := range e.multicastMemberships {
e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr)
}
- e.multicastMemberships = make(map[multicastMembership]struct{})
+ e.multicastMemberships = nil
// Close the receive list and drain it.
e.rcvMu.Lock()