diff options
Diffstat (limited to 'pkg/tcpip')
27 files changed, 708 insertions, 165 deletions
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index e0dfe5813..2f34bf8dd 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -729,7 +729,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp return } l := int(opts[i+1]) - if i < 2 || i+l > limit { + if l < 2 || i+l > limit { return } i += l diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go index 95ade0e5c..1f18213e5 100644 --- a/pkg/tcpip/header/eth.go +++ b/pkg/tcpip/header/eth.go @@ -49,9 +49,9 @@ const ( // EthernetAddressSize is the size, in bytes, of an ethernet address. EthernetAddressSize = 6 - // unspecifiedEthernetAddress is the unspecified ethernet address + // UnspecifiedEthernetAddress is the unspecified ethernet address // (all bits set to 0). - unspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00") + UnspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00") // EthernetBroadcastAddress is an ethernet address that addresses every node // on a local link. @@ -134,7 +134,7 @@ func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool { return false } - if addr == unspecifiedEthernetAddress { + if addr == UnspecifiedEthernetAddress { return false } diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go index bf9ccbf1a..adc04e855 100644 --- a/pkg/tcpip/header/eth_test.go +++ b/pkg/tcpip/header/eth_test.go @@ -44,7 +44,7 @@ func TestIsValidUnicastEthernetAddress(t *testing.T) { }, { "Unspecified", - unspecifiedEthernetAddress, + UnspecifiedEthernetAddress, false, }, { @@ -91,7 +91,7 @@ func TestIsMulticastEthernetAddress(t *testing.T) { }, { "Unspecified", - unspecifiedEthernetAddress, + UnspecifiedEthernetAddress, false, }, { diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go index b427c6170..b9db273d0 100644 --- a/pkg/tcpip/link/ethernet/ethernet.go +++ b/pkg/tcpip/link/ethernet/ethernet.go @@ -42,6 +42,14 @@ type Endpoint struct { nested.Endpoint } +// LinkAddress implements stack.LinkEndpoint. +func (e *Endpoint) LinkAddress() tcpip.LinkAddress { + if l := e.Endpoint.LinkAddress(); len(l) != 0 { + return l + } + return header.UnspecifiedEthernetAddress +} + // DeliverNetworkPacket implements stack.NetworkDispatcher. func (e *Endpoint) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) @@ -57,18 +65,22 @@ func (e *Endpoint) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkP // Capabilities implements stack.LinkEndpoint. func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { - return stack.CapabilityResolutionRequired | e.Endpoint.Capabilities() + c := e.Endpoint.Capabilities() + if c&stack.CapabilityLoopback == 0 { + c |= stack.CapabilityResolutionRequired + } + return c } // WritePacket implements stack.LinkEndpoint. func (e *Endpoint) WritePacket(r stack.RouteInfo, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt) + e.AddHeader(e.LinkAddress(), r.RemoteLinkAddress, proto, pkt) return e.Endpoint.WritePacket(r, proto, pkt) } // WritePackets implements stack.LinkEndpoint. func (e *Endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - linkAddr := e.Endpoint.LinkAddress() + linkAddr := e.LinkAddress() for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.AddHeader(linkAddr, r.RemoteLinkAddress, proto, pkt) @@ -83,7 +95,10 @@ func (e *Endpoint) MaxHeaderLength() uint16 { } // ARPHardwareType implements stack.LinkEndpoint. -func (*Endpoint) ARPHardwareType() header.ARPHardwareType { +func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { + if a := e.Endpoint.ARPHardwareType(); a != header.ARPHardwareNone { + return a + } return header.ARPHardwareEther } diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index 4758a99ad..c3e4c3455 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -31,7 +31,6 @@ go_library( "//pkg/refs", "//pkg/refsvfs2", "//pkg/sync", - "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index d23210503..fa2131c28 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -174,7 +173,7 @@ func (d *Device) Write(data []byte) (int64, error) { return 0, linuxerr.EBADFD } if !endpoint.IsAttached() { - return 0, syserror.EIO + return 0, linuxerr.EIO } dataLen := int64(len(data)) @@ -249,7 +248,7 @@ func (d *Device) Read() ([]byte, error) { for { info, ok := endpoint.Read() if !ok { - return nil, syserror.ErrWouldBlock + return nil, linuxerr.ErrWouldBlock } v, ok := d.encodePkt(&info) diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index 6bce3af04..34ac62444 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -57,6 +57,11 @@ type SocketOptionsHandler interface { // OnSetReceiveBufferSize is invoked by SO_RCVBUF and SO_RCVBUFFORCE. OnSetReceiveBufferSize(v, oldSz int64) (newSz int64) + + // WakeupWriters is invoked when the send buffer size for an endpoint is + // changed. The handler notifies the writers if the send buffer size is + // increased with setsockopt(2) for TCP endpoints. + WakeupWriters() } // DefaultSocketOptionsHandler is an embeddable type that implements no-op @@ -98,6 +103,9 @@ func (*DefaultSocketOptionsHandler) OnSetSendBufferSize(v int64) (newSz int64) { return v } +// WakeupWriters implements SocketOptionsHandler.WakeupWriters. +func (*DefaultSocketOptionsHandler) WakeupWriters() {} + // OnSetReceiveBufferSize implements SocketOptionsHandler.OnSetReceiveBufferSize. func (*DefaultSocketOptionsHandler) OnSetReceiveBufferSize(v, oldSz int64) (newSz int64) { return v @@ -626,6 +634,9 @@ func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) { sendBufferSize = so.handler.OnSetSendBufferSize(sendBufferSize) } so.sendBufferSize.Store(sendBufferSize) + if notify { + so.handler.WakeupWriters() + } } // GetReceiveBufferSize gets value for SO_RCVBUF option. diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 9192d8433..29c22bfd4 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -282,14 +282,12 @@ func (pk *PacketBuffer) headerView(typ headerType) tcpipbuffer.View { return v } -// Clone makes a shallow copy of pk. -// -// Clone should be called in such cases so that no modifications is done to -// underlying packet payload. +// Clone makes a semi-deep copy of pk. The underlying packet payload is +// shared. Hence, no modifications is done to underlying packet payload. func (pk *PacketBuffer) Clone() *PacketBuffer { return &PacketBuffer{ PacketBufferEntry: pk.PacketBufferEntry, - buf: pk.buf, + buf: pk.buf.Clone(), reserved: pk.reserved, pushed: pk.pushed, consumed: pk.consumed, @@ -321,14 +319,14 @@ func (pk *PacketBuffer) Network() header.Network { } } -// CloneToInbound makes a shallow copy of the packet buffer to be used as an -// inbound packet. +// CloneToInbound makes a semi-deep copy of the packet buffer (similar to +// Clone) to be used as an inbound packet. // // See PacketBuffer.Data for details about how a packet buffer holds an inbound // packet. func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { newPk := &PacketBuffer{ - buf: pk.buf, + buf: pk.buf.Clone(), // Treat unfilled header portion as reserved. reserved: pk.AvailableHeaderBytes(), } diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go index a8da34992..87b023445 100644 --- a/pkg/tcpip/stack/packet_buffer_test.go +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -123,6 +123,32 @@ func TestPacketHeaderPush(t *testing.T) { } } +func TestPacketBufferClone(t *testing.T) { + data := concatViews(makeView(20), makeView(30), makeView(40)) + pk := NewPacketBuffer(PacketBufferOptions{ + // Make a copy of data to make sure our truth data won't be taint by + // PacketBuffer. + Data: buffer.NewViewFromBytes(data).ToVectorisedView(), + }) + + bytesToDelete := 30 + originalSize := data.Size() + + clonedPks := []*PacketBuffer{ + pk.Clone(), + pk.CloneToInbound(), + } + pk.Data().DeleteFront(bytesToDelete) + if got, want := pk.Data().Size(), originalSize-bytesToDelete; got != want { + t.Errorf("original packet was not changed: size expected = %d, got = %d", want, got) + } + for _, clonedPk := range clonedPks { + if got := clonedPk.Data().Size(); got != originalSize { + t.Errorf("cloned packet should not be modified: expected size = %d, got = %d", originalSize, got) + } + } +} + func TestPacketHeaderConsume(t *testing.T) { for _, test := range []struct { name string diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index c73890c4c..8e5c6edbf 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -119,8 +119,7 @@ type Stack struct { // by the stack. icmpRateLimiter *ICMPRateLimiter - // seed is a one-time random value initialized at stack startup - // and is used to seed the TCP port picking on active connections + // seed is a one-time random value initialized at stack startup. // // TODO(gvisor.dev/issue/940): S/R this field. seed uint32 @@ -161,6 +160,10 @@ type Stack struct { // This is required to prevent potential ACK loops. // Setting this to 0 will disable all rate limiting. tcpInvalidRateLimit time.Duration + + // tsOffsetSecret is the secret key for generating timestamp offsets + // initialized at stack startup. + tsOffsetSecret uint32 } // UniqueID is an abstract generator of unique identifiers. @@ -384,6 +387,7 @@ func New(opts Options) *Stack { Max: DefaultMaxBufferSize, }, tcpInvalidRateLimit: defaultTCPInvalidRateLimit, + tsOffsetSecret: randomGenerator.Uint32(), } // Add specified network protocols. @@ -1819,14 +1823,6 @@ func (s *Stack) SetNUDConfigurations(id tcpip.NICID, proto tcpip.NetworkProtocol return nic.setNUDConfigs(proto, c) } -// Seed returns a 32 bit value that can be used as a seed value for port -// picking, ISN generation etc. -// -// NOTE: The seed is generated once during stack initialization only. -func (s *Stack) Seed() uint32 { - return s.seed -} - // Rand returns a reference to a pseudo random generator that can be used // to generate random numbers as required. func (s *Stack) Rand() *rand.Rand { diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go index 90a8ba6cf..93ea83cdc 100644 --- a/pkg/tcpip/stack/tcp.go +++ b/pkg/tcpip/stack/tcp.go @@ -386,6 +386,12 @@ type TCPSndBufState struct { // SndMTU is the smallest MTU seen in the control packets received. SndMTU int + + // AutoTuneSndBufDisabled indicates that the auto tuning of send buffer + // is disabled. + // + // Must be accessed using atomic operations. + AutoTuneSndBufDisabled uint32 } // TCPEndpointStateInner contains the members of TCPEndpointState used directly diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index dda57e225..824cf6526 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -479,7 +479,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol if !ok { epsByNIC = &endpointsByNIC{ endpoints: make(map[tcpip.NICID]*multiPortEndpoint), - seed: d.stack.Seed(), + seed: d.stack.seed, } } if err := epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice); err != nil { diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index b3d8951ff..55854ba59 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -321,28 +321,26 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } defer route.Release() + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(route.MaxHeaderLength()), + Data: buffer.View(payloadBytes).ToVectorisedView(), + }) + pkt.Owner = owner + if e.ops.GetHeaderIncluded() { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.View(payloadBytes).ToVectorisedView(), - }) if err := route.WriteHeaderIncludedPacket(pkt); err != nil { return 0, err } - } else { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(route.MaxHeaderLength()), - Data: buffer.View(payloadBytes).ToVectorisedView(), - }) - pkt.Owner = owner - if err := route.WritePacket(stack.NetworkHeaderParams{ - Protocol: e.TransProto, - TTL: route.DefaultTTL(), - TOS: stack.DefaultTOS, - }, pkt); err != nil { - return 0, err - } + return int64(len(payloadBytes)), nil } + if err := route.WritePacket(stack.NetworkHeaderParams{ + Protocol: e.TransProto, + TTL: route.DefaultTTL(), + TOS: stack.DefaultTOS, + }, pkt); err != nil { + return 0, err + } return int64(len(payloadBytes)), nil } diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 8436d2cf0..c3922bbe5 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -96,6 +96,7 @@ go_test( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/loopback", "//pkg/tcpip/link/sniffer", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index aa413ad05..9560ed43c 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -72,7 +72,8 @@ func encodeMSS(mss uint16) uint32 { // and must not be accessed or have its methods called concurrently as they // may mutate the stored objects. type listenContext struct { - stack *stack.Stack + stack *stack.Stack + protocol *protocol // rcvWnd is the receive window that is sent by this listening context // in the initial SYN-ACK. @@ -119,9 +120,10 @@ func timeStamp(clock tcpip.Clock) uint32 { } // newListenContext creates a new listen context. -func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { +func newListenContext(stk *stack.Stack, protocol *protocol, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { l := &listenContext{ stack: stk, + protocol: protocol, rcvWnd: rcvWnd, hasher: sha1.New(), v6Only: v6Only, @@ -213,7 +215,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header return nil, err } - n := newEndpoint(l.stack, netProto, queue) + n := newEndpoint(l.stack, l.protocol, netProto, queue) n.ops.SetV6Only(l.v6Only) n.TransportEndpointInfo.ID = s.id n.boundNICID = s.nicID @@ -247,7 +249,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber - isn := generateSecureISN(s.id, l.stack.Clock(), l.stack.Seed()) + isn := generateSecureISN(s.id, l.stack.Clock(), l.protocol.seqnumSecret) ep, err := l.createConnectingEndpoint(s, opts, queue) if err != nil { return nil, err @@ -600,7 +602,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err synOpts := header.TCPSynOptions{ WS: -1, TS: opts.TS, - TSVal: tcpTimeStamp(e.stack.Clock().NowMonotonic(), timeStampOffset(e.stack.Rand())), + TSVal: tcpTimeStamp(e.stack.Clock().NowMonotonic(), timeStampOffset(e.protocol.tsOffsetSecret, s.dstAddr, s.srcAddr)), TSEcr: opts.TSVal, MSS: calculateAdvertisedMSS(e.userMSS, route), } @@ -726,24 +728,24 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err n.isRegistered = true - // clear the tsOffset for the newly created - // endpoint as the Timestamp was already - // randomly offset when the original SYN-ACK was - // sent above. - n.TSOffset = 0 + // Reset the tsOffset for the newly created endpoint to the one + // that we have used in SYN-ACK in order to calculate RTT. + n.TSOffset = timeStampOffset(e.protocol.tsOffsetSecret, s.dstAddr, s.srcAddr) // Switch state to connected. n.isConnectNotified = true - n.transitionToStateEstablishedLocked(&handshake{ - ep: n, - iss: iss, - ackNum: irs + 1, - rcvWnd: seqnum.Size(n.initialReceiveWindow()), - sndWnd: s.window, - rcvWndScale: e.rcvWndScaleForHandshake(), - sndWndScale: rcvdSynOptions.WS, - mss: rcvdSynOptions.MSS, - }) + h := &handshake{ + ep: n, + iss: iss, + ackNum: irs + 1, + rcvWnd: seqnum.Size(n.initialReceiveWindow()), + sndWnd: s.window, + rcvWndScale: e.rcvWndScaleForHandshake(), + sndWndScale: rcvdSynOptions.WS, + mss: rcvdSynOptions.MSS, + sampleRTTWithTSOnly: true, + } + h.transitionToStateEstablishedLocked(s) // Requeue the segment if the ACK completing the handshake has more info // to be procesed by the newly established endpoint. @@ -779,7 +781,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { e.mu.Lock() v6Only := e.ops.GetV6Only() - ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto) + ctx := newListenContext(e.stack, e.protocol, e, rcvWnd, v6Only, e.NetProto) defer func() { // Mark endpoint as closed. This will prevent goroutines running diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 93ed161f9..f85775a48 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -105,6 +105,11 @@ type handshake struct { // sendSYNOpts is the cached values for the SYN options to be sent. sendSYNOpts header.TCPSynOptions + + // sampleRTTWithTSOnly is true when the segment was retransmitted or we can't + // tell; then RTT can only be sampled when the incoming segment has timestamp + // options enabled. + sampleRTTWithTSOnly bool } func (e *endpoint) newHandshake() *handshake { @@ -117,6 +122,8 @@ func (e *endpoint) newHandshake() *handshake { h.resetState() // Store reference to handshake state in endpoint. e.h = h + // By the time handshake is created, e.ID is already initialized. + e.TSOffset = timeStampOffset(e.protocol.tsOffsetSecret, e.ID.LocalAddress, e.ID.RemoteAddress) return h } @@ -150,7 +157,7 @@ func (h *handshake) resetState() { h.flags = header.TCPFlagSyn h.ackNum = 0 h.mss = 0 - h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Clock(), h.ep.stack.Seed()) + h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Clock(), h.ep.protocol.seqnumSecret) } // generateSecureISN generates a secure Initial Sequence number based on the @@ -266,8 +273,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error { // and the handshake is completed. if s.flags.Contains(header.TCPFlagAck) { h.state = handshakeCompleted - - h.ep.transitionToStateEstablishedLocked(h) + h.transitionToStateEstablishedLocked(s) h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale()) return nil @@ -402,9 +408,10 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { if h.ep.SendTSOk && s.parsedOptions.TS { h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber) } + h.state = handshakeCompleted - h.ep.transitionToStateEstablishedLocked(h) + h.transitionToStateEstablishedLocked(s) // Requeue the segment if the ACK completing the handshake has more info // to be procesed by the newly established endpoint. @@ -557,6 +564,10 @@ func (h *handshake) complete() tcpip.Error { ack: h.ackNum, rcvWnd: h.rcvWnd, }, h.sendSYNOpts) + // If we have ever retransmitted the SYN-ACK or + // SYN segment, we should only measure RTT if + // TS option is present. + h.sampleRTTWithTSOnly = true } case wakerForNotification: @@ -600,6 +611,38 @@ func (h *handshake) complete() tcpip.Error { return nil } +// transitionToStateEstablisedLocked transitions the endpoint of the handshake +// to an established state given the last segment received from peer. It also +// initializes sender/receiver. +func (h *handshake) transitionToStateEstablishedLocked(s *segment) { + // Transfer handshake state to TCP connection. We disable + // receive window scaling if the peer doesn't support it + // (indicated by a negative send window scale). + h.ep.snd = newSender(h.ep, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) + + var rtt time.Duration + if h.ep.SendTSOk && s.parsedOptions.TSEcr != 0 { + rtt = time.Duration(h.ep.timestamp()-s.parsedOptions.TSEcr) * time.Millisecond + } + if !h.sampleRTTWithTSOnly && rtt == 0 { + rtt = h.ep.stack.Clock().NowMonotonic().Sub(h.startTime) + } + + if rtt > 0 { + h.ep.snd.updateRTO(rtt) + } + + h.ep.rcvQueueInfo.rcvQueueMu.Lock() + h.ep.rcv = newReceiver(h.ep, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale()) + // Bootstrap the auto tuning algorithm. Starting at zero will + // result in a really large receive window after the first auto + // tuning adjustment. + h.ep.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = int(h.rcvWnd) + h.ep.rcvQueueInfo.rcvQueueMu.Unlock() + + h.ep.setEndpointState(StateEstablished) +} + type backoffTimer struct { timeout time.Duration maxTimeout time.Duration @@ -965,26 +1008,6 @@ func (e *endpoint) completeWorkerLocked() { } } -// transitionToStateEstablisedLocked transitions a given endpoint -// to an established state using the handshake parameters provided. -// It also initializes sender/receiver. -func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { - // Transfer handshake state to TCP connection. We disable - // receive window scaling if the peer doesn't support it - // (indicated by a negative send window scale). - e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) - - e.rcvQueueInfo.rcvQueueMu.Lock() - e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale()) - // Bootstrap the auto tuning algorithm. Starting at zero will - // result in a really large receive window after the first auto - // tuning adjustment. - e.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = int(h.rcvWnd) - e.rcvQueueInfo.rcvQueueMu.Unlock() - - e.setEndpointState(StateEstablished) -} - // transitionToStateCloseLocked ensures that the endpoint is // cleaned up from the transport demuxer, "before" moving to // StateClose. This will ensure that no packet will be diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 044123185..4937d126f 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -20,7 +20,6 @@ import ( "fmt" "io" "math" - "math/rand" "runtime" "strings" "sync/atomic" @@ -378,6 +377,7 @@ type endpoint struct { // The following fields are initialized at creation time and do not // change throughout the lifetime of the endpoint. stack *stack.Stack `state:"manual"` + protocol *protocol `state:"manual"` waiterQueue *waiter.Queue `state:"wait"` uniqueID uint64 @@ -803,9 +803,10 @@ type keepalive struct { waker sleep.Waker `state:"nosave"` } -func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { +func newEndpoint(s *stack.Stack, protocol *protocol, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { e := &endpoint{ - stack: s, + stack: s, + protocol: protocol, TransportEndpointInfo: stack.TransportEndpointInfo{ NetProto: netProto, TransProto: header.TCPProtocolNumber, @@ -874,7 +875,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue } e.segmentQueue.ep = e - e.TSOffset = timeStampOffset(e.stack.Rand()) + e.acceptCond = sync.NewCond(&e.acceptMu) e.keepalive.timer.init(e.stack.Clock(), &e.keepalive.waker) @@ -1717,6 +1718,27 @@ func (e *endpoint) OnSetReceiveBufferSize(rcvBufSz, oldSz int64) (newSz int64) { return rcvBufSz } +// OnSetSendBufferSize implements tcpip.SocketOptionsHandler.OnSetSendBufferSize. +func (e *endpoint) OnSetSendBufferSize(sz int64) int64 { + atomic.StoreUint32(&e.sndQueueInfo.TCPSndBufState.AutoTuneSndBufDisabled, 1) + return sz +} + +// WakeupWriters implements tcpip.SocketOptionsHandler.WakeupWriters. +func (e *endpoint) WakeupWriters() { + e.LockUser() + defer e.UnlockUser() + + sendBufferSize := e.getSendBufferSize() + e.sndQueueInfo.sndQueueMu.Lock() + notify := (sendBufferSize - e.sndQueueInfo.SndBufUsed) >= e.sndQueueInfo.SndBufUsed>>1 + e.sndQueueInfo.sndQueueMu.Unlock() + + if notify { + e.waiterQueue.Notify(waiter.WritableEvents) + } +} + // SetSockOptInt sets a socket option. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 @@ -2177,7 +2199,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp portBuf := make([]byte, 2) binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort) - h := jenkins.Sum32(e.stack.Seed()) + h := jenkins.Sum32(e.protocol.portOffsetSecret) for _, s := range [][]byte{ []byte(e.ID.LocalAddress), []byte(e.ID.RemoteAddress), @@ -2329,6 +2351,9 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp e.segmentQueue.mu.Unlock() e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0) e.setEndpointState(StateEstablished) + // Set the new auto tuned send buffer size after entering + // established state. + e.ops.SetSendBufferSize(e.computeTCPSendBufferSize(), false /* notify */) } if run { @@ -2763,13 +2788,20 @@ func (e *endpoint) updateSndBufferUsage(v int) { e.sndQueueInfo.sndQueueMu.Lock() notify := e.sndQueueInfo.SndBufUsed >= sendBufferSize>>1 e.sndQueueInfo.SndBufUsed -= v + + // Get the new send buffer size with auto tuning, but do not set it + // unless we decide to notify the writers. + newSndBufSz := e.computeTCPSendBufferSize() + // We only notify when there is half the sendBufferSize available after // a full buffer event occurs. This ensures that we don't wake up // writers to queue just 1-2 segments and go back to sleep. - notify = notify && e.sndQueueInfo.SndBufUsed < sendBufferSize>>1 + notify = notify && e.sndQueueInfo.SndBufUsed < int(newSndBufSz)>>1 e.sndQueueInfo.sndQueueMu.Unlock() if notify { + // Set the new send buffer size calculated from auto tuning. + e.ops.SetSendBufferSize(newSndBufSz, false /* notify */) e.waiterQueue.Notify(waiter.WritableEvents) } } @@ -2896,17 +2928,22 @@ func tcpTimeStamp(curTime tcpip.MonotonicTime, offset uint32) uint32 { // timeStampOffset returns a randomized timestamp offset to be used when sending // timestamp values in a timestamp option for a TCP segment. -func timeStampOffset(rng *rand.Rand) uint32 { +func timeStampOffset(secret uint32, src, dst tcpip.Address) uint32 { // Initialize a random tsOffset that will be added to the recentTS // everytime the timestamp is sent when the Timestamp option is enabled. // // See https://tools.ietf.org/html/rfc7323#section-5.4 for details on // why this is required. // - // NOTE: This is not completely to spec as normally this should be - // initialized in a manner analogous to how sequence numbers are - // randomized per connection basis. But for now this is sufficient. - return rng.Uint32() + // TODO(https://gvisor.dev/issues/6473): This is not really secure as + // it does not use the recommended algorithm linked above. + h := jenkins.Sum32(secret) + // Per hash.Hash.Writer: + // + // It never returns an error. + _, _ = h.Write([]byte(src)) + _, _ = h.Write([]byte(dst)) + return h.Sum32() } // maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint @@ -3091,3 +3128,36 @@ func GetTCPReceiveBufferLimits(s tcpip.StackHandler) tcpip.ReceiveBufferSizeOpti Max: ss.Max, } } + +// computeTCPSendBufferSize implements auto tuning of send buffer size and +// returns the new send buffer size. +func (e *endpoint) computeTCPSendBufferSize() int64 { + curSndBufSz := int64(e.getSendBufferSize()) + + // Auto tuning is disabled when the user explicitly sets the send + // buffer size with SO_SNDBUF option. + if disabled := atomic.LoadUint32(&e.sndQueueInfo.TCPSndBufState.AutoTuneSndBufDisabled); disabled == 1 { + return curSndBufSz + } + + const packetOverheadFactor = 2 + curMSS := e.snd.MaxPayloadSize + numSeg := InitialCwnd + if numSeg < e.snd.SndCwnd { + numSeg = e.snd.SndCwnd + } + + // SndCwnd indicates the number of segments that can be sent. This means + // that the sender can send upto #SndCwnd segments and the send buffer + // size should be set to SndCwnd*MSS to accommodate sending of all the + // segments. + newSndBufSz := int64(numSeg * curMSS * packetOverheadFactor) + if newSndBufSz < curSndBufSz { + return curSndBufSz + } + if ss := GetTCPSendBufferLimits(e.stack); int64(ss.Max) < newSndBufSz { + newSndBufSz = int64(ss.Max) + } + + return newSndBufSz +} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 952ccacdd..f2e8b3840 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -170,6 +170,7 @@ func (e *endpoint) Resume(s *stack.Stack) { snd.probeTimer.init(s.Clock(), &snd.probeWaker) } e.stack = s + e.protocol = protocolFromStack(s) e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits) e.segmentQueue.thaw() epState := EndpointState(e.origEndpointState) diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 2e709ed78..78745ea86 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -54,7 +54,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward maxInFlight: maxInFlight, handler: handler, inFlight: make(map[stack.TransportEndpointID]struct{}), - listen: newListenContext(s, nil /* listenEP */, seqnum.Size(rcvWnd), true, 0), + listen: newListenContext(s, protocolFromStack(s), nil /* listenEP */, seqnum.Size(rcvWnd), true, 0), } } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 18b834243..00a083dbe 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -96,6 +96,11 @@ type protocol struct { maxRetries uint32 synRetries uint8 dispatcher dispatcher + + // The following secrets are initialized once and stay unchanged after. + seqnumSecret uint32 + portOffsetSecret uint32 + tsOffsetSecret uint32 } // Number returns the tcp protocol number. @@ -105,7 +110,7 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { // NewEndpoint creates a new tcp endpoint. func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { - return newEndpoint(p.stack, netProto, waiterQueue), nil + return newEndpoint(p.stack, p, netProto, waiterQueue), nil } // NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently @@ -292,22 +297,26 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip case *tcpip.TCPMinRTOOption: p.mu.Lock() + defer p.mu.Unlock() if *v < 0 { p.minRTO = MinRTO + } else if minRTO := time.Duration(*v); minRTO <= p.maxRTO { + p.minRTO = minRTO } else { - p.minRTO = time.Duration(*v) + return &tcpip.ErrInvalidOptionValue{} } - p.mu.Unlock() return nil case *tcpip.TCPMaxRTOOption: p.mu.Lock() + defer p.mu.Unlock() if *v < 0 { p.maxRTO = MaxRTO + } else if maxRTO := time.Duration(*v); maxRTO >= p.minRTO { + p.maxRTO = maxRTO } else { - p.maxRTO = time.Duration(*v) + return &tcpip.ErrInvalidOptionValue{} } - p.mu.Unlock() return nil case *tcpip.TCPMaxRetriesOption: @@ -479,7 +488,15 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol { maxRTO: MaxRTO, maxRetries: MaxRetries, recovery: tcpip.TCPRACKLossDetection, + seqnumSecret: s.Rand().Uint32(), + portOffsetSecret: s.Rand().Uint32(), + tsOffsetSecret: s.Rand().Uint32(), } p.dispatcher.init(s.Rand(), runtime.GOMAXPROCS(0)) return &p } + +// protocolFromStack retrieves the tcp.protocol instance from stack s. +func protocolFromStack(s *stack.Stack) *protocol { + return s.TransportProtocolInstance(ProtocolNumber).(*protocol) +} diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 92a66f17e..a1f1c4e59 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -382,6 +382,9 @@ func (s *sender) updateRTO(rtt time.Duration) { if s.RTO < s.minRTO { s.RTO = s.minRTO } + if s.RTO > s.maxRTO { + s.RTO = s.maxRTO + } } // resendSegment resends the first unacknowledged segment. @@ -1415,9 +1418,6 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { ackLeft -= datalen } - // Update the send buffer usage and notify potential waiters. - s.ep.updateSndBufferUsage(int(acked)) - // Clear SACK information for all acked data. s.ep.scoreboard.Delete(s.SndUna) @@ -1437,6 +1437,9 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } } + // Update the send buffer usage and notify potential waiters. + s.ep.updateSndBufferUsage(int(acked)) + // It is possible for s.outstanding to drop below zero if we get // a retransmit timeout, reset outstanding to zero but later // get an ack that cover previously sent data. diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index 89e9fb886..c35db7c95 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -33,7 +33,6 @@ const ( tsOptionSize = 12 maxTCPOptionSize = 40 mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload - latency = 5 * time.Millisecond ) func setStackTCPRecovery(t *testing.T, c *context.Context, recovery int) { @@ -163,7 +162,10 @@ func sendAndReceiveWithSACK(t *testing.T, c *context.Context, numPackets int, en if !enableRACK { setStackTCPRecovery(t, c, 0) } - createConnectedWithSACKAndTS(c) + // The delay should be below initial RTO (1s) otherwise retransimission + // will start. Choose a relatively large value so that estimated RTT + // keeps high even after a few rounds of undelayed RTT samples. + c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true}, 800*time.Millisecond /* delay */) data := make([]byte, numPackets*maxPayload) for i := range data { @@ -181,9 +183,6 @@ func sendAndReceiveWithSACK(t *testing.T, c *context.Context, numPackets int, en for i := 0; i < numPackets; i++ { c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) bytesRead += maxPayload - // This delay is added to increase RTT as low RTT can cause TLP - // before sending ACK. - time.Sleep(latency) } return data diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 83e0653b9..6255355bb 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -35,13 +35,13 @@ import ( // SACKPermitted option enabled if the stack in the context has the SACK support // enabled. func createConnectedWithSACKPermittedOption(c *context.Context) *context.RawEndpoint { - return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()}) + return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()}) } // createConnectedWithSACKAndTS creates and connects c.ep with the SACK & TS // option enabled if the stack in the context has SACK and TS enabled. func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint { - return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true}) + return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true}) } func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) { @@ -108,7 +108,7 @@ func TestSackDisabledConnect(t *testing.T) { setStackSACKPermitted(t, c, sackEnabled) setStackTCPRecovery(t, c, 0) - rep := c.CreateConnectedWithOptions(header.TCPSynOptions{}) + rep := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{}) data := []byte{1, 2, 3} @@ -170,7 +170,7 @@ func TestSackPermittedAccept(t *testing.T) { setStackSACKPermitted(t, c, sackEnabled) setStackTCPRecovery(t, c, 0) - rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted}) + rep := c.AcceptWithOptionsNoDelay(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted}) // Now verify no SACK blocks are // received when sack is disabled. data := []byte{1, 2, 3} @@ -244,7 +244,7 @@ func TestSackDisabledAccept(t *testing.T) { setStackSACKPermitted(t, c, sackEnabled) setStackTCPRecovery(t, c, 0) - rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) + rep := c.AcceptWithOptionsNoDelay(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) // Now verify no SACK blocks are // received when sack is disabled. diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 031f01357..bf726e86a 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -28,6 +28,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" @@ -2143,7 +2144,7 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) } - c.AcceptWithOptions(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS}) + c.AcceptWithOptionsNoDelay(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS}) // Bump up the receive buffer size such that, when the receive window grows, // the scaled window exceeds maxUint16. @@ -2535,7 +2536,7 @@ func TestScaledWindowAccept(t *testing.T) { // Do 3-way handshake. // wndScale expected is 3 as 65535 * 3 * 2 < 65535 * 2^3 but > 65535 *2 *2 - c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS}) + c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS}, 0 /* delay */) // Try to accept the connection. we, ch := waiter.NewChannelEntry(nil) @@ -3532,6 +3533,12 @@ func TestMaxRetransmitsTimeout(t *testing.T) { t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } + // Wait for the connection to timeout after MaxRetries retransmits. + initRTO := time.Second + minRTOOpt := tcpip.TCPMinRTOOption(initRTO) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err) + } c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) waitEntry, notifyCh := waiter.NewChannelEntry(nil) @@ -3554,8 +3561,6 @@ func TestMaxRetransmitsTimeout(t *testing.T) { ), ) } - // Wait for the connection to timeout after MaxRetries retransmits. - initRTO := 1 * time.Second select { case <-notifyCh: case <-time.After((2 << numRetries) * initRTO): @@ -3590,9 +3595,13 @@ func TestMaxRTO(t *testing.T) { defer c.Cleanup() rto := 1 * time.Second - opt := tcpip.TCPMaxRTOOption(rto) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + minRTOOpt := tcpip.TCPMinRTOOption(rto / 2) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err) + } + maxRTOOpt := tcpip.TCPMaxRTOOption(rto) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &maxRTOOpt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, maxRTOOpt, maxRTOOpt, err) } c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) @@ -3618,8 +3627,8 @@ func TestMaxRTO(t *testing.T) { checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) - if time.Since(start).Round(time.Second).Seconds() != rto.Seconds() { - t.Errorf("Retransmit interval not capped to MaxRTO.\n") + if elapsed := time.Since(start); elapsed.Round(time.Second).Seconds() != rto.Seconds() { + t.Errorf("Retransmit interval not capped to MaxRTO(%s). %s", rto, elapsed) } } } @@ -3670,6 +3679,10 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + minRTOOpt := tcpip.TCPMinRTOOption(time.Second) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err) + } c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) // Disabling PMTU discovery causes all packets sent from this socket to @@ -6304,7 +6317,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) { t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } - c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS}) + c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS}, 0 /* delay */) // Try to accept the connection. we, ch := waiter.NewChannelEntry(nil) @@ -6385,7 +6398,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // maximum buffer size defined above. c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) - rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) + rawEP := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, WS: 4}) // NOTE: The timestamp values in the sent packets are meaningless to the // peer so we just increment the timestamp value by 1 every batch as we @@ -6515,7 +6528,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // maximum buffer size used by stack. c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) - rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) + rawEP := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, WS: 4}) tsVal := rawEP.TSVal rawEP.NextSeqNum-- rawEP.SendPacketWithTS(nil, tsVal) @@ -7430,6 +7443,11 @@ func TestTCPUserTimeout(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + initRTO := 1 * time.Second + minRTOOpt := tcpip.TCPMinRTOOption(initRTO) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err) + } c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) waitEntry, notifyCh := waiter.NewChannelEntry(nil) @@ -7440,7 +7458,6 @@ func TestTCPUserTimeout(t *testing.T) { // Ensure that on the next retransmit timer fire, the user timeout has // expired. - initRTO := 1 * time.Second userTimeout := initRTO / 2 v := tcpip.TCPUserTimeoutOption(userTimeout) if err := c.EP.SetSockOpt(&v); err != nil { @@ -7954,6 +7971,151 @@ func TestSetStackTimeWaitReuse(t *testing.T) { } } +func TestHandshakeRTT(t *testing.T) { + type testCase struct { + connect bool + tsEnabled bool + useCookie bool + retrans bool + delay time.Duration + wantRTT time.Duration + } + var testCases []testCase + for _, connect := range []bool{false, true} { + for _, tsEnabled := range []bool{false, true} { + for _, useCookie := range []bool{false, true} { + for _, retrans := range []bool{false, true} { + if connect && useCookie { + continue + } + delay := 800 * time.Millisecond + if retrans { + delay = 1200 * time.Millisecond + } + wantRTT := delay + // If syncookie is enabled, sample RTT only when TS option is enabled. + if !retrans && useCookie && !tsEnabled { + wantRTT = 0 + } + // If retransmitted, sample RTT only when TS option is enabled. + if retrans && !tsEnabled { + wantRTT = 0 + } + testCases = append(testCases, testCase{connect, tsEnabled, useCookie, retrans, delay, wantRTT}) + } + } + } + } + for _, tt := range testCases { + tt := tt + t.Run(fmt.Sprintf("connect=%t,TS=%t,cookie=%t,retrans=%t)", tt.connect, tt.tsEnabled, tt.useCookie, tt.retrans), func(t *testing.T) { + t.Parallel() + c := context.New(t, defaultMTU) + if tt.useCookie { + opt := tcpip.TCPAlwaysUseSynCookies(true) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } + } + synOpts := header.TCPSynOptions{} + if tt.tsEnabled { + synOpts.TS = true + synOpts.TSVal = 42 + } + if tt.connect { + c.CreateConnectedWithOptions(synOpts, tt.delay) + } else { + synOpts.MSS = defaultIPv4MSS + synOpts.WS = -1 + c.AcceptWithOptions(-1, synOpts, tt.delay) + } + var info tcpip.TCPInfoOption + if err := c.EP.GetSockOpt(&info); err != nil { + t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err) + } + if got := time.Duration(info.RTT).Round(tt.wantRTT); got != tt.wantRTT { + t.Fatalf("got info.RTT=%s, expect %s", got, tt.wantRTT) + } + if info.RTTVar != 0 && tt.wantRTT == 0 { + t.Fatalf("got info.RTTVar=%s, expect 0", info.RTTVar) + } + if info.RTTVar == 0 && tt.wantRTT != 0 { + t.Fatalf("got info.RTTVar=0, expect non zero") + } + }) + } +} + +func TestSetRTO(t *testing.T) { + c := context.New(t, defaultMTU) + minRTO, maxRTO := tcpRTOMinMax(t, c) + for _, tt := range []struct { + name string + RTO time.Duration + minRTO time.Duration + maxRTO time.Duration + err tcpip.Error + }{ + { + name: "invalid minRTO", + minRTO: maxRTO + time.Second, + err: &tcpip.ErrInvalidOptionValue{}, + }, + { + name: "invalid maxRTO", + maxRTO: minRTO - time.Millisecond, + err: &tcpip.ErrInvalidOptionValue{}, + }, + { + name: "valid minRTO", + minRTO: maxRTO - time.Second, + }, + { + name: "valid maxRTO", + maxRTO: minRTO + time.Millisecond, + }, + } { + t.Run(tt.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + var opt tcpip.SettableTransportProtocolOption + if tt.minRTO > 0 { + min := tcpip.TCPMinRTOOption(tt.minRTO) + opt = &min + } + if tt.maxRTO > 0 { + max := tcpip.TCPMaxRTOOption(tt.maxRTO) + opt = &max + } + err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt) + if got, want := err, tt.err; got != want { + t.Fatalf("c.Stack().SetTransportProtocolOption(TCP, &%T(%v)) = %v, want = %v", opt, opt, got, want) + } + if tt.err == nil { + minRTO, maxRTO := tcpRTOMinMax(t, c) + if tt.minRTO > 0 && tt.minRTO != minRTO { + t.Fatalf("got minRTO = %s, want %s", minRTO, tt.minRTO) + } + if tt.maxRTO > 0 && tt.maxRTO != maxRTO { + t.Fatalf("got maxRTO = %s, want %s", maxRTO, tt.maxRTO) + } + } + }) + } +} + +func tcpRTOMinMax(t *testing.T, c *context.Context) (time.Duration, time.Duration) { + t.Helper() + var minOpt tcpip.TCPMinRTOOption + var maxOpt tcpip.TCPMaxRTOOption + if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &minOpt); err != nil { + t.Fatalf("c.Stack().TransportProtocolOption(TCP, %T): %s", minOpt, err) + } + if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &maxOpt); err != nil { + t.Fatalf("c.Stack().TransportProtocolOption(TCP, %T): %s", maxOpt, err) + } + return time.Duration(minOpt), time.Duration(maxOpt) +} + // generateRandomPayload generates a random byte slice of the specified length // causing a fatal test failure if it is unable to do so. func generateRandomPayload(t *testing.T, n int) []byte { @@ -7964,3 +8126,185 @@ func generateRandomPayload(t *testing.T, n int) []byte { } return buf } + +func TestSendBufferTuning(t *testing.T) { + const maxPayload = 536 + const mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload + const packetOverheadFactor = 2 + + testCases := []struct { + name string + autoTuningDisabled bool + }{ + {"autoTuningDisabled", true}, + {"autoTuningEnabled", false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() + + // Set the stack option for send buffer size. + const defaultSndBufSz = maxPayload * tcp.InitialCwnd + const maxSndBufSz = defaultSndBufSz * 10 + { + opt := tcpip.TCPSendBufferSizeRangeOption{Min: 1, Default: defaultSndBufSz, Max: maxSndBufSz} + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } + } + + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + + oldSz := c.EP.SocketOptions().GetSendBufferSize() + if oldSz != defaultSndBufSz { + t.Fatalf("Wrong send buffer size got %d want %d", oldSz, defaultSndBufSz) + } + + if tc.autoTuningDisabled { + c.EP.SocketOptions().SetSendBufferSize(defaultSndBufSz, true /* notify */) + } + + data := make([]byte, maxPayload) + for i := range data { + data[i] = byte(i) + } + + w, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&w, waiter.WritableEvents) + defer c.WQ.EventUnregister(&w) + + bytesRead := 0 + for { + // Packets will be sent till the send buffer + // size is reached. + var r bytes.Reader + r.Reset(data[bytesRead : bytesRead+maxPayload]) + _, err := c.EP.Write(&r, tcpip.WriteOptions{}) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + break + } + + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, 0) + bytesRead += maxPayload + data = append(data, data...) + } + + // Send an ACK and wait for connection to become writable again. + c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead) + select { + case <-ch: + if err := c.EP.LastError(); err != nil { + t.Fatalf("Write failed: %s", err) + } + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for connection") + } + + outSz := int64(defaultSndBufSz) + if !tc.autoTuningDisabled { + // Calculate the new auto tuned send buffer. + var info tcpip.TCPInfoOption + if err := c.EP.GetSockOpt(&info); err != nil { + t.Fatalf("GetSockOpt failed: %v", err) + } + outSz = (int64(info.SndCwnd) * packetOverheadFactor * (maxPayload)) + } + + if newSz := c.EP.SocketOptions().GetSendBufferSize(); newSz != outSz { + t.Fatalf("Wrong send buffer size, got %d want %d", newSz, outSz) + } + }) + } +} + +func TestTimestampSynCookies(t *testing.T) { + clock := faketime.NewManualClock() + c := context.NewWithOpts(t, context.Options{ + EnableV4: true, + EnableV6: true, + MTU: defaultMTU, + Clock: clock, + }) + defer c.Cleanup() + opt := tcpip.TCPAlwaysUseSynCookies(true) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + defer ep.Close() + + tcpOpts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} + header.EncodeTSOption(42, 0, tcpOpts[2:]) + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + iss := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + RcvWnd: seqnum.Size(512), + SeqNum: iss, + TCPOpts: tcpOpts[:], + }) + // Get the TSVal of SYN-ACK. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + initialTSVal := tcpHdr.ParsedOptions().TSVal + + header.EncodeTSOption(420, initialTSVal, tcpOpts[2:]) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + RcvWnd: seqnum.Size(512), + SeqNum: iss + 1, + AckNum: c.IRS + 1, + TCPOpts: tcpOpts[:], + }) + c.EP, _, err = ep.Accept(nil) + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.ReadableEvents) + defer wq.EventUnregister(&we) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } else if err != nil { + t.Fatalf("failed to accept: %s", err) + } + + const elapsed = 200 * time.Millisecond + clock.Advance(elapsed) + data := []byte{1, 2, 3} + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // The endpoint should have a correct TSOffset so that the received TSVal + // should match our expectation. + if got, want := header.TCP(header.IPv4(c.GetPacket()).Payload()).ParsedOptions().TSVal, initialTSVal+uint32(elapsed.Milliseconds()); got != want { + t.Fatalf("got TSVal = %d, want %d", got, want) + } +} diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index 1deb1fe4d..65925daa5 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -32,7 +32,7 @@ import ( // createConnectedWithTimestampOption creates and connects c.ep with the // timestamp option enabled. func createConnectedWithTimestampOption(c *context.Context) *context.RawEndpoint { - return c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, TSVal: 1}) + return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, TSVal: 1}) } // TestTimeStampEnabledConnect tests that netstack sends the timestamp option on @@ -131,7 +131,7 @@ func TestTimeStampDisabledConnect(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnectedWithOptions(header.TCPSynOptions{}) + c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{}) } func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) { @@ -147,7 +147,7 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS t.Logf("Test w/ CookieEnabled = %v", cookieEnabled) tsVal := rand.Uint32() - c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal}) + c.AcceptWithOptionsNoDelay(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal}) // Now send some data and validate that timestamp is echoed correctly in the ACK. data := []byte{1, 2, 3} @@ -209,7 +209,7 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd } t.Logf("Test w/ CookieEnabled = %v", cookieEnabled) - c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) + c.AcceptWithOptionsNoDelay(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) // Now send some data with the accepted connection endpoint and validate // that no timestamp option is sent in the TCP segment. diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 96e4849d2..6e55a7a32 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -122,6 +122,9 @@ type Options struct { // MTU indicates the maximum transmission unit on the link layer. MTU uint32 + + // Clock that is used by Stack. + Clock tcpip.Clock } // Context provides an initialized Network stack and a link layer endpoint @@ -182,6 +185,7 @@ func NewWithOpts(t *testing.T, opts Options) *Context { stackOpts := stack.Options{ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + Clock: opts.Clock, } if opts.EnableV4 { stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv4.NewProtocol) @@ -879,13 +883,21 @@ func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) { ) } +// CreateConnectedWithOptionsNoDelay just calls CreateConnectedWithOptions +// without delay. +func (c *Context) CreateConnectedWithOptionsNoDelay(wantOptions header.TCPSynOptions) *RawEndpoint { + return c.CreateConnectedWithOptions(wantOptions, 0 /* delay */) +} + // CreateConnectedWithOptions creates and connects c.ep with the specified TCP // options enabled and returns a RawEndpoint which represents the other end of -// the connection. +// the connection. It delays before a SYNACK is sent. This makes c.EP have a +// higher RTT estimate so that spurious TLPs aren't sent in tests, which helps +// reduce flakiness. // // It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK // does not carry an option that was not requested. -func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint { +func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint { var err tcpip.Error c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { @@ -911,18 +923,17 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * // TS value. mss := uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize) - checker.IPv4(c.t, b, - checker.TCP( - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{ - MSS: mss, - TS: true, - WS: int(c.WindowScale), - SACKPermitted: c.SACKEnabled(), - }), - ), + synChecker := checker.TCP( + checker.DstPort(TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{ + MSS: mss, + TS: true, + WS: int(c.WindowScale), + SACKPermitted: c.SACKEnabled(), + }), ) + checker.IPv4(c.t, b, synChecker) if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) } @@ -948,6 +959,10 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * // Build SYN-ACK. c.IRS = seqnum.Value(tcpSeg.SequenceNumber()) iss := seqnum.Value(TestInitialSequenceNumber) + if delay > 0 { + // Sleep so that RTT is increased. + time.Sleep(delay) + } c.SendPacket(nil, &Headers{ SrcPort: tcpSeg.DestinationPort(), DstPort: tcpSeg.SourcePort(), @@ -959,7 +974,17 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * }) // Read ACK. - ackPacket := c.GetPacket() + var ackPacket []byte + // Ignore retransimitted SYN packets. + for { + packet := c.GetPacket() + if header.TCP(header.IPv4(packet).Payload()).Flags()&header.TCPFlagSyn != 0 { + checker.IPv4(c.t, packet, synChecker) + } else { + ackPacket = packet + break + } + } // Verify TCP header fields. tcpCheckers := []checker.TransportChecker{ @@ -1016,13 +1041,19 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * } } -// AcceptWithOptions initializes a listening endpoint and connects to it with the -// provided options enabled. It also verifies that the SYN-ACK has the expected -// values for the provided options. +// AcceptWithOptionsNoDelay delegates call to AcceptWithOptions without delay. +func (c *Context) AcceptWithOptionsNoDelay(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { + return c.AcceptWithOptions(wndScale, synOptions, 0 /* delay */) +} + +// AcceptWithOptions initializes a listening endpoint and connects to it with +// the provided options enabled. It delays before the final ACK of the 3WHS is +// sent. It also verifies that the SYN-ACK has the expected values for the +// provided options. // // The function returns a RawEndpoint representing the other end of the accepted // endpoint. -func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { +func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint { // Create EP and start listening. wq := &waiter.Queue{} ep, err := c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) @@ -1045,7 +1076,7 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) } - rep := c.PassiveConnectWithOptions(100, wndScale, synOptions) + rep := c.PassiveConnectWithOptions(100, wndScale, synOptions, delay) // Try to accept the connection. we, ch := waiter.NewChannelEntry(nil) @@ -1077,13 +1108,14 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption // PassiveConnectWithOptions. func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCPSynOptions) { synOptions.WS = -1 - c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions) + c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions, 0 /* delay */) } // PassiveConnectWithOptions initiates a new connection (with the specified TCP // options enabled) to the port on which the Context.ep is listening for new // connections. It also validates that the SYN-ACK has the expected values for -// the enabled options. +// the enabled options. The final ACK of the handshake is delayed by specified +// duration. // // NOTE: MSS is not a negotiated option and it can be asymmetric // in each direction. This function uses the maxPayload to set the MSS to be @@ -1093,7 +1125,7 @@ func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCP // wndScale is the expected window scale in the SYN-ACK and synOptions.WS is the // value of the window scaling option to be sent in the SYN. If synOptions.WS > // 0 then we send the WindowScale option. -func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { +func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint { c.t.Helper() opts := make([]byte, header.TCPOptionsMaximumSize) offset := 0 @@ -1180,7 +1212,10 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions ackHeaders.TCPOpts = opts[:] } - // Send ACK. + // Send ACK, delay if needed. + if delay > 0 { + time.Sleep(delay) + } c.SendPacket(nil, ackHeaders) c.RcvdWindowScale = uint8(rcvdSynOptions.WS) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 82a3f2287..108580508 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -266,7 +266,7 @@ func (e *endpoint) Close() { for mem := range e.multicastMemberships { e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr) } - e.multicastMemberships = make(map[multicastMembership]struct{}) + e.multicastMemberships = nil // Close the receive list and drain it. e.rcvMu.Lock() |