From 033f96cc9313d7ceb3df14227ef3724ec3295d2a Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Fri, 31 May 2019 16:16:24 -0700 Subject: Change segment queue limit to be of fixed size. Netstack sets the unprocessed segment queue size to match the receive buffer size. This is not required as this queue only needs to hold enough for a short duration before the endpoint goroutine can process it. Updates #230 PiperOrigin-RevId: 250976323 --- pkg/tcpip/transport/tcp/protocol.go | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'pkg/tcpip/transport/tcp/protocol.go') diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index d31a1edcb..b31bcccfa 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -48,6 +48,10 @@ const ( // MaxBufferSize is the largest size a receive and send buffer can grow to. maxBufferSize = 4 << 20 // 4MB + + // MaxUnprocessedSegments is the maximum number of unprocessed segments + // that can be queued for a given endpoint. + MaxUnprocessedSegments = 300 ) // SACKEnabled option can be used to enable SACK support in the TCP -- cgit v1.2.3 From 70578806e8d3e01fae2249b3e602cd5b05d378a0 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Wed, 12 Jun 2019 13:34:47 -0700 Subject: Add support for TCP_CONGESTION socket option. This CL also cleans up the error returned for setting congestion control which was incorrectly returning EINVAL instead of ENOENT. PiperOrigin-RevId: 252889093 --- pkg/sentry/socket/epsocket/epsocket.go | 30 +++++ pkg/tcpip/tcpip.go | 8 ++ pkg/tcpip/transport/tcp/BUILD | 1 + pkg/tcpip/transport/tcp/cubic.go | 3 +- pkg/tcpip/transport/tcp/cubic_state.go | 29 +++++ pkg/tcpip/transport/tcp/endpoint.go | 45 +++++++- pkg/tcpip/transport/tcp/protocol.go | 22 ++-- pkg/tcpip/transport/tcp/snd.go | 10 +- pkg/tcpip/transport/tcp/tcp_test.go | 112 +++++++++++++++--- pkg/tcpip/transport/tcp/testing/context/context.go | 52 +++++---- test/syscalls/linux/socket_ip_tcp_generic.cc | 104 +++++++++++++++++ test/syscalls/linux/tcp_socket.cc | 127 +++++++++++++++++++++ 12 files changed, 481 insertions(+), 62 deletions(-) create mode 100644 pkg/tcpip/transport/tcp/cubic_state.go (limited to 'pkg/tcpip/transport/tcp/protocol.go') diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index f67451179..a50798cb3 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -920,6 +920,30 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa t.Kernel().EmitUnimplementedEvent(t) + case linux.TCP_CONGESTION: + if outLen <= 0 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.CongestionControlOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + // We match linux behaviour here where it returns the lower of + // TCP_CA_NAME_MAX bytes or the value of the option length. + // + // This is Linux's net/tcp.h TCP_CA_NAME_MAX. + const tcpCANameMax = 16 + + toCopy := tcpCANameMax + if outLen < tcpCANameMax { + toCopy = outLen + } + b := make([]byte, toCopy) + copy(b, v) + return b, nil + default: emitUnimplementedEventTCP(t, name) } @@ -1222,6 +1246,12 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v)))) + case linux.TCP_CONGESTION: + v := tcpip.CongestionControlOption(optVal) + if err := ep.SetSockOpt(v); err != nil { + return syserr.TranslateNetstackError(err) + } + return nil case linux.TCP_REPAIR_OPTIONS: t.Kernel().EmitUnimplementedEvent(t) diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 85ef014d0..04c776205 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -472,6 +472,14 @@ type KeepaliveIntervalOption time.Duration // closed. type KeepaliveCountOption int +// CongestionControlOption is used by SetSockOpt/GetSockOpt to set/get +// the current congestion control algorithm. +type CongestionControlOption string + +// AvailableCongestionControlOption is used to query the supported congestion +// control algorithms. +type AvailableCongestionControlOption string + // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default // TTL value for multicast messages. The default is 1. type MulticastTTLOption uint8 diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 9db38196b..a9dbfb930 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -21,6 +21,7 @@ go_library( "accept.go", "connect.go", "cubic.go", + "cubic_state.go", "endpoint.go", "endpoint_state.go", "forwarder.go", diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go index e618cd2b9..7b1f5e763 100644 --- a/pkg/tcpip/transport/tcp/cubic.go +++ b/pkg/tcpip/transport/tcp/cubic.go @@ -23,6 +23,7 @@ import ( // control algorithm state. // // See: https://tools.ietf.org/html/rfc8312. +// +stateify savable type cubicState struct { // wLastMax is the previous wMax value. wLastMax float64 @@ -33,7 +34,7 @@ type cubicState struct { // t denotes the time when the current congestion avoidance // was entered. - t time.Time + t time.Time `state:".(unixTime)"` // numCongestionEvents tracks the number of congestion events since last // RTO. diff --git a/pkg/tcpip/transport/tcp/cubic_state.go b/pkg/tcpip/transport/tcp/cubic_state.go new file mode 100644 index 000000000..d0f58cfaf --- /dev/null +++ b/pkg/tcpip/transport/tcp/cubic_state.go @@ -0,0 +1,29 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" +) + +// saveT is invoked by stateify. +func (c *cubicState) saveT() unixTime { + return unixTime{c.t.Unix(), c.t.UnixNano()} +} + +// loadT is invoked by stateify. +func (c *cubicState) loadT(unix unixTime) { + c.t = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 23422ca5e..1efe9d3fb 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -17,6 +17,7 @@ package tcp import ( "fmt" "math" + "strings" "sync" "sync/atomic" "time" @@ -286,7 +287,7 @@ type endpoint struct { // cc stores the name of the Congestion Control algorithm to use for // this endpoint. - cc CongestionControlOption + cc tcpip.CongestionControlOption // The following are used when a "packet too big" control packet is // received. They are protected by sndBufMu. They are used to @@ -394,7 +395,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite e.rcvBufSize = rs.Default } - var cs CongestionControlOption + var cs tcpip.CongestionControlOption if err := stack.TransportProtocolOption(ProtocolNumber, &cs); err == nil { e.cc = cs } @@ -898,6 +899,40 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case tcpip.CongestionControlOption: + // Query the available cc algorithms in the stack and + // validate that the specified algorithm is actually + // supported in the stack. + var avail tcpip.AvailableCongestionControlOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &avail); err != nil { + return err + } + availCC := strings.Split(string(avail), " ") + for _, cc := range availCC { + if v == tcpip.CongestionControlOption(cc) { + // Acquire the work mutex as we may need to + // reinitialize the congestion control state. + e.mu.Lock() + state := e.state + e.cc = v + e.mu.Unlock() + switch state { + case StateEstablished: + e.workMu.Lock() + e.mu.Lock() + if e.state == state { + e.snd.cc = e.snd.initCongestionControl(e.cc) + } + e.mu.Unlock() + e.workMu.Unlock() + } + return nil + } + } + + // Linux returns ENOENT when an invalid congestion + // control algorithm is specified. + return tcpip.ErrNoSuchFile default: return nil } @@ -1067,6 +1102,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.CongestionControlOption: + e.mu.Lock() + *o = e.cc + e.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index b31bcccfa..59f4009a1 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -79,13 +79,6 @@ const ( ccCubic = "cubic" ) -// CongestionControlOption sets the current congestion control algorithm. -type CongestionControlOption string - -// AvailableCongestionControlOption returns the supported congestion control -// algorithms. -type AvailableCongestionControlOption string - type protocol struct { mu sync.Mutex sackEnabled bool @@ -93,7 +86,6 @@ type protocol struct { recvBufferSize ReceiveBufferSizeOption congestionControl string availableCongestionControl []string - allowedCongestionControl []string } // Number returns the tcp protocol number. @@ -188,7 +180,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil - case CongestionControlOption: + case tcpip.CongestionControlOption: for _, c := range p.availableCongestionControl { if string(v) == c { p.mu.Lock() @@ -197,7 +189,9 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { return nil } } - return tcpip.ErrInvalidOptionValue + // linux returns ENOENT when an invalid congestion control + // is specified. + return tcpip.ErrNoSuchFile default: return tcpip.ErrUnknownProtocolOption } @@ -223,14 +217,14 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { *v = p.recvBufferSize p.mu.Unlock() return nil - case *CongestionControlOption: + case *tcpip.CongestionControlOption: p.mu.Lock() - *v = CongestionControlOption(p.congestionControl) + *v = tcpip.CongestionControlOption(p.congestionControl) p.mu.Unlock() return nil - case *AvailableCongestionControlOption: + case *tcpip.AvailableCongestionControlOption: p.mu.Lock() - *v = AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) + *v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) p.mu.Unlock() return nil default: diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index b236d7af2..daa3e8341 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -194,8 +194,6 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint s := &sender{ ep: ep, - sndCwnd: InitialCwnd, - sndSsthresh: math.MaxInt64, sndWnd: sndWnd, sndUna: iss + 1, sndNxt: iss + 1, @@ -238,7 +236,13 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint return s } -func (s *sender) initCongestionControl(congestionControlName CongestionControlOption) congestionControl { +// initCongestionControl initializes the specified congestion control module and +// returns a handle to it. It also initializes the sndCwnd and sndSsThresh to +// their initial values. +func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionControlOption) congestionControl { + s.sndCwnd = InitialCwnd + s.sndSsthresh = math.MaxInt64 + switch congestionControlName { case ccCubic: return newCubicCC(s) diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 779ca8b76..7d8987219 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3205,13 +3205,14 @@ func TestTCPEndpointProbe(t *testing.T) { } } -func TestSetCongestionControl(t *testing.T) { +func TestStackSetCongestionControl(t *testing.T) { testCases := []struct { - cc tcp.CongestionControlOption - mustPass bool + cc tcpip.CongestionControlOption + err *tcpip.Error }{ - {"reno", true}, - {"cubic", true}, + {"reno", nil}, + {"cubic", nil}, + {"blahblah", tcpip.ErrNoSuchFile}, } for _, tc := range testCases { @@ -3221,62 +3222,135 @@ func TestSetCongestionControl(t *testing.T) { s := c.Stack() - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != nil && tc.mustPass { - t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want not-nil", tcp.ProtocolNumber, tc.cc, err) + var oldCC tcpip.CongestionControlOption + if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil { + t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &oldCC, err) } - var cc tcp.CongestionControlOption + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err { + t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.cc, err, tc.err) + } + + var cc tcpip.CongestionControlOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) } - if got, want := cc, tc.cc; got != want { + + got, want := cc, oldCC + // If SetTransportProtocolOption is expected to succeed + // then the returned value for congestion control should + // match the one specified in the + // SetTransportProtocolOption call above, else it should + // be what it was before the call to + // SetTransportProtocolOption. + if tc.err == nil { + want = tc.cc + } + if got != want { t.Fatalf("got congestion control: %v, want: %v", got, want) } }) } } -func TestAvailableCongestionControl(t *testing.T) { +func TestStackAvailableCongestionControl(t *testing.T) { c := context.New(t, 1500) defer c.Cleanup() s := c.Stack() // Query permitted congestion control algorithms. - var aCC tcp.AvailableCongestionControlOption + var aCC tcpip.AvailableCongestionControlOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil { t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err) } - if got, want := aCC, tcp.AvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcp.AvailableCongestionControlOption: %v, want: %v", got, want) + if got, want := aCC, tcpip.AvailableCongestionControlOption("reno cubic"); got != want { + t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want) } } -func TestSetAvailableCongestionControl(t *testing.T) { +func TestStackSetAvailableCongestionControl(t *testing.T) { c := context.New(t, 1500) defer c.Cleanup() s := c.Stack() // Setting AvailableCongestionControlOption should fail. - aCC := tcp.AvailableCongestionControlOption("xyz") + aCC := tcpip.AvailableCongestionControlOption("xyz") if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil { t.Fatalf("s.TransportProtocolOption(%v, %v) = nil, want non-nil", tcp.ProtocolNumber, &aCC) } // Verify that we still get the expected list of congestion control options. - var cc tcp.AvailableCongestionControlOption + var cc tcpip.AvailableCongestionControlOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) } - if got, want := cc, tcp.AvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcp.AvailableCongestionControlOption: %v, want: %v", got, want) + if got, want := cc, tcpip.AvailableCongestionControlOption("reno cubic"); got != want { + t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want) + } +} + +func TestEndpointSetCongestionControl(t *testing.T) { + testCases := []struct { + cc tcpip.CongestionControlOption + err *tcpip.Error + }{ + {"reno", nil}, + {"cubic", nil}, + {"blahblah", tcpip.ErrNoSuchFile}, + } + + for _, connected := range []bool{false, true} { + for _, tc := range testCases { + t.Run(fmt.Sprintf("SetSockOpt(.., %v) w/ connected = %v", tc.cc, connected), func(t *testing.T) { + c := context.New(t, 1500) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + var oldCC tcpip.CongestionControlOption + if err := c.EP.GetSockOpt(&oldCC); err != nil { + t.Fatalf("c.EP.SockOpt(%v) = %v", &oldCC, err) + } + + if connected { + c.Connect(789 /* iss */, 32768 /* rcvWnd */, nil) + } + + if err := c.EP.SetSockOpt(tc.cc); err != tc.err { + t.Fatalf("c.EP.SetSockOpt(%v) = %v, want %v", tc.cc, err, tc.err) + } + + var cc tcpip.CongestionControlOption + if err := c.EP.GetSockOpt(&cc); err != nil { + t.Fatalf("c.EP.SockOpt(%v) = %v", &cc, err) + } + + got, want := cc, oldCC + // If SetSockOpt is expected to succeed then the + // returned value for congestion control should match + // the one specified in the SetSockOpt above, else it + // should be what it was before the call to SetSockOpt. + if tc.err == nil { + want = tc.cc + } + if got != want { + t.Fatalf("got congestion control: %v, want: %v", got, want) + } + }) + } } } func enableCUBIC(t *testing.T, c *context.Context) { t.Helper() - opt := tcp.CongestionControlOption("cubic") + opt := tcpip.CongestionControlOption("cubic") if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil { t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %v = %v", opt, err) } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 69a43b6f4..a4d89e24d 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -520,35 +520,21 @@ func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil) } -// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends -// the specified option bytes as the Option field in the initial SYN packet. +// Connect performs the 3-way handshake for c.EP with the provided Initial +// Sequence Number (iss) and receive window(rcvWnd) and any options if +// specified. // // It also sets the receive buffer for the endpoint to the specified // value in epRcvBuf. -func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) { - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want { - c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - if epRcvBuf != nil { - if err := c.EP.SetSockOpt(*epRcvBuf); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) - } - } - +// +// PreCondition: c.EP must already be created. +func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) { // Start connection attempt. waitEntry, notifyCh := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&waitEntry, waiter.EventOut) defer c.WQ.EventUnregister(&waitEntry) - err = c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}) - if err != tcpip.ErrConnectStarted { + if err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}); err != tcpip.ErrConnectStarted { c.t.Fatalf("Unexpected return value from Connect: %v", err) } @@ -590,8 +576,7 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. // Wait for connection to be established. select { case <-notifyCh: - err = c.EP.GetSockOpt(tcpip.ErrorOption{}) - if err != nil { + if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil { c.t.Fatalf("Unexpected error when connecting: %v", err) } case <-time.After(1 * time.Second): @@ -604,6 +589,27 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. c.Port = tcpHdr.SourcePort() } +// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends +// the specified option bytes as the Option field in the initial SYN packet. +// +// It also sets the receive buffer for the endpoint to the specified +// value in epRcvBuf. +func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) { + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + if epRcvBuf != nil { + if err := c.EP.SetSockOpt(*epRcvBuf); err != nil { + c.t.Fatalf("SetSockOpt failed failed: %v", err) + } + } + c.Connect(iss, rcvWnd, options) +} + // RawEndpoint is just a small wrapper around a TCP endpoint's state to make // sending data and ACK packets easy while being able to manipulate the sequence // numbers and timestamp values as needed. diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index 5b198f49d..0b76280a7 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -592,5 +592,109 @@ TEST_P(TCPSocketPairTest, MsgTruncMsgPeek) { EXPECT_EQ(0, memcmp(received_data2, sent_data, sizeof(sent_data))); } +TEST_P(TCPSocketPairTest, SetCongestionControlSucceedsForSupported) { + // This is Linux's net/tcp.h TCP_CA_NAME_MAX. + const int kTcpCaNameMax = 16; + + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + // Netstack only supports reno & cubic so we only test these two values here. + { + const char kSetCC[kTcpCaNameMax] = "reno"; + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &kSetCC, strlen(kSetCC)), + SyscallSucceedsWithValue(0)); + + char got_cc[kTcpCaNameMax]; + memset(got_cc, '1', sizeof(got_cc)); + socklen_t optlen = sizeof(got_cc); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kSetCC))); + } + { + const char kSetCC[kTcpCaNameMax] = "cubic"; + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &kSetCC, strlen(kSetCC)), + SyscallSucceedsWithValue(0)); + + char got_cc[kTcpCaNameMax]; + memset(got_cc, '1', sizeof(got_cc)); + socklen_t optlen = sizeof(got_cc); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kSetCC))); + } +} + +TEST_P(TCPSocketPairTest, SetGetTCPCongestionShortReadBuffer) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + { + // Verify that getsockopt/setsockopt work with buffers smaller than + // kTcpCaNameMax. + const char kSetCC[] = "cubic"; + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &kSetCC, strlen(kSetCC)), + SyscallSucceedsWithValue(0)); + + char got_cc[sizeof(kSetCC)]; + socklen_t optlen = sizeof(got_cc); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(got_cc))); + } +} + +TEST_P(TCPSocketPairTest, SetGetTCPCongestionLargeReadBuffer) { + // This is Linux's net/tcp.h TCP_CA_NAME_MAX. + const int kTcpCaNameMax = 16; + + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + { + // Verify that getsockopt works with buffers larger than + // kTcpCaNameMax. + const char kSetCC[] = "cubic"; + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &kSetCC, strlen(kSetCC)), + SyscallSucceedsWithValue(0)); + + char got_cc[kTcpCaNameMax + 5]; + socklen_t optlen = sizeof(got_cc); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + // Linux copies the minimum of kTcpCaNameMax or the length of the passed in + // buffer and sets optlen to the number of bytes actually copied + // irrespective of the actual length of the congestion control name. + EXPECT_EQ(kTcpCaNameMax, optlen); + EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kSetCC))); + } +} + +TEST_P(TCPSocketPairTest, SetCongestionControlFailsForUnsupported) { + // This is Linux's net/tcp.h TCP_CA_NAME_MAX. + const int kTcpCaNameMax = 16; + + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + char old_cc[kTcpCaNameMax]; + socklen_t optlen; + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &old_cc, &optlen), + SyscallSucceedsWithValue(0)); + + const char kSetCC[] = "invalid_ca_cc"; + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &kSetCC, strlen(kSetCC)), + SyscallFailsWithErrno(ENOENT)); + + char got_cc[kTcpCaNameMax]; + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(0, memcmp(got_cc, old_cc, sizeof(old_cc))); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index e3f9f9f9d..e95b644ac 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -751,6 +751,133 @@ TEST_P(SimpleTcpSocketTest, NonBlockingConnectRefused) { EXPECT_THAT(close(s.release()), SyscallSucceeds()); } +// Test that setting a supported congestion control algorithm succeeds for an +// unconnected TCP socket +TEST_P(SimpleTcpSocketTest, SetCongestionControlSucceedsForSupported) { + // This is Linux's net/tcp.h TCP_CA_NAME_MAX. + const int kTcpCaNameMax = 16; + + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + { + const char kSetCC[kTcpCaNameMax] = "reno"; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &kSetCC, + strlen(kSetCC)), + SyscallSucceedsWithValue(0)); + + char got_cc[kTcpCaNameMax]; + memset(got_cc, '1', sizeof(got_cc)); + socklen_t optlen = sizeof(got_cc); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + // We ignore optlen here as the linux kernel sets optlen to the lower of the + // size of the buffer passed in or kTcpCaNameMax and not the length of the + // congestion control algorithm's actual name. + EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kTcpCaNameMax))); + } + { + const char kSetCC[kTcpCaNameMax] = "cubic"; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &kSetCC, + strlen(kSetCC)), + SyscallSucceedsWithValue(0)); + + char got_cc[kTcpCaNameMax]; + memset(got_cc, '1', sizeof(got_cc)); + socklen_t optlen = sizeof(got_cc); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + // We ignore optlen here as the linux kernel sets optlen to the lower of the + // size of the buffer passed in or kTcpCaNameMax and not the length of the + // congestion control algorithm's actual name. + EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kTcpCaNameMax))); + } +} + +// This test verifies that a getsockopt(...TCP_CONGESTION) behaviour is +// consistent between linux and gvisor when the passed in buffer is smaller than +// kTcpCaNameMax. +TEST_P(SimpleTcpSocketTest, SetGetTCPCongestionShortReadBuffer) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + { + // Verify that getsockopt/setsockopt work with buffers smaller than + // kTcpCaNameMax. + const char kSetCC[] = "cubic"; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &kSetCC, + strlen(kSetCC)), + SyscallSucceedsWithValue(0)); + + char got_cc[sizeof(kSetCC)]; + socklen_t optlen = sizeof(got_cc); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(sizeof(got_cc), optlen); + EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(got_cc))); + } +} + +// This test verifies that a getsockopt(...TCP_CONGESTION) behaviour is +// consistent between linux and gvisor when the passed in buffer is larger than +// kTcpCaNameMax. +TEST_P(SimpleTcpSocketTest, SetGetTCPCongestionLargeReadBuffer) { + // This is Linux's net/tcp.h TCP_CA_NAME_MAX. + const int kTcpCaNameMax = 16; + + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + { + // Verify that getsockopt works with buffers larger than + // kTcpCaNameMax. + const char kSetCC[] = "cubic"; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &kSetCC, + strlen(kSetCC)), + SyscallSucceedsWithValue(0)); + + char got_cc[kTcpCaNameMax + 5]; + socklen_t optlen = sizeof(got_cc); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + // Linux copies the minimum of kTcpCaNameMax or the length of the passed in + // buffer and sets optlen to the number of bytes actually copied + // irrespective of the actual length of the congestion control name. + EXPECT_EQ(kTcpCaNameMax, optlen); + EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kSetCC))); + } +} + +// Test that setting an unsupported congestion control algorithm fails for an +// unconnected TCP socket. +TEST_P(SimpleTcpSocketTest, SetCongestionControlFailsForUnsupported) { + // This is Linux's net/tcp.h TCP_CA_NAME_MAX. + const int kTcpCaNameMax = 16; + + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + char old_cc[kTcpCaNameMax]; + socklen_t optlen = sizeof(old_cc); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &old_cc, &optlen), + SyscallSucceedsWithValue(0)); + + const char kSetCC[] = "invalid_ca_kSetCC"; + ASSERT_THAT( + setsockopt(s.get(), SOL_TCP, TCP_CONGESTION, &kSetCC, strlen(kSetCC)), + SyscallFailsWithErrno(ENOENT)); + + char got_cc[kTcpCaNameMax]; + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + // We ignore optlen here as the linux kernel sets optlen to the lower of the + // size of the buffer passed in or kTcpCaNameMax and not the length of the + // congestion control algorithm's actual name. + EXPECT_EQ(0, memcmp(got_cc, old_cc, sizeof(kTcpCaNameMax))); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); -- cgit v1.2.3