diff options
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 30 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/cubic.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/cubic_state.go | 29 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 45 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/protocol.go | 22 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/snd.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 112 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/testing/context/context.go | 52 | ||||
-rw-r--r-- | test/syscalls/linux/socket_ip_tcp_generic.cc | 104 | ||||
-rw-r--r-- | test/syscalls/linux/tcp_socket.cc | 127 |
12 files changed, 481 insertions, 62 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index f67451179..a50798cb3 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -920,6 +920,30 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa t.Kernel().EmitUnimplementedEvent(t) + case linux.TCP_CONGESTION: + if outLen <= 0 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.CongestionControlOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + // We match linux behaviour here where it returns the lower of + // TCP_CA_NAME_MAX bytes or the value of the option length. + // + // This is Linux's net/tcp.h TCP_CA_NAME_MAX. + const tcpCANameMax = 16 + + toCopy := tcpCANameMax + if outLen < tcpCANameMax { + toCopy = outLen + } + b := make([]byte, toCopy) + copy(b, v) + return b, nil + default: emitUnimplementedEventTCP(t, name) } @@ -1222,6 +1246,12 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v)))) + case linux.TCP_CONGESTION: + v := tcpip.CongestionControlOption(optVal) + if err := ep.SetSockOpt(v); err != nil { + return syserr.TranslateNetstackError(err) + } + return nil case linux.TCP_REPAIR_OPTIONS: t.Kernel().EmitUnimplementedEvent(t) diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 85ef014d0..04c776205 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -472,6 +472,14 @@ type KeepaliveIntervalOption time.Duration // closed. type KeepaliveCountOption int +// CongestionControlOption is used by SetSockOpt/GetSockOpt to set/get +// the current congestion control algorithm. +type CongestionControlOption string + +// AvailableCongestionControlOption is used to query the supported congestion +// control algorithms. +type AvailableCongestionControlOption string + // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default // TTL value for multicast messages. The default is 1. type MulticastTTLOption uint8 diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 9db38196b..a9dbfb930 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -21,6 +21,7 @@ go_library( "accept.go", "connect.go", "cubic.go", + "cubic_state.go", "endpoint.go", "endpoint_state.go", "forwarder.go", diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go index e618cd2b9..7b1f5e763 100644 --- a/pkg/tcpip/transport/tcp/cubic.go +++ b/pkg/tcpip/transport/tcp/cubic.go @@ -23,6 +23,7 @@ import ( // control algorithm state. // // See: https://tools.ietf.org/html/rfc8312. +// +stateify savable type cubicState struct { // wLastMax is the previous wMax value. wLastMax float64 @@ -33,7 +34,7 @@ type cubicState struct { // t denotes the time when the current congestion avoidance // was entered. - t time.Time + t time.Time `state:".(unixTime)"` // numCongestionEvents tracks the number of congestion events since last // RTO. diff --git a/pkg/tcpip/transport/tcp/cubic_state.go b/pkg/tcpip/transport/tcp/cubic_state.go new file mode 100644 index 000000000..d0f58cfaf --- /dev/null +++ b/pkg/tcpip/transport/tcp/cubic_state.go @@ -0,0 +1,29 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" +) + +// saveT is invoked by stateify. +func (c *cubicState) saveT() unixTime { + return unixTime{c.t.Unix(), c.t.UnixNano()} +} + +// loadT is invoked by stateify. +func (c *cubicState) loadT(unix unixTime) { + c.t = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 23422ca5e..1efe9d3fb 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -17,6 +17,7 @@ package tcp import ( "fmt" "math" + "strings" "sync" "sync/atomic" "time" @@ -286,7 +287,7 @@ type endpoint struct { // cc stores the name of the Congestion Control algorithm to use for // this endpoint. - cc CongestionControlOption + cc tcpip.CongestionControlOption // The following are used when a "packet too big" control packet is // received. They are protected by sndBufMu. They are used to @@ -394,7 +395,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite e.rcvBufSize = rs.Default } - var cs CongestionControlOption + var cs tcpip.CongestionControlOption if err := stack.TransportProtocolOption(ProtocolNumber, &cs); err == nil { e.cc = cs } @@ -898,6 +899,40 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case tcpip.CongestionControlOption: + // Query the available cc algorithms in the stack and + // validate that the specified algorithm is actually + // supported in the stack. + var avail tcpip.AvailableCongestionControlOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &avail); err != nil { + return err + } + availCC := strings.Split(string(avail), " ") + for _, cc := range availCC { + if v == tcpip.CongestionControlOption(cc) { + // Acquire the work mutex as we may need to + // reinitialize the congestion control state. + e.mu.Lock() + state := e.state + e.cc = v + e.mu.Unlock() + switch state { + case StateEstablished: + e.workMu.Lock() + e.mu.Lock() + if e.state == state { + e.snd.cc = e.snd.initCongestionControl(e.cc) + } + e.mu.Unlock() + e.workMu.Unlock() + } + return nil + } + } + + // Linux returns ENOENT when an invalid congestion + // control algorithm is specified. + return tcpip.ErrNoSuchFile default: return nil } @@ -1067,6 +1102,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.CongestionControlOption: + e.mu.Lock() + *o = e.cc + e.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index b31bcccfa..59f4009a1 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -79,13 +79,6 @@ const ( ccCubic = "cubic" ) -// CongestionControlOption sets the current congestion control algorithm. -type CongestionControlOption string - -// AvailableCongestionControlOption returns the supported congestion control -// algorithms. -type AvailableCongestionControlOption string - type protocol struct { mu sync.Mutex sackEnabled bool @@ -93,7 +86,6 @@ type protocol struct { recvBufferSize ReceiveBufferSizeOption congestionControl string availableCongestionControl []string - allowedCongestionControl []string } // Number returns the tcp protocol number. @@ -188,7 +180,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil - case CongestionControlOption: + case tcpip.CongestionControlOption: for _, c := range p.availableCongestionControl { if string(v) == c { p.mu.Lock() @@ -197,7 +189,9 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { return nil } } - return tcpip.ErrInvalidOptionValue + // linux returns ENOENT when an invalid congestion control + // is specified. + return tcpip.ErrNoSuchFile default: return tcpip.ErrUnknownProtocolOption } @@ -223,14 +217,14 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { *v = p.recvBufferSize p.mu.Unlock() return nil - case *CongestionControlOption: + case *tcpip.CongestionControlOption: p.mu.Lock() - *v = CongestionControlOption(p.congestionControl) + *v = tcpip.CongestionControlOption(p.congestionControl) p.mu.Unlock() return nil - case *AvailableCongestionControlOption: + case *tcpip.AvailableCongestionControlOption: p.mu.Lock() - *v = AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) + *v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) p.mu.Unlock() return nil default: diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index b236d7af2..daa3e8341 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -194,8 +194,6 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint s := &sender{ ep: ep, - sndCwnd: InitialCwnd, - sndSsthresh: math.MaxInt64, sndWnd: sndWnd, sndUna: iss + 1, sndNxt: iss + 1, @@ -238,7 +236,13 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint return s } -func (s *sender) initCongestionControl(congestionControlName CongestionControlOption) congestionControl { +// initCongestionControl initializes the specified congestion control module and +// returns a handle to it. It also initializes the sndCwnd and sndSsThresh to +// their initial values. +func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionControlOption) congestionControl { + s.sndCwnd = InitialCwnd + s.sndSsthresh = math.MaxInt64 + switch congestionControlName { case ccCubic: return newCubicCC(s) diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 779ca8b76..7d8987219 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3205,13 +3205,14 @@ func TestTCPEndpointProbe(t *testing.T) { } } -func TestSetCongestionControl(t *testing.T) { +func TestStackSetCongestionControl(t *testing.T) { testCases := []struct { - cc tcp.CongestionControlOption - mustPass bool + cc tcpip.CongestionControlOption + err *tcpip.Error }{ - {"reno", true}, - {"cubic", true}, + {"reno", nil}, + {"cubic", nil}, + {"blahblah", tcpip.ErrNoSuchFile}, } for _, tc := range testCases { @@ -3221,62 +3222,135 @@ func TestSetCongestionControl(t *testing.T) { s := c.Stack() - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != nil && tc.mustPass { - t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want not-nil", tcp.ProtocolNumber, tc.cc, err) + var oldCC tcpip.CongestionControlOption + if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil { + t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &oldCC, err) } - var cc tcp.CongestionControlOption + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err { + t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.cc, err, tc.err) + } + + var cc tcpip.CongestionControlOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) } - if got, want := cc, tc.cc; got != want { + + got, want := cc, oldCC + // If SetTransportProtocolOption is expected to succeed + // then the returned value for congestion control should + // match the one specified in the + // SetTransportProtocolOption call above, else it should + // be what it was before the call to + // SetTransportProtocolOption. + if tc.err == nil { + want = tc.cc + } + if got != want { t.Fatalf("got congestion control: %v, want: %v", got, want) } }) } } -func TestAvailableCongestionControl(t *testing.T) { +func TestStackAvailableCongestionControl(t *testing.T) { c := context.New(t, 1500) defer c.Cleanup() s := c.Stack() // Query permitted congestion control algorithms. - var aCC tcp.AvailableCongestionControlOption + var aCC tcpip.AvailableCongestionControlOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil { t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err) } - if got, want := aCC, tcp.AvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcp.AvailableCongestionControlOption: %v, want: %v", got, want) + if got, want := aCC, tcpip.AvailableCongestionControlOption("reno cubic"); got != want { + t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want) } } -func TestSetAvailableCongestionControl(t *testing.T) { +func TestStackSetAvailableCongestionControl(t *testing.T) { c := context.New(t, 1500) defer c.Cleanup() s := c.Stack() // Setting AvailableCongestionControlOption should fail. - aCC := tcp.AvailableCongestionControlOption("xyz") + aCC := tcpip.AvailableCongestionControlOption("xyz") if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil { t.Fatalf("s.TransportProtocolOption(%v, %v) = nil, want non-nil", tcp.ProtocolNumber, &aCC) } // Verify that we still get the expected list of congestion control options. - var cc tcp.AvailableCongestionControlOption + var cc tcpip.AvailableCongestionControlOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) } - if got, want := cc, tcp.AvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcp.AvailableCongestionControlOption: %v, want: %v", got, want) + if got, want := cc, tcpip.AvailableCongestionControlOption("reno cubic"); got != want { + t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want) + } +} + +func TestEndpointSetCongestionControl(t *testing.T) { + testCases := []struct { + cc tcpip.CongestionControlOption + err *tcpip.Error + }{ + {"reno", nil}, + {"cubic", nil}, + {"blahblah", tcpip.ErrNoSuchFile}, + } + + for _, connected := range []bool{false, true} { + for _, tc := range testCases { + t.Run(fmt.Sprintf("SetSockOpt(.., %v) w/ connected = %v", tc.cc, connected), func(t *testing.T) { + c := context.New(t, 1500) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + var oldCC tcpip.CongestionControlOption + if err := c.EP.GetSockOpt(&oldCC); err != nil { + t.Fatalf("c.EP.SockOpt(%v) = %v", &oldCC, err) + } + + if connected { + c.Connect(789 /* iss */, 32768 /* rcvWnd */, nil) + } + + if err := c.EP.SetSockOpt(tc.cc); err != tc.err { + t.Fatalf("c.EP.SetSockOpt(%v) = %v, want %v", tc.cc, err, tc.err) + } + + var cc tcpip.CongestionControlOption + if err := c.EP.GetSockOpt(&cc); err != nil { + t.Fatalf("c.EP.SockOpt(%v) = %v", &cc, err) + } + + got, want := cc, oldCC + // If SetSockOpt is expected to succeed then the + // returned value for congestion control should match + // the one specified in the SetSockOpt above, else it + // should be what it was before the call to SetSockOpt. + if tc.err == nil { + want = tc.cc + } + if got != want { + t.Fatalf("got congestion control: %v, want: %v", got, want) + } + }) + } } } func enableCUBIC(t *testing.T, c *context.Context) { t.Helper() - opt := tcp.CongestionControlOption("cubic") + opt := tcpip.CongestionControlOption("cubic") if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil { t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %v = %v", opt, err) } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 69a43b6f4..a4d89e24d 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -520,35 +520,21 @@ func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil) } -// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends -// the specified option bytes as the Option field in the initial SYN packet. +// Connect performs the 3-way handshake for c.EP with the provided Initial +// Sequence Number (iss) and receive window(rcvWnd) and any options if +// specified. // // It also sets the receive buffer for the endpoint to the specified // value in epRcvBuf. -func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) { - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want { - c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - if epRcvBuf != nil { - if err := c.EP.SetSockOpt(*epRcvBuf); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) - } - } - +// +// PreCondition: c.EP must already be created. +func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) { // Start connection attempt. waitEntry, notifyCh := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&waitEntry, waiter.EventOut) defer c.WQ.EventUnregister(&waitEntry) - err = c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}) - if err != tcpip.ErrConnectStarted { + if err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}); err != tcpip.ErrConnectStarted { c.t.Fatalf("Unexpected return value from Connect: %v", err) } @@ -590,8 +576,7 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. // Wait for connection to be established. select { case <-notifyCh: - err = c.EP.GetSockOpt(tcpip.ErrorOption{}) - if err != nil { + if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil { c.t.Fatalf("Unexpected error when connecting: %v", err) } case <-time.After(1 * time.Second): @@ -604,6 +589,27 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. c.Port = tcpHdr.SourcePort() } +// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends +// the specified option bytes as the Option field in the initial SYN packet. +// +// It also sets the receive buffer for the endpoint to the specified +// value in epRcvBuf. +func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) { + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + if epRcvBuf != nil { + if err := c.EP.SetSockOpt(*epRcvBuf); err != nil { + c.t.Fatalf("SetSockOpt failed failed: %v", err) + } + } + c.Connect(iss, rcvWnd, options) +} + // RawEndpoint is just a small wrapper around a TCP endpoint's state to make // sending data and ACK packets easy while being able to manipulate the sequence // numbers and timestamp values as needed. diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index 5b198f49d..0b76280a7 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -592,5 +592,109 @@ TEST_P(TCPSocketPairTest, MsgTruncMsgPeek) { EXPECT_EQ(0, memcmp(received_data2, sent_data, sizeof(sent_data))); } +TEST_P(TCPSocketPairTest, SetCongestionControlSucceedsForSupported) { + // This is Linux's net/tcp.h TCP_CA_NAME_MAX. + const int kTcpCaNameMax = 16; + + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + // Netstack only supports reno & cubic so we only test these two values here. + { + const char kSetCC[kTcpCaNameMax] = "reno"; + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &kSetCC, strlen(kSetCC)), + SyscallSucceedsWithValue(0)); + + char got_cc[kTcpCaNameMax]; + memset(got_cc, '1', sizeof(got_cc)); + socklen_t optlen = sizeof(got_cc); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kSetCC))); + } + { + const char kSetCC[kTcpCaNameMax] = "cubic"; + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &kSetCC, strlen(kSetCC)), + SyscallSucceedsWithValue(0)); + + char got_cc[kTcpCaNameMax]; + memset(got_cc, '1', sizeof(got_cc)); + socklen_t optlen = sizeof(got_cc); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kSetCC))); + } +} + +TEST_P(TCPSocketPairTest, SetGetTCPCongestionShortReadBuffer) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + { + // Verify that getsockopt/setsockopt work with buffers smaller than + // kTcpCaNameMax. + const char kSetCC[] = "cubic"; + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &kSetCC, strlen(kSetCC)), + SyscallSucceedsWithValue(0)); + + char got_cc[sizeof(kSetCC)]; + socklen_t optlen = sizeof(got_cc); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(got_cc))); + } +} + +TEST_P(TCPSocketPairTest, SetGetTCPCongestionLargeReadBuffer) { + // This is Linux's net/tcp.h TCP_CA_NAME_MAX. + const int kTcpCaNameMax = 16; + + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + { + // Verify that getsockopt works with buffers larger than + // kTcpCaNameMax. + const char kSetCC[] = "cubic"; + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &kSetCC, strlen(kSetCC)), + SyscallSucceedsWithValue(0)); + + char got_cc[kTcpCaNameMax + 5]; + socklen_t optlen = sizeof(got_cc); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + // Linux copies the minimum of kTcpCaNameMax or the length of the passed in + // buffer and sets optlen to the number of bytes actually copied + // irrespective of the actual length of the congestion control name. + EXPECT_EQ(kTcpCaNameMax, optlen); + EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kSetCC))); + } +} + +TEST_P(TCPSocketPairTest, SetCongestionControlFailsForUnsupported) { + // This is Linux's net/tcp.h TCP_CA_NAME_MAX. + const int kTcpCaNameMax = 16; + + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + char old_cc[kTcpCaNameMax]; + socklen_t optlen; + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &old_cc, &optlen), + SyscallSucceedsWithValue(0)); + + const char kSetCC[] = "invalid_ca_cc"; + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &kSetCC, strlen(kSetCC)), + SyscallFailsWithErrno(ENOENT)); + + char got_cc[kTcpCaNameMax]; + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(0, memcmp(got_cc, old_cc, sizeof(old_cc))); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index e3f9f9f9d..e95b644ac 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -751,6 +751,133 @@ TEST_P(SimpleTcpSocketTest, NonBlockingConnectRefused) { EXPECT_THAT(close(s.release()), SyscallSucceeds()); } +// Test that setting a supported congestion control algorithm succeeds for an +// unconnected TCP socket +TEST_P(SimpleTcpSocketTest, SetCongestionControlSucceedsForSupported) { + // This is Linux's net/tcp.h TCP_CA_NAME_MAX. + const int kTcpCaNameMax = 16; + + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + { + const char kSetCC[kTcpCaNameMax] = "reno"; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &kSetCC, + strlen(kSetCC)), + SyscallSucceedsWithValue(0)); + + char got_cc[kTcpCaNameMax]; + memset(got_cc, '1', sizeof(got_cc)); + socklen_t optlen = sizeof(got_cc); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + // We ignore optlen here as the linux kernel sets optlen to the lower of the + // size of the buffer passed in or kTcpCaNameMax and not the length of the + // congestion control algorithm's actual name. + EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kTcpCaNameMax))); + } + { + const char kSetCC[kTcpCaNameMax] = "cubic"; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &kSetCC, + strlen(kSetCC)), + SyscallSucceedsWithValue(0)); + + char got_cc[kTcpCaNameMax]; + memset(got_cc, '1', sizeof(got_cc)); + socklen_t optlen = sizeof(got_cc); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + // We ignore optlen here as the linux kernel sets optlen to the lower of the + // size of the buffer passed in or kTcpCaNameMax and not the length of the + // congestion control algorithm's actual name. + EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kTcpCaNameMax))); + } +} + +// This test verifies that a getsockopt(...TCP_CONGESTION) behaviour is +// consistent between linux and gvisor when the passed in buffer is smaller than +// kTcpCaNameMax. +TEST_P(SimpleTcpSocketTest, SetGetTCPCongestionShortReadBuffer) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + { + // Verify that getsockopt/setsockopt work with buffers smaller than + // kTcpCaNameMax. + const char kSetCC[] = "cubic"; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &kSetCC, + strlen(kSetCC)), + SyscallSucceedsWithValue(0)); + + char got_cc[sizeof(kSetCC)]; + socklen_t optlen = sizeof(got_cc); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(sizeof(got_cc), optlen); + EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(got_cc))); + } +} + +// This test verifies that a getsockopt(...TCP_CONGESTION) behaviour is +// consistent between linux and gvisor when the passed in buffer is larger than +// kTcpCaNameMax. +TEST_P(SimpleTcpSocketTest, SetGetTCPCongestionLargeReadBuffer) { + // This is Linux's net/tcp.h TCP_CA_NAME_MAX. + const int kTcpCaNameMax = 16; + + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + { + // Verify that getsockopt works with buffers larger than + // kTcpCaNameMax. + const char kSetCC[] = "cubic"; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &kSetCC, + strlen(kSetCC)), + SyscallSucceedsWithValue(0)); + + char got_cc[kTcpCaNameMax + 5]; + socklen_t optlen = sizeof(got_cc); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + // Linux copies the minimum of kTcpCaNameMax or the length of the passed in + // buffer and sets optlen to the number of bytes actually copied + // irrespective of the actual length of the congestion control name. + EXPECT_EQ(kTcpCaNameMax, optlen); + EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kSetCC))); + } +} + +// Test that setting an unsupported congestion control algorithm fails for an +// unconnected TCP socket. +TEST_P(SimpleTcpSocketTest, SetCongestionControlFailsForUnsupported) { + // This is Linux's net/tcp.h TCP_CA_NAME_MAX. + const int kTcpCaNameMax = 16; + + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + char old_cc[kTcpCaNameMax]; + socklen_t optlen = sizeof(old_cc); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &old_cc, &optlen), + SyscallSucceedsWithValue(0)); + + const char kSetCC[] = "invalid_ca_kSetCC"; + ASSERT_THAT( + setsockopt(s.get(), SOL_TCP, TCP_CONGESTION, &kSetCC, strlen(kSetCC)), + SyscallFailsWithErrno(ENOENT)); + + char got_cc[kTcpCaNameMax]; + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + // We ignore optlen here as the linux kernel sets optlen to the lower of the + // size of the buffer passed in or kTcpCaNameMax and not the length of the + // congestion control algorithm's actual name. + EXPECT_EQ(0, memcmp(got_cc, old_cc, sizeof(kTcpCaNameMax))); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); |