summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
authorBhasker Hariharan <bhaskerh@google.com>2019-06-12 13:34:47 -0700
committerShentubot <shentubot@google.com>2019-06-12 13:35:50 -0700
commit70578806e8d3e01fae2249b3e602cd5b05d378a0 (patch)
tree6909e9f6103e23c43d753d10d9ee424670e8c8cd /pkg/tcpip
parentbb849bad296f372670c2d2cf97424f74cf750ce2 (diff)
Add support for TCP_CONGESTION socket option.
This CL also cleans up the error returned for setting congestion control which was incorrectly returning EINVAL instead of ENOENT. PiperOrigin-RevId: 252889093
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/tcpip.go8
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/cubic.go3
-rw-r--r--pkg/tcpip/transport/tcp/cubic_state.go29
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go45
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go22
-rw-r--r--pkg/tcpip/transport/tcp/snd.go10
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go112
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go52
9 files changed, 220 insertions, 62 deletions
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 85ef014d0..04c776205 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -472,6 +472,14 @@ type KeepaliveIntervalOption time.Duration
// closed.
type KeepaliveCountOption int
+// CongestionControlOption is used by SetSockOpt/GetSockOpt to set/get
+// the current congestion control algorithm.
+type CongestionControlOption string
+
+// AvailableCongestionControlOption is used to query the supported congestion
+// control algorithms.
+type AvailableCongestionControlOption string
+
// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default
// TTL value for multicast messages. The default is 1.
type MulticastTTLOption uint8
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 9db38196b..a9dbfb930 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -21,6 +21,7 @@ go_library(
"accept.go",
"connect.go",
"cubic.go",
+ "cubic_state.go",
"endpoint.go",
"endpoint_state.go",
"forwarder.go",
diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go
index e618cd2b9..7b1f5e763 100644
--- a/pkg/tcpip/transport/tcp/cubic.go
+++ b/pkg/tcpip/transport/tcp/cubic.go
@@ -23,6 +23,7 @@ import (
// control algorithm state.
//
// See: https://tools.ietf.org/html/rfc8312.
+// +stateify savable
type cubicState struct {
// wLastMax is the previous wMax value.
wLastMax float64
@@ -33,7 +34,7 @@ type cubicState struct {
// t denotes the time when the current congestion avoidance
// was entered.
- t time.Time
+ t time.Time `state:".(unixTime)"`
// numCongestionEvents tracks the number of congestion events since last
// RTO.
diff --git a/pkg/tcpip/transport/tcp/cubic_state.go b/pkg/tcpip/transport/tcp/cubic_state.go
new file mode 100644
index 000000000..d0f58cfaf
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/cubic_state.go
@@ -0,0 +1,29 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+)
+
+// saveT is invoked by stateify.
+func (c *cubicState) saveT() unixTime {
+ return unixTime{c.t.Unix(), c.t.UnixNano()}
+}
+
+// loadT is invoked by stateify.
+func (c *cubicState) loadT(unix unixTime) {
+ c.t = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 23422ca5e..1efe9d3fb 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -17,6 +17,7 @@ package tcp
import (
"fmt"
"math"
+ "strings"
"sync"
"sync/atomic"
"time"
@@ -286,7 +287,7 @@ type endpoint struct {
// cc stores the name of the Congestion Control algorithm to use for
// this endpoint.
- cc CongestionControlOption
+ cc tcpip.CongestionControlOption
// The following are used when a "packet too big" control packet is
// received. They are protected by sndBufMu. They are used to
@@ -394,7 +395,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite
e.rcvBufSize = rs.Default
}
- var cs CongestionControlOption
+ var cs tcpip.CongestionControlOption
if err := stack.TransportProtocolOption(ProtocolNumber, &cs); err == nil {
e.cc = cs
}
@@ -898,6 +899,40 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
return nil
+ case tcpip.CongestionControlOption:
+ // Query the available cc algorithms in the stack and
+ // validate that the specified algorithm is actually
+ // supported in the stack.
+ var avail tcpip.AvailableCongestionControlOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &avail); err != nil {
+ return err
+ }
+ availCC := strings.Split(string(avail), " ")
+ for _, cc := range availCC {
+ if v == tcpip.CongestionControlOption(cc) {
+ // Acquire the work mutex as we may need to
+ // reinitialize the congestion control state.
+ e.mu.Lock()
+ state := e.state
+ e.cc = v
+ e.mu.Unlock()
+ switch state {
+ case StateEstablished:
+ e.workMu.Lock()
+ e.mu.Lock()
+ if e.state == state {
+ e.snd.cc = e.snd.initCongestionControl(e.cc)
+ }
+ e.mu.Unlock()
+ e.workMu.Unlock()
+ }
+ return nil
+ }
+ }
+
+ // Linux returns ENOENT when an invalid congestion
+ // control algorithm is specified.
+ return tcpip.ErrNoSuchFile
default:
return nil
}
@@ -1067,6 +1102,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.CongestionControlOption:
+ e.mu.Lock()
+ *o = e.cc
+ e.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index b31bcccfa..59f4009a1 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -79,13 +79,6 @@ const (
ccCubic = "cubic"
)
-// CongestionControlOption sets the current congestion control algorithm.
-type CongestionControlOption string
-
-// AvailableCongestionControlOption returns the supported congestion control
-// algorithms.
-type AvailableCongestionControlOption string
-
type protocol struct {
mu sync.Mutex
sackEnabled bool
@@ -93,7 +86,6 @@ type protocol struct {
recvBufferSize ReceiveBufferSizeOption
congestionControl string
availableCongestionControl []string
- allowedCongestionControl []string
}
// Number returns the tcp protocol number.
@@ -188,7 +180,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
- case CongestionControlOption:
+ case tcpip.CongestionControlOption:
for _, c := range p.availableCongestionControl {
if string(v) == c {
p.mu.Lock()
@@ -197,7 +189,9 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
return nil
}
}
- return tcpip.ErrInvalidOptionValue
+ // linux returns ENOENT when an invalid congestion control
+ // is specified.
+ return tcpip.ErrNoSuchFile
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -223,14 +217,14 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
*v = p.recvBufferSize
p.mu.Unlock()
return nil
- case *CongestionControlOption:
+ case *tcpip.CongestionControlOption:
p.mu.Lock()
- *v = CongestionControlOption(p.congestionControl)
+ *v = tcpip.CongestionControlOption(p.congestionControl)
p.mu.Unlock()
return nil
- case *AvailableCongestionControlOption:
+ case *tcpip.AvailableCongestionControlOption:
p.mu.Lock()
- *v = AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
+ *v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
p.mu.Unlock()
return nil
default:
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index b236d7af2..daa3e8341 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -194,8 +194,6 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
s := &sender{
ep: ep,
- sndCwnd: InitialCwnd,
- sndSsthresh: math.MaxInt64,
sndWnd: sndWnd,
sndUna: iss + 1,
sndNxt: iss + 1,
@@ -238,7 +236,13 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
return s
}
-func (s *sender) initCongestionControl(congestionControlName CongestionControlOption) congestionControl {
+// initCongestionControl initializes the specified congestion control module and
+// returns a handle to it. It also initializes the sndCwnd and sndSsThresh to
+// their initial values.
+func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionControlOption) congestionControl {
+ s.sndCwnd = InitialCwnd
+ s.sndSsthresh = math.MaxInt64
+
switch congestionControlName {
case ccCubic:
return newCubicCC(s)
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 779ca8b76..7d8987219 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -3205,13 +3205,14 @@ func TestTCPEndpointProbe(t *testing.T) {
}
}
-func TestSetCongestionControl(t *testing.T) {
+func TestStackSetCongestionControl(t *testing.T) {
testCases := []struct {
- cc tcp.CongestionControlOption
- mustPass bool
+ cc tcpip.CongestionControlOption
+ err *tcpip.Error
}{
- {"reno", true},
- {"cubic", true},
+ {"reno", nil},
+ {"cubic", nil},
+ {"blahblah", tcpip.ErrNoSuchFile},
}
for _, tc := range testCases {
@@ -3221,62 +3222,135 @@ func TestSetCongestionControl(t *testing.T) {
s := c.Stack()
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != nil && tc.mustPass {
- t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want not-nil", tcp.ProtocolNumber, tc.cc, err)
+ var oldCC tcpip.CongestionControlOption
+ if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &oldCC, err)
}
- var cc tcp.CongestionControlOption
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err {
+ t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.cc, err, tc.err)
+ }
+
+ var cc tcpip.CongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err)
}
- if got, want := cc, tc.cc; got != want {
+
+ got, want := cc, oldCC
+ // If SetTransportProtocolOption is expected to succeed
+ // then the returned value for congestion control should
+ // match the one specified in the
+ // SetTransportProtocolOption call above, else it should
+ // be what it was before the call to
+ // SetTransportProtocolOption.
+ if tc.err == nil {
+ want = tc.cc
+ }
+ if got != want {
t.Fatalf("got congestion control: %v, want: %v", got, want)
}
})
}
}
-func TestAvailableCongestionControl(t *testing.T) {
+func TestStackAvailableCongestionControl(t *testing.T) {
c := context.New(t, 1500)
defer c.Cleanup()
s := c.Stack()
// Query permitted congestion control algorithms.
- var aCC tcp.AvailableCongestionControlOption
+ var aCC tcpip.AvailableCongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil {
t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err)
}
- if got, want := aCC, tcp.AvailableCongestionControlOption("reno cubic"); got != want {
- t.Fatalf("got tcp.AvailableCongestionControlOption: %v, want: %v", got, want)
+ if got, want := aCC, tcpip.AvailableCongestionControlOption("reno cubic"); got != want {
+ t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want)
}
}
-func TestSetAvailableCongestionControl(t *testing.T) {
+func TestStackSetAvailableCongestionControl(t *testing.T) {
c := context.New(t, 1500)
defer c.Cleanup()
s := c.Stack()
// Setting AvailableCongestionControlOption should fail.
- aCC := tcp.AvailableCongestionControlOption("xyz")
+ aCC := tcpip.AvailableCongestionControlOption("xyz")
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil {
t.Fatalf("s.TransportProtocolOption(%v, %v) = nil, want non-nil", tcp.ProtocolNumber, &aCC)
}
// Verify that we still get the expected list of congestion control options.
- var cc tcp.AvailableCongestionControlOption
+ var cc tcpip.AvailableCongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err)
}
- if got, want := cc, tcp.AvailableCongestionControlOption("reno cubic"); got != want {
- t.Fatalf("got tcp.AvailableCongestionControlOption: %v, want: %v", got, want)
+ if got, want := cc, tcpip.AvailableCongestionControlOption("reno cubic"); got != want {
+ t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want)
+ }
+}
+
+func TestEndpointSetCongestionControl(t *testing.T) {
+ testCases := []struct {
+ cc tcpip.CongestionControlOption
+ err *tcpip.Error
+ }{
+ {"reno", nil},
+ {"cubic", nil},
+ {"blahblah", tcpip.ErrNoSuchFile},
+ }
+
+ for _, connected := range []bool{false, true} {
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("SetSockOpt(.., %v) w/ connected = %v", tc.cc, connected), func(t *testing.T) {
+ c := context.New(t, 1500)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ var oldCC tcpip.CongestionControlOption
+ if err := c.EP.GetSockOpt(&oldCC); err != nil {
+ t.Fatalf("c.EP.SockOpt(%v) = %v", &oldCC, err)
+ }
+
+ if connected {
+ c.Connect(789 /* iss */, 32768 /* rcvWnd */, nil)
+ }
+
+ if err := c.EP.SetSockOpt(tc.cc); err != tc.err {
+ t.Fatalf("c.EP.SetSockOpt(%v) = %v, want %v", tc.cc, err, tc.err)
+ }
+
+ var cc tcpip.CongestionControlOption
+ if err := c.EP.GetSockOpt(&cc); err != nil {
+ t.Fatalf("c.EP.SockOpt(%v) = %v", &cc, err)
+ }
+
+ got, want := cc, oldCC
+ // If SetSockOpt is expected to succeed then the
+ // returned value for congestion control should match
+ // the one specified in the SetSockOpt above, else it
+ // should be what it was before the call to SetSockOpt.
+ if tc.err == nil {
+ want = tc.cc
+ }
+ if got != want {
+ t.Fatalf("got congestion control: %v, want: %v", got, want)
+ }
+ })
+ }
}
}
func enableCUBIC(t *testing.T, c *context.Context) {
t.Helper()
- opt := tcp.CongestionControlOption("cubic")
+ opt := tcpip.CongestionControlOption("cubic")
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil {
t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %v = %v", opt, err)
}
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 69a43b6f4..a4d89e24d 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -520,35 +520,21 @@ func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf
c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil)
}
-// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
-// the specified option bytes as the Option field in the initial SYN packet.
+// Connect performs the 3-way handshake for c.EP with the provided Initial
+// Sequence Number (iss) and receive window(rcvWnd) and any options if
+// specified.
//
// It also sets the receive buffer for the endpoint to the specified
// value in epRcvBuf.
-func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) {
- // Create TCP endpoint.
- var err *tcpip.Error
- c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want {
- c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- if epRcvBuf != nil {
- if err := c.EP.SetSockOpt(*epRcvBuf); err != nil {
- c.t.Fatalf("SetSockOpt failed failed: %v", err)
- }
- }
-
+//
+// PreCondition: c.EP must already be created.
+func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) {
// Start connection attempt.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&waitEntry, waiter.EventOut)
defer c.WQ.EventUnregister(&waitEntry)
- err = c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort})
- if err != tcpip.ErrConnectStarted {
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}); err != tcpip.ErrConnectStarted {
c.t.Fatalf("Unexpected return value from Connect: %v", err)
}
@@ -590,8 +576,7 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.
// Wait for connection to be established.
select {
case <-notifyCh:
- err = c.EP.GetSockOpt(tcpip.ErrorOption{})
- if err != nil {
+ if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil {
c.t.Fatalf("Unexpected error when connecting: %v", err)
}
case <-time.After(1 * time.Second):
@@ -604,6 +589,27 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.
c.Port = tcpHdr.SourcePort()
}
+// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
+// the specified option bytes as the Option field in the initial SYN packet.
+//
+// It also sets the receive buffer for the endpoint to the specified
+// value in epRcvBuf.
+func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) {
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ if epRcvBuf != nil {
+ if err := c.EP.SetSockOpt(*epRcvBuf); err != nil {
+ c.t.Fatalf("SetSockOpt failed failed: %v", err)
+ }
+ }
+ c.Connect(iss, rcvWnd, options)
+}
+
// RawEndpoint is just a small wrapper around a TCP endpoint's state to make
// sending data and ACK packets easy while being able to manipulate the sequence
// numbers and timestamp values as needed.