summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorBhasker Hariharan <bhaskerh@google.com>2019-06-12 13:34:47 -0700
committerShentubot <shentubot@google.com>2019-06-12 13:35:50 -0700
commit70578806e8d3e01fae2249b3e602cd5b05d378a0 (patch)
tree6909e9f6103e23c43d753d10d9ee424670e8c8cd
parentbb849bad296f372670c2d2cf97424f74cf750ce2 (diff)
Add support for TCP_CONGESTION socket option.
This CL also cleans up the error returned for setting congestion control which was incorrectly returning EINVAL instead of ENOENT. PiperOrigin-RevId: 252889093
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go30
-rw-r--r--pkg/tcpip/tcpip.go8
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/cubic.go3
-rw-r--r--pkg/tcpip/transport/tcp/cubic_state.go29
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go45
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go22
-rw-r--r--pkg/tcpip/transport/tcp/snd.go10
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go112
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go52
-rw-r--r--test/syscalls/linux/socket_ip_tcp_generic.cc104
-rw-r--r--test/syscalls/linux/tcp_socket.cc127
12 files changed, 481 insertions, 62 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index f67451179..a50798cb3 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -920,6 +920,30 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
t.Kernel().EmitUnimplementedEvent(t)
+ case linux.TCP_CONGESTION:
+ if outLen <= 0 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.CongestionControlOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ // We match linux behaviour here where it returns the lower of
+ // TCP_CA_NAME_MAX bytes or the value of the option length.
+ //
+ // This is Linux's net/tcp.h TCP_CA_NAME_MAX.
+ const tcpCANameMax = 16
+
+ toCopy := tcpCANameMax
+ if outLen < tcpCANameMax {
+ toCopy = outLen
+ }
+ b := make([]byte, toCopy)
+ copy(b, v)
+ return b, nil
+
default:
emitUnimplementedEventTCP(t, name)
}
@@ -1222,6 +1246,12 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v))))
+ case linux.TCP_CONGESTION:
+ v := tcpip.CongestionControlOption(optVal)
+ if err := ep.SetSockOpt(v); err != nil {
+ return syserr.TranslateNetstackError(err)
+ }
+ return nil
case linux.TCP_REPAIR_OPTIONS:
t.Kernel().EmitUnimplementedEvent(t)
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 85ef014d0..04c776205 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -472,6 +472,14 @@ type KeepaliveIntervalOption time.Duration
// closed.
type KeepaliveCountOption int
+// CongestionControlOption is used by SetSockOpt/GetSockOpt to set/get
+// the current congestion control algorithm.
+type CongestionControlOption string
+
+// AvailableCongestionControlOption is used to query the supported congestion
+// control algorithms.
+type AvailableCongestionControlOption string
+
// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default
// TTL value for multicast messages. The default is 1.
type MulticastTTLOption uint8
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 9db38196b..a9dbfb930 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -21,6 +21,7 @@ go_library(
"accept.go",
"connect.go",
"cubic.go",
+ "cubic_state.go",
"endpoint.go",
"endpoint_state.go",
"forwarder.go",
diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go
index e618cd2b9..7b1f5e763 100644
--- a/pkg/tcpip/transport/tcp/cubic.go
+++ b/pkg/tcpip/transport/tcp/cubic.go
@@ -23,6 +23,7 @@ import (
// control algorithm state.
//
// See: https://tools.ietf.org/html/rfc8312.
+// +stateify savable
type cubicState struct {
// wLastMax is the previous wMax value.
wLastMax float64
@@ -33,7 +34,7 @@ type cubicState struct {
// t denotes the time when the current congestion avoidance
// was entered.
- t time.Time
+ t time.Time `state:".(unixTime)"`
// numCongestionEvents tracks the number of congestion events since last
// RTO.
diff --git a/pkg/tcpip/transport/tcp/cubic_state.go b/pkg/tcpip/transport/tcp/cubic_state.go
new file mode 100644
index 000000000..d0f58cfaf
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/cubic_state.go
@@ -0,0 +1,29 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+)
+
+// saveT is invoked by stateify.
+func (c *cubicState) saveT() unixTime {
+ return unixTime{c.t.Unix(), c.t.UnixNano()}
+}
+
+// loadT is invoked by stateify.
+func (c *cubicState) loadT(unix unixTime) {
+ c.t = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 23422ca5e..1efe9d3fb 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -17,6 +17,7 @@ package tcp
import (
"fmt"
"math"
+ "strings"
"sync"
"sync/atomic"
"time"
@@ -286,7 +287,7 @@ type endpoint struct {
// cc stores the name of the Congestion Control algorithm to use for
// this endpoint.
- cc CongestionControlOption
+ cc tcpip.CongestionControlOption
// The following are used when a "packet too big" control packet is
// received. They are protected by sndBufMu. They are used to
@@ -394,7 +395,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite
e.rcvBufSize = rs.Default
}
- var cs CongestionControlOption
+ var cs tcpip.CongestionControlOption
if err := stack.TransportProtocolOption(ProtocolNumber, &cs); err == nil {
e.cc = cs
}
@@ -898,6 +899,40 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
return nil
+ case tcpip.CongestionControlOption:
+ // Query the available cc algorithms in the stack and
+ // validate that the specified algorithm is actually
+ // supported in the stack.
+ var avail tcpip.AvailableCongestionControlOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &avail); err != nil {
+ return err
+ }
+ availCC := strings.Split(string(avail), " ")
+ for _, cc := range availCC {
+ if v == tcpip.CongestionControlOption(cc) {
+ // Acquire the work mutex as we may need to
+ // reinitialize the congestion control state.
+ e.mu.Lock()
+ state := e.state
+ e.cc = v
+ e.mu.Unlock()
+ switch state {
+ case StateEstablished:
+ e.workMu.Lock()
+ e.mu.Lock()
+ if e.state == state {
+ e.snd.cc = e.snd.initCongestionControl(e.cc)
+ }
+ e.mu.Unlock()
+ e.workMu.Unlock()
+ }
+ return nil
+ }
+ }
+
+ // Linux returns ENOENT when an invalid congestion
+ // control algorithm is specified.
+ return tcpip.ErrNoSuchFile
default:
return nil
}
@@ -1067,6 +1102,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.CongestionControlOption:
+ e.mu.Lock()
+ *o = e.cc
+ e.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index b31bcccfa..59f4009a1 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -79,13 +79,6 @@ const (
ccCubic = "cubic"
)
-// CongestionControlOption sets the current congestion control algorithm.
-type CongestionControlOption string
-
-// AvailableCongestionControlOption returns the supported congestion control
-// algorithms.
-type AvailableCongestionControlOption string
-
type protocol struct {
mu sync.Mutex
sackEnabled bool
@@ -93,7 +86,6 @@ type protocol struct {
recvBufferSize ReceiveBufferSizeOption
congestionControl string
availableCongestionControl []string
- allowedCongestionControl []string
}
// Number returns the tcp protocol number.
@@ -188,7 +180,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
- case CongestionControlOption:
+ case tcpip.CongestionControlOption:
for _, c := range p.availableCongestionControl {
if string(v) == c {
p.mu.Lock()
@@ -197,7 +189,9 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
return nil
}
}
- return tcpip.ErrInvalidOptionValue
+ // linux returns ENOENT when an invalid congestion control
+ // is specified.
+ return tcpip.ErrNoSuchFile
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -223,14 +217,14 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
*v = p.recvBufferSize
p.mu.Unlock()
return nil
- case *CongestionControlOption:
+ case *tcpip.CongestionControlOption:
p.mu.Lock()
- *v = CongestionControlOption(p.congestionControl)
+ *v = tcpip.CongestionControlOption(p.congestionControl)
p.mu.Unlock()
return nil
- case *AvailableCongestionControlOption:
+ case *tcpip.AvailableCongestionControlOption:
p.mu.Lock()
- *v = AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
+ *v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
p.mu.Unlock()
return nil
default:
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index b236d7af2..daa3e8341 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -194,8 +194,6 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
s := &sender{
ep: ep,
- sndCwnd: InitialCwnd,
- sndSsthresh: math.MaxInt64,
sndWnd: sndWnd,
sndUna: iss + 1,
sndNxt: iss + 1,
@@ -238,7 +236,13 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
return s
}
-func (s *sender) initCongestionControl(congestionControlName CongestionControlOption) congestionControl {
+// initCongestionControl initializes the specified congestion control module and
+// returns a handle to it. It also initializes the sndCwnd and sndSsThresh to
+// their initial values.
+func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionControlOption) congestionControl {
+ s.sndCwnd = InitialCwnd
+ s.sndSsthresh = math.MaxInt64
+
switch congestionControlName {
case ccCubic:
return newCubicCC(s)
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 779ca8b76..7d8987219 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -3205,13 +3205,14 @@ func TestTCPEndpointProbe(t *testing.T) {
}
}
-func TestSetCongestionControl(t *testing.T) {
+func TestStackSetCongestionControl(t *testing.T) {
testCases := []struct {
- cc tcp.CongestionControlOption
- mustPass bool
+ cc tcpip.CongestionControlOption
+ err *tcpip.Error
}{
- {"reno", true},
- {"cubic", true},
+ {"reno", nil},
+ {"cubic", nil},
+ {"blahblah", tcpip.ErrNoSuchFile},
}
for _, tc := range testCases {
@@ -3221,62 +3222,135 @@ func TestSetCongestionControl(t *testing.T) {
s := c.Stack()
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != nil && tc.mustPass {
- t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want not-nil", tcp.ProtocolNumber, tc.cc, err)
+ var oldCC tcpip.CongestionControlOption
+ if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &oldCC, err)
}
- var cc tcp.CongestionControlOption
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err {
+ t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.cc, err, tc.err)
+ }
+
+ var cc tcpip.CongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err)
}
- if got, want := cc, tc.cc; got != want {
+
+ got, want := cc, oldCC
+ // If SetTransportProtocolOption is expected to succeed
+ // then the returned value for congestion control should
+ // match the one specified in the
+ // SetTransportProtocolOption call above, else it should
+ // be what it was before the call to
+ // SetTransportProtocolOption.
+ if tc.err == nil {
+ want = tc.cc
+ }
+ if got != want {
t.Fatalf("got congestion control: %v, want: %v", got, want)
}
})
}
}
-func TestAvailableCongestionControl(t *testing.T) {
+func TestStackAvailableCongestionControl(t *testing.T) {
c := context.New(t, 1500)
defer c.Cleanup()
s := c.Stack()
// Query permitted congestion control algorithms.
- var aCC tcp.AvailableCongestionControlOption
+ var aCC tcpip.AvailableCongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil {
t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err)
}
- if got, want := aCC, tcp.AvailableCongestionControlOption("reno cubic"); got != want {
- t.Fatalf("got tcp.AvailableCongestionControlOption: %v, want: %v", got, want)
+ if got, want := aCC, tcpip.AvailableCongestionControlOption("reno cubic"); got != want {
+ t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want)
}
}
-func TestSetAvailableCongestionControl(t *testing.T) {
+func TestStackSetAvailableCongestionControl(t *testing.T) {
c := context.New(t, 1500)
defer c.Cleanup()
s := c.Stack()
// Setting AvailableCongestionControlOption should fail.
- aCC := tcp.AvailableCongestionControlOption("xyz")
+ aCC := tcpip.AvailableCongestionControlOption("xyz")
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil {
t.Fatalf("s.TransportProtocolOption(%v, %v) = nil, want non-nil", tcp.ProtocolNumber, &aCC)
}
// Verify that we still get the expected list of congestion control options.
- var cc tcp.AvailableCongestionControlOption
+ var cc tcpip.AvailableCongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err)
}
- if got, want := cc, tcp.AvailableCongestionControlOption("reno cubic"); got != want {
- t.Fatalf("got tcp.AvailableCongestionControlOption: %v, want: %v", got, want)
+ if got, want := cc, tcpip.AvailableCongestionControlOption("reno cubic"); got != want {
+ t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want)
+ }
+}
+
+func TestEndpointSetCongestionControl(t *testing.T) {
+ testCases := []struct {
+ cc tcpip.CongestionControlOption
+ err *tcpip.Error
+ }{
+ {"reno", nil},
+ {"cubic", nil},
+ {"blahblah", tcpip.ErrNoSuchFile},
+ }
+
+ for _, connected := range []bool{false, true} {
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("SetSockOpt(.., %v) w/ connected = %v", tc.cc, connected), func(t *testing.T) {
+ c := context.New(t, 1500)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ var oldCC tcpip.CongestionControlOption
+ if err := c.EP.GetSockOpt(&oldCC); err != nil {
+ t.Fatalf("c.EP.SockOpt(%v) = %v", &oldCC, err)
+ }
+
+ if connected {
+ c.Connect(789 /* iss */, 32768 /* rcvWnd */, nil)
+ }
+
+ if err := c.EP.SetSockOpt(tc.cc); err != tc.err {
+ t.Fatalf("c.EP.SetSockOpt(%v) = %v, want %v", tc.cc, err, tc.err)
+ }
+
+ var cc tcpip.CongestionControlOption
+ if err := c.EP.GetSockOpt(&cc); err != nil {
+ t.Fatalf("c.EP.SockOpt(%v) = %v", &cc, err)
+ }
+
+ got, want := cc, oldCC
+ // If SetSockOpt is expected to succeed then the
+ // returned value for congestion control should match
+ // the one specified in the SetSockOpt above, else it
+ // should be what it was before the call to SetSockOpt.
+ if tc.err == nil {
+ want = tc.cc
+ }
+ if got != want {
+ t.Fatalf("got congestion control: %v, want: %v", got, want)
+ }
+ })
+ }
}
}
func enableCUBIC(t *testing.T, c *context.Context) {
t.Helper()
- opt := tcp.CongestionControlOption("cubic")
+ opt := tcpip.CongestionControlOption("cubic")
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil {
t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %v = %v", opt, err)
}
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 69a43b6f4..a4d89e24d 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -520,35 +520,21 @@ func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf
c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil)
}
-// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
-// the specified option bytes as the Option field in the initial SYN packet.
+// Connect performs the 3-way handshake for c.EP with the provided Initial
+// Sequence Number (iss) and receive window(rcvWnd) and any options if
+// specified.
//
// It also sets the receive buffer for the endpoint to the specified
// value in epRcvBuf.
-func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) {
- // Create TCP endpoint.
- var err *tcpip.Error
- c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want {
- c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- if epRcvBuf != nil {
- if err := c.EP.SetSockOpt(*epRcvBuf); err != nil {
- c.t.Fatalf("SetSockOpt failed failed: %v", err)
- }
- }
-
+//
+// PreCondition: c.EP must already be created.
+func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) {
// Start connection attempt.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&waitEntry, waiter.EventOut)
defer c.WQ.EventUnregister(&waitEntry)
- err = c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort})
- if err != tcpip.ErrConnectStarted {
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}); err != tcpip.ErrConnectStarted {
c.t.Fatalf("Unexpected return value from Connect: %v", err)
}
@@ -590,8 +576,7 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.
// Wait for connection to be established.
select {
case <-notifyCh:
- err = c.EP.GetSockOpt(tcpip.ErrorOption{})
- if err != nil {
+ if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil {
c.t.Fatalf("Unexpected error when connecting: %v", err)
}
case <-time.After(1 * time.Second):
@@ -604,6 +589,27 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.
c.Port = tcpHdr.SourcePort()
}
+// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
+// the specified option bytes as the Option field in the initial SYN packet.
+//
+// It also sets the receive buffer for the endpoint to the specified
+// value in epRcvBuf.
+func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) {
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ if epRcvBuf != nil {
+ if err := c.EP.SetSockOpt(*epRcvBuf); err != nil {
+ c.t.Fatalf("SetSockOpt failed failed: %v", err)
+ }
+ }
+ c.Connect(iss, rcvWnd, options)
+}
+
// RawEndpoint is just a small wrapper around a TCP endpoint's state to make
// sending data and ACK packets easy while being able to manipulate the sequence
// numbers and timestamp values as needed.
diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc
index 5b198f49d..0b76280a7 100644
--- a/test/syscalls/linux/socket_ip_tcp_generic.cc
+++ b/test/syscalls/linux/socket_ip_tcp_generic.cc
@@ -592,5 +592,109 @@ TEST_P(TCPSocketPairTest, MsgTruncMsgPeek) {
EXPECT_EQ(0, memcmp(received_data2, sent_data, sizeof(sent_data)));
}
+TEST_P(TCPSocketPairTest, SetCongestionControlSucceedsForSupported) {
+ // This is Linux's net/tcp.h TCP_CA_NAME_MAX.
+ const int kTcpCaNameMax = 16;
+
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ // Netstack only supports reno & cubic so we only test these two values here.
+ {
+ const char kSetCC[kTcpCaNameMax] = "reno";
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION,
+ &kSetCC, strlen(kSetCC)),
+ SyscallSucceedsWithValue(0));
+
+ char got_cc[kTcpCaNameMax];
+ memset(got_cc, '1', sizeof(got_cc));
+ socklen_t optlen = sizeof(got_cc);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION,
+ &got_cc, &optlen),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kSetCC)));
+ }
+ {
+ const char kSetCC[kTcpCaNameMax] = "cubic";
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION,
+ &kSetCC, strlen(kSetCC)),
+ SyscallSucceedsWithValue(0));
+
+ char got_cc[kTcpCaNameMax];
+ memset(got_cc, '1', sizeof(got_cc));
+ socklen_t optlen = sizeof(got_cc);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION,
+ &got_cc, &optlen),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kSetCC)));
+ }
+}
+
+TEST_P(TCPSocketPairTest, SetGetTCPCongestionShortReadBuffer) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ {
+ // Verify that getsockopt/setsockopt work with buffers smaller than
+ // kTcpCaNameMax.
+ const char kSetCC[] = "cubic";
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION,
+ &kSetCC, strlen(kSetCC)),
+ SyscallSucceedsWithValue(0));
+
+ char got_cc[sizeof(kSetCC)];
+ socklen_t optlen = sizeof(got_cc);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION,
+ &got_cc, &optlen),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(got_cc)));
+ }
+}
+
+TEST_P(TCPSocketPairTest, SetGetTCPCongestionLargeReadBuffer) {
+ // This is Linux's net/tcp.h TCP_CA_NAME_MAX.
+ const int kTcpCaNameMax = 16;
+
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ {
+ // Verify that getsockopt works with buffers larger than
+ // kTcpCaNameMax.
+ const char kSetCC[] = "cubic";
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION,
+ &kSetCC, strlen(kSetCC)),
+ SyscallSucceedsWithValue(0));
+
+ char got_cc[kTcpCaNameMax + 5];
+ socklen_t optlen = sizeof(got_cc);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION,
+ &got_cc, &optlen),
+ SyscallSucceedsWithValue(0));
+ // Linux copies the minimum of kTcpCaNameMax or the length of the passed in
+ // buffer and sets optlen to the number of bytes actually copied
+ // irrespective of the actual length of the congestion control name.
+ EXPECT_EQ(kTcpCaNameMax, optlen);
+ EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kSetCC)));
+ }
+}
+
+TEST_P(TCPSocketPairTest, SetCongestionControlFailsForUnsupported) {
+ // This is Linux's net/tcp.h TCP_CA_NAME_MAX.
+ const int kTcpCaNameMax = 16;
+
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ char old_cc[kTcpCaNameMax];
+ socklen_t optlen;
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION,
+ &old_cc, &optlen),
+ SyscallSucceedsWithValue(0));
+
+ const char kSetCC[] = "invalid_ca_cc";
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION,
+ &kSetCC, strlen(kSetCC)),
+ SyscallFailsWithErrno(ENOENT));
+
+ char got_cc[kTcpCaNameMax];
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION,
+ &got_cc, &optlen),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(0, memcmp(got_cc, old_cc, sizeof(old_cc)));
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
index e3f9f9f9d..e95b644ac 100644
--- a/test/syscalls/linux/tcp_socket.cc
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -751,6 +751,133 @@ TEST_P(SimpleTcpSocketTest, NonBlockingConnectRefused) {
EXPECT_THAT(close(s.release()), SyscallSucceeds());
}
+// Test that setting a supported congestion control algorithm succeeds for an
+// unconnected TCP socket
+TEST_P(SimpleTcpSocketTest, SetCongestionControlSucceedsForSupported) {
+ // This is Linux's net/tcp.h TCP_CA_NAME_MAX.
+ const int kTcpCaNameMax = 16;
+
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ {
+ const char kSetCC[kTcpCaNameMax] = "reno";
+ ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &kSetCC,
+ strlen(kSetCC)),
+ SyscallSucceedsWithValue(0));
+
+ char got_cc[kTcpCaNameMax];
+ memset(got_cc, '1', sizeof(got_cc));
+ socklen_t optlen = sizeof(got_cc);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen),
+ SyscallSucceedsWithValue(0));
+ // We ignore optlen here as the linux kernel sets optlen to the lower of the
+ // size of the buffer passed in or kTcpCaNameMax and not the length of the
+ // congestion control algorithm's actual name.
+ EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kTcpCaNameMax)));
+ }
+ {
+ const char kSetCC[kTcpCaNameMax] = "cubic";
+ ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &kSetCC,
+ strlen(kSetCC)),
+ SyscallSucceedsWithValue(0));
+
+ char got_cc[kTcpCaNameMax];
+ memset(got_cc, '1', sizeof(got_cc));
+ socklen_t optlen = sizeof(got_cc);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen),
+ SyscallSucceedsWithValue(0));
+ // We ignore optlen here as the linux kernel sets optlen to the lower of the
+ // size of the buffer passed in or kTcpCaNameMax and not the length of the
+ // congestion control algorithm's actual name.
+ EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kTcpCaNameMax)));
+ }
+}
+
+// This test verifies that a getsockopt(...TCP_CONGESTION) behaviour is
+// consistent between linux and gvisor when the passed in buffer is smaller than
+// kTcpCaNameMax.
+TEST_P(SimpleTcpSocketTest, SetGetTCPCongestionShortReadBuffer) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ {
+ // Verify that getsockopt/setsockopt work with buffers smaller than
+ // kTcpCaNameMax.
+ const char kSetCC[] = "cubic";
+ ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &kSetCC,
+ strlen(kSetCC)),
+ SyscallSucceedsWithValue(0));
+
+ char got_cc[sizeof(kSetCC)];
+ socklen_t optlen = sizeof(got_cc);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(sizeof(got_cc), optlen);
+ EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(got_cc)));
+ }
+}
+
+// This test verifies that a getsockopt(...TCP_CONGESTION) behaviour is
+// consistent between linux and gvisor when the passed in buffer is larger than
+// kTcpCaNameMax.
+TEST_P(SimpleTcpSocketTest, SetGetTCPCongestionLargeReadBuffer) {
+ // This is Linux's net/tcp.h TCP_CA_NAME_MAX.
+ const int kTcpCaNameMax = 16;
+
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ {
+ // Verify that getsockopt works with buffers larger than
+ // kTcpCaNameMax.
+ const char kSetCC[] = "cubic";
+ ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &kSetCC,
+ strlen(kSetCC)),
+ SyscallSucceedsWithValue(0));
+
+ char got_cc[kTcpCaNameMax + 5];
+ socklen_t optlen = sizeof(got_cc);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen),
+ SyscallSucceedsWithValue(0));
+ // Linux copies the minimum of kTcpCaNameMax or the length of the passed in
+ // buffer and sets optlen to the number of bytes actually copied
+ // irrespective of the actual length of the congestion control name.
+ EXPECT_EQ(kTcpCaNameMax, optlen);
+ EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kSetCC)));
+ }
+}
+
+// Test that setting an unsupported congestion control algorithm fails for an
+// unconnected TCP socket.
+TEST_P(SimpleTcpSocketTest, SetCongestionControlFailsForUnsupported) {
+ // This is Linux's net/tcp.h TCP_CA_NAME_MAX.
+ const int kTcpCaNameMax = 16;
+
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ char old_cc[kTcpCaNameMax];
+ socklen_t optlen = sizeof(old_cc);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &old_cc, &optlen),
+ SyscallSucceedsWithValue(0));
+
+ const char kSetCC[] = "invalid_ca_kSetCC";
+ ASSERT_THAT(
+ setsockopt(s.get(), SOL_TCP, TCP_CONGESTION, &kSetCC, strlen(kSetCC)),
+ SyscallFailsWithErrno(ENOENT));
+
+ char got_cc[kTcpCaNameMax];
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen),
+ SyscallSucceedsWithValue(0));
+ // We ignore optlen here as the linux kernel sets optlen to the lower of the
+ // size of the buffer passed in or kTcpCaNameMax and not the length of the
+ // congestion control algorithm's actual name.
+ EXPECT_EQ(0, memcmp(got_cc, old_cc, sizeof(kTcpCaNameMax)));
+}
+
INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest,
::testing::Values(AF_INET, AF_INET6));