From d35f07b36a40ee94b3a5b588a5fa62572753c620 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Tue, 8 Sep 2020 12:15:58 -0700 Subject: Improve type safety for transport protocol options The existing implementation for TransportProtocol.{Set}Option take arguments of an empty interface type which all types (implicitly) implement; any type may be passed to the functions. This change introduces marker interfaces for transport protocol options that may be set or queried which transport protocol option types implement to ensure that invalid types are caught at compile time. Different interfaces are used to allow the compiler to enforce read-only or set-only socket options. RELNOTES: n/a PiperOrigin-RevId: 330559811 --- pkg/tcpip/transport/tcp/protocol.go | 178 ++++++++++++++---------------------- 1 file changed, 69 insertions(+), 109 deletions(-) (limited to 'pkg/tcpip/transport/tcp/protocol.go') diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index c5afa2680..63ec12be8 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -79,50 +79,6 @@ const ( ccCubic = "cubic" ) -// SACKEnabled is used by stack.(*Stack).TransportProtocolOption to -// enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018. -type SACKEnabled bool - -// Recovery is used by stack.(*Stack).TransportProtocolOption to -// set loss detection algorithm in TCP. -type Recovery int32 - -const ( - // RACKLossDetection indicates RACK is used for loss detection and - // recovery. - RACKLossDetection Recovery = 1 << iota - - // RACKStaticReoWnd indicates the reordering window should not be - // adjusted when DSACK is received. - RACKStaticReoWnd - - // RACKNoDupTh indicates RACK should not consider the classic three - // duplicate acknowledgements rule to mark the segments as lost. This - // is used when reordering is not detected. - RACKNoDupTh -) - -// DelayEnabled is used by stack.(Stack*).TransportProtocolOption to -// enable/disable Nagle's algorithm in TCP. -type DelayEnabled bool - -// SendBufferSizeOption is used by stack.(Stack*).TransportProtocolOption -// to get/set the default, min and max TCP send buffer sizes. -type SendBufferSizeOption struct { - Min int - Default int - Max int -} - -// ReceiveBufferSizeOption is used by -// stack.(Stack*).TransportProtocolOption to get/set the default, min and max -// TCP receive buffer sizes. -type ReceiveBufferSizeOption struct { - Min int - Default int - Max int -} - // syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The // value is protected by a mutex so that we can increment only when it's // guaranteed not to go above a threshold. @@ -183,10 +139,10 @@ func (s *synRcvdCounter) Threshold() uint64 { type protocol struct { mu sync.RWMutex sackEnabled bool - recovery Recovery + recovery tcpip.TCPRecovery delayEnabled bool - sendBufferSize SendBufferSizeOption - recvBufferSize ReceiveBufferSizeOption + sendBufferSize tcpip.TCPSendBufferSizeRangeOption + recvBufferSize tcpip.TCPReceiveBufferSizeRangeOption congestionControl string availableCongestionControl []string moderateReceiveBuffer bool @@ -296,49 +252,49 @@ func replyWithReset(s *segment, tos, ttl uint8) { } // SetOption implements stack.TransportProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error { switch v := option.(type) { - case SACKEnabled: + case *tcpip.TCPSACKEnabled: p.mu.Lock() - p.sackEnabled = bool(v) + p.sackEnabled = bool(*v) p.mu.Unlock() return nil - case Recovery: + case *tcpip.TCPRecovery: p.mu.Lock() - p.recovery = Recovery(v) + p.recovery = *v p.mu.Unlock() return nil - case DelayEnabled: + case *tcpip.TCPDelayEnabled: p.mu.Lock() - p.delayEnabled = bool(v) + p.delayEnabled = bool(*v) p.mu.Unlock() return nil - case SendBufferSizeOption: + case *tcpip.TCPSendBufferSizeRangeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue } p.mu.Lock() - p.sendBufferSize = v + p.sendBufferSize = *v p.mu.Unlock() return nil - case ReceiveBufferSizeOption: + case *tcpip.TCPReceiveBufferSizeRangeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue } p.mu.Lock() - p.recvBufferSize = v + p.recvBufferSize = *v p.mu.Unlock() return nil - case tcpip.CongestionControlOption: + case *tcpip.CongestionControlOption: for _, c := range p.availableCongestionControl { - if string(v) == c { + if string(*v) == c { p.mu.Lock() - p.congestionControl = string(v) + p.congestionControl = string(*v) p.mu.Unlock() return nil } @@ -347,75 +303,79 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { // is specified. return tcpip.ErrNoSuchFile - case tcpip.ModerateReceiveBufferOption: + case *tcpip.TCPModerateReceiveBufferOption: p.mu.Lock() - p.moderateReceiveBuffer = bool(v) + p.moderateReceiveBuffer = bool(*v) p.mu.Unlock() return nil - case tcpip.TCPLingerTimeoutOption: - if v < 0 { - v = 0 - } + case *tcpip.TCPLingerTimeoutOption: p.mu.Lock() - p.lingerTimeout = time.Duration(v) + if *v < 0 { + p.lingerTimeout = 0 + } else { + p.lingerTimeout = time.Duration(*v) + } p.mu.Unlock() return nil - case tcpip.TCPTimeWaitTimeoutOption: - if v < 0 { - v = 0 - } + case *tcpip.TCPTimeWaitTimeoutOption: p.mu.Lock() - p.timeWaitTimeout = time.Duration(v) + if *v < 0 { + p.timeWaitTimeout = 0 + } else { + p.timeWaitTimeout = time.Duration(*v) + } p.mu.Unlock() return nil - case tcpip.TCPTimeWaitReuseOption: - if v < tcpip.TCPTimeWaitReuseDisabled || v > tcpip.TCPTimeWaitReuseLoopbackOnly { + case *tcpip.TCPTimeWaitReuseOption: + if *v < tcpip.TCPTimeWaitReuseDisabled || *v > tcpip.TCPTimeWaitReuseLoopbackOnly { return tcpip.ErrInvalidOptionValue } p.mu.Lock() - p.timeWaitReuse = v + p.timeWaitReuse = *v p.mu.Unlock() return nil - case tcpip.TCPMinRTOOption: - if v < 0 { - v = tcpip.TCPMinRTOOption(MinRTO) - } + case *tcpip.TCPMinRTOOption: p.mu.Lock() - p.minRTO = time.Duration(v) + if *v < 0 { + p.minRTO = MinRTO + } else { + p.minRTO = time.Duration(*v) + } p.mu.Unlock() return nil - case tcpip.TCPMaxRTOOption: - if v < 0 { - v = tcpip.TCPMaxRTOOption(MaxRTO) - } + case *tcpip.TCPMaxRTOOption: p.mu.Lock() - p.maxRTO = time.Duration(v) + if *v < 0 { + p.maxRTO = MaxRTO + } else { + p.maxRTO = time.Duration(*v) + } p.mu.Unlock() return nil - case tcpip.TCPMaxRetriesOption: + case *tcpip.TCPMaxRetriesOption: p.mu.Lock() - p.maxRetries = uint32(v) + p.maxRetries = uint32(*v) p.mu.Unlock() return nil - case tcpip.TCPSynRcvdCountThresholdOption: + case *tcpip.TCPSynRcvdCountThresholdOption: p.mu.Lock() - p.synRcvdCount.SetThreshold(uint64(v)) + p.synRcvdCount.SetThreshold(uint64(*v)) p.mu.Unlock() return nil - case tcpip.TCPSynRetriesOption: - if v < 1 || v > 255 { + case *tcpip.TCPSynRetriesOption: + if *v < 1 || *v > 255 { return tcpip.ErrInvalidOptionValue } p.mu.Lock() - p.synRetries = uint8(v) + p.synRetries = uint8(*v) p.mu.Unlock() return nil @@ -425,33 +385,33 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { } // Option implements stack.TransportProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error { switch v := option.(type) { - case *SACKEnabled: + case *tcpip.TCPSACKEnabled: p.mu.RLock() - *v = SACKEnabled(p.sackEnabled) + *v = tcpip.TCPSACKEnabled(p.sackEnabled) p.mu.RUnlock() return nil - case *Recovery: + case *tcpip.TCPRecovery: p.mu.RLock() - *v = Recovery(p.recovery) + *v = tcpip.TCPRecovery(p.recovery) p.mu.RUnlock() return nil - case *DelayEnabled: + case *tcpip.TCPDelayEnabled: p.mu.RLock() - *v = DelayEnabled(p.delayEnabled) + *v = tcpip.TCPDelayEnabled(p.delayEnabled) p.mu.RUnlock() return nil - case *SendBufferSizeOption: + case *tcpip.TCPSendBufferSizeRangeOption: p.mu.RLock() *v = p.sendBufferSize p.mu.RUnlock() return nil - case *ReceiveBufferSizeOption: + case *tcpip.TCPReceiveBufferSizeRangeOption: p.mu.RLock() *v = p.recvBufferSize p.mu.RUnlock() @@ -463,15 +423,15 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.RUnlock() return nil - case *tcpip.AvailableCongestionControlOption: + case *tcpip.TCPAvailableCongestionControlOption: p.mu.RLock() - *v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) + *v = tcpip.TCPAvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) p.mu.RUnlock() return nil - case *tcpip.ModerateReceiveBufferOption: + case *tcpip.TCPModerateReceiveBufferOption: p.mu.RLock() - *v = tcpip.ModerateReceiveBufferOption(p.moderateReceiveBuffer) + *v = tcpip.TCPModerateReceiveBufferOption(p.moderateReceiveBuffer) p.mu.RUnlock() return nil @@ -567,12 +527,12 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { // NewProtocol returns a TCP transport protocol. func NewProtocol() stack.TransportProtocol { p := protocol{ - sendBufferSize: SendBufferSizeOption{ + sendBufferSize: tcpip.TCPSendBufferSizeRangeOption{ Min: MinBufferSize, Default: DefaultSendBufferSize, Max: MaxBufferSize, }, - recvBufferSize: ReceiveBufferSizeOption{ + recvBufferSize: tcpip.TCPReceiveBufferSizeRangeOption{ Min: MinBufferSize, Default: DefaultReceiveBufferSize, Max: MaxBufferSize, @@ -587,7 +547,7 @@ func NewProtocol() stack.TransportProtocol { minRTO: MinRTO, maxRTO: MaxRTO, maxRetries: MaxRetries, - recovery: RACKLossDetection, + recovery: tcpip.TCPRACKLossDetection, } p.dispatcher.init(runtime.GOMAXPROCS(0)) return &p -- cgit v1.2.3