summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go46
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go68
2 files changed, 107 insertions, 7 deletions
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index d81a1dd9b..e9748d8a6 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -11,6 +11,7 @@
package tcp
import (
+ "strings"
"sync"
"gvisor.googlesource.com/gvisor/pkg/tcpip"
@@ -58,11 +59,21 @@ type ReceiveBufferSizeOption struct {
Max int
}
+// CongestionControlOption sets the current congestion control algorithm.
+type CongestionControlOption string
+
+// AvailableCongestionControlOption returns the supported congestion control
+// algorithms.
+type AvailableCongestionControlOption string
+
type protocol struct {
- mu sync.Mutex
- sackEnabled bool
- sendBufferSize SendBufferSizeOption
- recvBufferSize ReceiveBufferSizeOption
+ mu sync.Mutex
+ sackEnabled bool
+ sendBufferSize SendBufferSizeOption
+ recvBufferSize ReceiveBufferSizeOption
+ congestionControl string
+ availableCongestionControl []string
+ allowedCongestionControl []string
}
// Number returns the tcp protocol number.
@@ -151,6 +162,16 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
+ case CongestionControlOption:
+ for _, c := range p.availableCongestionControl {
+ if string(v) == c {
+ p.mu.Lock()
+ p.congestionControl = string(v)
+ p.mu.Unlock()
+ return nil
+ }
+ }
+ return tcpip.ErrInvalidOptionValue
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -176,7 +197,16 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
*v = p.recvBufferSize
p.mu.Unlock()
return nil
-
+ case *CongestionControlOption:
+ p.mu.Lock()
+ *v = CongestionControlOption(p.congestionControl)
+ p.mu.Unlock()
+ return nil
+ case *AvailableCongestionControlOption:
+ p.mu.Lock()
+ *v = AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
+ p.mu.Unlock()
+ return nil
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -185,8 +215,10 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
func init() {
stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
return &protocol{
- sendBufferSize: SendBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize},
- recvBufferSize: ReceiveBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize},
+ sendBufferSize: SendBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize},
+ recvBufferSize: ReceiveBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize},
+ congestionControl: "reno",
+ availableCongestionControl: []string{"reno"},
}
})
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 8c54310f2..fa2ef52f9 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -2770,3 +2770,71 @@ func TestTCPEndpointProbe(t *testing.T) {
t.Fatalf("TCP Probe function was not called")
}
}
+
+func TestSetCongestionControl(t *testing.T) {
+ testCases := []struct {
+ cc tcp.CongestionControlOption
+ mustPass bool
+ }{
+ {"reno", true},
+ {"cubic", false},
+ }
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("SetTransportProtocolOption(.., %v)", tc.cc), func(t *testing.T) {
+ c := context.New(t, 1500)
+ defer c.Cleanup()
+
+ s := c.Stack()
+
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != nil && tc.mustPass {
+ t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want not-nil", tcp.ProtocolNumber, tc.cc, err)
+ }
+
+ var cc tcp.CongestionControlOption
+ if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err)
+ }
+ if got, want := cc, tcp.CongestionControlOption("reno"); got != want {
+ t.Fatalf("unexpected value for congestion control got: %v, want: %v", got, want)
+ }
+ })
+ }
+}
+
+func TestAvailableCongestionControl(t *testing.T) {
+ c := context.New(t, 1500)
+ defer c.Cleanup()
+
+ s := c.Stack()
+
+ // Query permitted congestion control algorithms.
+ var aCC tcp.AvailableCongestionControlOption
+ if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err)
+ }
+ if got, want := aCC, tcp.AvailableCongestionControlOption("reno"); got != want {
+ t.Fatalf("unexpected value for AvailableCongestionControlOption: got: %v, want: %v", got, want)
+ }
+}
+
+func TestSetAvailableCongestionControl(t *testing.T) {
+ c := context.New(t, 1500)
+ defer c.Cleanup()
+
+ s := c.Stack()
+
+ // Setting AvailableCongestionControlOption should fail.
+ aCC := tcp.AvailableCongestionControlOption("xyz")
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = nil, want non-nil", tcp.ProtocolNumber, &aCC)
+ }
+
+ // Verify that we still get the expected list of congestion control options.
+ var cc tcp.CongestionControlOption
+ if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err)
+ }
+ if got, want := cc, tcp.CongestionControlOption("reno"); got != want {
+ t.Fatalf("unexpected value for congestion control got: %v, want: %v", got, want)
+ }
+}