diff options
-rw-r--r-- | pkg/tcpip/tcpip.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/cubic.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/snd.go | 35 |
4 files changed, 45 insertions, 14 deletions
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 166d37004..976f0b0d1 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -430,7 +430,10 @@ type TimestampOption int // TCPInfoOption is used by GetSockOpt to expose TCP statistics. // // TODO: Add and populate stat fields. -type TCPInfoOption struct{} +type TCPInfoOption struct { + RTT time.Duration + RTTVar time.Duration +} // Route is a row in the routing table. It specifies through which NIC (and // gateway) sets of packets should be routed. A row is considered viable if the diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go index cdb85598d..8cea416d2 100644 --- a/pkg/tcpip/transport/tcp/cubic.go +++ b/pkg/tcpip/transport/tcp/cubic.go @@ -129,7 +129,10 @@ func (c *cubicState) Update(packetsAcked int) { return } } else { - c.s.sndCwnd = c.getCwnd(packetsAcked, c.s.sndCwnd, c.s.srtt) + c.s.rtt.Lock() + srtt := c.s.rtt.srtt + c.s.rtt.Unlock() + c.s.sndCwnd = c.getCwnd(packetsAcked, c.s.sndCwnd, srtt) } } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index cbbbbc084..7c73f0d13 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -788,6 +788,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case *tcpip.TCPInfoOption: *o = tcpip.TCPInfoOption{} + e.mu.RLock() + snd := e.snd + e.mu.RUnlock() + if snd != nil { + snd.rtt.Lock() + o.RTT = snd.rtt.srtt + o.RTTVar = snd.rtt.rttvar + snd.rtt.Unlock() + } + return nil } @@ -1463,13 +1473,15 @@ func (e *endpoint) completeState() stack.TCPEndpointState { RTTMeasureSeqNum: e.snd.rttMeasureSeqNum, RTTMeasureTime: e.snd.rttMeasureTime, Closed: e.snd.closed, - SRTT: e.snd.srtt, RTO: e.snd.rto, SRTTInited: e.snd.srttInited, MaxPayloadSize: e.snd.maxPayloadSize, SndWndScale: e.snd.sndWndScale, MaxSentAck: e.snd.maxSentAck, } + e.snd.rtt.Lock() + s.Sender.SRTT = e.snd.rtt.srtt + e.snd.rtt.Unlock() if cubic, ok := e.snd.cc.(*cubicState); ok { s.Sender.Cubic = stack.TCPCubicState{ diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 568bd7024..096ea9cd4 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -16,6 +16,7 @@ package tcp import ( "math" + "sync" "time" "gvisor.googlesource.com/gvisor/pkg/sleep" @@ -116,11 +117,10 @@ type sender struct { resendTimer timer `state:"nosave"` resendWaker sleep.Waker `state:"nosave"` - // srtt, rttvar & rto are the "smoothed round-trip time", "round-trip - // time variation" and "retransmit timeout", as defined in section 2 of - // RFC 6298. - srtt time.Duration - rttvar time.Duration + // rtt.srtt, rtt.rttvar, and rto are the "smoothed round-trip time", + // "round-trip time variation" and "retransmit timeout", as defined in + // section 2 of RFC 6298. + rtt rtt rto time.Duration srttInited bool @@ -139,6 +139,17 @@ type sender struct { cc congestionControl } +// rtt is a synchronization wrapper used to appease stateify. See the comment +// in sender, where it is used. +// +// +stateify savable +type rtt struct { + sync.Mutex `state:"nosave"` + + srtt time.Duration + rttvar time.Duration +} + // fastRecovery holds information related to fast recovery from a packet loss. // // +stateify savable @@ -265,20 +276,22 @@ func (s *sender) sendAck() { // updateRTO updates the retransmit timeout when a new roud-trip time is // available. This is done in accordance with section 2 of RFC 6298. func (s *sender) updateRTO(rtt time.Duration) { + s.rtt.Lock() if !s.srttInited { - s.rttvar = rtt / 2 - s.srtt = rtt + s.rtt.rttvar = rtt / 2 + s.rtt.srtt = rtt s.srttInited = true } else { - diff := s.srtt - rtt + diff := s.rtt.srtt - rtt if diff < 0 { diff = -diff } - s.rttvar = (3*s.rttvar + diff) / 4 - s.srtt = (7*s.srtt + rtt) / 8 + s.rtt.rttvar = (3*s.rtt.rttvar + diff) / 4 + s.rtt.srtt = (7*s.rtt.srtt + rtt) / 8 } - s.rto = s.srtt + 4*s.rttvar + s.rto = s.rtt.srtt + 4*s.rtt.rttvar + s.rtt.Unlock() if s.rto < minRTO { s.rto = minRTO } |