summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/tcp
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport/tcp')
-rw-r--r--pkg/tcpip/transport/tcp/connect.go2
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go73
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go10
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go33
-rw-r--r--pkg/tcpip/transport/tcp/tcp_state_autogen.go148
5 files changed, 176 insertions, 90 deletions
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 8dd759ba2..46702906b 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -1706,7 +1706,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
}
case notification:
n := e.fetchNotifications()
- if n&notifyClose != 0 || n&notifyAbort != 0 {
+ if n&notifyAbort != 0 {
return nil
}
if n&notifyDrain != 0 {
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index b8b52b03d..d08cfe0ff 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -449,10 +449,11 @@ type endpoint struct {
// recentTS is the timestamp that should be sent in the TSEcr field of
// the timestamp for future segments sent by the endpoint. This field is
// updated if required when a new segment is received by this endpoint.
- //
- // recentTS must be read/written atomically.
recentTS uint32
+ // recentTSTime is the unix time when we updated recentTS last.
+ recentTSTime time.Time `state:".(unixTime)"`
+
// tsOffset is a randomized offset added to the value of the
// TSVal field in the timestamp option.
tsOffset uint32
@@ -795,15 +796,15 @@ func (e *endpoint) EndpointState() EndpointState {
return EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
}
-// setRecentTimestamp atomically sets the recentTS field to the
-// provided value.
+// setRecentTimestamp sets the recentTS field to the provided value.
func (e *endpoint) setRecentTimestamp(recentTS uint32) {
- atomic.StoreUint32(&e.recentTS, recentTS)
+ e.recentTS = recentTS
+ e.recentTSTime = time.Now()
}
-// recentTimestamp atomically reads and returns the value of the recentTS field.
+// recentTimestamp returns the value of the recentTS field.
func (e *endpoint) recentTimestamp() uint32 {
- return atomic.LoadUint32(&e.recentTS)
+ return e.recentTS
}
// keepalive is a synchronization wrapper used to appease stateify. See the
@@ -902,7 +903,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv:
// Ready for nothing.
- case StateClose, StateError:
+ case StateClose, StateError, StateTimeWait:
// Ready for anything.
result = mask
@@ -2148,12 +2149,66 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
h.Write(portBuf)
portOffset := h.Sum32()
+ var twReuse tcpip.TCPTimeWaitReuseOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &twReuse); err != nil {
+ panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &twReuse, err))
+ }
+
+ reuse := twReuse == tcpip.TCPTimeWaitReuseGlobal
+ if twReuse == tcpip.TCPTimeWaitReuseLoopbackOnly {
+ switch netProto {
+ case header.IPv4ProtocolNumber:
+ reuse = header.IsV4LoopbackAddress(e.ID.LocalAddress) && header.IsV4LoopbackAddress(e.ID.RemoteAddress)
+ case header.IPv6ProtocolNumber:
+ reuse = e.ID.LocalAddress == header.IPv6Loopback && e.ID.RemoteAddress == header.IPv6Loopback
+ }
+ }
+
if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) {
if sameAddr && p == e.ID.RemotePort {
return false, nil
}
if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil {
- return false, nil
+ if err != tcpip.ErrPortInUse || !reuse {
+ return false, nil
+ }
+ transEPID := e.ID
+ transEPID.LocalPort = p
+ // Check if an endpoint is registered with demuxer in TIME-WAIT and if
+ // we can reuse it. If we can't find a transport endpoint then we just
+ // skip using this port as it's possible that either an endpoint has
+ // bound the port but not registered with demuxer yet (no listen/connect
+ // done yet) or the reservation was freed between the check above and
+ // the FindTransportEndpoint below. But rather than retry the same port
+ // we just skip it and move on.
+ transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, &r)
+ if transEP == nil {
+ // ReservePort failed but there is no registered endpoint with
+ // demuxer. Which indicates there is at least some endpoint that has
+ // bound the port.
+ return false, nil
+ }
+
+ tcpEP := transEP.(*endpoint)
+ tcpEP.LockUser()
+ // If the endpoint is not in TIME-WAIT or if it is in TIME-WAIT but
+ // less than 1 second has elapsed since its recentTS was updated then
+ // we cannot reuse the port.
+ if tcpEP.EndpointState() != StateTimeWait || time.Since(tcpEP.recentTSTime) < 1*time.Second {
+ tcpEP.UnlockUser()
+ return false, nil
+ }
+ // Since the endpoint is in TIME-WAIT it should be safe to acquire its
+ // Lock while holding the lock for this endpoint as endpoints in
+ // TIME-WAIT do not acquire locks on other endpoints.
+ tcpEP.workerCleanup = false
+ tcpEP.cleanupLocked()
+ tcpEP.notifyProtocolGoroutine(notifyAbort)
+ tcpEP.UnlockUser()
+ // Now try and Reserve again if it fails then we skip.
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil {
+ return false, nil
+ }
}
id := e.ID
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index abf1ac5c9..723e47ddc 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -309,6 +309,16 @@ func (e *endpoint) loadLastError(s string) {
e.lastError = tcpip.StringToError(s)
}
+// saveRecentTSTime is invoked by stateify.
+func (e *endpoint) saveRecentTSTime() unixTime {
+ return unixTime{e.recentTSTime.Unix(), e.recentTSTime.UnixNano()}
+}
+
+// loadRecentTSTime is invoked by stateify.
+func (e *endpoint) loadRecentTSTime(unix unixTime) {
+ e.recentTSTime = time.Unix(unix.second, unix.nano)
+}
+
// saveHardError is invoked by stateify.
func (e *EndpointInfo) saveHardError() string {
if e.HardError == nil {
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 2e5093b36..49a673b42 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -191,8 +191,9 @@ type protocol struct {
congestionControl string
availableCongestionControl []string
moderateReceiveBuffer bool
- tcpLingerTimeout time.Duration
- tcpTimeWaitTimeout time.Duration
+ lingerTimeout time.Duration
+ timeWaitTimeout time.Duration
+ timeWaitReuse tcpip.TCPTimeWaitReuseOption
minRTO time.Duration
maxRTO time.Duration
maxRetries uint32
@@ -358,7 +359,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
v = 0
}
p.mu.Lock()
- p.tcpLingerTimeout = time.Duration(v)
+ p.lingerTimeout = time.Duration(v)
p.mu.Unlock()
return nil
@@ -367,7 +368,16 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
v = 0
}
p.mu.Lock()
- p.tcpTimeWaitTimeout = time.Duration(v)
+ p.timeWaitTimeout = time.Duration(v)
+ p.mu.Unlock()
+ return nil
+
+ case tcpip.TCPTimeWaitReuseOption:
+ if v < tcpip.TCPTimeWaitReuseDisabled || v > tcpip.TCPTimeWaitReuseLoopbackOnly {
+ return tcpip.ErrInvalidOptionValue
+ }
+ p.mu.Lock()
+ p.timeWaitReuse = v
p.mu.Unlock()
return nil
@@ -468,13 +478,19 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
case *tcpip.TCPLingerTimeoutOption:
p.mu.RLock()
- *v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout)
+ *v = tcpip.TCPLingerTimeoutOption(p.lingerTimeout)
p.mu.RUnlock()
return nil
case *tcpip.TCPTimeWaitTimeoutOption:
p.mu.RLock()
- *v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout)
+ *v = tcpip.TCPTimeWaitTimeoutOption(p.timeWaitTimeout)
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.TCPTimeWaitReuseOption:
+ p.mu.RLock()
+ *v = tcpip.TCPTimeWaitReuseOption(p.timeWaitReuse)
p.mu.RUnlock()
return nil
@@ -564,8 +580,9 @@ func NewProtocol() stack.TransportProtocol {
},
congestionControl: ccReno,
availableCongestionControl: []string{ccReno, ccCubic},
- tcpLingerTimeout: DefaultTCPLingerTimeout,
- tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout,
+ lingerTimeout: DefaultTCPLingerTimeout,
+ timeWaitTimeout: DefaultTCPTimeWaitTimeout,
+ timeWaitReuse: tcpip.TCPTimeWaitReuseLoopbackOnly,
synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold},
synRetries: DefaultSynRetries,
minRTO: MinRTO,
diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
index 132131815..bed45e9a1 100644
--- a/pkg/tcpip/transport/tcp/tcp_state_autogen.go
+++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
@@ -185,6 +185,7 @@ func (x *endpoint) StateFields() []string {
"workerCleanup",
"sendTSOk",
"recentTS",
+ "recentTSTime",
"tsOffset",
"shutdownFlags",
"sackPermitted",
@@ -230,8 +231,10 @@ func (x *endpoint) StateSave(m state.Sink) {
m.SaveValue(3, lastError)
var state EndpointState = x.saveState()
m.SaveValue(10, state)
+ var recentTSTime unixTime = x.saveRecentTSTime()
+ m.SaveValue(25, recentTSTime)
var acceptedChan []*endpoint = x.saveAcceptedChan()
- m.SaveValue(50, acceptedChan)
+ m.SaveValue(51, acceptedChan)
m.Save(0, &x.EndpointInfo)
m.Save(1, &x.waiterQueue)
m.Save(2, &x.uniqueID)
@@ -255,41 +258,41 @@ func (x *endpoint) StateSave(m state.Sink) {
m.Save(22, &x.workerCleanup)
m.Save(23, &x.sendTSOk)
m.Save(24, &x.recentTS)
- m.Save(25, &x.tsOffset)
- m.Save(26, &x.shutdownFlags)
- m.Save(27, &x.sackPermitted)
- m.Save(28, &x.sack)
- m.Save(29, &x.bindToDevice)
- m.Save(30, &x.delay)
- m.Save(31, &x.cork)
- m.Save(32, &x.scoreboard)
- m.Save(33, &x.slowAck)
- m.Save(34, &x.segmentQueue)
- m.Save(35, &x.synRcvdCount)
- m.Save(36, &x.userMSS)
- m.Save(37, &x.maxSynRetries)
- m.Save(38, &x.windowClamp)
- m.Save(39, &x.sndBufSize)
- m.Save(40, &x.sndBufUsed)
- m.Save(41, &x.sndClosed)
- m.Save(42, &x.sndBufInQueue)
- m.Save(43, &x.sndQueue)
- m.Save(44, &x.cc)
- m.Save(45, &x.packetTooBigCount)
- m.Save(46, &x.sndMTU)
- m.Save(47, &x.keepalive)
- m.Save(48, &x.userTimeout)
- m.Save(49, &x.deferAccept)
- m.Save(51, &x.rcv)
- m.Save(52, &x.snd)
- m.Save(53, &x.connectingAddress)
- m.Save(54, &x.amss)
- m.Save(55, &x.sendTOS)
- m.Save(56, &x.gso)
- m.Save(57, &x.tcpLingerTimeout)
- m.Save(58, &x.closed)
- m.Save(59, &x.txHash)
- m.Save(60, &x.owner)
+ m.Save(26, &x.tsOffset)
+ m.Save(27, &x.shutdownFlags)
+ m.Save(28, &x.sackPermitted)
+ m.Save(29, &x.sack)
+ m.Save(30, &x.bindToDevice)
+ m.Save(31, &x.delay)
+ m.Save(32, &x.cork)
+ m.Save(33, &x.scoreboard)
+ m.Save(34, &x.slowAck)
+ m.Save(35, &x.segmentQueue)
+ m.Save(36, &x.synRcvdCount)
+ m.Save(37, &x.userMSS)
+ m.Save(38, &x.maxSynRetries)
+ m.Save(39, &x.windowClamp)
+ m.Save(40, &x.sndBufSize)
+ m.Save(41, &x.sndBufUsed)
+ m.Save(42, &x.sndClosed)
+ m.Save(43, &x.sndBufInQueue)
+ m.Save(44, &x.sndQueue)
+ m.Save(45, &x.cc)
+ m.Save(46, &x.packetTooBigCount)
+ m.Save(47, &x.sndMTU)
+ m.Save(48, &x.keepalive)
+ m.Save(49, &x.userTimeout)
+ m.Save(50, &x.deferAccept)
+ m.Save(52, &x.rcv)
+ m.Save(53, &x.snd)
+ m.Save(54, &x.connectingAddress)
+ m.Save(55, &x.amss)
+ m.Save(56, &x.sendTOS)
+ m.Save(57, &x.gso)
+ m.Save(58, &x.tcpLingerTimeout)
+ m.Save(59, &x.closed)
+ m.Save(60, &x.txHash)
+ m.Save(61, &x.owner)
}
func (x *endpoint) StateLoad(m state.Source) {
@@ -316,44 +319,45 @@ func (x *endpoint) StateLoad(m state.Source) {
m.Load(22, &x.workerCleanup)
m.Load(23, &x.sendTSOk)
m.Load(24, &x.recentTS)
- m.Load(25, &x.tsOffset)
- m.Load(26, &x.shutdownFlags)
- m.Load(27, &x.sackPermitted)
- m.Load(28, &x.sack)
- m.Load(29, &x.bindToDevice)
- m.Load(30, &x.delay)
- m.Load(31, &x.cork)
- m.Load(32, &x.scoreboard)
- m.Load(33, &x.slowAck)
- m.LoadWait(34, &x.segmentQueue)
- m.Load(35, &x.synRcvdCount)
- m.Load(36, &x.userMSS)
- m.Load(37, &x.maxSynRetries)
- m.Load(38, &x.windowClamp)
- m.Load(39, &x.sndBufSize)
- m.Load(40, &x.sndBufUsed)
- m.Load(41, &x.sndClosed)
- m.Load(42, &x.sndBufInQueue)
- m.LoadWait(43, &x.sndQueue)
- m.Load(44, &x.cc)
- m.Load(45, &x.packetTooBigCount)
- m.Load(46, &x.sndMTU)
- m.Load(47, &x.keepalive)
- m.Load(48, &x.userTimeout)
- m.Load(49, &x.deferAccept)
- m.LoadWait(51, &x.rcv)
- m.LoadWait(52, &x.snd)
- m.Load(53, &x.connectingAddress)
- m.Load(54, &x.amss)
- m.Load(55, &x.sendTOS)
- m.Load(56, &x.gso)
- m.Load(57, &x.tcpLingerTimeout)
- m.Load(58, &x.closed)
- m.Load(59, &x.txHash)
- m.Load(60, &x.owner)
+ m.Load(26, &x.tsOffset)
+ m.Load(27, &x.shutdownFlags)
+ m.Load(28, &x.sackPermitted)
+ m.Load(29, &x.sack)
+ m.Load(30, &x.bindToDevice)
+ m.Load(31, &x.delay)
+ m.Load(32, &x.cork)
+ m.Load(33, &x.scoreboard)
+ m.Load(34, &x.slowAck)
+ m.LoadWait(35, &x.segmentQueue)
+ m.Load(36, &x.synRcvdCount)
+ m.Load(37, &x.userMSS)
+ m.Load(38, &x.maxSynRetries)
+ m.Load(39, &x.windowClamp)
+ m.Load(40, &x.sndBufSize)
+ m.Load(41, &x.sndBufUsed)
+ m.Load(42, &x.sndClosed)
+ m.Load(43, &x.sndBufInQueue)
+ m.LoadWait(44, &x.sndQueue)
+ m.Load(45, &x.cc)
+ m.Load(46, &x.packetTooBigCount)
+ m.Load(47, &x.sndMTU)
+ m.Load(48, &x.keepalive)
+ m.Load(49, &x.userTimeout)
+ m.Load(50, &x.deferAccept)
+ m.LoadWait(52, &x.rcv)
+ m.LoadWait(53, &x.snd)
+ m.Load(54, &x.connectingAddress)
+ m.Load(55, &x.amss)
+ m.Load(56, &x.sendTOS)
+ m.Load(57, &x.gso)
+ m.Load(58, &x.tcpLingerTimeout)
+ m.Load(59, &x.closed)
+ m.Load(60, &x.txHash)
+ m.Load(61, &x.owner)
m.LoadValue(3, new(string), func(y interface{}) { x.loadLastError(y.(string)) })
m.LoadValue(10, new(EndpointState), func(y interface{}) { x.loadState(y.(EndpointState)) })
- m.LoadValue(50, new([]*endpoint), func(y interface{}) { x.loadAcceptedChan(y.([]*endpoint)) })
+ m.LoadValue(25, new(unixTime), func(y interface{}) { x.loadRecentTSTime(y.(unixTime)) })
+ m.LoadValue(51, new([]*endpoint), func(y interface{}) { x.loadAcceptedChan(y.([]*endpoint)) })
m.AfterLoad(x.afterLoad)
}