diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/header/tcp.go | 29 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 90 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 66 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint_state.go | 68 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/protocol.go | 52 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/snd.go | 59 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 116 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 32 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint_state.go | 18 |
10 files changed, 428 insertions, 106 deletions
diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index 29454c4b9..4c6f808e5 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -66,6 +66,14 @@ const ( TCPOptionSACK = 5 ) +// Option Lengths. +const ( + TCPOptionMSSLength = 4 + TCPOptionTSLength = 10 + TCPOptionWSLength = 3 + TCPOptionSackPermittedLength = 2 +) + // TCPFields contains the fields of a TCP packet. It is used to describe the // fields of a packet that needs to be encoded. type TCPFields struct { @@ -494,14 +502,11 @@ func ParseTCPOptions(b []byte) TCPOptions { // returns without encoding anything. It returns the number of bytes written to // the provided buffer. func EncodeMSSOption(mss uint32, b []byte) int { - // mssOptionSize is the number of bytes in a valid MSS option. - const mssOptionSize = 4 - - if len(b) < mssOptionSize { + if len(b) < TCPOptionMSSLength { return 0 } - b[0], b[1], b[2], b[3] = TCPOptionMSS, mssOptionSize, byte(mss>>8), byte(mss) - return mssOptionSize + b[0], b[1], b[2], b[3] = TCPOptionMSS, TCPOptionMSSLength, byte(mss>>8), byte(mss) + return TCPOptionMSSLength } // EncodeWSOption encodes the WS TCP option with the WS value in the @@ -509,10 +514,10 @@ func EncodeMSSOption(mss uint32, b []byte) int { // returns without encoding anything. It returns the number of bytes written to // the provided buffer. func EncodeWSOption(ws int, b []byte) int { - if len(b) < 3 { + if len(b) < TCPOptionWSLength { return 0 } - b[0], b[1], b[2] = TCPOptionWS, 3, uint8(ws) + b[0], b[1], b[2] = TCPOptionWS, TCPOptionWSLength, uint8(ws) return int(b[1]) } @@ -521,10 +526,10 @@ func EncodeWSOption(ws int, b []byte) int { // just returns without encoding anything. It returns the number of bytes // written to the provided buffer. func EncodeTSOption(tsVal, tsEcr uint32, b []byte) int { - if len(b) < 10 { + if len(b) < TCPOptionTSLength { return 0 } - b[0], b[1] = TCPOptionTS, 10 + b[0], b[1] = TCPOptionTS, TCPOptionTSLength binary.BigEndian.PutUint32(b[2:], tsVal) binary.BigEndian.PutUint32(b[6:], tsEcr) return int(b[1]) @@ -535,11 +540,11 @@ func EncodeTSOption(tsVal, tsEcr uint32, b []byte) int { // encoding anything. It returns the number of bytes written to the provided // buffer. func EncodeSACKPermittedOption(b []byte) int { - if len(b) < 2 { + if len(b) < TCPOptionSackPermittedLength { return 0 } - b[0], b[1] = TCPOptionSACKPermitted, 2 + b[0], b[1] = TCPOptionSACKPermitted, TCPOptionSackPermittedLength return int(b[1]) } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index b39ffa9fb..0ab4c3e19 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -235,11 +235,11 @@ type RcvBufAutoTuneParams struct { // was started. MeasureTime time.Time - // CopiedBytes is the number of bytes copied to user space since + // CopiedBytes is the number of bytes copied to userspace since // this measure began. CopiedBytes int - // PrevCopiedBytes is the number of bytes copied to user space in + // PrevCopiedBytes is the number of bytes copied to userspace in // the previous RTT period. PrevCopiedBytes int diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 1ca4088c9..b7b227328 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -110,6 +110,71 @@ var ( ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"} ) +var messageToError map[string]*Error + +var populate sync.Once + +// StringToError converts an error message to the error. +func StringToError(s string) *Error { + populate.Do(func() { + var errors = []*Error{ + ErrUnknownProtocol, + ErrUnknownNICID, + ErrUnknownDevice, + ErrUnknownProtocolOption, + ErrDuplicateNICID, + ErrDuplicateAddress, + ErrNoRoute, + ErrBadLinkEndpoint, + ErrAlreadyBound, + ErrInvalidEndpointState, + ErrAlreadyConnecting, + ErrAlreadyConnected, + ErrNoPortAvailable, + ErrPortInUse, + ErrBadLocalAddress, + ErrClosedForSend, + ErrClosedForReceive, + ErrWouldBlock, + ErrConnectionRefused, + ErrTimeout, + ErrAborted, + ErrConnectStarted, + ErrDestinationRequired, + ErrNotSupported, + ErrQueueSizeNotSupported, + ErrNotConnected, + ErrConnectionReset, + ErrConnectionAborted, + ErrNoSuchFile, + ErrInvalidOptionValue, + ErrNoLinkAddress, + ErrBadAddress, + ErrNetworkUnreachable, + ErrMessageTooLong, + ErrNoBufferSpace, + ErrBroadcastDisabled, + ErrNotPermitted, + ErrAddressFamilyNotSupported, + } + + messageToError = make(map[string]*Error) + for _, e := range errors { + if messageToError[e.String()] != nil { + panic("tcpip errors with duplicated message: " + e.String()) + } + messageToError[e.String()] = e + } + }) + + e, ok := messageToError[s] + if !ok { + panic("unknown error message: " + s) + } + + return e +} + // Errors related to Subnet var ( errSubnetLengthMismatch = errors.New("subnet length of address and mask differ") @@ -622,6 +687,19 @@ const ( // // A zero value indicates the default. TTLOption + + // TCPSynCountOption is used by SetSockOpt/GetSockOpt to specify the number of + // SYN retransmits that TCP should send before aborting the attempt to + // connect. It cannot exceed 255. + // + // NOTE: This option is currently only stubbed out and is no-op. + TCPSynCountOption + + // TCPWindowClampOption is used by SetSockOpt/GetSockOpt to bound the size + // of the advertised window to this value. + // + // NOTE: This option is currently only stubed out and is a no-op + TCPWindowClampOption ) // ErrorOption is used in GetSockOpt to specify that the last error reported by @@ -685,11 +763,23 @@ type TCPDeferAcceptOption time.Duration // default MinRTO used by the Stack. type TCPMinRTOOption time.Duration +// TCPMaxRTOOption is use by SetSockOpt/GetSockOpt to allow overriding +// default MaxRTO used by the Stack. +type TCPMaxRTOOption time.Duration + +// TCPMaxRetriesOption is used by SetSockOpt/GetSockOpt to set/get the +// maximum number of retransmits after which we time out the connection. +type TCPMaxRetriesOption uint64 + // TCPSynRcvdCountThresholdOption is used by SetSockOpt/GetSockOpt to specify // the number of endpoints that can be in SYN-RCVD state before the stack // switches to using SYN cookies. type TCPSynRcvdCountThresholdOption uint64 +// TCPSynRetriesOption is used by SetSockOpt/GetSockOpt to specify stack-wide +// default for number of times SYN is retransmitted before aborting a connect. +type TCPSynRetriesOption uint8 + // MulticastInterfaceOption is used by SetSockOpt/GetSockOpt to specify a // default interface for multicast. type MulticastInterfaceOption struct { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 07d3e64c8..b5ba972f1 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -470,6 +470,17 @@ type endpoint struct { // for this endpoint using the TCP_MAXSEG setsockopt. userMSS uint16 + // maxSynRetries is the maximum number of SYN retransmits that TCP should + // send before aborting the attempt to connect. It cannot exceed 255. + // + // NOTE: This is currently a no-op and does not change the SYN + // retransmissions. + maxSynRetries uint8 + + // windowClamp is used to bound the size of the advertised window to + // this value. + windowClamp uint32 + // The following fields are used to manage the send buffer. When // segments are ready to be sent, they are added to sndQueue and the // protocol goroutine is signaled via sndWaker. @@ -795,8 +806,10 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue interval: 75 * time.Second, count: 9, }, - uniqueID: s.UniqueID(), - txHash: s.Rand().Uint32(), + uniqueID: s.UniqueID(), + txHash: s.Rand().Uint32(), + windowClamp: DefaultReceiveBufferSize, + maxSynRetries: DefaultSynRetries, } var ss SendBufferSizeOption @@ -829,6 +842,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.tcpLingerTimeout = time.Duration(tcpLT) } + var synRetries tcpip.TCPSynRetriesOption + if err := s.TransportProtocolOption(ProtocolNumber, &synRetries); err == nil { + e.maxSynRetries = uint8(synRetries) + } + if p := s.GetTCPProbe(); p != nil { e.probe = p } @@ -1079,7 +1097,7 @@ func (e *endpoint) initialReceiveWindow() int { } // ModerateRecvBuf adjusts the receive buffer and the advertised window -// based on the number of bytes copied to user space. +// based on the number of bytes copied to userspace. func (e *endpoint) ModerateRecvBuf(copied int) { e.LockUser() defer e.UnlockUser() @@ -1603,6 +1621,36 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.ttl = uint8(v) e.UnlockUser() + case tcpip.TCPSynCountOption: + if v < 1 || v > 255 { + return tcpip.ErrInvalidOptionValue + } + e.LockUser() + e.maxSynRetries = uint8(v) + e.UnlockUser() + + case tcpip.TCPWindowClampOption: + if v == 0 { + e.LockUser() + switch e.EndpointState() { + case StateClose, StateInitial: + e.windowClamp = 0 + e.UnlockUser() + return nil + default: + e.UnlockUser() + return tcpip.ErrInvalidOptionValue + } + } + var rs ReceiveBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + if v < rs.Min/2 { + v = rs.Min / 2 + } + } + e.LockUser() + e.windowClamp = uint32(v) + e.UnlockUser() } return nil } @@ -1826,6 +1874,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.UnlockUser() return v, nil + case tcpip.TCPSynCountOption: + e.LockUser() + v := int(e.maxSynRetries) + e.UnlockUser() + return v, nil + + case tcpip.TCPWindowClampOption: + e.LockUser() + v := int(e.windowClamp) + e.UnlockUser() + return v, nil + default: return -1, tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 8b7562396..fc43c11e2 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -314,7 +314,7 @@ func (e *endpoint) loadLastError(s string) { return } - e.lastError = loadError(s) + e.lastError = tcpip.StringToError(s) } // saveHardError is invoked by stateify. @@ -332,71 +332,7 @@ func (e *EndpointInfo) loadHardError(s string) { return } - e.HardError = loadError(s) -} - -var messageToError map[string]*tcpip.Error - -var populate sync.Once - -func loadError(s string) *tcpip.Error { - populate.Do(func() { - var errors = []*tcpip.Error{ - tcpip.ErrUnknownProtocol, - tcpip.ErrUnknownNICID, - tcpip.ErrUnknownDevice, - tcpip.ErrUnknownProtocolOption, - tcpip.ErrDuplicateNICID, - tcpip.ErrDuplicateAddress, - tcpip.ErrNoRoute, - tcpip.ErrBadLinkEndpoint, - tcpip.ErrAlreadyBound, - tcpip.ErrInvalidEndpointState, - tcpip.ErrAlreadyConnecting, - tcpip.ErrAlreadyConnected, - tcpip.ErrNoPortAvailable, - tcpip.ErrPortInUse, - tcpip.ErrBadLocalAddress, - tcpip.ErrClosedForSend, - tcpip.ErrClosedForReceive, - tcpip.ErrWouldBlock, - tcpip.ErrConnectionRefused, - tcpip.ErrTimeout, - tcpip.ErrAborted, - tcpip.ErrConnectStarted, - tcpip.ErrDestinationRequired, - tcpip.ErrNotSupported, - tcpip.ErrQueueSizeNotSupported, - tcpip.ErrNotConnected, - tcpip.ErrConnectionReset, - tcpip.ErrConnectionAborted, - tcpip.ErrNoSuchFile, - tcpip.ErrInvalidOptionValue, - tcpip.ErrNoLinkAddress, - tcpip.ErrBadAddress, - tcpip.ErrNetworkUnreachable, - tcpip.ErrMessageTooLong, - tcpip.ErrNoBufferSpace, - tcpip.ErrBroadcastDisabled, - tcpip.ErrNotPermitted, - tcpip.ErrAddressFamilyNotSupported, - } - - messageToError = make(map[string]*tcpip.Error) - for _, e := range errors { - if messageToError[e.String()] != nil { - panic("tcpip errors with duplicated message: " + e.String()) - } - messageToError[e.String()] = e - } - }) - - e, ok := messageToError[s] - if !ok { - panic("unknown error message: " + s) - } - - return e + e.HardError = tcpip.StringToError(s) } // saveMeasureTime is invoked by stateify. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index cfd9a4e8e..2a2a7ddeb 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -64,6 +64,10 @@ const ( // DefaultTCPTimeWaitTimeout is the amount of time that sockets linger // in TIME_WAIT state before being marked closed. DefaultTCPTimeWaitTimeout = 60 * time.Second + + // DefaultSynRetries is the default value for the number of SYN retransmits + // before a connect is aborted. + DefaultSynRetries = 6 ) // SACKEnabled option can be used to enable SACK support in the TCP @@ -163,7 +167,10 @@ type protocol struct { tcpLingerTimeout time.Duration tcpTimeWaitTimeout time.Duration minRTO time.Duration + maxRTO time.Duration + maxRetries uint32 synRcvdCount synRcvdCounter + synRetries uint8 dispatcher *dispatcher } @@ -340,12 +347,36 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case tcpip.TCPMaxRTOOption: + if v < 0 { + v = tcpip.TCPMaxRTOOption(MaxRTO) + } + p.mu.Lock() + p.maxRTO = time.Duration(v) + p.mu.Unlock() + return nil + + case tcpip.TCPMaxRetriesOption: + p.mu.Lock() + p.maxRetries = uint32(v) + p.mu.Unlock() + return nil + case tcpip.TCPSynRcvdCountThresholdOption: p.mu.Lock() p.synRcvdCount.SetThreshold(uint64(v)) p.mu.Unlock() return nil + case tcpip.TCPSynRetriesOption: + if v < 1 || v > 255 { + return tcpip.ErrInvalidOptionValue + } + p.mu.Lock() + p.synRetries = uint8(v) + p.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -414,12 +445,30 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.RUnlock() return nil + case *tcpip.TCPMaxRTOOption: + p.mu.RLock() + *v = tcpip.TCPMaxRTOOption(p.maxRTO) + p.mu.RUnlock() + return nil + + case *tcpip.TCPMaxRetriesOption: + p.mu.RLock() + *v = tcpip.TCPMaxRetriesOption(p.maxRetries) + p.mu.RUnlock() + return nil + case *tcpip.TCPSynRcvdCountThresholdOption: p.mu.RLock() *v = tcpip.TCPSynRcvdCountThresholdOption(p.synRcvdCount.Threshold()) p.mu.RUnlock() return nil + case *tcpip.TCPSynRetriesOption: + p.mu.RLock() + *v = tcpip.TCPSynRetriesOption(p.synRetries) + p.mu.RUnlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -452,6 +501,9 @@ func NewProtocol() stack.TransportProtocol { tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold}, dispatcher: newDispatcher(runtime.GOMAXPROCS(0)), + synRetries: DefaultSynRetries, minRTO: MinRTO, + maxRTO: MaxRTO, + maxRetries: MaxRetries, } } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 9e547a221..06dc9b7d7 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -43,7 +43,8 @@ const ( nDupAckThreshold = 3 // MaxRetries is the maximum number of probe retries sender does - // before timing out the connection, Linux default TCP_RETR2. + // before timing out the connection. + // Linux default TCP_RETR2, net.ipv4.tcp_retries2. MaxRetries = 15 ) @@ -165,6 +166,12 @@ type sender struct { // minRTO is the minimum permitted value for sender.rto. minRTO time.Duration + // maxRTO is the maximum permitted value for sender.rto. + maxRTO time.Duration + + // maxRetries is the maximum permitted retransmissions. + maxRetries uint32 + // maxPayloadSize is the maximum size of the payload of a given segment. // It is initialized on demand. maxPayloadSize int @@ -276,12 +283,24 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint // etc. s.ep.scoreboard = NewSACKScoreboard(uint16(s.maxPayloadSize), iss) - // Get Stack wide minRTO. - var v tcpip.TCPMinRTOOption - if err := ep.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil { + // Get Stack wide config. + var minRTO tcpip.TCPMinRTOOption + if err := ep.stack.TransportProtocolOption(ProtocolNumber, &minRTO); err != nil { panic(fmt.Sprintf("unable to get minRTO from stack: %s", err)) } - s.minRTO = time.Duration(v) + s.minRTO = time.Duration(minRTO) + + var maxRTO tcpip.TCPMaxRTOOption + if err := ep.stack.TransportProtocolOption(ProtocolNumber, &maxRTO); err != nil { + panic(fmt.Sprintf("unable to get maxRTO from stack: %s", err)) + } + s.maxRTO = time.Duration(maxRTO) + + var maxRetries tcpip.TCPMaxRetriesOption + if err := ep.stack.TransportProtocolOption(ProtocolNumber, &maxRetries); err != nil { + panic(fmt.Sprintf("unable to get maxRetries from stack: %s", err)) + } + s.maxRetries = uint32(maxRetries) return s } @@ -485,7 +504,7 @@ func (s *sender) retransmitTimerExpired() bool { } elapsed := time.Since(s.firstRetransmittedSegXmitTime) - remaining := MaxRTO + remaining := s.maxRTO if uto != 0 { // Cap to the user specified timeout if one is specified. remaining = uto - elapsed @@ -494,24 +513,17 @@ func (s *sender) retransmitTimerExpired() bool { // Always honor the user-timeout irrespective of whether the zero // window probes were acknowledged. // net/ipv4/tcp_timer.c::tcp_probe_timer() - if remaining <= 0 || s.unackZeroWindowProbes >= MaxRetries { + if remaining <= 0 || s.unackZeroWindowProbes >= s.maxRetries { return false } - if s.rto >= MaxRTO { - // RFC 1122 section: 4.2.2.17 - // A TCP MAY keep its offered receive window closed - // indefinitely. As long as the receiving TCP continues to - // send acknowledgments in response to the probe segments, the - // sending TCP MUST allow the connection to stay open. - if !(s.zeroWindowProbing && s.unackZeroWindowProbes == 0) { - return false - } - } - // Set new timeout. The timer will be restarted by the call to sendData // below. s.rto *= 2 + // Cap the RTO as per RFC 1122 4.2.3.1, RFC 6298 5.5 + if s.rto > s.maxRTO { + s.rto = s.maxRTO + } // Cap RTO to remaining time. if s.rto > remaining { @@ -565,9 +577,20 @@ func (s *sender) retransmitTimerExpired() bool { // send. if s.zeroWindowProbing { s.sendZeroWindowProbe() + // RFC 1122 4.2.2.17: A TCP MAY keep its offered receive window closed + // indefinitely. As long as the receiving TCP continues to send + // acknowledgments in response to the probe segments, the sending TCP + // MUST allow the connection to stay open. return true } + seg := s.writeNext + // RFC 1122 4.2.3.5: Close the connection when the number of + // retransmissions for this segment is beyond a limit. + if seg != nil && seg.xmitCount > s.maxRetries { + return false + } + s.sendData() return true diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index d2c90ebd5..6ef32a1b3 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -2994,6 +2994,101 @@ func TestSendOnResetConnection(t *testing.T) { } } +// TestMaxRetransmitsTimeout tests if the connection is timed out after +// a segment has been retransmitted MaxRetries times. +func TestMaxRetransmitsTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + const numRetries = 2 + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRetriesOption(numRetries)); err != nil { + t.Fatalf("could not set protocol option MaxRetries.\n") + } + + c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventHUp) + defer c.WQ.EventUnregister(&waitEntry) + + _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) + if err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Expect first transmit and MaxRetries retransmits. + for i := 0; i < numRetries+1; i++ { + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), + ), + ) + } + // Wait for the connection to timeout after MaxRetries retransmits. + initRTO := 1 * time.Second + select { + case <-notifyCh: + case <-time.After((2 << numRetries) * initRTO): + t.Fatalf("connection still alive after maximum retransmits.\n") + } + + // Send an ACK and expect a RST as the connection would have been closed. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + ), + ) + + if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %v, want = 1", got) + } +} + +// TestMaxRTO tests if the retransmit interval caps to MaxRTO. +func TestMaxRTO(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + rto := 1 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRTOOption(rto)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPMaxRTO(%d) failed: %s", rto, err) + } + + c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + + _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) + if err != nil { + t.Fatalf("Write failed: %v", err) + } + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + const numRetransmits = 2 + for i := 0; i < numRetransmits; i++ { + start := time.Now() + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + if time.Since(start).Round(time.Second).Seconds() != rto.Seconds() { + t.Errorf("Retransmit interval not capped to MaxRTO.\n") + } + } +} + func TestFinImmediately(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -5774,7 +5869,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // Invoke the moderation API. This is required for auto-tuning // to happen. This method is normally expected to be invoked // from a higher layer than tcpip.Endpoint. So we simulate - // copying to user-space by invoking it explicitly here. + // copying to userspace by invoking it explicitly here. c.EP.ModerateRecvBuf(totalCopied) // Now send a keep-alive packet to trigger an ACK so that we can @@ -6605,9 +6700,16 @@ func TestTCPUserTimeout(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventHUp) + defer c.WQ.EventUnregister(&waitEntry) + origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() - userTimeout := 50 * time.Millisecond + // Ensure that on the next retransmit timer fire, the user timeout has + // expired. + initRTO := 1 * time.Second + userTimeout := initRTO / 2 c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) // Send some data and wait before ACKing it. @@ -6627,9 +6729,13 @@ func TestTCPUserTimeout(t *testing.T) { ), ) - // Wait for a little over the minimum retransmit timeout of 200ms for - // the retransmitTimer to fire and close the connection. - time.Sleep(tcp.MinRTO + 10*time.Millisecond) + // Wait for the retransmit timer to be fired and the user timeout to cause + // close of the connection. + select { + case <-notifyCh: + case <-time.After(2 * initRTO): + t.Fatalf("connection still alive after %s, should have been closed after :%s", 2*initRTO, userTimeout) + } // No packet should be received as the connection should be silently // closed due to timeout. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 756ab913a..647b2067a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -106,6 +106,9 @@ type endpoint struct { bindToDevice tcpip.NICID broadcast bool + lastErrorMu sync.Mutex `state:"nosave"` + lastError *tcpip.Error `state:".(string)"` + // Values used to reserve a port or register a transport endpoint. // (which ever happens first). boundBindToDevice tcpip.NICID @@ -188,6 +191,15 @@ func (e *endpoint) UniqueID() uint64 { return e.uniqueID } +func (e *endpoint) takeLastError() *tcpip.Error { + e.lastErrorMu.Lock() + defer e.lastErrorMu.Unlock() + + err := e.lastError + e.lastError = nil + return err +} + // Abort implements stack.TransportEndpoint.Abort. func (e *endpoint) Abort() { e.Close() @@ -243,6 +255,10 @@ func (e *endpoint) IPTables() (stack.IPTables, error) { // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + if err := e.takeLastError(); err != nil { + return buffer.View{}, tcpip.ControlMessages{}, err + } + e.rcvMu.Lock() if e.rcvList.Empty() { @@ -382,6 +398,10 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + if err := e.takeLastError(); err != nil { + return 0, nil, err + } + // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -853,6 +873,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { case tcpip.ErrorOption: + return e.takeLastError() case *tcpip.MulticastInterfaceOption: e.mu.Lock() *o = tcpip.MulticastInterfaceOption{ @@ -1316,6 +1337,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { + if typ == stack.ControlPortUnreachable { + e.mu.RLock() + defer e.mu.RUnlock() + + if e.state == StateConnected { + e.lastErrorMu.Lock() + defer e.lastErrorMu.Unlock() + + e.lastError = tcpip.ErrConnectionRefused + } + } } // State implements tcpip.Endpoint.State. diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 466bd9381..851e6b635 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -37,6 +37,24 @@ func (u *udpPacket) loadData(data buffer.VectorisedView) { u.data = data } +// saveLastError is invoked by stateify. +func (e *endpoint) saveLastError() string { + if e.lastError == nil { + return "" + } + + return e.lastError.String() +} + +// loadLastError is invoked by stateify. +func (e *endpoint) loadLastError(s string) { + if s == "" { + return + } + + e.lastError = tcpip.StringToError(s) +} + // beforeSave is invoked by stateify. func (e *endpoint) beforeSave() { // Stop incoming packets from being handled (and mutate endpoint state). |