diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/tcpip/transport/tcp/protocol.go | 46 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 68 |
2 files changed, 107 insertions, 7 deletions
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index d81a1dd9b..e9748d8a6 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -11,6 +11,7 @@ package tcp import ( + "strings" "sync" "gvisor.googlesource.com/gvisor/pkg/tcpip" @@ -58,11 +59,21 @@ type ReceiveBufferSizeOption struct { Max int } +// CongestionControlOption sets the current congestion control algorithm. +type CongestionControlOption string + +// AvailableCongestionControlOption returns the supported congestion control +// algorithms. +type AvailableCongestionControlOption string + type protocol struct { - mu sync.Mutex - sackEnabled bool - sendBufferSize SendBufferSizeOption - recvBufferSize ReceiveBufferSizeOption + mu sync.Mutex + sackEnabled bool + sendBufferSize SendBufferSizeOption + recvBufferSize ReceiveBufferSizeOption + congestionControl string + availableCongestionControl []string + allowedCongestionControl []string } // Number returns the tcp protocol number. @@ -151,6 +162,16 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case CongestionControlOption: + for _, c := range p.availableCongestionControl { + if string(v) == c { + p.mu.Lock() + p.congestionControl = string(v) + p.mu.Unlock() + return nil + } + } + return tcpip.ErrInvalidOptionValue default: return tcpip.ErrUnknownProtocolOption } @@ -176,7 +197,16 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { *v = p.recvBufferSize p.mu.Unlock() return nil - + case *CongestionControlOption: + p.mu.Lock() + *v = CongestionControlOption(p.congestionControl) + p.mu.Unlock() + return nil + case *AvailableCongestionControlOption: + p.mu.Lock() + *v = AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) + p.mu.Unlock() + return nil default: return tcpip.ErrUnknownProtocolOption } @@ -185,8 +215,10 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { func init() { stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol { return &protocol{ - sendBufferSize: SendBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize}, - recvBufferSize: ReceiveBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize}, + sendBufferSize: SendBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize}, + recvBufferSize: ReceiveBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize}, + congestionControl: "reno", + availableCongestionControl: []string{"reno"}, } }) } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 8c54310f2..fa2ef52f9 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -2770,3 +2770,71 @@ func TestTCPEndpointProbe(t *testing.T) { t.Fatalf("TCP Probe function was not called") } } + +func TestSetCongestionControl(t *testing.T) { + testCases := []struct { + cc tcp.CongestionControlOption + mustPass bool + }{ + {"reno", true}, + {"cubic", false}, + } + for _, tc := range testCases { + t.Run(fmt.Sprintf("SetTransportProtocolOption(.., %v)", tc.cc), func(t *testing.T) { + c := context.New(t, 1500) + defer c.Cleanup() + + s := c.Stack() + + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != nil && tc.mustPass { + t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want not-nil", tcp.ProtocolNumber, tc.cc, err) + } + + var cc tcp.CongestionControlOption + if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { + t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) + } + if got, want := cc, tcp.CongestionControlOption("reno"); got != want { + t.Fatalf("unexpected value for congestion control got: %v, want: %v", got, want) + } + }) + } +} + +func TestAvailableCongestionControl(t *testing.T) { + c := context.New(t, 1500) + defer c.Cleanup() + + s := c.Stack() + + // Query permitted congestion control algorithms. + var aCC tcp.AvailableCongestionControlOption + if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil { + t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err) + } + if got, want := aCC, tcp.AvailableCongestionControlOption("reno"); got != want { + t.Fatalf("unexpected value for AvailableCongestionControlOption: got: %v, want: %v", got, want) + } +} + +func TestSetAvailableCongestionControl(t *testing.T) { + c := context.New(t, 1500) + defer c.Cleanup() + + s := c.Stack() + + // Setting AvailableCongestionControlOption should fail. + aCC := tcp.AvailableCongestionControlOption("xyz") + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil { + t.Fatalf("s.TransportProtocolOption(%v, %v) = nil, want non-nil", tcp.ProtocolNumber, &aCC) + } + + // Verify that we still get the expected list of congestion control options. + var cc tcp.CongestionControlOption + if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { + t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) + } + if got, want := cc, tcp.CongestionControlOption("reno"); got != want { + t.Fatalf("unexpected value for congestion control got: %v, want: %v", got, want) + } +} |