summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/icmp/BUILD4
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go18
-rw-r--r--pkg/tcpip/transport/icmp/endpoint_state.go6
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go12
-rw-r--r--pkg/tcpip/transport/raw/BUILD8
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go20
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go6
-rw-r--r--pkg/tcpip/transport/tcp/BUILD5
-rw-r--r--pkg/tcpip/transport/tcp/accept.go28
-rw-r--r--pkg/tcpip/transport/tcp/connect.go64
-rw-r--r--pkg/tcpip/transport/tcp/cubic.go3
-rw-r--r--pkg/tcpip/transport/tcp/cubic_state.go29
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go18
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go322
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go27
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go14
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go72
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go82
-rw-r--r--pkg/tcpip/transport/tcp/sack.go16
-rw-r--r--pkg/tcpip/transport/tcp/sack_scoreboard.go4
-rw-r--r--pkg/tcpip/transport/tcp/sack_scoreboard_test.go6
-rw-r--r--pkg/tcpip/transport/tcp/segment.go8
-rw-r--r--pkg/tcpip/transport/tcp/segment_state.go2
-rw-r--r--pkg/tcpip/transport/tcp/snd.go34
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go10
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go14
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go448
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go18
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/BUILD2
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go117
-rw-r--r--pkg/tcpip/transport/tcp/timer.go2
-rw-r--r--pkg/tcpip/transport/tcpconntrack/BUILD2
-rw-r--r--pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go4
-rw-r--r--pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go4
-rw-r--r--pkg/tcpip/transport/udp/BUILD4
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go55
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go10
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go8
-rw-r--r--pkg/tcpip/transport/udp/protocol.go12
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go22
40 files changed, 1192 insertions, 348 deletions
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index 84a2b53b7..62182a3e6 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -23,8 +23,8 @@ go_library(
"icmp_packet_list.go",
"protocol.go",
],
- importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/icmp",
- imports = ["gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"],
+ importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/icmp",
+ imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
"//pkg/sleep",
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index b8005093a..ab9e80747 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -18,11 +18,11 @@ import (
"encoding/binary"
"sync"
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
- "gvisor.googlesource.com/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
)
// +stateify savable
@@ -127,6 +127,9 @@ func (e *endpoint) Close() {
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
+// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
+func (e *endpoint) ModerateRecvBuf(copied int) {}
+
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
@@ -419,6 +422,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
+ if addr.Addr == "" {
+ // AF_UNSPEC isn't supported.
+ return tcpip.ErrAddressFamilyNotSupported
+ }
+
nicid := addr.NIC
localPort := uint16(0)
switch e.state {
diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go
index 332b3cd33..99b8c4093 100644
--- a/pkg/tcpip/transport/icmp/endpoint_state.go
+++ b/pkg/tcpip/transport/icmp/endpoint_state.go
@@ -15,9 +15,9 @@
package icmp
import (
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
// saveData saves icmpPacket.data field.
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 954fde9d8..c89538131 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -26,12 +26,12 @@ import (
"encoding/binary"
"fmt"
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/raw"
- "gvisor.googlesource.com/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/raw"
+ "gvisor.dev/gvisor/pkg/waiter"
)
const (
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
index 6d3f0130e..34a14bf7f 100644
--- a/pkg/tcpip/transport/raw/BUILD
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -1,6 +1,4 @@
-package(
- licenses = ["notice"], # Apache 2.0
-)
+package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
@@ -24,8 +22,8 @@ go_library(
"endpoint_state.go",
"packet_list.go",
],
- importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/raw",
- imports = ["gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"],
+ importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/raw",
+ imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
"//pkg/log",
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index e4ff50c91..42aded77f 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -16,7 +16,7 @@
// sockets allow applications to:
//
// * manually write and inspect transport layer headers and payloads
-// * receive all traffic of a given transport protcol (e.g. ICMP or UDP)
+// * receive all traffic of a given transport protocol (e.g. ICMP or UDP)
// * optionally write and inspect network layer and link layer headers for
// packets
//
@@ -29,11 +29,11 @@ package raw
import (
"sync"
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
- "gvisor.googlesource.com/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
)
// +stateify savable
@@ -147,6 +147,9 @@ func (ep *endpoint) Close() {
ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
+// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
+func (ep *endpoint) ModerateRecvBuf(copied int) {}
+
// Read implements tcpip.Endpoint.Read.
func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
ep.rcvMu.Lock()
@@ -295,6 +298,11 @@ func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
+ if addr.Addr == "" {
+ // AF_UNSPEC isn't supported.
+ return tcpip.ErrAddressFamilyNotSupported
+ }
+
if ep.closed {
return tcpip.ErrInvalidEndpointState
}
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index e8907ebb1..cb5534d90 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -15,9 +15,9 @@
package raw
import (
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
// saveData saves packet.data field.
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 9db38196b..4cd25e8e2 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -21,6 +21,7 @@ go_library(
"accept.go",
"connect.go",
"cubic.go",
+ "cubic_state.go",
"endpoint.go",
"endpoint_state.go",
"forwarder.go",
@@ -38,8 +39,8 @@ go_library(
"tcp_segment_list.go",
"timer.go",
],
- importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp",
- imports = ["gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"],
+ importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp",
+ imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
"//pkg/rand",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index d05259c0a..52fd1bfa3 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -22,13 +22,13 @@ import (
"sync"
"time"
- "gvisor.googlesource.com/gvisor/pkg/rand"
- "gvisor.googlesource.com/gvisor/pkg/sleep"
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
- "gvisor.googlesource.com/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
)
const (
@@ -213,6 +213,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n.route = s.route.Clone()
n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto}
n.rcvBufSize = int(l.rcvWnd)
+ n.amss = mssForRoute(&n.route)
n.maybeEnableTimestamp(rcvdSynOpts)
n.maybeEnableSACKPermitted(rcvdSynOpts)
@@ -232,7 +233,11 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
// The receiver at least temporarily has a zero receive window scale,
// but the caller may change it (before starting the protocol loop).
n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS)
- n.rcv = newReceiver(n, irs, l.rcvWnd, 0)
+ n.rcv = newReceiver(n, irs, seqnum.Size(n.initialReceiveWindow()), 0, seqnum.Size(n.receiveBufferSize()))
+ // Bootstrap the auto tuning algorithm. Starting at zero will result in
+ // a large step function on the first window adjustment causing the
+ // window to grow to a really large value.
+ n.rcvAutoParams.prevCopied = n.initialReceiveWindow()
return n, nil
}
@@ -249,7 +254,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
}
// Perform the 3-way handshake.
- h := newHandshake(ep, l.rcvWnd)
+ h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow()))
h.resetToSynRcvd(cookie, irs, opts)
if err := h.execute(); err != nil {
@@ -359,16 +364,19 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
return
}
cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
- // Send SYN with window scaling because we currently
+
+ // Send SYN without window scaling because we currently
// dont't encode this information in the cookie.
//
// Enable Timestamp option if the original syn did have
// the timestamp option specified.
+ mss := mssForRoute(&s.route)
synOpts := header.TCPSynOptions{
WS: -1,
TS: opts.TS,
TSVal: tcpTimeStamp(timeStampOffset()),
TSEcr: opts.TSVal,
+ MSS: uint16(mss),
}
sendSynTCP(&s.route, s.id, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts)
e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index dd671f7ce..00d2ae524 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -18,14 +18,14 @@ import (
"sync"
"time"
- "gvisor.googlesource.com/gvisor/pkg/rand"
- "gvisor.googlesource.com/gvisor/pkg/sleep"
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
- "gvisor.googlesource.com/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
)
// maxSegmentsPerWake is the maximum number of segments to process in the main
@@ -78,6 +78,9 @@ type handshake struct {
// mss is the maximum segment size received from the peer.
mss uint16
+ // amss is the maximum segment size advertised by us to the peer.
+ amss uint16
+
// sndWndScale is the send window scale, as defined in RFC 1323. A
// negative value means no scaling is supported by the peer.
sndWndScale int
@@ -87,11 +90,24 @@ type handshake struct {
}
func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake {
+ rcvWndScale := ep.rcvWndScaleForHandshake()
+
+ // Round-down the rcvWnd to a multiple of wndScale. This ensures that the
+ // window offered in SYN won't be reduced due to the loss of precision if
+ // window scaling is enabled after the handshake.
+ rcvWnd = (rcvWnd >> uint8(rcvWndScale)) << uint8(rcvWndScale)
+
+ // Ensure we can always accept at least 1 byte if the scale specified
+ // was too high for the provided rcvWnd.
+ if rcvWnd == 0 {
+ rcvWnd = 1
+ }
+
h := handshake{
ep: ep,
active: true,
rcvWnd: rcvWnd,
- rcvWndScale: FindWndScale(rcvWnd),
+ rcvWndScale: int(rcvWndScale),
}
h.resetState()
return h
@@ -224,7 +240,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
h.ep.state = StateSynRecv
h.ep.mu.Unlock()
synOpts := header.TCPSynOptions{
- WS: h.rcvWndScale,
+ WS: int(h.effectiveRcvWndScale()),
TS: rcvSynOpts.TS,
TSVal: h.ep.timestamp(),
TSEcr: h.ep.recentTS,
@@ -233,6 +249,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// permits SACK. This is not explicitly defined in the RFC but
// this is the behaviour implemented by Linux.
SACKPermitted: rcvSynOpts.SACKPermitted,
+ MSS: h.ep.amss,
}
sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
@@ -277,6 +294,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
TSVal: h.ep.timestamp(),
TSEcr: h.ep.recentTS,
SACKPermitted: h.ep.sackPermitted,
+ MSS: h.ep.amss,
}
sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
return nil
@@ -419,12 +437,15 @@ func (h *handshake) execute() *tcpip.Error {
// Send the initial SYN segment and loop until the handshake is
// completed.
+ h.ep.amss = mssForRoute(&h.ep.route)
+
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
TS: true,
TSVal: h.ep.timestamp(),
TSEcr: h.ep.recentTS,
SACKPermitted: bool(sackEnabled),
+ MSS: h.ep.amss,
}
// Execute is also called in a listen context so we want to make sure we
@@ -433,6 +454,11 @@ func (h *handshake) execute() *tcpip.Error {
if h.state == handshakeSynRcvd {
synOpts.TS = h.ep.sendTSOk
synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled)
+ if h.sndWndScale < 0 {
+ // Disable window scaling if the peer did not send us
+ // the window scaling option.
+ synOpts.WS = -1
+ }
}
sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
for h.state != handshakeCompleted {
@@ -554,13 +580,6 @@ func makeSynOptions(opts header.TCPSynOptions) []byte {
}
func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error {
- // The MSS in opts is automatically calculated as this function is
- // called from many places and we don't want every call point being
- // embedded with the MSS calculation.
- if opts.MSS == 0 {
- opts.MSS = uint16(r.MTU() - header.TCPMinimumSize)
- }
-
options := makeSynOptions(opts)
err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options, nil)
putOptions(options)
@@ -861,7 +880,8 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
// This is an active connection, so we must initiate the 3-way
// handshake, and then inform potential waiters about its
// completion.
- h := newHandshake(e, seqnum.Size(e.receiveBufferAvailable()))
+ initialRcvWnd := e.initialReceiveWindow()
+ h := newHandshake(e, seqnum.Size(initialRcvWnd))
e.mu.Lock()
h.ep.state = StateSynSent
e.mu.Unlock()
@@ -886,8 +906,14 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
// (indicated by a negative send window scale).
e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
+ rcvBufSize := seqnum.Size(e.receiveBufferSize())
e.rcvListMu.Lock()
- e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale())
+ e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize)
+ // boot strap the auto tuning algorithm. Starting at zero will
+ // result in a large step function on the first proper causing
+ // the window to just go to a really large value after the first
+ // RTT itself.
+ e.rcvAutoParams.prevCopied = initialRcvWnd
e.rcvListMu.Unlock()
}
diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go
index e618cd2b9..7b1f5e763 100644
--- a/pkg/tcpip/transport/tcp/cubic.go
+++ b/pkg/tcpip/transport/tcp/cubic.go
@@ -23,6 +23,7 @@ import (
// control algorithm state.
//
// See: https://tools.ietf.org/html/rfc8312.
+// +stateify savable
type cubicState struct {
// wLastMax is the previous wMax value.
wLastMax float64
@@ -33,7 +34,7 @@ type cubicState struct {
// t denotes the time when the current congestion avoidance
// was entered.
- t time.Time
+ t time.Time `state:".(unixTime)"`
// numCongestionEvents tracks the number of congestion events since last
// RTO.
diff --git a/pkg/tcpip/transport/tcp/cubic_state.go b/pkg/tcpip/transport/tcp/cubic_state.go
new file mode 100644
index 000000000..d0f58cfaf
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/cubic_state.go
@@ -0,0 +1,29 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+)
+
+// saveT is invoked by stateify.
+func (c *cubicState) saveT() unixTime {
+ return unixTime{c.t.Unix(), c.t.UnixNano()}
+}
+
+// loadT is invoked by stateify.
+func (c *cubicState) loadT(unix unixTime) {
+ c.t = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index 43bcfa070..d9f79e8c5 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -18,15 +18,15 @@ import (
"testing"
"time"
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/checker"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp/testing/context"
- "gvisor.googlesource.com/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
+ "gvisor.dev/gvisor/pkg/waiter"
)
func TestV4MappedConnectOnV6Only(t *testing.T) {
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 23422ca5e..beb90afb5 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -17,19 +17,20 @@ package tcp
import (
"fmt"
"math"
+ "strings"
"sync"
"sync/atomic"
"time"
- "gvisor.googlesource.com/gvisor/pkg/rand"
- "gvisor.googlesource.com/gvisor/pkg/sleep"
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
- "gvisor.googlesource.com/gvisor/pkg/tmutex"
- "gvisor.googlesource.com/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tmutex"
+ "gvisor.dev/gvisor/pkg/waiter"
)
// EndpointState represents the state of a TCP endpoint.
@@ -116,6 +117,7 @@ const (
notifyDrain
notifyReset
notifyKeepaliveChanged
+ notifyMSSChanged
)
// SACKInfo holds TCP SACK related information for a given endpoint.
@@ -131,6 +133,42 @@ type SACKInfo struct {
NumBlocks int
}
+// rcvBufAutoTuneParams are used to hold state variables to compute
+// the auto tuned recv buffer size.
+//
+// +stateify savable
+type rcvBufAutoTuneParams struct {
+ // measureTime is the time at which the current measurement
+ // was started.
+ measureTime time.Time `state:".(unixTime)"`
+
+ // copied is the number of bytes copied out of the receive
+ // buffers since this measure began.
+ copied int
+
+ // prevCopied is the number of bytes copied out of the receive
+ // buffers in the previous RTT period.
+ prevCopied int
+
+ // rtt is the non-smoothed minimum RTT as measured by observing the time
+ // between when a byte is first acknowledged and the receipt of data
+ // that is at least one window beyond the sequence number that was
+ // acknowledged.
+ rtt time.Duration
+
+ // rttMeasureSeqNumber is the highest acceptable sequence number at the
+ // time this RTT measurement period began.
+ rttMeasureSeqNumber seqnum.Value
+
+ // rttMeasureTime is the absolute time at which the current rtt
+ // measurement period began.
+ rttMeasureTime time.Time `state:".(unixTime)"`
+
+ // disabled is true if an explicit receive buffer is set for the
+ // endpoint.
+ disabled bool
+}
+
// endpoint represents a TCP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
@@ -164,18 +202,23 @@ type endpoint struct {
// to indicate to users that no more data is coming.
//
// rcvListMu can be taken after the endpoint mu below.
- rcvListMu sync.Mutex `state:"nosave"`
- rcvList segmentList `state:"wait"`
- rcvClosed bool
- rcvBufSize int
- rcvBufUsed int
+ rcvListMu sync.Mutex `state:"nosave"`
+ rcvList segmentList `state:"wait"`
+ rcvClosed bool
+ rcvBufSize int
+ rcvBufUsed int
+ rcvAutoParams rcvBufAutoTuneParams
+ // zeroWindow indicates that the window was closed due to receive buffer
+ // space being filled up. This is set by the worker goroutine before
+ // moving a segment to the rcvList. This setting is cleared by the
+ // endpoint when a Read() call reads enough data for the new window to
+ // be non-zero.
+ zeroWindow bool
// The following fields are protected by the mutex.
mu sync.RWMutex `state:"nosave"`
id stack.TransportEndpointID
- // state endpointState `state:".(endpointState)"`
- // pState ProtocolState
state EndpointState `state:".(EndpointState)"`
isPortReserved bool `state:"manual"`
@@ -269,6 +312,10 @@ type endpoint struct {
// in SYN-RCVD state.
synRcvdCount int
+ // userMSS if non-zero is the MSS value explicitly set by the user
+ // for this endpoint using the TCP_MAXSEG setsockopt.
+ userMSS int
+
// The following fields are used to manage the send buffer. When
// segments are ready to be sent, they are added to sndQueue and the
// protocol goroutine is signaled via sndWaker.
@@ -286,7 +333,7 @@ type endpoint struct {
// cc stores the name of the Congestion Control algorithm to use for
// this endpoint.
- cc CongestionControlOption
+ cc tcpip.CongestionControlOption
// The following are used when a "packet too big" control packet is
// received. They are protected by sndBufMu. They are used to
@@ -338,6 +385,9 @@ type endpoint struct {
bindAddress tcpip.Address
connectingAddress tcpip.Address
+ // amss is the advertised MSS to the peer by this endpoint.
+ amss uint16
+
gso *stack.GSO
}
@@ -372,8 +422,8 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite
netProto: netProto,
waiterQueue: waiterQueue,
state: StateInitial,
- rcvBufSize: DefaultBufferSize,
- sndBufSize: DefaultBufferSize,
+ rcvBufSize: DefaultReceiveBufferSize,
+ sndBufSize: DefaultSendBufferSize,
sndMTU: int(math.MaxInt32),
reuseAddr: true,
keepalive: keepalive{
@@ -394,11 +444,16 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite
e.rcvBufSize = rs.Default
}
- var cs CongestionControlOption
+ var cs tcpip.CongestionControlOption
if err := stack.TransportProtocolOption(ProtocolNumber, &cs); err == nil {
e.cc = cs
}
+ var mrb tcpip.ModerateReceiveBufferOption
+ if err := stack.TransportProtocolOption(ProtocolNumber, &mrb); err == nil {
+ e.rcvAutoParams.disabled = !bool(mrb)
+ }
+
if p := stack.GetTCPProbe(); p != nil {
e.probe = p
}
@@ -407,6 +462,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite
e.workMu.Init()
e.workMu.Lock()
e.tsOffset = timeStampOffset()
+
return e
}
@@ -550,6 +606,83 @@ func (e *endpoint) cleanupLocked() {
tcpip.DeleteDanglingEndpoint(e)
}
+// initialReceiveWindow returns the initial receive window to advertise in the
+// SYN/SYN-ACK.
+func (e *endpoint) initialReceiveWindow() int {
+ rcvWnd := e.receiveBufferAvailable()
+ if rcvWnd > math.MaxUint16 {
+ rcvWnd = math.MaxUint16
+ }
+ routeWnd := InitialCwnd * int(mssForRoute(&e.route)) * 2
+ if rcvWnd > routeWnd {
+ rcvWnd = routeWnd
+ }
+ return rcvWnd
+}
+
+// ModerateRecvBuf adjusts the receive buffer and the advertised window
+// based on the number of bytes copied to user space.
+func (e *endpoint) ModerateRecvBuf(copied int) {
+ e.rcvListMu.Lock()
+ if e.rcvAutoParams.disabled {
+ e.rcvListMu.Unlock()
+ return
+ }
+ now := time.Now()
+ if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt {
+ e.rcvAutoParams.copied += copied
+ e.rcvListMu.Unlock()
+ return
+ }
+ prevRTTCopied := e.rcvAutoParams.copied + copied
+ prevCopied := e.rcvAutoParams.prevCopied
+ rcvWnd := 0
+ if prevRTTCopied > prevCopied {
+ // The minimal receive window based on what was copied by the app
+ // in the immediate preceding RTT and some extra buffer for 16
+ // segments to account for variations.
+ // We multiply by 2 to account for packet losses.
+ rcvWnd = prevRTTCopied*2 + 16*int(e.amss)
+
+ // Scale for slow start based on bytes copied in this RTT vs previous.
+ grow := (rcvWnd * (prevRTTCopied - prevCopied)) / prevCopied
+
+ // Multiply growth factor by 2 again to account for sender being
+ // in slow-start where the sender grows it's congestion window
+ // by 100% per RTT.
+ rcvWnd += grow * 2
+
+ // Make sure auto tuned buffer size can always receive upto 2x
+ // the initial window of 10 segments.
+ if minRcvWnd := int(e.amss) * InitialCwnd * 2; rcvWnd < minRcvWnd {
+ rcvWnd = minRcvWnd
+ }
+
+ // Cap the auto tuned buffer size by the maximum permissible
+ // receive buffer size.
+ if max := e.maxReceiveBufferSize(); rcvWnd > max {
+ rcvWnd = max
+ }
+
+ // We do not adjust downwards as that can cause the receiver to
+ // reject valid data that might already be in flight as the
+ // acceptable window will shrink.
+ if rcvWnd > e.rcvBufSize {
+ e.rcvBufSize = rcvWnd
+ e.notifyProtocolGoroutine(notifyReceiveWindowChanged)
+ }
+
+ // We only update prevCopied when we grow the buffer because in cases
+ // where prevCopied > prevRTTCopied the existing buffer is already big
+ // enough to handle the current rate and we don't need to do any
+ // adjustments.
+ e.rcvAutoParams.prevCopied = prevRTTCopied
+ }
+ e.rcvAutoParams.measureTime = now
+ e.rcvAutoParams.copied = 0
+ e.rcvListMu.Unlock()
+}
+
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.mu.RLock()
@@ -595,10 +728,12 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
s.decRef()
}
- scale := e.rcv.rcvWndScale
- wasZero := e.zeroReceiveWindow(scale)
e.rcvBufUsed -= len(v)
- if wasZero && !e.zeroReceiveWindow(scale) {
+ // If the window was zero before this read and if the read freed up
+ // enough buffer space for the scaled window to be non-zero then notify
+ // the protocol goroutine to send a window update.
+ if e.zeroWindow && !e.zeroReceiveWindow(e.rcv.rcvWndScale) {
+ e.zeroWindow = false
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
@@ -785,6 +920,17 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case tcpip.MaxSegOption:
+ userMSS := v
+ if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
+ return tcpip.ErrInvalidOptionValue
+ }
+ e.mu.Lock()
+ e.userMSS = int(userMSS)
+ e.mu.Unlock()
+ e.notifyProtocolGoroutine(notifyMSSChanged)
+ return nil
+
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
// allowed.
@@ -818,9 +964,10 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
size = math.MaxInt32 / 2
}
- wasZero := e.zeroReceiveWindow(scale)
e.rcvBufSize = size
- if wasZero && !e.zeroReceiveWindow(scale) {
+ e.rcvAutoParams.disabled = true
+ if e.zeroWindow && !e.zeroReceiveWindow(scale) {
+ e.zeroWindow = false
mask |= notifyNonZeroReceiveWindow
}
e.rcvListMu.Unlock()
@@ -898,6 +1045,40 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
return nil
+ case tcpip.CongestionControlOption:
+ // Query the available cc algorithms in the stack and
+ // validate that the specified algorithm is actually
+ // supported in the stack.
+ var avail tcpip.AvailableCongestionControlOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &avail); err != nil {
+ return err
+ }
+ availCC := strings.Split(string(avail), " ")
+ for _, cc := range availCC {
+ if v == tcpip.CongestionControlOption(cc) {
+ // Acquire the work mutex as we may need to
+ // reinitialize the congestion control state.
+ e.mu.Lock()
+ state := e.state
+ e.cc = v
+ e.mu.Unlock()
+ switch state {
+ case StateEstablished:
+ e.workMu.Lock()
+ e.mu.Lock()
+ if e.state == state {
+ e.snd.cc = e.snd.initCongestionControl(e.cc)
+ }
+ e.mu.Unlock()
+ e.workMu.Unlock()
+ }
+ return nil
+ }
+ }
+
+ // Linux returns ENOENT when an invalid congestion
+ // control algorithm is specified.
+ return tcpip.ErrNoSuchFile
default:
return nil
}
@@ -929,6 +1110,14 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
e.lastErrorMu.Unlock()
return err
+ case *tcpip.MaxSegOption:
+ // This is just stubbed out. Linux never returns the user_mss
+ // value as it either returns the defaultMSS or returns the
+ // actual current MSS. Netstack just returns the defaultMSS
+ // always for now.
+ *o = header.TCPDefaultMSS
+ return nil
+
case *tcpip.SendBufferSizeOption:
e.sndBufMu.Lock()
*o = tcpip.SendBufferSizeOption(e.sndBufSize)
@@ -1067,6 +1256,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.CongestionControlOption:
+ e.mu.Lock()
+ *o = e.cc
+ e.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -1098,6 +1293,11 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol
// Connect connects the endpoint to its peer.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ if addr.Addr == "" && addr.Port == 0 {
+ // AF_UNSPEC isn't supported.
+ return tcpip.ErrAddressFamilyNotSupported
+ }
+
return e.connect(addr, true, true)
}
@@ -1582,6 +1782,13 @@ func (e *endpoint) readyToRead(s *segment) {
if s != nil {
s.incRef()
e.rcvBufUsed += s.data.Size()
+ // Check if the receive window is now closed. If so make sure
+ // we set the zero window before we deliver the segment to ensure
+ // that a subsequent read of the segment will correctly trigger
+ // a non-zero notification.
+ if avail := e.receiveBufferAvailableLocked(); avail>>e.rcv.rcvWndScale == 0 {
+ e.zeroWindow = true
+ }
e.rcvList.PushBack(s)
} else {
e.rcvClosed = true
@@ -1591,21 +1798,26 @@ func (e *endpoint) readyToRead(s *segment) {
e.waiterQueue.Notify(waiter.EventIn)
}
-// receiveBufferAvailable calculates how many bytes are still available in the
-// receive buffer.
-func (e *endpoint) receiveBufferAvailable() int {
- e.rcvListMu.Lock()
- size := e.rcvBufSize
- used := e.rcvBufUsed
- e.rcvListMu.Unlock()
-
+// receiveBufferAvailableLocked calculates how many bytes are still available
+// in the receive buffer.
+// rcvListMu must be held when this function is called.
+func (e *endpoint) receiveBufferAvailableLocked() int {
// We may use more bytes than the buffer size when the receive buffer
// shrinks.
- if used >= size {
+ if e.rcvBufUsed >= e.rcvBufSize {
return 0
}
- return size - used
+ return e.rcvBufSize - e.rcvBufUsed
+}
+
+// receiveBufferAvailable calculates how many bytes are still available in the
+// receive buffer.
+func (e *endpoint) receiveBufferAvailable() int {
+ e.rcvListMu.Lock()
+ available := e.receiveBufferAvailableLocked()
+ e.rcvListMu.Unlock()
+ return available
}
func (e *endpoint) receiveBufferSize() int {
@@ -1616,6 +1828,33 @@ func (e *endpoint) receiveBufferSize() int {
return size
}
+func (e *endpoint) maxReceiveBufferSize() int {
+ var rs ReceiveBufferSizeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil {
+ // As a fallback return the hardcoded max buffer size.
+ return MaxBufferSize
+ }
+ return rs.Max
+}
+
+// rcvWndScaleForHandshake computes the receive window scale to offer to the
+// peer when window scaling is enabled (true by default). If auto-tuning is
+// disabled then the window scaling factor is based on the size of the
+// receiveBuffer otherwise we use the max permissible receive buffer size to
+// compute the scale.
+func (e *endpoint) rcvWndScaleForHandshake() int {
+ bufSizeForScale := e.receiveBufferSize()
+
+ e.rcvListMu.Lock()
+ autoTuningDisabled := e.rcvAutoParams.disabled
+ e.rcvListMu.Unlock()
+ if autoTuningDisabled {
+ return FindWndScale(seqnum.Size(bufSizeForScale))
+ }
+
+ return FindWndScale(seqnum.Size(e.maxReceiveBufferSize()))
+}
+
// updateRecentTimestamp updates the recent timestamp using the algorithm
// described in https://tools.ietf.org/html/rfc7323#section-4.3
func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) {
@@ -1708,6 +1947,13 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
s.RcvBufSize = e.rcvBufSize
s.RcvBufUsed = e.rcvBufUsed
s.RcvClosed = e.rcvClosed
+ s.RcvAutoParams.MeasureTime = e.rcvAutoParams.measureTime
+ s.RcvAutoParams.CopiedBytes = e.rcvAutoParams.copied
+ s.RcvAutoParams.PrevCopiedBytes = e.rcvAutoParams.prevCopied
+ s.RcvAutoParams.RTT = e.rcvAutoParams.rtt
+ s.RcvAutoParams.RTTMeasureSeqNumber = e.rcvAutoParams.rttMeasureSeqNumber
+ s.RcvAutoParams.RTTMeasureTime = e.rcvAutoParams.rttMeasureTime
+ s.RcvAutoParams.Disabled = e.rcvAutoParams.disabled
e.rcvListMu.Unlock()
// Endpoint TCP Option state.
@@ -1761,13 +2007,13 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
RTTMeasureTime: e.snd.rttMeasureTime,
Closed: e.snd.closed,
RTO: e.snd.rto,
- SRTTInited: e.snd.srttInited,
MaxPayloadSize: e.snd.maxPayloadSize,
SndWndScale: e.snd.sndWndScale,
MaxSentAck: e.snd.maxSentAck,
}
e.snd.rtt.Lock()
s.Sender.SRTT = e.snd.rtt.srtt
+ s.Sender.SRTTInited = e.snd.rtt.srttInited
e.snd.rtt.Unlock()
if cubic, ok := e.snd.cc.(*cubicState); ok {
@@ -1815,3 +2061,7 @@ func (e *endpoint) State() uint32 {
defer e.mu.Unlock()
return uint32(e.state)
}
+
+func mssForRoute(r *stack.Route) uint16 {
+ return uint16(r.MTU() - header.TCPMinimumSize)
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 5f30c2374..b93959034 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -19,9 +19,9 @@ import (
"sync"
"time"
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
func (e *endpoint) drainSegmentLocked() {
@@ -342,6 +342,7 @@ func loadError(s string) *tcpip.Error {
tcpip.ErrNoBufferSpace,
tcpip.ErrBroadcastDisabled,
tcpip.ErrNotPermitted,
+ tcpip.ErrAddressFamilyNotSupported,
}
messageToError = make(map[string]*tcpip.Error)
@@ -360,3 +361,23 @@ func loadError(s string) *tcpip.Error {
return e
}
+
+// saveMeasureTime is invoked by stateify.
+func (r *rcvBufAutoTuneParams) saveMeasureTime() unixTime {
+ return unixTime{r.measureTime.Unix(), r.measureTime.UnixNano()}
+}
+
+// loadMeasureTime is invoked by stateify.
+func (r *rcvBufAutoTuneParams) loadMeasureTime(unix unixTime) {
+ r.measureTime = time.Unix(unix.second, unix.nano)
+}
+
+// saveRttMeasureTime is invoked by stateify.
+func (r *rcvBufAutoTuneParams) saveRttMeasureTime() unixTime {
+ return unixTime{r.rttMeasureTime.Unix(), r.rttMeasureTime.UnixNano()}
+}
+
+// loadRttMeasureTime is invoked by stateify.
+func (r *rcvBufAutoTuneParams) loadRttMeasureTime(unix unixTime) {
+ r.rttMeasureTime = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index c30b45c2c..63666f0b3 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -17,12 +17,12 @@ package tcp
import (
"sync"
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
- "gvisor.googlesource.com/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
)
// Forwarder is a connection request forwarder, which allows clients to decide
@@ -47,7 +47,7 @@ type Forwarder struct {
// If rcvWnd is set to zero, the default buffer size is used instead.
func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*ForwarderRequest)) *Forwarder {
if rcvWnd == 0 {
- rcvWnd = DefaultBufferSize
+ rcvWnd = DefaultReceiveBufferSize
}
return &Forwarder{
maxInFlight: maxInFlight,
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index b31bcccfa..ee04dcfcc 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -24,13 +24,13 @@ import (
"strings"
"sync"
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/raw"
- "gvisor.googlesource.com/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/raw"
+ "gvisor.dev/gvisor/pkg/waiter"
)
const (
@@ -41,13 +41,18 @@ const (
ProtocolNumber = header.TCPProtocolNumber
// MinBufferSize is the smallest size of a receive or send buffer.
- minBufferSize = 4 << 10 // 4096 bytes.
+ MinBufferSize = 4 << 10 // 4096 bytes.
- // DefaultBufferSize is the default size of the receive and send buffers.
- DefaultBufferSize = 1 << 20 // 1MB
+ // DefaultSendBufferSize is the default size of the send buffer for
+ // an endpoint.
+ DefaultSendBufferSize = 1 << 20 // 1MB
- // MaxBufferSize is the largest size a receive and send buffer can grow to.
- maxBufferSize = 4 << 20 // 4MB
+ // DefaultReceiveBufferSize is the default size of the receive buffer
+ // for an endpoint.
+ DefaultReceiveBufferSize = 1 << 20 // 1MB
+
+ // MaxBufferSize is the largest size a receive/send buffer can grow to.
+ MaxBufferSize = 4 << 20 // 4MB
// MaxUnprocessedSegments is the maximum number of unprocessed segments
// that can be queued for a given endpoint.
@@ -79,13 +84,6 @@ const (
ccCubic = "cubic"
)
-// CongestionControlOption sets the current congestion control algorithm.
-type CongestionControlOption string
-
-// AvailableCongestionControlOption returns the supported congestion control
-// algorithms.
-type AvailableCongestionControlOption string
-
type protocol struct {
mu sync.Mutex
sackEnabled bool
@@ -93,7 +91,7 @@ type protocol struct {
recvBufferSize ReceiveBufferSizeOption
congestionControl string
availableCongestionControl []string
- allowedCongestionControl []string
+ moderateReceiveBuffer bool
}
// Number returns the tcp protocol number.
@@ -188,7 +186,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
- case CongestionControlOption:
+ case tcpip.CongestionControlOption:
for _, c := range p.availableCongestionControl {
if string(v) == c {
p.mu.Lock()
@@ -197,7 +195,16 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
return nil
}
}
- return tcpip.ErrInvalidOptionValue
+ // linux returns ENOENT when an invalid congestion control
+ // is specified.
+ return tcpip.ErrNoSuchFile
+
+ case tcpip.ModerateReceiveBufferOption:
+ p.mu.Lock()
+ p.moderateReceiveBuffer = bool(v)
+ p.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -223,16 +230,25 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
*v = p.recvBufferSize
p.mu.Unlock()
return nil
- case *CongestionControlOption:
+
+ case *tcpip.CongestionControlOption:
p.mu.Lock()
- *v = CongestionControlOption(p.congestionControl)
+ *v = tcpip.CongestionControlOption(p.congestionControl)
p.mu.Unlock()
return nil
- case *AvailableCongestionControlOption:
+
+ case *tcpip.AvailableCongestionControlOption:
p.mu.Lock()
- *v = AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
+ *v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
p.mu.Unlock()
return nil
+
+ case *tcpip.ModerateReceiveBufferOption:
+ p.mu.Lock()
+ *v = tcpip.ModerateReceiveBufferOption(p.moderateReceiveBuffer)
+ p.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -241,8 +257,8 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
func init() {
stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
return &protocol{
- sendBufferSize: SendBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize},
- recvBufferSize: ReceiveBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize},
+ sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize},
+ recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize},
congestionControl: ccReno,
availableCongestionControl: []string{ccReno, ccCubic},
}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index f02fa6105..e90f9a7d9 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -16,9 +16,10 @@ package tcp
import (
"container/heap"
+ "time"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
)
// receiver holds the state necessary to receive TCP segments and turn them
@@ -38,6 +39,9 @@ type receiver struct {
// shrinking it.
rcvAcc seqnum.Value
+ // rcvWnd is the non-scaled receive window last advertised to the peer.
+ rcvWnd seqnum.Size
+
rcvWndScale uint8
closed bool
@@ -47,13 +51,14 @@ type receiver struct {
pendingBufSize seqnum.Size
}
-func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver {
+func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver {
return &receiver{
ep: ep,
rcvNxt: irs + 1,
rcvAcc: irs.Add(rcvWnd + 1),
+ rcvWnd: rcvWnd,
rcvWndScale: rcvWndScale,
- pendingBufSize: rcvWnd,
+ pendingBufSize: pendingBufSize,
}
}
@@ -72,14 +77,16 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
// getSendParams returns the parameters needed by the sender when building
// segments to send.
func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
- // Calculate the window size based on the current buffer size.
- n := r.ep.receiveBufferAvailable()
- acc := r.rcvNxt.Add(seqnum.Size(n))
+ // Calculate the window size based on the available buffer space.
+ receiveBufferAvailable := r.ep.receiveBufferAvailable()
+ acc := r.rcvNxt.Add(seqnum.Size(receiveBufferAvailable))
if r.rcvAcc.LessThan(acc) {
r.rcvAcc = acc
}
-
- return r.rcvNxt, r.rcvNxt.Size(r.rcvAcc) >> r.rcvWndScale
+ // Stash away the non-scaled receive window as we use it for measuring
+ // receiver's estimated RTT.
+ r.rcvWnd = r.rcvNxt.Size(r.rcvAcc)
+ return r.rcvNxt, r.rcvWnd >> r.rcvWndScale
}
// nonZeroWindow is called when the receive window grows from zero to nonzero;
@@ -130,6 +137,21 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// Update the segment that we're expecting to consume.
r.rcvNxt = segSeq.Add(segLen)
+ // In cases of a misbehaving sender which could send more than the
+ // advertised window, we could end up in a situation where we get a
+ // segment that exceeds the window advertised. Instead of partially
+ // accepting the segment and discarding bytes beyond the advertised
+ // window, we accept the whole segment and make sure r.rcvAcc is moved
+ // forward to match r.rcvNxt to indicate that the window is now closed.
+ //
+ // In absence of this check the r.acceptable() check fails and accepts
+ // segments that should be dropped because rcvWnd is calculated as
+ // the size of the interval (rcvNxt, rcvAcc] which becomes extremely
+ // large if rcvAcc is ever less than rcvNxt.
+ if r.rcvAcc.LessThan(r.rcvNxt) {
+ r.rcvAcc = r.rcvNxt
+ }
+
// Trim SACK Blocks to remove any SACK information that covers
// sequence numbers that have been consumed.
TrimSACKBlockList(&r.ep.sack, r.rcvNxt)
@@ -198,6 +220,39 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
return true
}
+// updateRTT updates the receiver RTT measurement based on the sequence number
+// of the received segment.
+func (r *receiver) updateRTT() {
+ // From: https://public.lanl.gov/radiant/pubs/drs/sc2001-poster.pdf
+ //
+ // A system that is only transmitting acknowledgements can still
+ // estimate the round-trip time by observing the time between when a byte
+ // is first acknowledged and the receipt of data that is at least one
+ // window beyond the sequence number that was acknowledged.
+ r.ep.rcvListMu.Lock()
+ if r.ep.rcvAutoParams.rttMeasureTime.IsZero() {
+ // New measurement.
+ r.ep.rcvAutoParams.rttMeasureTime = time.Now()
+ r.ep.rcvAutoParams.rttMeasureSeqNumber = r.rcvNxt.Add(r.rcvWnd)
+ r.ep.rcvListMu.Unlock()
+ return
+ }
+ if r.rcvNxt.LessThan(r.ep.rcvAutoParams.rttMeasureSeqNumber) {
+ r.ep.rcvListMu.Unlock()
+ return
+ }
+ rtt := time.Since(r.ep.rcvAutoParams.rttMeasureTime)
+ // We only store the minimum observed RTT here as this is only used in
+ // absence of a SRTT available from either timestamps or a sender
+ // measurement of RTT.
+ if r.ep.rcvAutoParams.rtt == 0 || rtt < r.ep.rcvAutoParams.rtt {
+ r.ep.rcvAutoParams.rtt = rtt
+ }
+ r.ep.rcvAutoParams.rttMeasureTime = time.Now()
+ r.ep.rcvAutoParams.rttMeasureSeqNumber = r.rcvNxt.Add(r.rcvWnd)
+ r.ep.rcvListMu.Unlock()
+}
+
// handleRcvdSegment handles TCP segments directed at the connection managed by
// r as they arrive. It is called by the protocol main loop.
func (r *receiver) handleRcvdSegment(s *segment) {
@@ -226,10 +281,9 @@ func (r *receiver) handleRcvdSegment(s *segment) {
r.pendingBufUsed += s.logicalLen()
s.incRef()
heap.Push(&r.pendingRcvdSegments, s)
+ UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt)
}
- UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt)
-
// Immediately send an ack so that the peer knows it may
// have to retransmit.
r.ep.snd.sendAck()
@@ -237,6 +291,12 @@ func (r *receiver) handleRcvdSegment(s *segment) {
return
}
+ // Since we consumed a segment update the receiver's RTT estimate
+ // if required.
+ if segLen > 0 {
+ r.updateRTT()
+ }
+
// By consuming the current segment, we may have filled a gap in the
// sequence number domain that allows pending segments to be consumed
// now. So try to do it.
diff --git a/pkg/tcpip/transport/tcp/sack.go b/pkg/tcpip/transport/tcp/sack.go
index 6a013d99b..7be86d68e 100644
--- a/pkg/tcpip/transport/tcp/sack.go
+++ b/pkg/tcpip/transport/tcp/sack.go
@@ -15,8 +15,8 @@
package tcp
import (
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
)
const (
@@ -31,6 +31,13 @@ const (
// segment identified by segStart->segEnd.
func UpdateSACKBlocks(sack *SACKInfo, segStart seqnum.Value, segEnd seqnum.Value, rcvNxt seqnum.Value) {
newSB := header.SACKBlock{Start: segStart, End: segEnd}
+
+ // Ignore any invalid SACK blocks or blocks that are before rcvNxt as
+ // those bytes have already been acked.
+ if newSB.End.LessThanEq(newSB.Start) || newSB.End.LessThan(rcvNxt) {
+ return
+ }
+
if sack.NumBlocks == 0 {
sack.Blocks[0] = newSB
sack.NumBlocks = 1
@@ -39,9 +46,8 @@ func UpdateSACKBlocks(sack *SACKInfo, segStart seqnum.Value, segEnd seqnum.Value
var n = 0
for i := 0; i < sack.NumBlocks; i++ {
start, end := sack.Blocks[i].Start, sack.Blocks[i].End
- if end.LessThanEq(start) || start.LessThanEq(rcvNxt) {
- // Discard any invalid blocks where end is before start
- // and discard any sack blocks that are before rcvNxt as
+ if end.LessThanEq(rcvNxt) {
+ // Discard any sack blocks that are before rcvNxt as
// those have already been acked.
continue
}
diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard.go b/pkg/tcpip/transport/tcp/sack_scoreboard.go
index 1c5766a42..7ef2df377 100644
--- a/pkg/tcpip/transport/tcp/sack_scoreboard.go
+++ b/pkg/tcpip/transport/tcp/sack_scoreboard.go
@@ -19,8 +19,8 @@ import (
"strings"
"github.com/google/btree"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
)
const (
diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard_test.go b/pkg/tcpip/transport/tcp/sack_scoreboard_test.go
index b59eedc9d..b4e5ba0df 100644
--- a/pkg/tcpip/transport/tcp/sack_scoreboard_test.go
+++ b/pkg/tcpip/transport/tcp/sack_scoreboard_test.go
@@ -17,9 +17,9 @@ package tcp_test
import (
"testing"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
)
const smss = 1500
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 450d9fbc1..ea725d513 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -18,10 +18,10 @@ import (
"sync/atomic"
"time"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
// segment represents a TCP segment. It holds the payload and parsed TCP segment
diff --git a/pkg/tcpip/transport/tcp/segment_state.go b/pkg/tcpip/transport/tcp/segment_state.go
index dd7e14aa6..7dc2741a6 100644
--- a/pkg/tcpip/transport/tcp/segment_state.go
+++ b/pkg/tcpip/transport/tcp/segment_state.go
@@ -17,7 +17,7 @@ package tcp
import (
"time"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
)
// saveData is invoked by stateify.
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index b236d7af2..0fee7ab72 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -20,11 +20,11 @@ import (
"sync/atomic"
"time"
- "gvisor.googlesource.com/gvisor/pkg/sleep"
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
)
const (
@@ -121,9 +121,8 @@ type sender struct {
// rtt.srtt, rtt.rttvar, and rto are the "smoothed round-trip time",
// "round-trip time variation" and "retransmit timeout", as defined in
// section 2 of RFC 6298.
- rtt rtt
- rto time.Duration
- srttInited bool
+ rtt rtt
+ rto time.Duration
// maxPayloadSize is the maximum size of the payload of a given segment.
// It is initialized on demand.
@@ -150,8 +149,9 @@ type sender struct {
type rtt struct {
sync.Mutex `state:"nosave"`
- srtt time.Duration
- rttvar time.Duration
+ srtt time.Duration
+ rttvar time.Duration
+ srttInited bool
}
// fastRecovery holds information related to fast recovery from a packet loss.
@@ -194,8 +194,6 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
s := &sender{
ep: ep,
- sndCwnd: InitialCwnd,
- sndSsthresh: math.MaxInt64,
sndWnd: sndWnd,
sndUna: iss + 1,
sndNxt: iss + 1,
@@ -238,7 +236,13 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
return s
}
-func (s *sender) initCongestionControl(congestionControlName CongestionControlOption) congestionControl {
+// initCongestionControl initializes the specified congestion control module and
+// returns a handle to it. It also initializes the sndCwnd and sndSsThresh to
+// their initial values.
+func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionControlOption) congestionControl {
+ s.sndCwnd = InitialCwnd
+ s.sndSsthresh = math.MaxInt64
+
switch congestionControlName {
case ccCubic:
return newCubicCC(s)
@@ -319,10 +323,10 @@ func (s *sender) sendAck() {
// available. This is done in accordance with section 2 of RFC 6298.
func (s *sender) updateRTO(rtt time.Duration) {
s.rtt.Lock()
- if !s.srttInited {
+ if !s.rtt.srttInited {
s.rtt.rttvar = rtt / 2
s.rtt.srtt = rtt
- s.srttInited = true
+ s.rtt.srttInited = true
} else {
diff := s.rtt.srtt - rtt
if diff < 0 {
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
index 4d1519860..272bbcdbd 100644
--- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
@@ -26,11 +26,11 @@ import (
"testing"
"time"
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp/testing/context"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
)
func TestFastRecovery(t *testing.T) {
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index 025d133be..4e7f1a740 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -21,13 +21,13 @@ import (
"testing"
"time"
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp/testing/context"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
)
// createConnectedWithSACKPermittedOption creates and connects c.ep with the
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 779ca8b76..915a98047 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -21,20 +21,20 @@ import (
"testing"
"time"
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/checker"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/link/loopback"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/ports"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp/testing/context"
- "gvisor.googlesource.com/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
+ "gvisor.dev/gvisor/pkg/waiter"
)
const (
@@ -1110,8 +1110,9 @@ func TestNonScaledWindowAccept(t *testing.T) {
t.Fatalf("Listen failed: %v", err)
}
- // Do 3-way handshake.
- c.PassiveConnect(100, 2, header.TCPSynOptions{MSS: defaultIPv4MSS})
+ // Do 3-way handshake w/ window scaling disabled. The SYN-ACK to the SYN
+ // should not carry the window scaling option.
+ c.PassiveConnect(100, -1, header.TCPSynOptions{MSS: defaultIPv4MSS})
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
@@ -1600,7 +1601,6 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
// Set the buffer size to a deterministic size so that we can check the
// window scaling option.
const rcvBufferSize = 0x20000
- const wndScale = 2
if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1614,7 +1614,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
}
// Do 3-way handshake.
- c.PassiveConnect(maxPayload, wndScale, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
+ c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
@@ -1713,7 +1713,7 @@ func TestForwarderSendMSSLessThanMTU(t *testing.T) {
s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket)
// Do 3-way handshake.
- c.PassiveConnect(maxPayload, 1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
+ c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
// Wait for connection to be available.
select {
@@ -2767,11 +2767,11 @@ func TestDefaultBufferSizes(t *testing.T) {
}
}()
- checkSendBufferSize(t, ep, tcp.DefaultBufferSize)
- checkRecvBufferSize(t, ep, tcp.DefaultBufferSize)
+ checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize)
+ checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
// Change the default send buffer size.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultBufferSize * 2, tcp.DefaultBufferSize * 20}); err != nil {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize * 2, tcp.DefaultSendBufferSize * 20}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %v", err)
}
@@ -2781,11 +2781,11 @@ func TestDefaultBufferSizes(t *testing.T) {
t.Fatalf("NewEndpoint failed; %v", err)
}
- checkSendBufferSize(t, ep, tcp.DefaultBufferSize*2)
- checkRecvBufferSize(t, ep, tcp.DefaultBufferSize)
+ checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2)
+ checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
// Change the default receive buffer size.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultBufferSize * 3, tcp.DefaultBufferSize * 30}); err != nil {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize * 3, tcp.DefaultReceiveBufferSize * 30}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %v", err)
}
@@ -2795,8 +2795,8 @@ func TestDefaultBufferSizes(t *testing.T) {
t.Fatalf("NewEndpoint failed; %v", err)
}
- checkSendBufferSize(t, ep, tcp.DefaultBufferSize*2)
- checkRecvBufferSize(t, ep, tcp.DefaultBufferSize*3)
+ checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2)
+ checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*3)
}
func TestMinMaxBufferSizes(t *testing.T) {
@@ -2810,11 +2810,11 @@ func TestMinMaxBufferSizes(t *testing.T) {
defer ep.Close()
// Change the min/max values for send/receive
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{200, tcp.DefaultBufferSize * 2, tcp.DefaultBufferSize * 20}); err != nil {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{200, tcp.DefaultReceiveBufferSize * 2, tcp.DefaultReceiveBufferSize * 20}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %v", err)
}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{300, tcp.DefaultBufferSize * 3, tcp.DefaultBufferSize * 30}); err != nil {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{300, tcp.DefaultSendBufferSize * 3, tcp.DefaultSendBufferSize * 30}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %v", err)
}
@@ -2832,17 +2832,17 @@ func TestMinMaxBufferSizes(t *testing.T) {
checkSendBufferSize(t, ep, 300)
// Set values above the max.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(1 + tcp.DefaultBufferSize*20)); err != nil {
+ if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(1 + tcp.DefaultReceiveBufferSize*20)); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
- checkRecvBufferSize(t, ep, tcp.DefaultBufferSize*20)
+ checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20)
- if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(1 + tcp.DefaultBufferSize*30)); err != nil {
+ if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(1 + tcp.DefaultSendBufferSize*30)); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
- checkSendBufferSize(t, ep, tcp.DefaultBufferSize*30)
+ checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30)
}
func makeStack() (*stack.Stack, *tcpip.Error) {
@@ -3205,13 +3205,14 @@ func TestTCPEndpointProbe(t *testing.T) {
}
}
-func TestSetCongestionControl(t *testing.T) {
+func TestStackSetCongestionControl(t *testing.T) {
testCases := []struct {
- cc tcp.CongestionControlOption
- mustPass bool
+ cc tcpip.CongestionControlOption
+ err *tcpip.Error
}{
- {"reno", true},
- {"cubic", true},
+ {"reno", nil},
+ {"cubic", nil},
+ {"blahblah", tcpip.ErrNoSuchFile},
}
for _, tc := range testCases {
@@ -3221,62 +3222,135 @@ func TestSetCongestionControl(t *testing.T) {
s := c.Stack()
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != nil && tc.mustPass {
- t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want not-nil", tcp.ProtocolNumber, tc.cc, err)
+ var oldCC tcpip.CongestionControlOption
+ if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &oldCC, err)
+ }
+
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err {
+ t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.cc, err, tc.err)
}
- var cc tcp.CongestionControlOption
+ var cc tcpip.CongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err)
}
- if got, want := cc, tc.cc; got != want {
+
+ got, want := cc, oldCC
+ // If SetTransportProtocolOption is expected to succeed
+ // then the returned value for congestion control should
+ // match the one specified in the
+ // SetTransportProtocolOption call above, else it should
+ // be what it was before the call to
+ // SetTransportProtocolOption.
+ if tc.err == nil {
+ want = tc.cc
+ }
+ if got != want {
t.Fatalf("got congestion control: %v, want: %v", got, want)
}
})
}
}
-func TestAvailableCongestionControl(t *testing.T) {
+func TestStackAvailableCongestionControl(t *testing.T) {
c := context.New(t, 1500)
defer c.Cleanup()
s := c.Stack()
// Query permitted congestion control algorithms.
- var aCC tcp.AvailableCongestionControlOption
+ var aCC tcpip.AvailableCongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil {
t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err)
}
- if got, want := aCC, tcp.AvailableCongestionControlOption("reno cubic"); got != want {
- t.Fatalf("got tcp.AvailableCongestionControlOption: %v, want: %v", got, want)
+ if got, want := aCC, tcpip.AvailableCongestionControlOption("reno cubic"); got != want {
+ t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want)
}
}
-func TestSetAvailableCongestionControl(t *testing.T) {
+func TestStackSetAvailableCongestionControl(t *testing.T) {
c := context.New(t, 1500)
defer c.Cleanup()
s := c.Stack()
// Setting AvailableCongestionControlOption should fail.
- aCC := tcp.AvailableCongestionControlOption("xyz")
+ aCC := tcpip.AvailableCongestionControlOption("xyz")
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil {
t.Fatalf("s.TransportProtocolOption(%v, %v) = nil, want non-nil", tcp.ProtocolNumber, &aCC)
}
// Verify that we still get the expected list of congestion control options.
- var cc tcp.AvailableCongestionControlOption
+ var cc tcpip.AvailableCongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err)
}
- if got, want := cc, tcp.AvailableCongestionControlOption("reno cubic"); got != want {
- t.Fatalf("got tcp.AvailableCongestionControlOption: %v, want: %v", got, want)
+ if got, want := cc, tcpip.AvailableCongestionControlOption("reno cubic"); got != want {
+ t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want)
+ }
+}
+
+func TestEndpointSetCongestionControl(t *testing.T) {
+ testCases := []struct {
+ cc tcpip.CongestionControlOption
+ err *tcpip.Error
+ }{
+ {"reno", nil},
+ {"cubic", nil},
+ {"blahblah", tcpip.ErrNoSuchFile},
+ }
+
+ for _, connected := range []bool{false, true} {
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("SetSockOpt(.., %v) w/ connected = %v", tc.cc, connected), func(t *testing.T) {
+ c := context.New(t, 1500)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ var oldCC tcpip.CongestionControlOption
+ if err := c.EP.GetSockOpt(&oldCC); err != nil {
+ t.Fatalf("c.EP.SockOpt(%v) = %v", &oldCC, err)
+ }
+
+ if connected {
+ c.Connect(789 /* iss */, 32768 /* rcvWnd */, nil)
+ }
+
+ if err := c.EP.SetSockOpt(tc.cc); err != tc.err {
+ t.Fatalf("c.EP.SetSockOpt(%v) = %v, want %v", tc.cc, err, tc.err)
+ }
+
+ var cc tcpip.CongestionControlOption
+ if err := c.EP.GetSockOpt(&cc); err != nil {
+ t.Fatalf("c.EP.SockOpt(%v) = %v", &cc, err)
+ }
+
+ got, want := cc, oldCC
+ // If SetSockOpt is expected to succeed then the
+ // returned value for congestion control should match
+ // the one specified in the SetSockOpt above, else it
+ // should be what it was before the call to SetSockOpt.
+ if tc.err == nil {
+ want = tc.cc
+ }
+ if got != want {
+ t.Fatalf("got congestion control: %v, want: %v", got, want)
+ }
+ })
+ }
}
}
func enableCUBIC(t *testing.T, c *context.Context) {
t.Helper()
- opt := tcp.CongestionControlOption("cubic")
+ opt := tcpip.CongestionControlOption("cubic")
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil {
t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %v = %v", opt, err)
}
@@ -3902,3 +3976,273 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
}
}
+
+// This test verifies that the auto tuning does not grow the receive buffer if
+// the application is not reading the data actively.
+func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
+ const mtu = 1500
+ const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
+
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ stk := c.Stack()
+ // Set lower limits for auto-tuning tests. This is required because the
+ // test stops the worker which can cause packets to be dropped because
+ // the segment queue holding unprocessed packets is limited to 500.
+ const receiveBufferSize = 80 << 10 // 80KB.
+ const maxReceiveBufferSize = receiveBufferSize * 10
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ }
+
+ // Enable auto-tuning.
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ }
+ // Change the expected window scale to match the value needed for the
+ // maximum buffer size defined above.
+ c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
+
+ rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
+
+ // NOTE: The timestamp values in the sent packets are meaningless to the
+ // peer so we just increment the timestamp value by 1 every batch as we
+ // are not really using them for anything. Send a single byte to verify
+ // the advertised window.
+ tsVal := rawEP.TSVal + 1
+
+ // Introduce a 25ms latency by delaying the first byte.
+ latency := 25 * time.Millisecond
+ time.Sleep(latency)
+ rawEP.SendPacketWithTS([]byte{1}, tsVal)
+
+ // Verify that the ACK has the expected window.
+ wantRcvWnd := receiveBufferSize
+ wantRcvWnd = (wantRcvWnd >> uint32(c.WindowScale))
+ rawEP.VerifyACKRcvWnd(uint16(wantRcvWnd - 1))
+ time.Sleep(25 * time.Millisecond)
+
+ // Allocate a large enough payload for the test.
+ b := make([]byte, int(receiveBufferSize)*2)
+ offset := 0
+ payloadSize := receiveBufferSize - 1
+ worker := (c.EP).(interface {
+ StopWork()
+ ResumeWork()
+ })
+ tsVal++
+
+ // Stop the worker goroutine.
+ worker.StopWork()
+ start := offset
+ end := offset + payloadSize
+ packetsSent := 0
+ for ; start < end; start += mss {
+ rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
+ packetsSent++
+ }
+ // Resume the worker so that it only sees the packets once all of them
+ // are waiting to be read.
+ worker.ResumeWork()
+
+ // Since we read no bytes the window should goto zero till the
+ // application reads some of the data.
+ // Discard all intermediate acks except the last one.
+ if packetsSent > 100 {
+ for i := 0; i < (packetsSent / 100); i++ {
+ _ = c.GetPacket()
+ }
+ }
+ rawEP.VerifyACKRcvWnd(0)
+
+ time.Sleep(25 * time.Millisecond)
+ // Verify that sending more data when window is closed is dropped and
+ // not acked.
+ rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
+
+ // Verify that the stack sends us back an ACK with the sequence number
+ // of the last packet sent indicating it was dropped.
+ p := c.GetPacket()
+ checker.IPv4(t, p, checker.TCP(
+ checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)),
+ checker.Window(0),
+ ))
+
+ // Now read all the data from the endpoint and verify that advertised
+ // window increases to the full available buffer size.
+ for {
+ _, _, err := c.EP.Read(nil)
+ if err == tcpip.ErrWouldBlock {
+ break
+ }
+ }
+
+ // Verify that we receive a non-zero window update ACK. When running
+ // under thread santizer this test can end up sending more than 1
+ // ack, 1 for the non-zero window
+ p = c.GetPacket()
+ checker.IPv4(t, p, checker.TCP(
+ checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)),
+ func(t *testing.T, h header.Transport) {
+ tcp, ok := h.(header.TCP)
+ if !ok {
+ return
+ }
+ if w := tcp.WindowSize(); w == 0 || w > uint16(wantRcvWnd) {
+ t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w, wantRcvWnd)
+ }
+ },
+ ))
+}
+
+// This test verifies that the auto tuning does not grow the receive buffer if
+// the application is not reading the data actively.
+func TestReceiveBufferAutoTuning(t *testing.T) {
+ const mtu = 1500
+ const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
+
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ // Enable Auto-tuning.
+ stk := c.Stack()
+ // Set lower limits for auto-tuning tests. This is required because the
+ // test stops the worker which can cause packets to be dropped because
+ // the segment queue holding unprocessed packets is limited to 500.
+ const receiveBufferSize = 80 << 10 // 80KB.
+ const maxReceiveBufferSize = receiveBufferSize * 10
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ }
+
+ // Enable auto-tuning.
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ }
+ // Change the expected window scale to match the value needed for the
+ // maximum buffer size used by stack.
+ c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
+
+ rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
+
+ wantRcvWnd := receiveBufferSize
+ scaleRcvWnd := func(rcvWnd int) uint16 {
+ return uint16(rcvWnd >> uint16(c.WindowScale))
+ }
+ // Allocate a large array to send to the endpoint.
+ b := make([]byte, receiveBufferSize*48)
+
+ // In every iteration we will send double the number of bytes sent in
+ // the previous iteration and read the same from the app. The received
+ // window should grow by at least 2x of bytes read by the app in every
+ // RTT.
+ offset := 0
+ payloadSize := receiveBufferSize / 8
+ worker := (c.EP).(interface {
+ StopWork()
+ ResumeWork()
+ })
+ tsVal := rawEP.TSVal
+ // We are going to do our own computation of what the moderated receive
+ // buffer should be based on sent/copied data per RTT and verify that
+ // the advertised window by the stack matches our calculations.
+ prevCopied := 0
+ done := false
+ latency := 1 * time.Millisecond
+ for i := 0; !done; i++ {
+ tsVal++
+
+ // Stop the worker goroutine.
+ worker.StopWork()
+ start := offset
+ end := offset + payloadSize
+ totalSent := 0
+ packetsSent := 0
+ for ; start < end; start += mss {
+ rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
+ totalSent += mss
+ packetsSent++
+ }
+ // Resume it so that it only sees the packets once all of them
+ // are waiting to be read.
+ worker.ResumeWork()
+
+ // Give 1ms for the worker to process the packets.
+ time.Sleep(1 * time.Millisecond)
+
+ // Verify that the advertised window on the ACK is reduced by
+ // the total bytes sent.
+ expectedWnd := wantRcvWnd - totalSent
+ if packetsSent > 100 {
+ for i := 0; i < (packetsSent / 100); i++ {
+ _ = c.GetPacket()
+ }
+ }
+ rawEP.VerifyACKRcvWnd(scaleRcvWnd(expectedWnd))
+
+ // Now read all the data from the endpoint and invoke the
+ // moderation API to allow for receive buffer auto-tuning
+ // to happen before we measure the new window.
+ totalCopied := 0
+ for {
+ b, _, err := c.EP.Read(nil)
+ if err == tcpip.ErrWouldBlock {
+ break
+ }
+ totalCopied += len(b)
+ }
+
+ // Invoke the moderation API. This is required for auto-tuning
+ // to happen. This method is normally expected to be invoked
+ // from a higher layer than tcpip.Endpoint. So we simulate
+ // copying to user-space by invoking it explicitly here.
+ c.EP.ModerateRecvBuf(totalCopied)
+
+ // Now send a keep-alive packet to trigger an ACK so that we can
+ // measure the new window.
+ rawEP.NextSeqNum--
+ rawEP.SendPacketWithTS(nil, tsVal)
+ rawEP.NextSeqNum++
+
+ if i == 0 {
+ // In the first iteration the receiver based RTT is not
+ // yet known as a result the moderation code should not
+ // increase the advertised window.
+ rawEP.VerifyACKRcvWnd(scaleRcvWnd(wantRcvWnd))
+ prevCopied = totalCopied
+ } else {
+ rttCopied := totalCopied
+ if i == 1 {
+ // The moderation code accumulates copied bytes till
+ // RTT is established. So add in the bytes sent in
+ // the first iteration to the total bytes for this
+ // RTT.
+ rttCopied += prevCopied
+ // Now reset it to the initial value used by the
+ // auto tuning logic.
+ prevCopied = tcp.InitialCwnd * mss * 2
+ }
+ newWnd := rttCopied<<1 + 16*mss
+ grow := (newWnd * (rttCopied - prevCopied)) / prevCopied
+ newWnd += (grow << 1)
+ if newWnd > maxReceiveBufferSize {
+ newWnd = maxReceiveBufferSize
+ done = true
+ }
+ rawEP.VerifyACKRcvWnd(scaleRcvWnd(newWnd))
+ wantRcvWnd = newWnd
+ prevCopied = rttCopied
+ // Increase the latency after first two iterations to
+ // establish a low RTT value in the receiver since it
+ // only tracks the lowest value. This ensures that when
+ // ModerateRcvBuf is called the elapsed time is always >
+ // rtt. Without this the test is flaky due to delays due
+ // to scheduling/wakeup etc.
+ latency += 50 * time.Millisecond
+ }
+ time.Sleep(latency)
+ offset += payloadSize
+ payloadSize *= 2
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index 039bbcfba..a641e953d 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -20,13 +20,13 @@ import (
"testing"
"time"
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/checker"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp/testing/context"
- "gvisor.googlesource.com/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
+ "gvisor.dev/gvisor/pkg/waiter"
)
// createConnectedWithTimestampOption creates and connects c.ep with the
@@ -182,7 +182,7 @@ func TestTimeStampEnabledAccept(t *testing.T) {
wndSize uint16
}{
{true, -1, 0xffff}, // When cookie is used window scaling is disabled.
- {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
+ {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5.
}
for _, tc := range testCases {
timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize)
@@ -239,7 +239,7 @@ func TestTimeStampDisabledAccept(t *testing.T) {
wndSize uint16
}{
{true, -1, 0xffff}, // When cookie is used window scaling is disabled.
- {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
+ {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5.
}
for _, tc := range testCases {
timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize)
diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD
index 1584e4095..19b0d31c5 100644
--- a/pkg/tcpip/transport/tcp/testing/context/BUILD
+++ b/pkg/tcpip/transport/tcp/testing/context/BUILD
@@ -6,7 +6,7 @@ go_library(
name = "context",
testonly = 1,
srcs = ["context.go"],
- importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp/testing/context",
+ importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context",
visibility = [
"//:sandbox",
],
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 69a43b6f4..630dd7925 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -21,18 +21,18 @@ import (
"testing"
"time"
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/checker"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.googlesource.com/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/waiter"
)
const (
@@ -72,12 +72,6 @@ const (
testInitialSequenceNumber = 789
)
-// defaultWindowScale value specified here depends on the tcp.DefaultBufferSize
-// constant defined in the tcp/endpoint.go because the tcp.DefaultBufferSize is
-// used in tcp.newHandshake to determine the window scale to use when sending a
-// SYN/SYN-ACK.
-var defaultWindowScale = tcp.FindWndScale(tcp.DefaultBufferSize)
-
// Headers is used to represent the TCP header fields when building a
// new packet.
type Headers struct {
@@ -134,6 +128,10 @@ type Context struct {
// TimeStampEnabled is true if ep is connected with the timestamp option
// enabled.
TimeStampEnabled bool
+
+ // WindowScale is the expected window scale in SYN packets sent by
+ // the stack.
+ WindowScale uint8
}
// New allocates and initializes a test context containing a new
@@ -142,11 +140,11 @@ func New(t *testing.T, mtu uint32) *Context {
s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
// Allow minimum send/receive buffer sizes to be 1 during tests.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultBufferSize, tcp.DefaultBufferSize * 10}); err != nil {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize, 10 * tcp.DefaultSendBufferSize}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %v", err)
}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultBufferSize, tcp.DefaultBufferSize * 10}); err != nil {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize, 10 * tcp.DefaultReceiveBufferSize}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %v", err)
}
@@ -184,9 +182,10 @@ func New(t *testing.T, mtu uint32) *Context {
})
return &Context{
- t: t,
- s: s,
- linkEP: linkEP,
+ t: t,
+ s: s,
+ linkEP: linkEP,
+ WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)),
}
}
@@ -520,35 +519,21 @@ func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf
c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil)
}
-// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
-// the specified option bytes as the Option field in the initial SYN packet.
+// Connect performs the 3-way handshake for c.EP with the provided Initial
+// Sequence Number (iss) and receive window(rcvWnd) and any options if
+// specified.
//
// It also sets the receive buffer for the endpoint to the specified
// value in epRcvBuf.
-func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) {
- // Create TCP endpoint.
- var err *tcpip.Error
- c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want {
- c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- if epRcvBuf != nil {
- if err := c.EP.SetSockOpt(*epRcvBuf); err != nil {
- c.t.Fatalf("SetSockOpt failed failed: %v", err)
- }
- }
-
+//
+// PreCondition: c.EP must already be created.
+func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) {
// Start connection attempt.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&waitEntry, waiter.EventOut)
defer c.WQ.EventUnregister(&waitEntry)
- err = c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort})
- if err != tcpip.ErrConnectStarted {
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}); err != tcpip.ErrConnectStarted {
c.t.Fatalf("Unexpected return value from Connect: %v", err)
}
@@ -590,8 +575,7 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.
// Wait for connection to be established.
select {
case <-notifyCh:
- err = c.EP.GetSockOpt(tcpip.ErrorOption{})
- if err != nil {
+ if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil {
c.t.Fatalf("Unexpected error when connecting: %v", err)
}
case <-time.After(1 * time.Second):
@@ -604,6 +588,27 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.
c.Port = tcpHdr.SourcePort()
}
+// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
+// the specified option bytes as the Option field in the initial SYN packet.
+//
+// It also sets the receive buffer for the endpoint to the specified
+// value in epRcvBuf.
+func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) {
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ if epRcvBuf != nil {
+ if err := c.EP.SetSockOpt(*epRcvBuf); err != nil {
+ c.t.Fatalf("SetSockOpt failed failed: %v", err)
+ }
+ }
+ c.Connect(iss, rcvWnd, options)
+}
+
// RawEndpoint is just a small wrapper around a TCP endpoint's state to make
// sending data and ACK packets easy while being able to manipulate the sequence
// numbers and timestamp values as needed.
@@ -666,6 +671,21 @@ func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) {
r.RecentTS = opts.TSVal
}
+// VerifyACKRcvWnd verifies that the window advertised by the incoming ACK
+// matches the provided rcvWnd.
+func (r *RawEndpoint) VerifyACKRcvWnd(rcvWnd uint16) {
+ ackPacket := r.C.GetPacket()
+ checker.IPv4(r.C.t, ackPacket,
+ checker.TCP(
+ checker.DstPort(r.SrcPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(r.AckNum)),
+ checker.AckNum(uint32(r.NextSeqNum)),
+ checker.Window(rcvWnd),
+ ),
+ )
+}
+
// VerifyACKNoSACK verifies that the ACK does not contain a SACK block.
func (r *RawEndpoint) VerifyACKNoSACK() {
r.VerifyACKHasSACK(nil)
@@ -726,7 +746,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
checker.TCPSynOptions(header.TCPSynOptions{
MSS: mss,
TS: true,
- WS: defaultWindowScale,
+ WS: int(c.WindowScale),
SACKPermitted: c.SACKEnabled(),
}),
),
@@ -741,6 +761,9 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
// Build options w/ tsVal to be sent in the SYN-ACK.
synAckOptions := make([]byte, header.TCPOptionsMaximumSize)
offset := 0
+ if wantOptions.WS != -1 {
+ offset += header.EncodeWSOption(wantOptions.WS, synAckOptions[offset:])
+ }
if wantOptions.TS {
offset += header.EncodeTSOption(wantOptions.TSVal, synOptions.TSVal, synAckOptions[offset:])
}
diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go
index fc1c7cbd2..c70525f27 100644
--- a/pkg/tcpip/transport/tcp/timer.go
+++ b/pkg/tcpip/transport/tcp/timer.go
@@ -17,7 +17,7 @@ package tcp
import (
"time"
- "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sleep"
)
type timerState int
diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD
index 31a845dee..4bec48c0f 100644
--- a/pkg/tcpip/transport/tcpconntrack/BUILD
+++ b/pkg/tcpip/transport/tcpconntrack/BUILD
@@ -5,7 +5,7 @@ package(licenses = ["notice"])
go_library(
name = "tcpconntrack",
srcs = ["tcp_conntrack.go"],
- importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcpconntrack",
+ importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
index f1dcd36d5..93712cd45 100644
--- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
+++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
@@ -18,8 +18,8 @@
package tcpconntrack
import (
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
)
// Result is returned when the state of a TCB is updated in response to an
diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go
index 435e136de..5e271b7ca 100644
--- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go
+++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go
@@ -17,8 +17,8 @@ package tcpconntrack_test
import (
"testing"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcpconntrack"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack"
)
// connected creates a connection tracker TCB and sets it to a connected state
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index b9520d6e0..6dac66b50 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -24,8 +24,8 @@ go_library(
"protocol.go",
"udp_packet_list.go",
],
- importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp",
- imports = ["gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"],
+ importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/udp",
+ imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
"//pkg/sleep",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index fa7278286..cb0ea83a6 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -18,11 +18,11 @@ import (
"math"
"sync"
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
- "gvisor.googlesource.com/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
)
// +stateify savable
@@ -169,6 +169,9 @@ func (e *endpoint) Close() {
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
+// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
+func (e *endpoint) ModerateRecvBuf(copied int) {}
+
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
@@ -695,8 +698,44 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t
return netProto, nil
}
+func (e *endpoint) disconnect() *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if e.state != stateConnected {
+ return nil
+ }
+ id := stack.TransportEndpointID{}
+ // Exclude ephemerally bound endpoints.
+ if e.bindNICID != 0 || e.id.LocalAddress == "" {
+ var err *tcpip.Error
+ id = stack.TransportEndpointID{
+ LocalPort: e.id.LocalPort,
+ LocalAddress: e.id.LocalAddress,
+ }
+ id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id)
+ if err != nil {
+ return err
+ }
+ e.state = stateBound
+ } else {
+ e.state = stateInitial
+ }
+
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.id = id
+ e.route.Release()
+ e.route = stack.Route{}
+ e.dstPort = 0
+
+ return nil
+}
+
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ if addr.Addr == "" {
+ return e.disconnect()
+ }
if addr.Port == 0 {
// We don't support connecting to port zero.
return tcpip.ErrInvalidEndpointState
@@ -731,12 +770,16 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
defer r.Release()
id := stack.TransportEndpointID{
- LocalAddress: r.LocalAddress,
+ LocalAddress: e.id.LocalAddress,
LocalPort: localPort,
RemotePort: addr.Port,
RemoteAddress: r.RemoteAddress,
}
+ if e.state == stateInitial {
+ id.LocalAddress = r.LocalAddress
+ }
+
// Even if we're connected, this endpoint can still be used to send
// packets on a different network protocol, so we register both even if
// v6only is set to false and this is an ipv6 endpoint.
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 74e8e9fd5..18e786397 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -15,10 +15,10 @@
package udp
import (
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
// saveData saves udpPacket.data field.
@@ -92,8 +92,6 @@ func (e *endpoint) afterLoad() {
if err != nil {
panic(*err)
}
-
- e.id.LocalAddress = e.route.LocalAddress
} else if len(e.id.LocalAddress) != 0 { // stateBound
if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 {
panic(tcpip.ErrBadLocalAddress)
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index 25bdd2929..a874fc9fd 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -15,10 +15,10 @@
package udp
import (
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
- "gvisor.googlesource.com/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
)
// Forwarder is a session request forwarder, which allows clients to decide
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 3d31dfbf1..f76e7fbe1 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -21,12 +21,12 @@
package udp
import (
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/raw"
- "gvisor.googlesource.com/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/raw"
+ "gvisor.dev/gvisor/pkg/waiter"
)
const (
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 86a8fa19b..75129a2ff 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -21,17 +21,17 @@ import (
"testing"
"time"
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/checker"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp"
- "gvisor.googlesource.com/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
)
const (