diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/header/ipv4.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/header/ipv6.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 73 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint_state.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/protocol.go | 33 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_state_autogen.go | 148 |
8 files changed, 208 insertions, 90 deletions
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index d0d1efd0d..680eafd16 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -315,3 +315,12 @@ func IsV4MulticastAddress(addr tcpip.Address) bool { } return (addr[0] & 0xf0) == 0xe0 } + +// IsV4LoopbackAddress determines if the provided address is an IPv4 loopback +// address (belongs to 127.0.0.1/8 subnet). +func IsV4LoopbackAddress(addr tcpip.Address) bool { + if len(addr) != IPv4AddressSize { + return false + } + return addr[0] == 0x7f +} diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 4f367fe4c..ea3823898 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -98,6 +98,9 @@ const ( // section 5. IPv6MinimumMTU = 1280 + // IPv6Loopback is the IPv6 Loopback address. + IPv6Loopback tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + // IPv6Any is the non-routable IPv6 "any" meta address. It is also // known as the unspecified address. IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 091bc5281..07c85ce59 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -958,6 +958,26 @@ type SocketDetachFilterOption int // and port of a redirected packet. type OriginalDestinationOption FullAddress +// TCPTimeWaitReuseOption is used stack.(*Stack).TransportProtocolOption to +// specify if the stack can reuse the port bound by an endpoint in TIME-WAIT for +// new connections when it is safe from protocol viewpoint. +type TCPTimeWaitReuseOption uint8 + +const ( + // TCPTimeWaitReuseDisabled indicates reuse of port bound by endponts in TIME-WAIT cannot + // be reused for new connections. + TCPTimeWaitReuseDisabled TCPTimeWaitReuseOption = iota + + // TCPTimeWaitReuseGlobal indicates reuse of port bound by endponts in TIME-WAIT can + // be reused for new connections irrespective of the src/dest addresses. + TCPTimeWaitReuseGlobal + + // TCPTimeWaitReuseLoopbackOnly indicates reuse of port bound by endpoint in TIME-WAIT can + // only be reused if the connection was a connection over loopback. i.e src/dest adddresses + // are loopback addresses. + TCPTimeWaitReuseLoopbackOnly +) + // IPPacketInfo is the message structure for IP_PKTINFO. // // +stateify savable diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 8dd759ba2..46702906b 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -1706,7 +1706,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { } case notification: n := e.fetchNotifications() - if n¬ifyClose != 0 || n¬ifyAbort != 0 { + if n¬ifyAbort != 0 { return nil } if n¬ifyDrain != 0 { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index b8b52b03d..d08cfe0ff 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -449,10 +449,11 @@ type endpoint struct { // recentTS is the timestamp that should be sent in the TSEcr field of // the timestamp for future segments sent by the endpoint. This field is // updated if required when a new segment is received by this endpoint. - // - // recentTS must be read/written atomically. recentTS uint32 + // recentTSTime is the unix time when we updated recentTS last. + recentTSTime time.Time `state:".(unixTime)"` + // tsOffset is a randomized offset added to the value of the // TSVal field in the timestamp option. tsOffset uint32 @@ -795,15 +796,15 @@ func (e *endpoint) EndpointState() EndpointState { return EndpointState(atomic.LoadUint32((*uint32)(&e.state))) } -// setRecentTimestamp atomically sets the recentTS field to the -// provided value. +// setRecentTimestamp sets the recentTS field to the provided value. func (e *endpoint) setRecentTimestamp(recentTS uint32) { - atomic.StoreUint32(&e.recentTS, recentTS) + e.recentTS = recentTS + e.recentTSTime = time.Now() } -// recentTimestamp atomically reads and returns the value of the recentTS field. +// recentTimestamp returns the value of the recentTS field. func (e *endpoint) recentTimestamp() uint32 { - return atomic.LoadUint32(&e.recentTS) + return e.recentTS } // keepalive is a synchronization wrapper used to appease stateify. See the @@ -902,7 +903,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv: // Ready for nothing. - case StateClose, StateError: + case StateClose, StateError, StateTimeWait: // Ready for anything. result = mask @@ -2148,12 +2149,66 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc h.Write(portBuf) portOffset := h.Sum32() + var twReuse tcpip.TCPTimeWaitReuseOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &twReuse); err != nil { + panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &twReuse, err)) + } + + reuse := twReuse == tcpip.TCPTimeWaitReuseGlobal + if twReuse == tcpip.TCPTimeWaitReuseLoopbackOnly { + switch netProto { + case header.IPv4ProtocolNumber: + reuse = header.IsV4LoopbackAddress(e.ID.LocalAddress) && header.IsV4LoopbackAddress(e.ID.RemoteAddress) + case header.IPv6ProtocolNumber: + reuse = e.ID.LocalAddress == header.IPv6Loopback && e.ID.RemoteAddress == header.IPv6Loopback + } + } + if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) { if sameAddr && p == e.ID.RemotePort { return false, nil } if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil { - return false, nil + if err != tcpip.ErrPortInUse || !reuse { + return false, nil + } + transEPID := e.ID + transEPID.LocalPort = p + // Check if an endpoint is registered with demuxer in TIME-WAIT and if + // we can reuse it. If we can't find a transport endpoint then we just + // skip using this port as it's possible that either an endpoint has + // bound the port but not registered with demuxer yet (no listen/connect + // done yet) or the reservation was freed between the check above and + // the FindTransportEndpoint below. But rather than retry the same port + // we just skip it and move on. + transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, &r) + if transEP == nil { + // ReservePort failed but there is no registered endpoint with + // demuxer. Which indicates there is at least some endpoint that has + // bound the port. + return false, nil + } + + tcpEP := transEP.(*endpoint) + tcpEP.LockUser() + // If the endpoint is not in TIME-WAIT or if it is in TIME-WAIT but + // less than 1 second has elapsed since its recentTS was updated then + // we cannot reuse the port. + if tcpEP.EndpointState() != StateTimeWait || time.Since(tcpEP.recentTSTime) < 1*time.Second { + tcpEP.UnlockUser() + return false, nil + } + // Since the endpoint is in TIME-WAIT it should be safe to acquire its + // Lock while holding the lock for this endpoint as endpoints in + // TIME-WAIT do not acquire locks on other endpoints. + tcpEP.workerCleanup = false + tcpEP.cleanupLocked() + tcpEP.notifyProtocolGoroutine(notifyAbort) + tcpEP.UnlockUser() + // Now try and Reserve again if it fails then we skip. + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil { + return false, nil + } } id := e.ID diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index abf1ac5c9..723e47ddc 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -309,6 +309,16 @@ func (e *endpoint) loadLastError(s string) { e.lastError = tcpip.StringToError(s) } +// saveRecentTSTime is invoked by stateify. +func (e *endpoint) saveRecentTSTime() unixTime { + return unixTime{e.recentTSTime.Unix(), e.recentTSTime.UnixNano()} +} + +// loadRecentTSTime is invoked by stateify. +func (e *endpoint) loadRecentTSTime(unix unixTime) { + e.recentTSTime = time.Unix(unix.second, unix.nano) +} + // saveHardError is invoked by stateify. func (e *EndpointInfo) saveHardError() string { if e.HardError == nil { diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 2e5093b36..49a673b42 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -191,8 +191,9 @@ type protocol struct { congestionControl string availableCongestionControl []string moderateReceiveBuffer bool - tcpLingerTimeout time.Duration - tcpTimeWaitTimeout time.Duration + lingerTimeout time.Duration + timeWaitTimeout time.Duration + timeWaitReuse tcpip.TCPTimeWaitReuseOption minRTO time.Duration maxRTO time.Duration maxRetries uint32 @@ -358,7 +359,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { v = 0 } p.mu.Lock() - p.tcpLingerTimeout = time.Duration(v) + p.lingerTimeout = time.Duration(v) p.mu.Unlock() return nil @@ -367,7 +368,16 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { v = 0 } p.mu.Lock() - p.tcpTimeWaitTimeout = time.Duration(v) + p.timeWaitTimeout = time.Duration(v) + p.mu.Unlock() + return nil + + case tcpip.TCPTimeWaitReuseOption: + if v < tcpip.TCPTimeWaitReuseDisabled || v > tcpip.TCPTimeWaitReuseLoopbackOnly { + return tcpip.ErrInvalidOptionValue + } + p.mu.Lock() + p.timeWaitReuse = v p.mu.Unlock() return nil @@ -468,13 +478,19 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { case *tcpip.TCPLingerTimeoutOption: p.mu.RLock() - *v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout) + *v = tcpip.TCPLingerTimeoutOption(p.lingerTimeout) p.mu.RUnlock() return nil case *tcpip.TCPTimeWaitTimeoutOption: p.mu.RLock() - *v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout) + *v = tcpip.TCPTimeWaitTimeoutOption(p.timeWaitTimeout) + p.mu.RUnlock() + return nil + + case *tcpip.TCPTimeWaitReuseOption: + p.mu.RLock() + *v = tcpip.TCPTimeWaitReuseOption(p.timeWaitReuse) p.mu.RUnlock() return nil @@ -564,8 +580,9 @@ func NewProtocol() stack.TransportProtocol { }, congestionControl: ccReno, availableCongestionControl: []string{ccReno, ccCubic}, - tcpLingerTimeout: DefaultTCPLingerTimeout, - tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, + lingerTimeout: DefaultTCPLingerTimeout, + timeWaitTimeout: DefaultTCPTimeWaitTimeout, + timeWaitReuse: tcpip.TCPTimeWaitReuseLoopbackOnly, synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold}, synRetries: DefaultSynRetries, minRTO: MinRTO, diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go index 132131815..bed45e9a1 100644 --- a/pkg/tcpip/transport/tcp/tcp_state_autogen.go +++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go @@ -185,6 +185,7 @@ func (x *endpoint) StateFields() []string { "workerCleanup", "sendTSOk", "recentTS", + "recentTSTime", "tsOffset", "shutdownFlags", "sackPermitted", @@ -230,8 +231,10 @@ func (x *endpoint) StateSave(m state.Sink) { m.SaveValue(3, lastError) var state EndpointState = x.saveState() m.SaveValue(10, state) + var recentTSTime unixTime = x.saveRecentTSTime() + m.SaveValue(25, recentTSTime) var acceptedChan []*endpoint = x.saveAcceptedChan() - m.SaveValue(50, acceptedChan) + m.SaveValue(51, acceptedChan) m.Save(0, &x.EndpointInfo) m.Save(1, &x.waiterQueue) m.Save(2, &x.uniqueID) @@ -255,41 +258,41 @@ func (x *endpoint) StateSave(m state.Sink) { m.Save(22, &x.workerCleanup) m.Save(23, &x.sendTSOk) m.Save(24, &x.recentTS) - m.Save(25, &x.tsOffset) - m.Save(26, &x.shutdownFlags) - m.Save(27, &x.sackPermitted) - m.Save(28, &x.sack) - m.Save(29, &x.bindToDevice) - m.Save(30, &x.delay) - m.Save(31, &x.cork) - m.Save(32, &x.scoreboard) - m.Save(33, &x.slowAck) - m.Save(34, &x.segmentQueue) - m.Save(35, &x.synRcvdCount) - m.Save(36, &x.userMSS) - m.Save(37, &x.maxSynRetries) - m.Save(38, &x.windowClamp) - m.Save(39, &x.sndBufSize) - m.Save(40, &x.sndBufUsed) - m.Save(41, &x.sndClosed) - m.Save(42, &x.sndBufInQueue) - m.Save(43, &x.sndQueue) - m.Save(44, &x.cc) - m.Save(45, &x.packetTooBigCount) - m.Save(46, &x.sndMTU) - m.Save(47, &x.keepalive) - m.Save(48, &x.userTimeout) - m.Save(49, &x.deferAccept) - m.Save(51, &x.rcv) - m.Save(52, &x.snd) - m.Save(53, &x.connectingAddress) - m.Save(54, &x.amss) - m.Save(55, &x.sendTOS) - m.Save(56, &x.gso) - m.Save(57, &x.tcpLingerTimeout) - m.Save(58, &x.closed) - m.Save(59, &x.txHash) - m.Save(60, &x.owner) + m.Save(26, &x.tsOffset) + m.Save(27, &x.shutdownFlags) + m.Save(28, &x.sackPermitted) + m.Save(29, &x.sack) + m.Save(30, &x.bindToDevice) + m.Save(31, &x.delay) + m.Save(32, &x.cork) + m.Save(33, &x.scoreboard) + m.Save(34, &x.slowAck) + m.Save(35, &x.segmentQueue) + m.Save(36, &x.synRcvdCount) + m.Save(37, &x.userMSS) + m.Save(38, &x.maxSynRetries) + m.Save(39, &x.windowClamp) + m.Save(40, &x.sndBufSize) + m.Save(41, &x.sndBufUsed) + m.Save(42, &x.sndClosed) + m.Save(43, &x.sndBufInQueue) + m.Save(44, &x.sndQueue) + m.Save(45, &x.cc) + m.Save(46, &x.packetTooBigCount) + m.Save(47, &x.sndMTU) + m.Save(48, &x.keepalive) + m.Save(49, &x.userTimeout) + m.Save(50, &x.deferAccept) + m.Save(52, &x.rcv) + m.Save(53, &x.snd) + m.Save(54, &x.connectingAddress) + m.Save(55, &x.amss) + m.Save(56, &x.sendTOS) + m.Save(57, &x.gso) + m.Save(58, &x.tcpLingerTimeout) + m.Save(59, &x.closed) + m.Save(60, &x.txHash) + m.Save(61, &x.owner) } func (x *endpoint) StateLoad(m state.Source) { @@ -316,44 +319,45 @@ func (x *endpoint) StateLoad(m state.Source) { m.Load(22, &x.workerCleanup) m.Load(23, &x.sendTSOk) m.Load(24, &x.recentTS) - m.Load(25, &x.tsOffset) - m.Load(26, &x.shutdownFlags) - m.Load(27, &x.sackPermitted) - m.Load(28, &x.sack) - m.Load(29, &x.bindToDevice) - m.Load(30, &x.delay) - m.Load(31, &x.cork) - m.Load(32, &x.scoreboard) - m.Load(33, &x.slowAck) - m.LoadWait(34, &x.segmentQueue) - m.Load(35, &x.synRcvdCount) - m.Load(36, &x.userMSS) - m.Load(37, &x.maxSynRetries) - m.Load(38, &x.windowClamp) - m.Load(39, &x.sndBufSize) - m.Load(40, &x.sndBufUsed) - m.Load(41, &x.sndClosed) - m.Load(42, &x.sndBufInQueue) - m.LoadWait(43, &x.sndQueue) - m.Load(44, &x.cc) - m.Load(45, &x.packetTooBigCount) - m.Load(46, &x.sndMTU) - m.Load(47, &x.keepalive) - m.Load(48, &x.userTimeout) - m.Load(49, &x.deferAccept) - m.LoadWait(51, &x.rcv) - m.LoadWait(52, &x.snd) - m.Load(53, &x.connectingAddress) - m.Load(54, &x.amss) - m.Load(55, &x.sendTOS) - m.Load(56, &x.gso) - m.Load(57, &x.tcpLingerTimeout) - m.Load(58, &x.closed) - m.Load(59, &x.txHash) - m.Load(60, &x.owner) + m.Load(26, &x.tsOffset) + m.Load(27, &x.shutdownFlags) + m.Load(28, &x.sackPermitted) + m.Load(29, &x.sack) + m.Load(30, &x.bindToDevice) + m.Load(31, &x.delay) + m.Load(32, &x.cork) + m.Load(33, &x.scoreboard) + m.Load(34, &x.slowAck) + m.LoadWait(35, &x.segmentQueue) + m.Load(36, &x.synRcvdCount) + m.Load(37, &x.userMSS) + m.Load(38, &x.maxSynRetries) + m.Load(39, &x.windowClamp) + m.Load(40, &x.sndBufSize) + m.Load(41, &x.sndBufUsed) + m.Load(42, &x.sndClosed) + m.Load(43, &x.sndBufInQueue) + m.LoadWait(44, &x.sndQueue) + m.Load(45, &x.cc) + m.Load(46, &x.packetTooBigCount) + m.Load(47, &x.sndMTU) + m.Load(48, &x.keepalive) + m.Load(49, &x.userTimeout) + m.Load(50, &x.deferAccept) + m.LoadWait(52, &x.rcv) + m.LoadWait(53, &x.snd) + m.Load(54, &x.connectingAddress) + m.Load(55, &x.amss) + m.Load(56, &x.sendTOS) + m.Load(57, &x.gso) + m.Load(58, &x.tcpLingerTimeout) + m.Load(59, &x.closed) + m.Load(60, &x.txHash) + m.Load(61, &x.owner) m.LoadValue(3, new(string), func(y interface{}) { x.loadLastError(y.(string)) }) m.LoadValue(10, new(EndpointState), func(y interface{}) { x.loadState(y.(EndpointState)) }) - m.LoadValue(50, new([]*endpoint), func(y interface{}) { x.loadAcceptedChan(y.([]*endpoint)) }) + m.LoadValue(25, new(unixTime), func(y interface{}) { x.loadRecentTSTime(y.(unixTime)) }) + m.LoadValue(51, new([]*endpoint), func(y interface{}) { x.loadAcceptedChan(y.([]*endpoint)) }) m.AfterLoad(x.afterLoad) } |