From d35f07b36a40ee94b3a5b588a5fa62572753c620 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Tue, 8 Sep 2020 12:15:58 -0700 Subject: Improve type safety for transport protocol options The existing implementation for TransportProtocol.{Set}Option take arguments of an empty interface type which all types (implicitly) implement; any type may be passed to the functions. This change introduces marker interfaces for transport protocol options that may be set or queried which transport protocol option types implement to ensure that invalid types are caught at compile time. Different interfaces are used to allow the compiler to enforce read-only or set-only socket options. RELNOTES: n/a PiperOrigin-RevId: 330559811 --- pkg/sentry/socket/netstack/stack.go | 22 +-- pkg/tcpip/stack/registration.go | 4 +- pkg/tcpip/stack/stack.go | 4 +- pkg/tcpip/stack/transport_test.go | 59 ++----- pkg/tcpip/tcpip.go | 127 ++++++++++++++- pkg/tcpip/transport/icmp/protocol.go | 4 +- pkg/tcpip/transport/tcp/connect.go | 2 +- pkg/tcpip/transport/tcp/dual_stack_test.go | 5 +- pkg/tcpip/transport/tcp/endpoint.go | 20 +-- pkg/tcpip/transport/tcp/endpoint_state.go | 4 +- pkg/tcpip/transport/tcp/protocol.go | 178 ++++++++------------- pkg/tcpip/transport/tcp/tcp_sack_test.go | 15 +- pkg/tcpip/transport/tcp/tcp_test.go | 164 +++++++++++-------- pkg/tcpip/transport/tcp/tcp_timestamp_test.go | 10 +- pkg/tcpip/transport/tcp/testing/context/context.go | 17 +- pkg/tcpip/transport/udp/protocol.go | 4 +- 16 files changed, 368 insertions(+), 271 deletions(-) (limited to 'pkg') diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index f0fe18684..36144e1eb 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -155,7 +155,7 @@ func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { // TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize. func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) { - var rs tcp.ReceiveBufferSizeOption + var rs tcpip.TCPReceiveBufferSizeRangeOption err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &rs) return inet.TCPBufferSize{ Min: rs.Min, @@ -166,17 +166,17 @@ func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) { // SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize. func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error { - rs := tcp.ReceiveBufferSizeOption{ + rs := tcpip.TCPReceiveBufferSizeRangeOption{ Min: size.Min, Default: size.Default, Max: size.Max, } - return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, rs)).ToError() + return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &rs)).ToError() } // TCPSendBufferSize implements inet.Stack.TCPSendBufferSize. func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) { - var ss tcp.SendBufferSizeOption + var ss tcpip.TCPSendBufferSizeRangeOption err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &ss) return inet.TCPBufferSize{ Min: ss.Min, @@ -187,29 +187,30 @@ func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) { // SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize. func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error { - ss := tcp.SendBufferSizeOption{ + ss := tcpip.TCPSendBufferSizeRangeOption{ Min: size.Min, Default: size.Default, Max: size.Max, } - return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, ss)).ToError() + return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &ss)).ToError() } // TCPSACKEnabled implements inet.Stack.TCPSACKEnabled. func (s *Stack) TCPSACKEnabled() (bool, error) { - var sack tcp.SACKEnabled + var sack tcpip.TCPSACKEnabled err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &sack) return bool(sack), syserr.TranslateNetstackError(err).ToError() } // SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled. func (s *Stack) SetTCPSACKEnabled(enabled bool) error { - return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enabled))).ToError() + opt := tcpip.TCPSACKEnabled(enabled) + return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt)).ToError() } // TCPRecovery implements inet.Stack.TCPRecovery. func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) { - var recovery tcp.Recovery + var recovery tcpip.TCPRecovery if err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &recovery); err != nil { return 0, syserr.TranslateNetstackError(err).ToError() } @@ -218,7 +219,8 @@ func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) { // SetTCPRecovery implements inet.Stack.SetTCPRecovery. func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error { - return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.Recovery(recovery))).ToError() + opt := tcpip.TCPRecovery(recovery) + return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt)).ToError() } // Statistics implements inet.Stack.Statistics. diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 2d88fa1f7..4fa86a3ac 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -159,12 +159,12 @@ type TransportProtocol interface { // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the // provided option value is invalid. - SetOption(option interface{}) *tcpip.Error + SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error // Option allows retrieving protocol specific option values. // Option returns an error if the option is not supported or the // provided option value is invalid. - Option(option interface{}) *tcpip.Error + Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error // Close requests that any worker goroutines owned by the protocol // stop. diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 133d90815..def8b0b43 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -817,7 +817,7 @@ func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, optio // options. This method returns an error if the protocol is not supported or // option is not supported by the protocol implementation or the provided value // is incorrect. -func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error { +func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.SettableTransportProtocolOption) *tcpip.Error { transProtoState, ok := s.transportProtocols[transport] if !ok { return tcpip.ErrUnknownProtocol @@ -832,7 +832,7 @@ func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumb // if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil { // ... // } -func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error { +func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.GettableTransportProtocolOption) *tcpip.Error { transProtoState, ok := s.transportProtocols[transport] if !ok { return tcpip.ErrUnknownProtocol diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 9292bfccb..ef3457e32 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -291,22 +291,20 @@ func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack return true } -func (f *fakeTransportProtocol) SetOption(option interface{}) *tcpip.Error { +func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error { switch v := option.(type) { - case fakeTransportGoodOption: - f.opts.good = bool(v) + case *tcpip.TCPModerateReceiveBufferOption: + f.opts.good = bool(*v) return nil - case fakeTransportInvalidValueOption: - return tcpip.ErrInvalidOptionValue default: return tcpip.ErrUnknownProtocolOption } } -func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { +func (f *fakeTransportProtocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error { switch v := option.(type) { - case *fakeTransportGoodOption: - *v = fakeTransportGoodOption(f.opts.good) + case *tcpip.TCPModerateReceiveBufferOption: + *v = tcpip.TCPModerateReceiveBufferOption(f.opts.good) return nil default: return tcpip.ErrUnknownProtocolOption @@ -533,41 +531,16 @@ func TestTransportOptions(t *testing.T) { TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, }) - // Try an unsupported transport protocol. - if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol { - t.Fatalf("SetTransportProtocolOption(fakeTrans2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err) - } - - testCases := []struct { - option interface{} - wantErr *tcpip.Error - verifier func(t *testing.T, p stack.TransportProtocol) - }{ - {fakeTransportGoodOption(true), nil, func(t *testing.T, p stack.TransportProtocol) { - t.Helper() - fakeTrans := p.(*fakeTransportProtocol) - if fakeTrans.opts.good != true { - t.Fatalf("fakeTrans.opts.good = false, want = true") - } - var v fakeTransportGoodOption - if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil { - t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) = %v, want = nil, where v is option %T", v, err) - } - if v != true { - t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) returned v = %v, want = true", v) - } - - }}, - {fakeTransportBadOption(true), tcpip.ErrUnknownProtocolOption, nil}, - {fakeTransportInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil}, - } - for _, tc := range testCases { - if got := s.SetTransportProtocolOption(fakeTransNumber, tc.option); got != tc.wantErr { - t.Errorf("s.SetTransportProtocolOption(fakeTrans, %v) = %v, want = %v", tc.option, got, tc.wantErr) - } - if tc.verifier != nil { - tc.verifier(t, s.TransportProtocolInstance(fakeTransNumber)) - } + v := tcpip.TCPModerateReceiveBufferOption(true) + if err := s.SetTransportProtocolOption(fakeTransNumber, &v); err != nil { + t.Errorf("s.SetTransportProtocolOption(fakeTrans, &%T(%t)): %s", v, v, err) + } + v = false + if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil { + t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &%T): %s", v, err) + } + if !v { + t.Fatalf("got tcpip.TCPModerateReceiveBufferOption = false, want = true") } } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 8ba615521..5e34e27ba 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -864,12 +864,93 @@ func (*DefaultTTLOption) isGettableNetworkProtocolOption() {} func (*DefaultTTLOption) isSettableNetworkProtocolOption() {} -// AvailableCongestionControlOption is used to query the supported congestion -// control algorithms. -type AvailableCongestionControlOption string +// GettableTransportProtocolOption is a marker interface for transport protocol +// options that may be queried. +type GettableTransportProtocolOption interface { + isGettableTransportProtocolOption() +} + +// SettableTransportProtocolOption is a marker interface for transport protocol +// options that may be set. +type SettableTransportProtocolOption interface { + isSettableTransportProtocolOption() +} + +// TCPSACKEnabled the SACK option for TCP. +// +// See: https://tools.ietf.org/html/rfc2018. +type TCPSACKEnabled bool + +func (*TCPSACKEnabled) isGettableTransportProtocolOption() {} + +func (*TCPSACKEnabled) isSettableTransportProtocolOption() {} + +// TCPRecovery is the loss deteoction algorithm used by TCP. +type TCPRecovery int32 + +func (*TCPRecovery) isGettableTransportProtocolOption() {} + +func (*TCPRecovery) isSettableTransportProtocolOption() {} + +const ( + // TCPRACKLossDetection indicates RACK is used for loss detection and + // recovery. + TCPRACKLossDetection TCPRecovery = 1 << iota + + // TCPRACKStaticReoWnd indicates the reordering window should not be + // adjusted when DSACK is received. + TCPRACKStaticReoWnd + + // TCPRACKNoDupTh indicates RACK should not consider the classic three + // duplicate acknowledgements rule to mark the segments as lost. This + // is used when reordering is not detected. + TCPRACKNoDupTh +) + +// TCPDelayEnabled enables/disables Nagle's algorithm in TCP. +type TCPDelayEnabled bool + +func (*TCPDelayEnabled) isGettableTransportProtocolOption() {} + +func (*TCPDelayEnabled) isSettableTransportProtocolOption() {} + +// TCPSendBufferSizeRangeOption is the send buffer size range for TCP. +type TCPSendBufferSizeRangeOption struct { + Min int + Default int + Max int +} + +func (*TCPSendBufferSizeRangeOption) isGettableTransportProtocolOption() {} + +func (*TCPSendBufferSizeRangeOption) isSettableTransportProtocolOption() {} + +// TCPReceiveBufferSizeRangeOption is the receive buffer size range for TCP. +type TCPReceiveBufferSizeRangeOption struct { + Min int + Default int + Max int +} + +func (*TCPReceiveBufferSizeRangeOption) isGettableTransportProtocolOption() {} + +func (*TCPReceiveBufferSizeRangeOption) isSettableTransportProtocolOption() {} + +// TCPAvailableCongestionControlOption is the supported congestion control +// algorithms for TCP +type TCPAvailableCongestionControlOption string + +func (*TCPAvailableCongestionControlOption) isGettableTransportProtocolOption() {} + +func (*TCPAvailableCongestionControlOption) isSettableTransportProtocolOption() {} + +// TCPModerateReceiveBufferOption enables/disables receive buffer moderation +// for TCP. +type TCPModerateReceiveBufferOption bool -// ModerateReceiveBufferOption is used by buffer moderation. -type ModerateReceiveBufferOption bool +func (*TCPModerateReceiveBufferOption) isGettableTransportProtocolOption() {} + +func (*TCPModerateReceiveBufferOption) isSettableTransportProtocolOption() {} // GettableSocketOption is a marker interface for socket options that may be // queried. @@ -935,6 +1016,10 @@ func (*CongestionControlOption) isGettableSocketOption() {} func (*CongestionControlOption) isSettableSocketOption() {} +func (*CongestionControlOption) isGettableTransportProtocolOption() {} + +func (*CongestionControlOption) isSettableTransportProtocolOption() {} + // TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the // maximum duration for which a socket lingers in the TCP_FIN_WAIT_2 state // before being marked closed. @@ -944,6 +1029,10 @@ func (*TCPLingerTimeoutOption) isGettableSocketOption() {} func (*TCPLingerTimeoutOption) isSettableSocketOption() {} +func (*TCPLingerTimeoutOption) isGettableTransportProtocolOption() {} + +func (*TCPLingerTimeoutOption) isSettableTransportProtocolOption() {} + // TCPTimeWaitTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the // maximum duration for which a socket lingers in the TIME_WAIT state // before being marked closed. @@ -953,6 +1042,10 @@ func (*TCPTimeWaitTimeoutOption) isGettableSocketOption() {} func (*TCPTimeWaitTimeoutOption) isSettableSocketOption() {} +func (*TCPTimeWaitTimeoutOption) isGettableTransportProtocolOption() {} + +func (*TCPTimeWaitTimeoutOption) isSettableTransportProtocolOption() {} + // TCPDeferAcceptOption is used by SetSockOpt/GetSockOpt to allow a // accept to return a completed connection only when there is data to be // read. This usually means the listening socket will drop the final ACK @@ -971,6 +1064,10 @@ func (*TCPMinRTOOption) isGettableSocketOption() {} func (*TCPMinRTOOption) isSettableSocketOption() {} +func (*TCPMinRTOOption) isGettableTransportProtocolOption() {} + +func (*TCPMinRTOOption) isSettableTransportProtocolOption() {} + // TCPMaxRTOOption is use by SetSockOpt/GetSockOpt to allow overriding // default MaxRTO used by the Stack. type TCPMaxRTOOption time.Duration @@ -979,6 +1076,10 @@ func (*TCPMaxRTOOption) isGettableSocketOption() {} func (*TCPMaxRTOOption) isSettableSocketOption() {} +func (*TCPMaxRTOOption) isGettableTransportProtocolOption() {} + +func (*TCPMaxRTOOption) isSettableTransportProtocolOption() {} + // TCPMaxRetriesOption is used by SetSockOpt/GetSockOpt to set/get the // maximum number of retransmits after which we time out the connection. type TCPMaxRetriesOption uint64 @@ -987,6 +1088,10 @@ func (*TCPMaxRetriesOption) isGettableSocketOption() {} func (*TCPMaxRetriesOption) isSettableSocketOption() {} +func (*TCPMaxRetriesOption) isGettableTransportProtocolOption() {} + +func (*TCPMaxRetriesOption) isSettableTransportProtocolOption() {} + // TCPSynRcvdCountThresholdOption is used by SetSockOpt/GetSockOpt to specify // the number of endpoints that can be in SYN-RCVD state before the stack // switches to using SYN cookies. @@ -996,6 +1101,10 @@ func (*TCPSynRcvdCountThresholdOption) isGettableSocketOption() {} func (*TCPSynRcvdCountThresholdOption) isSettableSocketOption() {} +func (*TCPSynRcvdCountThresholdOption) isGettableTransportProtocolOption() {} + +func (*TCPSynRcvdCountThresholdOption) isSettableTransportProtocolOption() {} + // TCPSynRetriesOption is used by SetSockOpt/GetSockOpt to specify stack-wide // default for number of times SYN is retransmitted before aborting a connect. type TCPSynRetriesOption uint8 @@ -1004,6 +1113,10 @@ func (*TCPSynRetriesOption) isGettableSocketOption() {} func (*TCPSynRetriesOption) isSettableSocketOption() {} +func (*TCPSynRetriesOption) isGettableTransportProtocolOption() {} + +func (*TCPSynRetriesOption) isSettableTransportProtocolOption() {} + // MulticastInterfaceOption is used by SetSockOpt/GetSockOpt to specify a // default interface for multicast. type MulticastInterfaceOption struct { @@ -1062,6 +1175,10 @@ func (*TCPTimeWaitReuseOption) isGettableSocketOption() {} func (*TCPTimeWaitReuseOption) isSettableSocketOption() {} +func (*TCPTimeWaitReuseOption) isGettableTransportProtocolOption() {} + +func (*TCPTimeWaitReuseOption) isSettableTransportProtocolOption() {} + const ( // TCPTimeWaitReuseDisabled indicates reuse of port bound by endponts in TIME-WAIT cannot // be reused for new connections. diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 74ef6541e..bb11e4e83 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -109,12 +109,12 @@ func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEnd } // SetOption implements stack.TransportProtocol.SetOption. -func (*protocol) SetOption(option interface{}) *tcpip.Error { +func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } // Option implements stack.TransportProtocol.Option. -func (*protocol) Option(option interface{}) *tcpip.Error { +func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 72df5c2a1..09d53d158 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -522,7 +522,7 @@ func (h *handshake) execute() *tcpip.Error { s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) defer s.Done() - var sackEnabled SACKEnabled + var sackEnabled tcpip.TCPSACKEnabled if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil { // If stack returned an error when checking for SACKEnabled // status then just default to switching off SACK negotiation. diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index 80e9dd465..94207c141 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -560,8 +560,9 @@ func TestV4AcceptOnV4(t *testing.T) { func testV4ListenClose(t *testing.T, c *context.Context) { // Set the SynRcvd threshold to zero to force a syn cookie based accept // to happen. - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption failed: %s", err) + var opt tcpip.TCPSynRcvdCountThresholdOption + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } const n = uint16(32) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 4cf966b65..8cb769d58 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -849,12 +849,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue maxSynRetries: DefaultSynRetries, } - var ss SendBufferSizeOption + var ss tcpip.TCPSendBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { e.sndBufSize = ss.Default } - var rs ReceiveBufferSizeOption + var rs tcpip.TCPReceiveBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { e.rcvBufSize = rs.Default } @@ -864,12 +864,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.cc = cs } - var mrb tcpip.ModerateReceiveBufferOption + var mrb tcpip.TCPModerateReceiveBufferOption if err := s.TransportProtocolOption(ProtocolNumber, &mrb); err == nil { e.rcvAutoParams.disabled = !bool(mrb) } - var de DelayEnabled + var de tcpip.TCPDelayEnabled if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de { e.SetSockOptBool(tcpip.DelayOption, true) } @@ -1609,7 +1609,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. - var rs ReceiveBufferSizeOption + var rs tcpip.TCPReceiveBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { if v < rs.Min { v = rs.Min @@ -1659,7 +1659,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max // allowed. - var ss SendBufferSizeOption + var ss tcpip.TCPSendBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { if v < ss.Min { v = ss.Min @@ -1699,7 +1699,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return tcpip.ErrInvalidOptionValue } } - var rs ReceiveBufferSizeOption + var rs tcpip.TCPReceiveBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { if v < rs.Min/2 { v = rs.Min / 2 @@ -1748,7 +1748,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { // Query the available cc algorithms in the stack and // validate that the specified algorithm is actually // supported in the stack. - var avail tcpip.AvailableCongestionControlOption + var avail tcpip.TCPAvailableCongestionControlOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &avail); err != nil { return err } @@ -2707,7 +2707,7 @@ func (e *endpoint) receiveBufferSize() int { } func (e *endpoint) maxReceiveBufferSize() int { - var rs ReceiveBufferSizeOption + var rs tcpip.TCPReceiveBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { // As a fallback return the hardcoded max buffer size. return MaxBufferSize @@ -2787,7 +2787,7 @@ func timeStampOffset() uint32 { // if the SYN options indicate that the SACK option was negotiated and the TCP // stack is configured to enable TCP SACK option. func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) { - var v SACKEnabled + var v tcpip.TCPSACKEnabled if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil { // Stack doesn't support SACK. So just return. return diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 723e47ddc..41d0050f3 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -182,14 +182,14 @@ func (e *endpoint) Resume(s *stack.Stack) { epState := e.origEndpointState switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: - var ss SendBufferSizeOption + var ss tcpip.TCPSendBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) } } - var rs ReceiveBufferSizeOption + var rs tcpip.TCPReceiveBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { if e.rcvBufSize < rs.Min || e.rcvBufSize > rs.Max { panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, rs.Min, rs.Max)) diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index c5afa2680..63ec12be8 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -79,50 +79,6 @@ const ( ccCubic = "cubic" ) -// SACKEnabled is used by stack.(*Stack).TransportProtocolOption to -// enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018. -type SACKEnabled bool - -// Recovery is used by stack.(*Stack).TransportProtocolOption to -// set loss detection algorithm in TCP. -type Recovery int32 - -const ( - // RACKLossDetection indicates RACK is used for loss detection and - // recovery. - RACKLossDetection Recovery = 1 << iota - - // RACKStaticReoWnd indicates the reordering window should not be - // adjusted when DSACK is received. - RACKStaticReoWnd - - // RACKNoDupTh indicates RACK should not consider the classic three - // duplicate acknowledgements rule to mark the segments as lost. This - // is used when reordering is not detected. - RACKNoDupTh -) - -// DelayEnabled is used by stack.(Stack*).TransportProtocolOption to -// enable/disable Nagle's algorithm in TCP. -type DelayEnabled bool - -// SendBufferSizeOption is used by stack.(Stack*).TransportProtocolOption -// to get/set the default, min and max TCP send buffer sizes. -type SendBufferSizeOption struct { - Min int - Default int - Max int -} - -// ReceiveBufferSizeOption is used by -// stack.(Stack*).TransportProtocolOption to get/set the default, min and max -// TCP receive buffer sizes. -type ReceiveBufferSizeOption struct { - Min int - Default int - Max int -} - // syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The // value is protected by a mutex so that we can increment only when it's // guaranteed not to go above a threshold. @@ -183,10 +139,10 @@ func (s *synRcvdCounter) Threshold() uint64 { type protocol struct { mu sync.RWMutex sackEnabled bool - recovery Recovery + recovery tcpip.TCPRecovery delayEnabled bool - sendBufferSize SendBufferSizeOption - recvBufferSize ReceiveBufferSizeOption + sendBufferSize tcpip.TCPSendBufferSizeRangeOption + recvBufferSize tcpip.TCPReceiveBufferSizeRangeOption congestionControl string availableCongestionControl []string moderateReceiveBuffer bool @@ -296,49 +252,49 @@ func replyWithReset(s *segment, tos, ttl uint8) { } // SetOption implements stack.TransportProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error { switch v := option.(type) { - case SACKEnabled: + case *tcpip.TCPSACKEnabled: p.mu.Lock() - p.sackEnabled = bool(v) + p.sackEnabled = bool(*v) p.mu.Unlock() return nil - case Recovery: + case *tcpip.TCPRecovery: p.mu.Lock() - p.recovery = Recovery(v) + p.recovery = *v p.mu.Unlock() return nil - case DelayEnabled: + case *tcpip.TCPDelayEnabled: p.mu.Lock() - p.delayEnabled = bool(v) + p.delayEnabled = bool(*v) p.mu.Unlock() return nil - case SendBufferSizeOption: + case *tcpip.TCPSendBufferSizeRangeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue } p.mu.Lock() - p.sendBufferSize = v + p.sendBufferSize = *v p.mu.Unlock() return nil - case ReceiveBufferSizeOption: + case *tcpip.TCPReceiveBufferSizeRangeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue } p.mu.Lock() - p.recvBufferSize = v + p.recvBufferSize = *v p.mu.Unlock() return nil - case tcpip.CongestionControlOption: + case *tcpip.CongestionControlOption: for _, c := range p.availableCongestionControl { - if string(v) == c { + if string(*v) == c { p.mu.Lock() - p.congestionControl = string(v) + p.congestionControl = string(*v) p.mu.Unlock() return nil } @@ -347,75 +303,79 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { // is specified. return tcpip.ErrNoSuchFile - case tcpip.ModerateReceiveBufferOption: + case *tcpip.TCPModerateReceiveBufferOption: p.mu.Lock() - p.moderateReceiveBuffer = bool(v) + p.moderateReceiveBuffer = bool(*v) p.mu.Unlock() return nil - case tcpip.TCPLingerTimeoutOption: - if v < 0 { - v = 0 - } + case *tcpip.TCPLingerTimeoutOption: p.mu.Lock() - p.lingerTimeout = time.Duration(v) + if *v < 0 { + p.lingerTimeout = 0 + } else { + p.lingerTimeout = time.Duration(*v) + } p.mu.Unlock() return nil - case tcpip.TCPTimeWaitTimeoutOption: - if v < 0 { - v = 0 - } + case *tcpip.TCPTimeWaitTimeoutOption: p.mu.Lock() - p.timeWaitTimeout = time.Duration(v) + if *v < 0 { + p.timeWaitTimeout = 0 + } else { + p.timeWaitTimeout = time.Duration(*v) + } p.mu.Unlock() return nil - case tcpip.TCPTimeWaitReuseOption: - if v < tcpip.TCPTimeWaitReuseDisabled || v > tcpip.TCPTimeWaitReuseLoopbackOnly { + case *tcpip.TCPTimeWaitReuseOption: + if *v < tcpip.TCPTimeWaitReuseDisabled || *v > tcpip.TCPTimeWaitReuseLoopbackOnly { return tcpip.ErrInvalidOptionValue } p.mu.Lock() - p.timeWaitReuse = v + p.timeWaitReuse = *v p.mu.Unlock() return nil - case tcpip.TCPMinRTOOption: - if v < 0 { - v = tcpip.TCPMinRTOOption(MinRTO) - } + case *tcpip.TCPMinRTOOption: p.mu.Lock() - p.minRTO = time.Duration(v) + if *v < 0 { + p.minRTO = MinRTO + } else { + p.minRTO = time.Duration(*v) + } p.mu.Unlock() return nil - case tcpip.TCPMaxRTOOption: - if v < 0 { - v = tcpip.TCPMaxRTOOption(MaxRTO) - } + case *tcpip.TCPMaxRTOOption: p.mu.Lock() - p.maxRTO = time.Duration(v) + if *v < 0 { + p.maxRTO = MaxRTO + } else { + p.maxRTO = time.Duration(*v) + } p.mu.Unlock() return nil - case tcpip.TCPMaxRetriesOption: + case *tcpip.TCPMaxRetriesOption: p.mu.Lock() - p.maxRetries = uint32(v) + p.maxRetries = uint32(*v) p.mu.Unlock() return nil - case tcpip.TCPSynRcvdCountThresholdOption: + case *tcpip.TCPSynRcvdCountThresholdOption: p.mu.Lock() - p.synRcvdCount.SetThreshold(uint64(v)) + p.synRcvdCount.SetThreshold(uint64(*v)) p.mu.Unlock() return nil - case tcpip.TCPSynRetriesOption: - if v < 1 || v > 255 { + case *tcpip.TCPSynRetriesOption: + if *v < 1 || *v > 255 { return tcpip.ErrInvalidOptionValue } p.mu.Lock() - p.synRetries = uint8(v) + p.synRetries = uint8(*v) p.mu.Unlock() return nil @@ -425,33 +385,33 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { } // Option implements stack.TransportProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error { switch v := option.(type) { - case *SACKEnabled: + case *tcpip.TCPSACKEnabled: p.mu.RLock() - *v = SACKEnabled(p.sackEnabled) + *v = tcpip.TCPSACKEnabled(p.sackEnabled) p.mu.RUnlock() return nil - case *Recovery: + case *tcpip.TCPRecovery: p.mu.RLock() - *v = Recovery(p.recovery) + *v = tcpip.TCPRecovery(p.recovery) p.mu.RUnlock() return nil - case *DelayEnabled: + case *tcpip.TCPDelayEnabled: p.mu.RLock() - *v = DelayEnabled(p.delayEnabled) + *v = tcpip.TCPDelayEnabled(p.delayEnabled) p.mu.RUnlock() return nil - case *SendBufferSizeOption: + case *tcpip.TCPSendBufferSizeRangeOption: p.mu.RLock() *v = p.sendBufferSize p.mu.RUnlock() return nil - case *ReceiveBufferSizeOption: + case *tcpip.TCPReceiveBufferSizeRangeOption: p.mu.RLock() *v = p.recvBufferSize p.mu.RUnlock() @@ -463,15 +423,15 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.RUnlock() return nil - case *tcpip.AvailableCongestionControlOption: + case *tcpip.TCPAvailableCongestionControlOption: p.mu.RLock() - *v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) + *v = tcpip.TCPAvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) p.mu.RUnlock() return nil - case *tcpip.ModerateReceiveBufferOption: + case *tcpip.TCPModerateReceiveBufferOption: p.mu.RLock() - *v = tcpip.ModerateReceiveBufferOption(p.moderateReceiveBuffer) + *v = tcpip.TCPModerateReceiveBufferOption(p.moderateReceiveBuffer) p.mu.RUnlock() return nil @@ -567,12 +527,12 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { // NewProtocol returns a TCP transport protocol. func NewProtocol() stack.TransportProtocol { p := protocol{ - sendBufferSize: SendBufferSizeOption{ + sendBufferSize: tcpip.TCPSendBufferSizeRangeOption{ Min: MinBufferSize, Default: DefaultSendBufferSize, Max: MaxBufferSize, }, - recvBufferSize: ReceiveBufferSizeOption{ + recvBufferSize: tcpip.TCPReceiveBufferSizeRangeOption{ Min: MinBufferSize, Default: DefaultReceiveBufferSize, Max: MaxBufferSize, @@ -587,7 +547,7 @@ func NewProtocol() stack.TransportProtocol { minRTO: MinRTO, maxRTO: MaxRTO, maxRetries: MaxRetries, - recovery: RACKLossDetection, + recovery: tcpip.TCPRACKLossDetection, } p.dispatcher.init(runtime.GOMAXPROCS(0)) return &p diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 99521f0c1..ef7f5719f 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -46,8 +46,9 @@ func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint { func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) { t.Helper() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%t) = %s", enable, err) + opt := tcpip.TCPSACKEnabled(enable) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("c.s.SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } } @@ -162,8 +163,9 @@ func TestSackPermittedAccept(t *testing.T) { // Set the SynRcvd threshold to // zero to force a syn cookie // based accept to happen. - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + var opt tcpip.TCPSynRcvdCountThresholdOption + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } } setStackSACKPermitted(t, c, sackEnabled) @@ -236,8 +238,9 @@ func TestSackDisabledAccept(t *testing.T) { // Set the SynRcvd threshold to // zero to force a syn cookie // based accept to happen. - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + var opt tcpip.TCPSynRcvdCountThresholdOption + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 3d09d6def..0d13e1efd 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -309,8 +309,8 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { // Lower stackwide TIME_WAIT timeout so that the reservations // are released instantly on Close. tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil { - t.Fatalf("e.stack.SetTransportProtocolOption(%d, %#v) = %s", tcp.ProtocolNumber, tcpTW, err) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &tcpTW); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, tcpTW, tcpTW, err) } c.EP.Close() @@ -432,8 +432,9 @@ func TestConnectResetAfterClose(t *testing.T) { // Set TCPLinger to 3 seconds so that sockets are marked closed // after 3 second in FIN_WAIT2 state. tcpLingerTimeout := 3 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%s) failed: %s", tcpLingerTimeout, err) + opt := tcpip.TCPLingerTimeoutOption(tcpLingerTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } c.CreateConnected(789, 30000, -1 /* epRcvBuf */) @@ -506,8 +507,9 @@ func TestCurrentConnectedIncrement(t *testing.T) { // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed // after 1 second in TIME_WAIT state. tcpTimeWaitTimeout := 1 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) + opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } c.CreateConnected(789, 30000, -1 /* epRcvBuf */) @@ -933,8 +935,8 @@ func TestUserSuppliedMSSOnListenAccept(t *testing.T) { // Set the SynRcvd threshold to force a syn cookie based accept to happen. opt := tcpip.TCPSynRcvdCountThresholdOption(nonSynCookieAccepts) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, %#v): %s", tcp.ProtocolNumber, opt, err) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { @@ -2867,8 +2869,9 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { // Set the SynRcvd threshold to zero to force a syn cookie based accept // to happen. - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + opt := tcpip.TCPSynRcvdCountThresholdOption(0) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } // Create EP and start listening. @@ -3146,8 +3149,9 @@ func TestMaxRetransmitsTimeout(t *testing.T) { defer c.Cleanup() const numRetries = 2 - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRetriesOption(numRetries)); err != nil { - t.Fatalf("could not set protocol option MaxRetries.\n") + opt := tcpip.TCPMaxRetriesOption(numRetries) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) @@ -3206,8 +3210,9 @@ func TestMaxRTO(t *testing.T) { defer c.Cleanup() rto := 1 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRTOOption(rto)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPMaxRTO(%d) failed: %s", rto, err) + opt := tcpip.TCPMaxRTOOption(rto) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) @@ -3964,8 +3969,9 @@ func TestReadAfterClosedState(t *testing.T) { // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed // after 1 second in TIME_WAIT state. tcpTimeWaitTimeout := 1 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) + opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } c.CreateConnected(789, 30000, -1 /* epRcvBuf */) @@ -4204,11 +4210,15 @@ func TestDefaultBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) // Change the default send buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{ - Min: 1, - Default: tcp.DefaultSendBufferSize * 2, - Max: tcp.DefaultSendBufferSize * 20}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPSendBufferSizeRangeOption{ + Min: 1, + Default: tcp.DefaultSendBufferSize * 2, + Max: tcp.DefaultSendBufferSize * 20, + } + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } ep.Close() @@ -4221,11 +4231,15 @@ func TestDefaultBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) // Change the default receive buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{ - Min: 1, - Default: tcp.DefaultReceiveBufferSize * 3, - Max: tcp.DefaultReceiveBufferSize * 30}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) + { + opt := tcpip.TCPReceiveBufferSizeRangeOption{ + Min: 1, + Default: tcp.DefaultReceiveBufferSize * 3, + Max: tcp.DefaultReceiveBufferSize * 30, + } + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } ep.Close() @@ -4252,12 +4266,18 @@ func TestMinMaxBufferSizes(t *testing.T) { defer ep.Close() // Change the min/max values for send/receive - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20} + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPSendBufferSizeRangeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30} + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } // Set values below the min. @@ -4718,8 +4738,8 @@ func TestStackSetCongestionControl(t *testing.T) { t.Fatalf("s.TransportProtocolOption(%v, %v) = %s", tcp.ProtocolNumber, &oldCC, err) } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err { - t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.cc, err, tc.err) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &tc.cc); err != tc.err { + t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = %s, want = %s", tcp.ProtocolNumber, tc.cc, tc.cc, err, tc.err) } var cc tcpip.CongestionControlOption @@ -4751,12 +4771,12 @@ func TestStackAvailableCongestionControl(t *testing.T) { s := c.Stack() // Query permitted congestion control algorithms. - var aCC tcpip.AvailableCongestionControlOption + var aCC tcpip.TCPAvailableCongestionControlOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil { t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err) } - if got, want := aCC, tcpip.AvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want) + if got, want := aCC, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want { + t.Fatalf("got tcpip.TCPAvailableCongestionControlOption: %v, want: %v", got, want) } } @@ -4767,18 +4787,18 @@ func TestStackSetAvailableCongestionControl(t *testing.T) { s := c.Stack() // Setting AvailableCongestionControlOption should fail. - aCC := tcpip.AvailableCongestionControlOption("xyz") + aCC := tcpip.TCPAvailableCongestionControlOption("xyz") if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = nil, want non-nil", tcp.ProtocolNumber, &aCC) + t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = nil, want non-nil", tcp.ProtocolNumber, aCC, aCC) } // Verify that we still get the expected list of congestion control options. - var cc tcpip.AvailableCongestionControlOption + var cc tcpip.TCPAvailableCongestionControlOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) + t.Fatalf("s.TransportProtocolOptio(%d, &%T(%s)): %s", tcp.ProtocolNumber, cc, cc, err) } - if got, want := cc, tcpip.AvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want) + if got, want := cc, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want { + t.Fatalf("got tcpip.TCPAvailableCongestionControlOption = %s, want = %s", got, want) } } @@ -4842,8 +4862,8 @@ func TestEndpointSetCongestionControl(t *testing.T) { func enableCUBIC(t *testing.T, c *context.Context) { t.Helper() opt := tcpip.CongestionControlOption("cubic") - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %s = %s", opt, err) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)) %s", tcp.ProtocolNumber, opt, opt, err) } } @@ -5505,8 +5525,9 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(1)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 1 failed: %s", err) + opt := tcpip.TCPSynRcvdCountThresholdOption(1) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } // Create TCP endpoint. @@ -5906,13 +5927,19 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // the segment queue holding unprocessed packets is limited to 500. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize} + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } // Enable auto-tuning. - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPModerateReceiveBufferOption(true) + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } } // Change the expected window scale to match the value needed for the // maximum buffer size defined above. @@ -6027,13 +6054,19 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // the segment queue holding unprocessed packets is limited to 300. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize} + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } // Enable auto-tuning. - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPModerateReceiveBufferOption(true) + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } } // Change the expected window scale to match the value needed for the // maximum buffer size used by stack. @@ -6169,7 +6202,7 @@ func TestDelayEnabled(t *testing.T) { checkDelayOption(t, c, false, false) // Delay is disabled by default. for _, v := range []struct { - delayEnabled tcp.DelayEnabled + delayEnabled tcpip.TCPDelayEnabled wantDelayOption bool }{ {delayEnabled: false, wantDelayOption: false}, @@ -6177,17 +6210,17 @@ func TestDelayEnabled(t *testing.T) { } { c := context.New(t, defaultMTU) defer c.Cleanup() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil { - t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %s", v.delayEnabled, err) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &v.delayEnabled); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, v.delayEnabled, v.delayEnabled, err) } checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption) } } -func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption bool) { +func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.TCPDelayEnabled, wantDelayOption bool) { t.Helper() - var gotDelayEnabled tcp.DelayEnabled + var gotDelayEnabled tcpip.TCPDelayEnabled if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil { t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err) } @@ -6625,8 +6658,9 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed // after 5 seconds in TIME_WAIT state. tcpTimeWaitTimeout := 5 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err) } want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1 @@ -6775,8 +6809,9 @@ func TestTCPCloseWithData(t *testing.T) { // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed // after 5 seconds in TIME_WAIT state. tcpTimeWaitTimeout := 5 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err) } wq := &waiter.Queue{} @@ -7462,9 +7497,10 @@ func TestSetStackTimeWaitReuse(t *testing.T) { } for _, tc := range testCases { - err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitReuseOption(tc.v)) + opt := tcpip.TCPTimeWaitReuseOption(tc.v) + err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt) if got, want := err, tc.err; got != want { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.v, err, tc.err) + t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)) = %s, want = %s", tcp.ProtocolNumber, tc.v, tc.v, err, tc.err) } if tc.err != nil { continue diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index 8edbff964..44593ed98 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -131,8 +131,9 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS defer c.Cleanup() if cookieEnabled { - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + var opt tcpip.TCPSynRcvdCountThresholdOption + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } } @@ -192,8 +193,9 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd defer c.Cleanup() if cookieEnabled { - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + var opt tcpip.TCPSynRcvdCountThresholdOption + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 8bb5e5f6d..baf7df197 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -146,19 +146,22 @@ func New(t *testing.T, mtu uint32) *Context { const sendBufferSize = 1 << 20 // 1 MiB const recvBufferSize = 1 << 20 // 1 MiB // Allow minimum send/receive buffer sizes to be 1 during tests. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + sendBufOpt := tcpip.TCPSendBufferSizeRangeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize} + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sendBufOpt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, sendBufOpt, err) } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + rcvBufOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize} + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvBufOpt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, rcvBufOpt, err) } // Increase minimum RTO in tests to avoid test flakes due to early // retransmit in case the test executors are overloaded and cause timers // to fire earlier than expected. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMinRTOOption(3*time.Second)); err != nil { - t.Fatalf("failed to set stack-wide minRTO: %s", err) + minRTOOpt := tcpip.TCPMinRTOOption(3 * time.Second) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil { + t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err) } // Some of the congestion control tests send up to 640 packets, we so @@ -1096,7 +1099,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions // SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true // for the Stack in the context. func (c *Context) SACKEnabled() bool { - var v tcp.SACKEnabled + var v tcpip.TCPSACKEnabled if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil { // Stack doesn't support SACK. So just return. return false diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index f65751dd4..3f87e8057 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -202,12 +202,12 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } // SetOption implements stack.TransportProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } // Option implements stack.TransportProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -- cgit v1.2.3