diff options
author | Kevin Krakauer <krakauer@google.com> | 2019-06-12 15:21:22 -0700 |
---|---|---|
committer | Kevin Krakauer <krakauer@google.com> | 2019-06-12 15:21:22 -0700 |
commit | 0bbbcafd68154e7c7b46692b84a39fb6bb5f1568 (patch) | |
tree | d8fba01ad76900715665b0418a786de2d77e2a05 /pkg/tcpip | |
parent | 06a83df533244dc2b3b8adfc1bf0608d3753c1d9 (diff) | |
parent | 70578806e8d3e01fae2249b3e602cd5b05d378a0 (diff) |
Merge branch 'master' into iptables-1-pkg
Change-Id: I7457a11de4725e1bf3811420c505d225b1cb6943
Diffstat (limited to 'pkg/tcpip')
28 files changed, 1418 insertions, 873 deletions
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 1f889c2a0..b88e2e7bf 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -21,12 +21,29 @@ // FD based endpoints can be used in the networking stack by calling New() to // create a new endpoint, and then passing it as an argument to // Stack.CreateNIC(). +// +// FD based endpoints can use more than one file descriptor to read incoming +// packets. If there are more than one FDs specified and the underlying FD is an +// AF_PACKET then the endpoint will enable FANOUT mode on the socket so that the +// host kernel will consistently hash the packets to the sockets. This ensures +// that packets for the same TCP streams are not reordered. +// +// Similarly if more than one FD's are specified where the underlying FD is not +// AF_PACKET then it's the caller's responsibility to ensure that all inbound +// packets on the descriptors are consistently 5 tuple hashed to one of the +// descriptors to prevent TCP reordering. +// +// Since netstack today does not compute 5 tuple hashes for outgoing packets we +// only use the first FD to write outbound packets. Once 5 tuple hashes for +// all outbound packets are available we will make use of all underlying FD's to +// write outbound packets. package fdbased import ( "fmt" "syscall" + "golang.org/x/sys/unix" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" "gvisor.googlesource.com/gvisor/pkg/tcpip/header" @@ -65,8 +82,10 @@ const ( ) type endpoint struct { - // fd is the file descriptor used to send and receive packets. - fd int + // fds is the set of file descriptors each identifying one inbound/outbound + // channel. The endpoint will dispatch from all inbound channels as well as + // hash outbound packets to specific channels based on the packet hash. + fds []int // mtu (maximum transmission unit) is the maximum size of a packet. mtu uint32 @@ -85,8 +104,8 @@ type endpoint struct { // its end of the communication pipe. closed func(*tcpip.Error) - inboundDispatcher linkDispatcher - dispatcher stack.NetworkDispatcher + inboundDispatchers []linkDispatcher + dispatcher stack.NetworkDispatcher // packetDispatchMode controls the packet dispatcher used by this // endpoint. @@ -99,17 +118,47 @@ type endpoint struct { // Options specify the details about the fd-based endpoint to be created. type Options struct { - FD int - MTU uint32 - EthernetHeader bool - ClosedFunc func(*tcpip.Error) - Address tcpip.LinkAddress - SaveRestore bool - DisconnectOk bool - GSOMaxSize uint32 + // FDs is a set of FDs used to read/write packets. + FDs []int + + // MTU is the mtu to use for this endpoint. + MTU uint32 + + // EthernetHeader if true, indicates that the endpoint should read/write + // ethernet frames instead of IP packets. + EthernetHeader bool + + // ClosedFunc is a function to be called when an endpoint's peer (if + // any) closes its end of the communication pipe. + ClosedFunc func(*tcpip.Error) + + // Address is the link address for this endpoint. Only used if + // EthernetHeader is true. + Address tcpip.LinkAddress + + // SaveRestore if true, indicates that this NIC capability set should + // include CapabilitySaveRestore + SaveRestore bool + + // DisconnectOk if true, indicates that this NIC capability set should + // include CapabilityDisconnectOk. + DisconnectOk bool + + // GSOMaxSize is the maximum GSO packet size. It is zero if GSO is + // disabled. + GSOMaxSize uint32 + + // PacketDispatchMode specifies the type of inbound dispatcher to be + // used for this endpoint. PacketDispatchMode PacketDispatchMode - TXChecksumOffload bool - RXChecksumOffload bool + + // TXChecksumOffload if true, indicates that this endpoints capability + // set should include CapabilityTXChecksumOffload. + TXChecksumOffload bool + + // RXChecksumOffload if true, indicates that this endpoints capability + // set should include CapabilityRXChecksumOffload. + RXChecksumOffload bool } // New creates a new fd-based endpoint. @@ -117,10 +166,6 @@ type Options struct { // Makes fd non-blocking, but does not take ownership of fd, which must remain // open for the lifetime of the returned endpoint. func New(opts *Options) (tcpip.LinkEndpointID, error) { - if err := syscall.SetNonblock(opts.FD, true); err != nil { - return 0, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", opts.FD, err) - } - caps := stack.LinkEndpointCapabilities(0) if opts.RXChecksumOffload { caps |= stack.CapabilityRXChecksumOffload @@ -144,8 +189,12 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) { caps |= stack.CapabilityDisconnectOk } + if len(opts.FDs) == 0 { + return 0, fmt.Errorf("opts.FD is empty, at least one FD must be specified") + } + e := &endpoint{ - fd: opts.FD, + fds: opts.FDs, mtu: opts.MTU, caps: caps, closed: opts.ClosedFunc, @@ -154,46 +203,71 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) { packetDispatchMode: opts.PacketDispatchMode, } - isSocket, err := isSocketFD(e.fd) - if err != nil { - return 0, err - } - if isSocket { - if opts.GSOMaxSize != 0 { - e.caps |= stack.CapabilityGSO - e.gsoMaxSize = opts.GSOMaxSize + // Create per channel dispatchers. + for i := 0; i < len(e.fds); i++ { + fd := e.fds[i] + if err := syscall.SetNonblock(fd, true); err != nil { + return 0, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", fd, err) } - } - e.inboundDispatcher, err = createInboundDispatcher(e, isSocket) - if err != nil { - return 0, fmt.Errorf("createInboundDispatcher(...) = %v", err) + + isSocket, err := isSocketFD(fd) + if err != nil { + return 0, err + } + if isSocket { + if opts.GSOMaxSize != 0 { + e.caps |= stack.CapabilityGSO + e.gsoMaxSize = opts.GSOMaxSize + } + } + inboundDispatcher, err := createInboundDispatcher(e, fd, isSocket) + if err != nil { + return 0, fmt.Errorf("createInboundDispatcher(...) = %v", err) + } + e.inboundDispatchers = append(e.inboundDispatchers, inboundDispatcher) } return stack.RegisterLinkEndpoint(e), nil } -func createInboundDispatcher(e *endpoint, isSocket bool) (linkDispatcher, error) { +func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher, error) { // By default use the readv() dispatcher as it works with all kinds of // FDs (tap/tun/unix domain sockets and af_packet). - inboundDispatcher, err := newReadVDispatcher(e.fd, e) + inboundDispatcher, err := newReadVDispatcher(fd, e) if err != nil { - return nil, fmt.Errorf("newReadVDispatcher(%d, %+v) = %v", e.fd, e, err) + return nil, fmt.Errorf("newReadVDispatcher(%d, %+v) = %v", fd, e, err) } if isSocket { + sa, err := unix.Getsockname(fd) + if err != nil { + return nil, fmt.Errorf("unix.Getsockname(%d) = %v", fd, err) + } + switch sa.(type) { + case *unix.SockaddrLinklayer: + // enable PACKET_FANOUT mode is the underlying socket is + // of type AF_PACKET. + const fanoutID = 1 + const fanoutType = 0x8000 // PACKET_FANOUT_HASH | PACKET_FANOUT_FLAG_DEFRAG + fanoutArg := fanoutID | fanoutType<<16 + if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil { + return nil, fmt.Errorf("failed to enable PACKET_FANOUT option: %v", err) + } + } + switch e.packetDispatchMode { case PacketMMap: - inboundDispatcher, err = newPacketMMapDispatcher(e.fd, e) + inboundDispatcher, err = newPacketMMapDispatcher(fd, e) if err != nil { - return nil, fmt.Errorf("newPacketMMapDispatcher(%d, %+v) = %v", e.fd, e, err) + return nil, fmt.Errorf("newPacketMMapDispatcher(%d, %+v) = %v", fd, e, err) } case RecvMMsg: // If the provided FD is a socket then we optimize // packet reads by using recvmmsg() instead of read() to // read packets in a batch. - inboundDispatcher, err = newRecvMMsgDispatcher(e.fd, e) + inboundDispatcher, err = newRecvMMsgDispatcher(fd, e) if err != nil { - return nil, fmt.Errorf("newRecvMMsgDispatcher(%d, %+v) = %v", e.fd, e, err) + return nil, fmt.Errorf("newRecvMMsgDispatcher(%d, %+v) = %v", fd, e, err) } } } @@ -215,7 +289,9 @@ func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { // Link endpoints are not savable. When transportation endpoints are // saved, they stop sending outgoing packets and all incoming packets // are rejected. - go e.dispatchLoop() // S/R-SAFE: See above. + for i := range e.inboundDispatchers { + go e.dispatchLoop(e.inboundDispatchers[i]) // S/R-SAFE: See above. + } } // IsAttached implements stack.LinkEndpoint.IsAttached. @@ -305,26 +381,26 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen } } - return rawfile.NonBlockingWrite3(e.fd, vnetHdrBuf, hdr.View(), payload.ToView()) + return rawfile.NonBlockingWrite3(e.fds[0], vnetHdrBuf, hdr.View(), payload.ToView()) } if payload.Size() == 0 { - return rawfile.NonBlockingWrite(e.fd, hdr.View()) + return rawfile.NonBlockingWrite(e.fds[0], hdr.View()) } - return rawfile.NonBlockingWrite3(e.fd, hdr.View(), payload.ToView(), nil) + return rawfile.NonBlockingWrite3(e.fds[0], hdr.View(), payload.ToView(), nil) } // WriteRawPacket writes a raw packet directly to the file descriptor. func (e *endpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error { - return rawfile.NonBlockingWrite(e.fd, packet) + return rawfile.NonBlockingWrite(e.fds[0], packet) } // dispatchLoop reads packets from the file descriptor in a loop and dispatches // them to the network stack. -func (e *endpoint) dispatchLoop() *tcpip.Error { +func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) *tcpip.Error { for { - cont, err := e.inboundDispatcher.dispatch() + cont, err := inboundDispatcher.dispatch() if err != nil || !cont { if e.closed != nil { e.closed(err) @@ -363,7 +439,7 @@ func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabiliti syscall.SetNonblock(fd, true) e := &InjectableEndpoint{endpoint: endpoint{ - fd: fd, + fds: []int{fd}, mtu: mtu, caps: capabilities, }} diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index fd1722074..ba3e09192 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -67,7 +67,7 @@ func newContext(t *testing.T, opt *Options) *context { done <- struct{}{} } - opt.FD = fds[1] + opt.FDs = []int{fds[1]} epID, err := New(opt) if err != nil { t.Fatalf("Failed to create FD endpoint: %v", err) diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index fccabd554..98581e50e 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -118,7 +118,7 @@ func NewWithFile(lower tcpip.LinkEndpointID, file *os.File, snapLen uint32) (tcp // logs the packet before forwarding to the actual dispatcher. func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { - logPacket("recv", protocol, vv.First()) + logPacket("recv", protocol, vv.First(), nil) } if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { vs := vv.Views() @@ -198,7 +198,7 @@ func (e *endpoint) GSOMaxSize() uint32 { // the request to the lower endpoint. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { - logPacket("send", protocol, hdr.View()) + logPacket("send", protocol, hdr.View(), gso) } if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { hdrBuf := hdr.View() @@ -240,7 +240,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen return e.lower.WritePacket(r, gso, hdr, payload, protocol) } -func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View) { +func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -404,5 +404,9 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie return } + if gso != nil { + details += fmt.Sprintf(" gso: %+v", gso) + } + log.Infof("%s %s %v:%v -> %v:%v len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details) } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index da07a39e5..44b1d5b9b 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -215,7 +215,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen views[0] = hdr.View() views = append(views, payload.Views()...) vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) - e.HandlePacket(r, vv) + loopedR := r.MakeLoopedRoute() + e.HandlePacket(&loopedR, vv) + loopedR.Release() } if loop&stack.PacketOut == 0 { return nil diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 4b8cd496b..bcae98e1f 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -108,7 +108,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen views[0] = hdr.View() views = append(views, payload.Views()...) vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) - e.HandlePacket(r, vv) + loopedR := r.MakeLoopedRoute() + e.HandlePacket(&loopedR, vv) + loopedR.Release() } if loop&stack.PacketOut == 0 { return nil diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index 1681de56e..1fa899e7e 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -137,7 +137,7 @@ func main() { log.Fatal(err) } - linkID, err := fdbased.New(&fdbased.Options{FD: fd, MTU: mtu}) + linkID, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu}) if err != nil { log.Fatal(err) } diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index 642607f83..d47085581 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -129,7 +129,7 @@ func main() { } linkID, err := fdbased.New(&fdbased.Options{ - FD: fd, + FDs: []int{fd}, MTU: mtu, EthernetHeader: *tap, Address: tcpip.LinkAddress(maddr), diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 3d4c282a9..55ed02479 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -187,3 +187,13 @@ func (r *Route) Clone() Route { r.ref.incRef() return *r } + +// MakeLoopedRoute duplicates the given route and tweaks it in case of multicast. +func (r *Route) MakeLoopedRoute() Route { + l := r.Clone() + if header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) { + l.RemoteAddress, l.LocalAddress = l.LocalAddress, l.RemoteAddress + l.RemoteLinkAddress = l.LocalLinkAddress + } + return l +} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 8d74f1543..e8a9392b5 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -188,6 +188,10 @@ func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, s f.proto.controlCount++ } +func (f *fakeTransportEndpoint) State() uint32 { + return 0 +} + type fakeTransportGoodOption bool type fakeTransportBadOption bool diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index f9886c6e4..04c776205 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -377,6 +377,10 @@ type Endpoint interface { // GetSockOpt gets a socket option. opt should be a pointer to one of the // *Option types. GetSockOpt(opt interface{}) *Error + + // State returns a socket's lifecycle state. The returned value is + // protocol-specific and is primarily used for diagnostics. + State() uint32 } // WriteOptions contains options for Endpoint.Write. @@ -468,6 +472,14 @@ type KeepaliveIntervalOption time.Duration // closed. type KeepaliveCountOption int +// CongestionControlOption is used by SetSockOpt/GetSockOpt to set/get +// the current congestion control algorithm. +type CongestionControlOption string + +// AvailableCongestionControlOption is used to query the supported congestion +// control algorithms. +type AvailableCongestionControlOption string + // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default // TTL value for multicast messages. The default is 1. type MulticastTTLOption uint8 diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index 9aa6f3978..84a2b53b7 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -33,6 +33,7 @@ go_library( "//pkg/tcpip/header", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", + "//pkg/tcpip/transport/tcp", "//pkg/waiter", ], ) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index e2b90ef10..b8005093a 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -708,3 +708,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { } + +// State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't +// expose internal socket state. +func (e *endpoint) State() uint32 { + return 0 +} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 1daf5823f..e4ff50c91 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -519,3 +519,8 @@ func (ep *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv b ep.waiterQueue.Notify(waiter.EventIn) } } + +// State implements socket.Socket.State. +func (ep *endpoint) State() uint32 { + return 0 +} diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index e31b03f7d..a9dbfb930 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -21,6 +21,7 @@ go_library( "accept.go", "connect.go", "cubic.go", + "cubic_state.go", "endpoint.go", "endpoint_state.go", "forwarder.go", @@ -70,6 +71,7 @@ go_test( srcs = [ "dual_stack_test.go", "sack_scoreboard_test.go", + "tcp_noracedetector_test.go", "tcp_sack_test.go", "tcp_test.go", "tcp_timestamp_test.go", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index d4b860975..d05259c0a 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -19,7 +19,6 @@ import ( "encoding/binary" "hash" "io" - "log" "sync" "time" @@ -227,7 +226,6 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i } n.isRegistered = true - n.state = stateConnecting // Create sender and receiver. // @@ -253,14 +251,15 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // Perform the 3-way handshake. h := newHandshake(ep, l.rcvWnd) - h.resetToSynRcvd(cookie, irs, opts, l.listenEP) + h.resetToSynRcvd(cookie, irs, opts) if err := h.execute(); err != nil { ep.stack.Stats().TCP.FailedConnectionAttempts.Increment() ep.Close() return nil, err } - - ep.state = stateConnected + ep.mu.Lock() + ep.state = StateEstablished + ep.mu.Unlock() // Update the receive window scaling. We can't do it before the // handshake because it's possible that the peer doesn't support window @@ -277,7 +276,7 @@ func (e *endpoint) deliverAccepted(n *endpoint) { e.mu.RLock() state := e.state e.mu.RUnlock() - if state == stateListen { + if state == StateListen { e.acceptedChan <- n e.waiterQueue.Notify(waiter.EventIn) } else { @@ -295,7 +294,6 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header defer decSynRcvdCount() defer e.decSynRcvdCount() defer s.decRef() - n, err := ctx.createEndpointAndPerformHandshake(s, opts) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() @@ -307,8 +305,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header func (e *endpoint) incSynRcvdCount() bool { e.mu.Lock() - log.Printf("l: %d, c: %d, e.synRcvdCount: %d", len(e.acceptedChan), cap(e.acceptedChan), e.synRcvdCount) - if l, c := len(e.acceptedChan), cap(e.acceptedChan); l == c && e.synRcvdCount >= c { + if e.synRcvdCount >= cap(e.acceptedChan) { e.mu.Unlock() return false } @@ -323,6 +320,16 @@ func (e *endpoint) decSynRcvdCount() { e.mu.Unlock() } +func (e *endpoint) acceptQueueIsFull() bool { + e.mu.Lock() + if l, c := len(e.acceptedChan)+e.synRcvdCount, cap(e.acceptedChan); l >= c { + e.mu.Unlock() + return true + } + e.mu.Unlock() + return false +} + // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { @@ -330,20 +337,27 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { case header.TCPFlagSyn: opts := parseSynSegmentOptions(s) if incSynRcvdCount() { - // Drop the SYN if the listen endpoint's accept queue is - // overflowing. - if e.incSynRcvdCount() { - log.Printf("processing syn packet") + // Only handle the syn if the following conditions hold + // - accept queue is not full. + // - number of connections in synRcvd state is less than the + // backlog. + if !e.acceptQueueIsFull() && e.incSynRcvdCount() { s.incRef() go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier. return } - log.Printf("dropping syn packet") + decSynRcvdCount() e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() return } else { - // TODO(bhaskerh): Increment syncookie sent stat. + // If cookies are in use but the endpoint accept queue + // is full then drop the syn. + if e.acceptQueueIsFull() { + e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() + e.stack.Stats().DroppedPackets.Increment() + return + } cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) // Send SYN with window scaling because we currently // dont't encode this information in the cookie. @@ -361,7 +375,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { } case header.TCPFlagAck: - if len(e.acceptedChan) == cap(e.acceptedChan) { + if e.acceptQueueIsFull() { // Silently drop the ack as the application can't accept // the connection at this point. The ack will be // retransmitted by the sender anyway and we can @@ -411,7 +425,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n.tsOffset = 0 // Switch state to connected. - n.state = stateConnected + n.state = StateEstablished // Do the delivery in a separate goroutine so // that we don't block the listen loop in case @@ -434,7 +448,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { // handleSynSegment() from attempting to queue new connections // to the endpoint. e.mu.Lock() - e.state = stateClosed + e.state = StateClose // Do cleanup if needed. e.completeWorkerLocked() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 2aed6f286..dd671f7ce 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -60,12 +60,11 @@ const ( // handshake holds the state used during a TCP 3-way handshake. type handshake struct { - ep *endpoint - listenEP *endpoint // only non nil when doing passive connects. - state handshakeState - active bool - flags uint8 - ackNum seqnum.Value + ep *endpoint + state handshakeState + active bool + flags uint8 + ackNum seqnum.Value // iss is the initial send sequence number, as defined in RFC 793. iss seqnum.Value @@ -142,7 +141,7 @@ func (h *handshake) effectiveRcvWndScale() uint8 { // resetToSynRcvd resets the state of the handshake object to the SYN-RCVD // state. -func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions, listenEP *endpoint) { +func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions) { h.active = false h.state = handshakeSynRcvd h.flags = header.TCPFlagSyn | header.TCPFlagAck @@ -150,7 +149,9 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea h.ackNum = irs + 1 h.mss = opts.MSS h.sndWndScale = opts.WS - h.listenEP = listenEP + h.ep.mu.Lock() + h.ep.state = StateSynRecv + h.ep.mu.Unlock() } // checkAck checks if the ACK number, if present, of a segment received during @@ -219,6 +220,9 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // but resend our own SYN and wait for it to be acknowledged in the // SYN-RCVD state. h.state = handshakeSynRcvd + h.ep.mu.Lock() + h.ep.state = StateSynRecv + h.ep.mu.Unlock() synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, TS: rcvSynOpts.TS, @@ -281,18 +285,6 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { // We have previously received (and acknowledged) the peer's SYN. If the // peer acknowledges our SYN, the handshake is completed. if s.flagIsSet(header.TCPFlagAck) { - // listenContext is also used by a tcp.Forwarder and in that - // context we do not have a listening endpoint to check the - // backlog. So skip this check if listenEP is nil. - if h.listenEP != nil && len(h.listenEP.acceptedChan) == cap(h.listenEP.acceptedChan) { - // If there is no space in the accept queue to accept - // this endpoint then silently drop this ACK. The peer - // will anyway resend the ack and we can complete the - // connection the next time it's retransmitted. - h.ep.stack.Stats().TCP.ListenOverflowAckDrop.Increment() - h.ep.stack.Stats().DroppedPackets.Increment() - return nil - } // If the timestamp option is negotiated and the segment does // not carry a timestamp option then the segment must be dropped // as per https://tools.ietf.org/html/rfc7323#section-3.2. @@ -663,7 +655,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // sendRaw sends a TCP segment to the endpoint's peer. func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { var sackBlocks []header.SACKBlock - if e.state == stateConnected && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { + if e.state == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) @@ -714,8 +706,7 @@ func (e *endpoint) handleClose() *tcpip.Error { // protocol goroutine. func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0) - - e.state = stateError + e.state = StateError e.hardError = err } @@ -871,14 +862,19 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // handshake, and then inform potential waiters about its // completion. h := newHandshake(e, seqnum.Size(e.receiveBufferAvailable())) + e.mu.Lock() + h.ep.state = StateSynSent + e.mu.Unlock() + if err := h.execute(); err != nil { e.lastErrorMu.Lock() e.lastError = err e.lastErrorMu.Unlock() e.mu.Lock() - e.state = stateError + e.state = StateError e.hardError = err + // Lock released below. epilogue() @@ -900,7 +896,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Tell waiters that the endpoint is connected and writable. e.mu.Lock() - e.state = stateConnected + e.state = StateEstablished drained := e.drainDone != nil e.mu.Unlock() if drained { @@ -1000,7 +996,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return err } } - if e.state != stateError { + if e.state != StateError { close(e.drainDone) <-e.undrain } @@ -1056,8 +1052,8 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Mark endpoint as closed. e.mu.Lock() - if e.state != stateError { - e.state = stateClosed + if e.state != StateError { + e.state = StateClose } // Lock released below. epilogue() diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go index e618cd2b9..7b1f5e763 100644 --- a/pkg/tcpip/transport/tcp/cubic.go +++ b/pkg/tcpip/transport/tcp/cubic.go @@ -23,6 +23,7 @@ import ( // control algorithm state. // // See: https://tools.ietf.org/html/rfc8312. +// +stateify savable type cubicState struct { // wLastMax is the previous wMax value. wLastMax float64 @@ -33,7 +34,7 @@ type cubicState struct { // t denotes the time when the current congestion avoidance // was entered. - t time.Time + t time.Time `state:".(unixTime)"` // numCongestionEvents tracks the number of congestion events since last // RTO. diff --git a/pkg/tcpip/transport/tcp/cubic_state.go b/pkg/tcpip/transport/tcp/cubic_state.go new file mode 100644 index 000000000..d0f58cfaf --- /dev/null +++ b/pkg/tcpip/transport/tcp/cubic_state.go @@ -0,0 +1,29 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" +) + +// saveT is invoked by stateify. +func (c *cubicState) saveT() unixTime { + return unixTime{c.t.Unix(), c.t.UnixNano()} +} + +// loadT is invoked by stateify. +func (c *cubicState) loadT(unix unixTime) { + c.t = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index b66610ee2..1efe9d3fb 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -17,6 +17,7 @@ package tcp import ( "fmt" "math" + "strings" "sync" "sync/atomic" "time" @@ -32,18 +33,81 @@ import ( "gvisor.googlesource.com/gvisor/pkg/waiter" ) -type endpointState int +// EndpointState represents the state of a TCP endpoint. +type EndpointState uint32 +// Endpoint states. Note that are represented in a netstack-specific manner and +// may not be meaningful externally. Specifically, they need to be translated to +// Linux's representation for these states if presented to userspace. const ( - stateInitial endpointState = iota - stateBound - stateListen - stateConnecting - stateConnected - stateClosed - stateError + // Endpoint states internal to netstack. These map to the TCP state CLOSED. + StateInitial EndpointState = iota + StateBound + StateConnecting // Connect() called, but the initial SYN hasn't been sent. + StateError + + // TCP protocol states. + StateEstablished + StateSynSent + StateSynRecv + StateFinWait1 + StateFinWait2 + StateTimeWait + StateClose + StateCloseWait + StateLastAck + StateListen + StateClosing ) +// connected is the set of states where an endpoint is connected to a peer. +func (s EndpointState) connected() bool { + switch s { + case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: + return true + default: + return false + } +} + +// String implements fmt.Stringer.String. +func (s EndpointState) String() string { + switch s { + case StateInitial: + return "INITIAL" + case StateBound: + return "BOUND" + case StateConnecting: + return "CONNECTING" + case StateError: + return "ERROR" + case StateEstablished: + return "ESTABLISHED" + case StateSynSent: + return "SYN-SENT" + case StateSynRecv: + return "SYN-RCVD" + case StateFinWait1: + return "FIN-WAIT1" + case StateFinWait2: + return "FIN-WAIT2" + case StateTimeWait: + return "TIME-WAIT" + case StateClose: + return "CLOSED" + case StateCloseWait: + return "CLOSE-WAIT" + case StateLastAck: + return "LAST-ACK" + case StateListen: + return "LISTEN" + case StateClosing: + return "CLOSING" + default: + panic("unreachable") + } +} + // Reasons for notifying the protocol goroutine. const ( notifyNonZeroReceiveWindow = 1 << iota @@ -108,10 +172,14 @@ type endpoint struct { rcvBufUsed int // The following fields are protected by the mutex. - mu sync.RWMutex `state:"nosave"` - id stack.TransportEndpointID - state endpointState `state:".(endpointState)"` - isPortReserved bool `state:"manual"` + mu sync.RWMutex `state:"nosave"` + id stack.TransportEndpointID + + // state endpointState `state:".(endpointState)"` + // pState ProtocolState + state EndpointState `state:".(EndpointState)"` + + isPortReserved bool `state:"manual"` isRegistered bool boundNICID tcpip.NICID `state:"manual"` route stack.Route `state:"manual"` @@ -219,7 +287,7 @@ type endpoint struct { // cc stores the name of the Congestion Control algorithm to use for // this endpoint. - cc CongestionControlOption + cc tcpip.CongestionControlOption // The following are used when a "packet too big" control packet is // received. They are protected by sndBufMu. They are used to @@ -304,6 +372,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite stack: stack, netProto: netProto, waiterQueue: waiterQueue, + state: StateInitial, rcvBufSize: DefaultBufferSize, sndBufSize: DefaultBufferSize, sndMTU: int(math.MaxInt32), @@ -326,7 +395,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite e.rcvBufSize = rs.Default } - var cs CongestionControlOption + var cs tcpip.CongestionControlOption if err := stack.TransportProtocolOption(ProtocolNumber, &cs); err == nil { e.cc = cs } @@ -335,7 +404,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite e.probe = p } - e.segmentQueue.setLimit(2 * e.rcvBufSize) + e.segmentQueue.setLimit(MaxUnprocessedSegments) e.workMu.Init() e.workMu.Lock() e.tsOffset = timeStampOffset() @@ -351,14 +420,14 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { defer e.mu.RUnlock() switch e.state { - case stateInitial, stateBound, stateConnecting: + case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv: // Ready for nothing. - case stateClosed, stateError: + case StateClose, StateError: // Ready for anything. result = mask - case stateListen: + case StateListen: // Check if there's anything in the accepted channel. if (mask & waiter.EventIn) != 0 { if len(e.acceptedChan) > 0 { @@ -366,7 +435,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { } } - case stateConnected: + case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: // Determine if the endpoint is writable if requested. if (mask & waiter.EventOut) != 0 { e.sndBufMu.Lock() @@ -427,7 +496,7 @@ func (e *endpoint) Close() { // are immediately available for reuse after Close() is called. If also // registered, we unregister as well otherwise the next user would fail // in Listen() when trying to register. - if e.state == stateListen && e.isPortReserved { + if e.state == StateListen && e.isPortReserved { if e.isRegistered { e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) e.isRegistered = false @@ -487,15 +556,15 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, e.mu.RLock() // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. Also note that a RST being received - // would cause the state to become stateError so we should allow the + // would cause the state to become StateError so we should allow the // reads to proceed before returning a ECONNRESET. e.rcvListMu.Lock() bufUsed := e.rcvBufUsed - if s := e.state; s != stateConnected && s != stateClosed && bufUsed == 0 { + if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() he := e.hardError e.mu.RUnlock() - if s == stateError { + if s == StateError { return buffer.View{}, tcpip.ControlMessages{}, he } return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState @@ -511,7 +580,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { if e.rcvBufUsed == 0 { - if e.rcvClosed || e.state != stateConnected { + if e.rcvClosed || !e.state.connected() { return buffer.View{}, tcpip.ErrClosedForReceive } return buffer.View{}, tcpip.ErrWouldBlock @@ -547,9 +616,9 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c defer e.mu.RUnlock() // The endpoint cannot be written to if it's not connected. - if e.state != stateConnected { + if !e.state.connected() { switch e.state { - case stateError: + case StateError: return 0, nil, e.hardError default: return 0, nil, tcpip.ErrClosedForSend @@ -612,8 +681,8 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. - if s := e.state; s != stateConnected && s != stateClosed { - if s == stateError { + if s := e.state; !s.connected() && s != StateClose { + if s == StateError { return 0, tcpip.ControlMessages{}, e.hardError } return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState @@ -623,7 +692,7 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er defer e.rcvListMu.Unlock() if e.rcvBufUsed == 0 { - if e.rcvClosed || e.state != stateConnected { + if e.rcvClosed || !e.state.connected() { return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive } return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock @@ -757,8 +826,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } e.rcvListMu.Unlock() - e.segmentQueue.setLimit(2 * size) - e.notifyProtocolGoroutine(mask) return nil @@ -791,7 +858,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { defer e.mu.Unlock() // We only allow this to be set when we're in the initial state. - if e.state != stateInitial { + if e.state != StateInitial { return tcpip.ErrInvalidEndpointState } @@ -832,6 +899,40 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case tcpip.CongestionControlOption: + // Query the available cc algorithms in the stack and + // validate that the specified algorithm is actually + // supported in the stack. + var avail tcpip.AvailableCongestionControlOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &avail); err != nil { + return err + } + availCC := strings.Split(string(avail), " ") + for _, cc := range availCC { + if v == tcpip.CongestionControlOption(cc) { + // Acquire the work mutex as we may need to + // reinitialize the congestion control state. + e.mu.Lock() + state := e.state + e.cc = v + e.mu.Unlock() + switch state { + case StateEstablished: + e.workMu.Lock() + e.mu.Lock() + if e.state == state { + e.snd.cc = e.snd.initCongestionControl(e.cc) + } + e.mu.Unlock() + e.workMu.Unlock() + } + return nil + } + } + + // Linux returns ENOENT when an invalid congestion + // control algorithm is specified. + return tcpip.ErrNoSuchFile default: return nil } @@ -843,7 +944,7 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { defer e.mu.RUnlock() // The endpoint cannot be in listen state. - if e.state == stateListen { + if e.state == StateListen { return 0, tcpip.ErrInvalidEndpointState } @@ -1001,6 +1102,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.CongestionControlOption: + e.mu.Lock() + *o = e.cc + e.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -1059,7 +1166,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er nicid := addr.NIC switch e.state { - case stateBound: + case StateBound: // If we're already bound to a NIC but the caller is requesting // that we use a different one now, we cannot proceed. if e.boundNICID == 0 { @@ -1072,16 +1179,16 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er nicid = e.boundNICID - case stateInitial: - // Nothing to do. We'll eventually fill-in the gaps in the ID - // (if any) when we find a route. + case StateInitial: + // Nothing to do. We'll eventually fill-in the gaps in the ID (if any) + // when we find a route. - case stateConnecting: - // A connection request has already been issued but hasn't - // completed yet. + case StateConnecting, StateSynSent, StateSynRecv: + // A connection request has already been issued but hasn't completed + // yet. return tcpip.ErrAlreadyConnecting - case stateConnected: + case StateEstablished: // The endpoint is already connected. If caller hasn't been notified yet, return success. if !e.isConnectNotified { e.isConnectNotified = true @@ -1090,7 +1197,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er // Otherwise return that it's already connected. return tcpip.ErrAlreadyConnected - case stateError: + case StateError: return e.hardError default: @@ -1156,7 +1263,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er } e.isRegistered = true - e.state = stateConnecting + e.state = StateConnecting e.route = r.Clone() e.boundNICID = nicid e.effectiveNetProtos = netProtos @@ -1177,7 +1284,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er } e.segmentQueue.mu.Unlock() e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0) - e.state = stateConnected + e.state = StateEstablished } if run { @@ -1201,8 +1308,8 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { defer e.mu.Unlock() e.shutdownFlags |= flags - switch e.state { - case stateConnected: + switch { + case e.state.connected(): // Close for read. if (e.shutdownFlags & tcpip.ShutdownRead) != 0 { // Mark read side as closed. @@ -1243,7 +1350,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { e.sndCloseWaker.Assert() } - case stateListen: + case e.state == StateListen: // Tell protocolListenLoop to stop. if flags&tcpip.ShutdownRead != 0 { e.notifyProtocolGoroutine(notifyClose) @@ -1271,7 +1378,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { // When the endpoint shuts down, it sets workerCleanup to true, and from // that point onward, acceptedChan is the responsibility of the cleanup() // method (and should not be touched anywhere else, including here). - if e.state == stateListen && !e.workerCleanup { + if e.state == StateListen && !e.workerCleanup { // Adjust the size of the channel iff we can fix existing // pending connections into the new one. if len(e.acceptedChan) > backlog { @@ -1290,7 +1397,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { } // Endpoint must be bound before it can transition to listen mode. - if e.state != stateBound { + if e.state != StateBound { return tcpip.ErrInvalidEndpointState } @@ -1300,7 +1407,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { } e.isRegistered = true - e.state = stateListen + e.state = StateListen if e.acceptedChan == nil { e.acceptedChan = make(chan *endpoint, backlog) } @@ -1327,7 +1434,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { defer e.mu.RUnlock() // Endpoint must be in listen state before it can accept connections. - if e.state != stateListen { + if e.state != StateListen { return nil, nil, tcpip.ErrInvalidEndpointState } @@ -1355,7 +1462,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { // Don't allow binding once endpoint is not in the initial state // anymore. This is because once the endpoint goes into a connected or // listen state, it is already bound. - if e.state != stateInitial { + if e.state != StateInitial { return tcpip.ErrAlreadyBound } @@ -1410,7 +1517,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { } // Mark endpoint as bound. - e.state = stateBound + e.state = StateBound return nil } @@ -1432,7 +1539,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - if e.state != stateConnected { + if !e.state.connected() { return tcpip.FullAddress{}, tcpip.ErrNotConnected } @@ -1741,3 +1848,11 @@ func (e *endpoint) initGSO() { gso.MaxSize = e.route.GSOMaxSize() e.gso = gso } + +// State implements tcpip.Endpoint.State. It exports the endpoint's protocol +// state for diagnostics. +func (e *endpoint) State() uint32 { + e.mu.Lock() + defer e.mu.Unlock() + return uint32(e.state) +} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 27b0be046..5f30c2374 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -49,8 +49,8 @@ func (e *endpoint) beforeSave() { defer e.mu.Unlock() switch e.state { - case stateInitial, stateBound: - case stateConnected: + case StateInitial, StateBound: + case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 { if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 { panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)}) @@ -66,17 +66,17 @@ func (e *endpoint) beforeSave() { break } fallthrough - case stateListen, stateConnecting: + case StateListen, StateConnecting: e.drainSegmentLocked() - if e.state != stateClosed && e.state != stateError { + if e.state != StateClose && e.state != StateError { if !e.workerRunning { panic("endpoint has no worker running in listen, connecting, or connected state") } break } fallthrough - case stateError, stateClosed: - for e.state == stateError && e.workerRunning { + case StateError, StateClose: + for e.state == StateError && e.workerRunning { e.mu.Unlock() time.Sleep(100 * time.Millisecond) e.mu.Lock() @@ -92,7 +92,7 @@ func (e *endpoint) beforeSave() { panic("endpoint still has waiters upon save") } - if e.state != stateClosed && !((e.state == stateBound || e.state == stateListen) == e.isPortReserved) { + if e.state != StateClose && !((e.state == StateBound || e.state == StateListen) == e.isPortReserved) { panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state") } } @@ -132,7 +132,7 @@ func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) { } // saveState is invoked by stateify. -func (e *endpoint) saveState() endpointState { +func (e *endpoint) saveState() EndpointState { return e.state } @@ -146,15 +146,15 @@ var connectingLoading sync.WaitGroup // Bound endpoint loading happens last. // loadState is invoked by stateify. -func (e *endpoint) loadState(state endpointState) { +func (e *endpoint) loadState(state EndpointState) { // This is to ensure that the loading wait groups include all applicable // endpoints before any asynchronous calls to the Wait() methods. switch state { - case stateConnected: + case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: connectedLoading.Add(1) - case stateListen: + case StateListen: listenLoading.Add(1) - case stateConnecting: + case StateConnecting, StateSynSent, StateSynRecv: connectingLoading.Add(1) } e.state = state @@ -163,12 +163,12 @@ func (e *endpoint) loadState(state endpointState) { // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { e.stack = stack.StackFromEnv - e.segmentQueue.setLimit(2 * e.rcvBufSize) + e.segmentQueue.setLimit(MaxUnprocessedSegments) e.workMu.Init() state := e.state switch state { - case stateInitial, stateBound, stateListen, stateConnecting, stateConnected: + case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss SendBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { @@ -181,7 +181,7 @@ func (e *endpoint) afterLoad() { } bind := func() { - e.state = stateInitial + e.state = StateInitial if len(e.bindAddress) == 0 { e.bindAddress = e.id.LocalAddress } @@ -191,7 +191,7 @@ func (e *endpoint) afterLoad() { } switch state { - case stateConnected: + case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: bind() if len(e.connectingAddress) == 0 { // This endpoint is accepted by netstack but not yet by @@ -211,7 +211,7 @@ func (e *endpoint) afterLoad() { panic("endpoint connecting failed: " + err.String()) } connectedLoading.Done() - case stateListen: + case StateListen: tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -223,7 +223,7 @@ func (e *endpoint) afterLoad() { listenLoading.Done() tcpip.AsyncLoading.Done() }() - case stateConnecting: + case StateConnecting, StateSynSent, StateSynRecv: tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -235,7 +235,7 @@ func (e *endpoint) afterLoad() { connectingLoading.Done() tcpip.AsyncLoading.Done() }() - case stateBound: + case StateBound: tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -244,7 +244,7 @@ func (e *endpoint) afterLoad() { bind() tcpip.AsyncLoading.Done() }() - case stateClosed: + case StateClose: if e.isPortReserved { tcpip.AsyncLoading.Add(1) go func() { @@ -252,12 +252,12 @@ func (e *endpoint) afterLoad() { listenLoading.Wait() connectingLoading.Wait() bind() - e.state = stateClosed + e.state = StateClose tcpip.AsyncLoading.Done() }() } fallthrough - case stateError: + case StateError: tcpip.DeleteDanglingEndpoint(e) } } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index d31a1edcb..59f4009a1 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -48,6 +48,10 @@ const ( // MaxBufferSize is the largest size a receive and send buffer can grow to. maxBufferSize = 4 << 20 // 4MB + + // MaxUnprocessedSegments is the maximum number of unprocessed segments + // that can be queued for a given endpoint. + MaxUnprocessedSegments = 300 ) // SACKEnabled option can be used to enable SACK support in the TCP @@ -75,13 +79,6 @@ const ( ccCubic = "cubic" ) -// CongestionControlOption sets the current congestion control algorithm. -type CongestionControlOption string - -// AvailableCongestionControlOption returns the supported congestion control -// algorithms. -type AvailableCongestionControlOption string - type protocol struct { mu sync.Mutex sackEnabled bool @@ -89,7 +86,6 @@ type protocol struct { recvBufferSize ReceiveBufferSizeOption congestionControl string availableCongestionControl []string - allowedCongestionControl []string } // Number returns the tcp protocol number. @@ -184,7 +180,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil - case CongestionControlOption: + case tcpip.CongestionControlOption: for _, c := range p.availableCongestionControl { if string(v) == c { p.mu.Lock() @@ -193,7 +189,9 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { return nil } } - return tcpip.ErrInvalidOptionValue + // linux returns ENOENT when an invalid congestion control + // is specified. + return tcpip.ErrNoSuchFile default: return tcpip.ErrUnknownProtocolOption } @@ -219,14 +217,14 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { *v = p.recvBufferSize p.mu.Unlock() return nil - case *CongestionControlOption: + case *tcpip.CongestionControlOption: p.mu.Lock() - *v = CongestionControlOption(p.congestionControl) + *v = tcpip.CongestionControlOption(p.congestionControl) p.mu.Unlock() return nil - case *AvailableCongestionControlOption: + case *tcpip.AvailableCongestionControlOption: p.mu.Lock() - *v = AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) + *v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) p.mu.Unlock() return nil default: diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index b08a0e356..f02fa6105 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -134,6 +134,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // sequence numbers that have been consumed. TrimSACKBlockList(&r.ep.sack, r.rcvNxt) + // Handle FIN or FIN-ACK. if s.flagIsSet(header.TCPFlagFin) { r.rcvNxt++ @@ -144,6 +145,25 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum r.closed = true r.ep.readyToRead(nil) + // We just received a FIN, our next state depends on whether we sent a + // FIN already or not. + r.ep.mu.Lock() + switch r.ep.state { + case StateEstablished: + r.ep.state = StateCloseWait + case StateFinWait1: + if s.flagIsSet(header.TCPFlagAck) { + // FIN-ACK, transition to TIME-WAIT. + r.ep.state = StateTimeWait + } else { + // Simultaneous close, expecting a final ACK. + r.ep.state = StateClosing + } + case StateFinWait2: + r.ep.state = StateTimeWait + } + r.ep.mu.Unlock() + // Flush out any pending segments, except the very first one if // it happens to be the one we're handling now because the // caller is using it. @@ -156,6 +176,23 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum r.pendingRcvdSegments[i].decRef() } r.pendingRcvdSegments = r.pendingRcvdSegments[:first] + + return true + } + + // Handle ACK (not FIN-ACK, which we handled above) during one of the + // shutdown states. + if s.flagIsSet(header.TCPFlagAck) { + r.ep.mu.Lock() + switch r.ep.state { + case StateFinWait1: + r.ep.state = StateFinWait2 + case StateClosing: + r.ep.state = StateTimeWait + case StateLastAck: + r.ep.state = StateClose + } + r.ep.mu.Unlock() } return true diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go index 3b020e580..e0759225e 100644 --- a/pkg/tcpip/transport/tcp/segment_queue.go +++ b/pkg/tcpip/transport/tcp/segment_queue.go @@ -16,8 +16,6 @@ package tcp import ( "sync" - - "gvisor.googlesource.com/gvisor/pkg/tcpip/header" ) // segmentQueue is a bounded, thread-safe queue of TCP segments. @@ -58,7 +56,7 @@ func (q *segmentQueue) enqueue(s *segment) bool { r := q.used < q.limit if r { q.list.PushBack(s) - q.used += s.data.Size() + header.TCPMinimumSize + q.used++ } q.mu.Unlock() @@ -73,7 +71,7 @@ func (q *segmentQueue) dequeue() *segment { s := q.list.Front() if s != nil { q.list.Remove(s) - q.used -= s.data.Size() + header.TCPMinimumSize + q.used-- } q.mu.Unlock() diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index afc1d0a55..daa3e8341 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -194,8 +194,6 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint s := &sender{ ep: ep, - sndCwnd: InitialCwnd, - sndSsthresh: math.MaxInt64, sndWnd: sndWnd, sndUna: iss + 1, sndNxt: iss + 1, @@ -238,7 +236,13 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint return s } -func (s *sender) initCongestionControl(congestionControlName CongestionControlOption) congestionControl { +// initCongestionControl initializes the specified congestion control module and +// returns a handle to it. It also initializes the sndCwnd and sndSsThresh to +// their initial values. +func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionControlOption) congestionControl { + s.sndCwnd = InitialCwnd + s.sndSsthresh = math.MaxInt64 + switch congestionControlName { case ccCubic: return newCubicCC(s) @@ -632,6 +636,10 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se } seg.flags = header.TCPFlagAck | header.TCPFlagFin segEnd = seg.sequenceNumber.Add(1) + // Transition to FIN-WAIT1 state since we're initiating an active close. + s.ep.mu.Lock() + s.ep.state = StateFinWait1 + s.ep.mu.Unlock() } else { // We're sending a non-FIN segment. if seg.flags&header.TCPFlagFin != 0 { @@ -779,7 +787,7 @@ func (s *sender) sendData() { break } dataSent = true - s.outstanding++ + s.outstanding += s.pCount(seg) s.writeNext = seg.Next() } } diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go new file mode 100644 index 000000000..4d1519860 --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go @@ -0,0 +1,519 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// These tests are flaky when run under the go race detector due to some +// iterations taking long enough that the retransmit timer can kick in causing +// the congestion window measurements to fail due to extra packets etc. +// +// +build !race + +package tcp_test + +import ( + "fmt" + "math" + "testing" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp/testing/context" +) + +func TestFastRecovery(t *testing.T) { + maxPayload := 32 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + const iterations = 7 + data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + for i := range data { + data[i] = byte(i) + } + + // Write all the data in one shot. Packets will only be written at the + // MTU size though. + if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Do slow start for a few iterations. + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + expected = tcp.InitialCwnd << uint(i) + if i > 0 { + // Acknowledge all the data received so far if not on + // first iteration. + c.SendAck(790, bytesRead) + } + + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) + } + + // Send 3 duplicate acks. This should force an immediate retransmit of + // the pending packet and put the sender into fast recovery. + rtxOffset := bytesRead - maxPayload*expected + for i := 0; i < 3; i++ { + c.SendAck(790, rtxOffset) + } + + // Receive the retransmitted packet. + c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) + + if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { + t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) + } + + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { + t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want) + } + + if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want { + t.Errorf("got stats.TCP.FastRecovery.Value = %v, want = %v", got, want) + } + + // Now send 7 mode duplicate acks. Each of these should cause a window + // inflation by 1 and cause the sender to send an extra packet. + for i := 0; i < 7; i++ { + c.SendAck(790, rtxOffset) + } + + recover := bytesRead + + // Ensure no new packets arrive. + c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.", + 50*time.Millisecond) + + // Acknowledge half of the pending data. + rtxOffset = bytesRead - expected*maxPayload/2 + c.SendAck(790, rtxOffset) + + // Receive the retransmit due to partial ack. + c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) + + if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want { + t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) + } + + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want { + t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want) + } + + // Receive the 10 extra packets that should have been released due to + // the congestion window inflation in recovery. + for i := 0; i < 10; i++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // A partial ACK during recovery should reduce congestion window by the + // number acked. Since we had "expected" packets outstanding before sending + // partial ack and we acked expected/2 , the cwnd and outstanding should + // be expected/2 + 10 (7 dupAcks + 3 for the original 3 dupacks that triggered + // fast recovery). Which means the sender should not send any more packets + // till we ack this one. + c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", + 50*time.Millisecond) + + // Acknowledge all pending data to recover point. + c.SendAck(790, recover) + + // At this point, the cwnd should reset to expected/2 and there are 10 + // packets outstanding. + // + // NOTE: Technically netstack is incorrect in that we adjust the cwnd on + // the same segment that takes us out of recovery. But because of that + // the actual cwnd at exit of recovery will be expected/2 + 1 as we + // acked a cwnd worth of packets which will increase the cwnd further by + // 1 in congestion avoidance. + // + // Now in the first iteration since there are 10 packets outstanding. + // We would expect to get expected/2 +1 - 10 packets. But subsequent + // iterations will send us expected/2 + 1 + 1 (per iteration). + expected = expected/2 + 1 - 10 + for i := 0; i < iterations; i++ { + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd.", expected), 50*time.Millisecond) + + // Acknowledge all the data received so far. + c.SendAck(790, bytesRead) + + // In cogestion avoidance, the packets trains increase by 1 in + // each iteration. + if i == 0 { + // After the first iteration we expect to get the full + // congestion window worth of packets in every + // iteration. + expected += 10 + } + expected++ + } +} + +func TestExponentialIncreaseDuringSlowStart(t *testing.T) { + maxPayload := 32 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + const iterations = 7 + data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) + for i := range data { + data[i] = byte(i) + } + + // Write all the data in one shot. Packets will only be written at the + // MTU size though. + if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) + + // Acknowledge all the data received so far. + c.SendAck(790, bytesRead) + + // Double the number of expected packets for the next iteration. + expected *= 2 + } +} + +func TestCongestionAvoidance(t *testing.T) { + maxPayload := 32 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + const iterations = 7 + data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + for i := range data { + data[i] = byte(i) + } + + // Write all the data in one shot. Packets will only be written at the + // MTU size though. + if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Do slow start for a few iterations. + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + expected = tcp.InitialCwnd << uint(i) + if i > 0 { + // Acknowledge all the data received so far if not on + // first iteration. + c.SendAck(790, bytesRead) + } + + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd (slow start phase).", 50*time.Millisecond) + } + + // Don't acknowledge the first packet of the last packet train. Let's + // wait for them to time out, which will trigger a restart of slow + // start, and initialization of ssthresh to cwnd/2. + rtxOffset := bytesRead - maxPayload*expected + c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) + + // Acknowledge all the data received so far. + c.SendAck(790, bytesRead) + + // This part is tricky: when the timeout happened, we had "expected" + // packets pending, cwnd reset to 1, and ssthresh set to expected/2. + // By acknowledging "expected" packets, the slow-start part will + // increase cwnd to expected/2 (which "consumes" expected/2-1 of the + // acknowledgements), then the congestion avoidance part will consume + // an extra expected/2 acks to take cwnd to expected/2 + 1. One ack + // remains in the "ack count" (which will cause cwnd to be incremented + // once it reaches cwnd acks). + // + // So we're straight into congestion avoidance with cwnd set to + // expected/2 + 1. + // + // Check that packets trains of cwnd packets are sent, and that cwnd is + // incremented by 1 after we acknowledge each packet. + expected = expected/2 + 1 + for i := 0; i < iterations; i++ { + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd (congestion avoidance phase).", 50*time.Millisecond) + + // Acknowledge all the data received so far. + c.SendAck(790, bytesRead) + + // In cogestion avoidance, the packets trains increase by 1 in + // each iteration. + expected++ + } +} + +// cubicCwnd returns an estimate of a cubic window given the +// originalCwnd, wMax, last congestion event time and sRTT. +func cubicCwnd(origCwnd int, wMax int, congEventTime time.Time, sRTT time.Duration) int { + cwnd := float64(origCwnd) + // We wait 50ms between each iteration so sRTT as computed by cubic + // should be close to 50ms. + elapsed := (time.Since(congEventTime) + sRTT).Seconds() + k := math.Cbrt(float64(wMax) * 0.3 / 0.7) + wtRTT := 0.4*math.Pow(elapsed-k, 3) + float64(wMax) + cwnd += (wtRTT - cwnd) / cwnd + return int(cwnd) +} + +func TestCubicCongestionAvoidance(t *testing.T) { + maxPayload := 32 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + enableCUBIC(t, c) + + c.CreateConnected(789, 30000, nil) + + const iterations = 7 + data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + + for i := range data { + data[i] = byte(i) + } + + // Write all the data in one shot. Packets will only be written at the + // MTU size though. + if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Do slow start for a few iterations. + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + expected = tcp.InitialCwnd << uint(i) + if i > 0 { + // Acknowledge all the data received so far if not on + // first iteration. + c.SendAck(790, bytesRead) + } + + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd (during slow-start phase).", 50*time.Millisecond) + } + + // Don't acknowledge the first packet of the last packet train. Let's + // wait for them to time out, which will trigger a restart of slow + // start, and initialization of ssthresh to cwnd * 0.7. + rtxOffset := bytesRead - maxPayload*expected + c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) + + // Acknowledge all pending data. + c.SendAck(790, bytesRead) + + // Store away the time we sent the ACK and assuming a 200ms RTO + // we estimate that the sender will have an RTO 200ms from now + // and go back into slow start. + packetDropTime := time.Now().Add(200 * time.Millisecond) + + // This part is tricky: when the timeout happened, we had "expected" + // packets pending, cwnd reset to 1, and ssthresh set to expected * 0.7. + // By acknowledging "expected" packets, the slow-start part will + // increase cwnd to expected/2 essentially putting the connection + // straight into congestion avoidance. + wMax := expected + // Lower expected as per cubic spec after a congestion event. + expected = int(float64(expected) * 0.7) + cwnd := expected + for i := 0; i < iterations; i++ { + // Cubic grows window independent of ACKs. Cubic Window growth + // is a function of time elapsed since last congestion event. + // As a result the congestion window does not grow + // deterministically in response to ACKs. + // + // We need to roughly estimate what the cwnd of the sender is + // based on when we sent the dupacks. + cwnd := cubicCwnd(cwnd, wMax, packetDropTime, 50*time.Millisecond) + + packetsExpected := cwnd + for j := 0; j < packetsExpected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + t.Logf("expected packets received, next trying to receive any extra packets that may come") + + // If our estimate was correct there should be no more pending packets. + // We attempt to read a packet a few times with a short sleep in between + // to ensure that we don't see the sender send any unexpected packets. + unexpectedPackets := 0 + for { + gotPacket := c.ReceiveNonBlockingAndCheckPacket(data, bytesRead, maxPayload) + if !gotPacket { + break + } + bytesRead += maxPayload + unexpectedPackets++ + time.Sleep(1 * time.Millisecond) + } + if unexpectedPackets != 0 { + t.Fatalf("received %d unexpected packets for iteration %d", unexpectedPackets, i) + } + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd(congestion avoidance)", 5*time.Millisecond) + + // Acknowledge all the data received so far. + c.SendAck(790, bytesRead) + } +} + +func TestRetransmit(t *testing.T) { + maxPayload := 32 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + const iterations = 7 + data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) + for i := range data { + data[i] = byte(i) + } + + // Write all the data in two shots. Packets will only be written at the + // MTU size though. + half := data[:len(data)/2] + if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + half = data[len(data)/2:] + if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Do slow start for a few iterations. + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + expected = tcp.InitialCwnd << uint(i) + if i > 0 { + // Acknowledge all the data received so far if not on + // first iteration. + c.SendAck(790, bytesRead) + } + + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) + } + + // Wait for a timeout and retransmit. + rtxOffset := bytesRead - maxPayload*expected + c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) + + if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want { + t.Errorf("got stats.TCP.Timeouts.Value = %v, want = %v", got, want) + } + + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { + t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) + } + + if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want { + t.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want) + } + + // Acknowledge half of the pending data. + rtxOffset = bytesRead - expected*maxPayload/2 + c.SendAck(790, rtxOffset) + + // Receive the remaining data, making sure that acknowledged data is not + // retransmitted. + for offset := rtxOffset; offset < len(data); offset += maxPayload { + c.ReceiveAndCheckPacket(data, offset, maxPayload) + c.SendAck(790, offset+maxPayload) + } + + c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index fe037602b..7d8987219 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -168,8 +168,8 @@ func TestTCPResetsSentIncrement(t *testing.T) { // Receive the SYN-ACK reply. b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) ackHeaders := &context.Headers{ SrcPort: context.TestPort, @@ -269,8 +269,8 @@ func TestConnectResetAfterClose(t *testing.T) { time.Sleep(3 * time.Second) for { b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - if tcp.Flags() == header.TCPFlagAck|header.TCPFlagFin { + tcpHdr := header.TCP(header.IPv4(b).Payload()) + if tcpHdr.Flags() == header.TCPFlagAck|header.TCPFlagFin { // This is a retransmit of the FIN, ignore it. continue } @@ -553,9 +553,13 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { // We shouldn't consume a sequence number on RST. checker.SeqNum(uint32(c.IRS)+1), )) + // The RST puts the endpoint into an error state. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } - // This final should be ignored because an ACK on a reset doesn't - // mean anything. + // This final ACK should be ignored because an ACK on a reset doesn't mean + // anything. c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, @@ -618,6 +622,10 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.SeqNum(uint32(c.IRS)+1), )) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + // Cause a RST to be generated by closing the read end now since we have // unread data. c.EP.Shutdown(tcpip.ShutdownRead) @@ -630,6 +638,10 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { // We shouldn't consume a sequence number on RST. checker.SeqNum(uint32(c.IRS)+1), )) + // The RST puts the endpoint into an error state. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } // The ACK to the FIN should now be rejected since the connection has been // closed by a RST. @@ -1510,8 +1522,8 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { for bytesReceived != dataLen { b := c.GetPacket() numPackets++ - tcp := header.TCP(header.IPv4(b).Payload()) - payloadLen := len(tcp.Payload()) + tcpHdr := header.TCP(header.IPv4(b).Payload()) + payloadLen := len(tcpHdr.Payload()) checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), @@ -1522,7 +1534,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { ) pdata := data[bytesReceived : bytesReceived+payloadLen] - if p := tcp.Payload(); !bytes.Equal(pdata, p) { + if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) { t.Fatalf("got data = %v, want = %v", p, pdata) } bytesReceived += payloadLen @@ -1530,7 +1542,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { if c.TimeStampEnabled { // If timestamp option is enabled, echo back the timestamp and increment // the TSEcr value included in the packet and send that back as the TSVal. - parsedOpts := tcp.ParsedOptions() + parsedOpts := tcpHdr.ParsedOptions() tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:]) options = tsOpt[:] @@ -1757,8 +1769,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { ), ) - tcp := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) // Wait for retransmit. time.Sleep(1 * time.Second) @@ -1766,8 +1778,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagSyn), - checker.SrcPort(tcp.SourcePort()), - checker.SeqNum(tcp.SequenceNumber()), + checker.SrcPort(tcpHdr.SourcePort()), + checker.SeqNum(tcpHdr.SequenceNumber()), checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), ), ) @@ -1775,8 +1787,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { // Send SYN-ACK. iss := seqnum.Value(789) c.SendPacket(nil, &context.Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: tcp.SourcePort(), + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), Flags: header.TCPFlagSyn | header.TCPFlagAck, SeqNum: iss, AckNum: c.IRS.Add(1), @@ -2336,491 +2348,6 @@ func TestFinWithPartialAck(t *testing.T) { }) } -func TestExponentialIncreaseDuringSlowStart(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - const iterations = 7 - data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // Double the number of expected packets for the next iteration. - expected *= 2 - } -} - -func TestCongestionAvoidance(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - const iterations = 7 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd (slow start phase).", 50*time.Millisecond) - } - - // Don't acknowledge the first packet of the last packet train. Let's - // wait for them to time out, which will trigger a restart of slow - // start, and initialization of ssthresh to cwnd/2. - rtxOffset := bytesRead - maxPayload*expected - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // This part is tricky: when the timeout happened, we had "expected" - // packets pending, cwnd reset to 1, and ssthresh set to expected/2. - // By acknowledging "expected" packets, the slow-start part will - // increase cwnd to expected/2 (which "consumes" expected/2-1 of the - // acknowledgements), then the congestion avoidance part will consume - // an extra expected/2 acks to take cwnd to expected/2 + 1. One ack - // remains in the "ack count" (which will cause cwnd to be incremented - // once it reaches cwnd acks). - // - // So we're straight into congestion avoidance with cwnd set to - // expected/2 + 1. - // - // Check that packets trains of cwnd packets are sent, and that cwnd is - // incremented by 1 after we acknowledge each packet. - expected = expected/2 + 1 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd (congestion avoidance phase).", 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // In cogestion avoidance, the packets trains increase by 1 in - // each iteration. - expected++ - } -} - -// cubicCwnd returns an estimate of a cubic window given the -// originalCwnd, wMax, last congestion event time and sRTT. -func cubicCwnd(origCwnd int, wMax int, congEventTime time.Time, sRTT time.Duration) int { - cwnd := float64(origCwnd) - // We wait 50ms between each iteration so sRTT as computed by cubic - // should be close to 50ms. - elapsed := (time.Since(congEventTime) + sRTT).Seconds() - k := math.Cbrt(float64(wMax) * 0.3 / 0.7) - wtRTT := 0.4*math.Pow(elapsed-k, 3) + float64(wMax) - cwnd += (wtRTT - cwnd) / cwnd - return int(cwnd) -} - -func TestCubicCongestionAvoidance(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - enableCUBIC(t, c) - - c.CreateConnected(789, 30000, nil) - - const iterations = 7 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) - - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd (during slow-start phase).", 50*time.Millisecond) - } - - // Don't acknowledge the first packet of the last packet train. Let's - // wait for them to time out, which will trigger a restart of slow - // start, and initialization of ssthresh to cwnd * 0.7. - rtxOffset := bytesRead - maxPayload*expected - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - // Acknowledge all pending data. - c.SendAck(790, bytesRead) - - // Store away the time we sent the ACK and assuming a 200ms RTO - // we estimate that the sender will have an RTO 200ms from now - // and go back into slow start. - packetDropTime := time.Now().Add(200 * time.Millisecond) - - // This part is tricky: when the timeout happened, we had "expected" - // packets pending, cwnd reset to 1, and ssthresh set to expected * 0.7. - // By acknowledging "expected" packets, the slow-start part will - // increase cwnd to expected/2 essentially putting the connection - // straight into congestion avoidance. - wMax := expected - // Lower expected as per cubic spec after a congestion event. - expected = int(float64(expected) * 0.7) - cwnd := expected - for i := 0; i < iterations; i++ { - // Cubic grows window independent of ACKs. Cubic Window growth - // is a function of time elapsed since last congestion event. - // As a result the congestion window does not grow - // deterministically in response to ACKs. - // - // We need to roughly estimate what the cwnd of the sender is - // based on when we sent the dupacks. - cwnd := cubicCwnd(cwnd, wMax, packetDropTime, 50*time.Millisecond) - - packetsExpected := cwnd - for j := 0; j < packetsExpected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - t.Logf("expected packets received, next trying to receive any extra packets that may come") - - // If our estimate was correct there should be no more pending packets. - // We attempt to read a packet a few times with a short sleep in between - // to ensure that we don't see the sender send any unexpected packets. - unexpectedPackets := 0 - for { - gotPacket := c.ReceiveNonBlockingAndCheckPacket(data, bytesRead, maxPayload) - if !gotPacket { - break - } - bytesRead += maxPayload - unexpectedPackets++ - time.Sleep(1 * time.Millisecond) - } - if unexpectedPackets != 0 { - t.Fatalf("received %d unexpected packets for iteration %d", unexpectedPackets, i) - } - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd(congestion avoidance)", 5*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - } -} - -func TestFastRecovery(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - const iterations = 7 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - } - - // Send 3 duplicate acks. This should force an immediate retransmit of - // the pending packet and put the sender into fast recovery. - rtxOffset := bytesRead - maxPayload*expected - for i := 0; i < 3; i++ { - c.SendAck(790, rtxOffset) - } - - // Receive the retransmitted packet. - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) - } - - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want) - } - - if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.FastRecovery.Value = %v, want = %v", got, want) - } - - // Now send 7 mode duplicate acks. Each of these should cause a window - // inflation by 1 and cause the sender to send an extra packet. - for i := 0; i < 7; i++ { - c.SendAck(790, rtxOffset) - } - - recover := bytesRead - - // Ensure no new packets arrive. - c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.", - 50*time.Millisecond) - - // Acknowledge half of the pending data. - rtxOffset = bytesRead - expected*maxPayload/2 - c.SendAck(790, rtxOffset) - - // Receive the retransmit due to partial ack. - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want { - t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) - } - - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want { - t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want) - } - - // Receive the 10 extra packets that should have been released due to - // the congestion window inflation in recovery. - for i := 0; i < 10; i++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // A partial ACK during recovery should reduce congestion window by the - // number acked. Since we had "expected" packets outstanding before sending - // partial ack and we acked expected/2 , the cwnd and outstanding should - // be expected/2 + 10 (7 dupAcks + 3 for the original 3 dupacks that triggered - // fast recovery). Which means the sender should not send any more packets - // till we ack this one. - c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", - 50*time.Millisecond) - - // Acknowledge all pending data to recover point. - c.SendAck(790, recover) - - // At this point, the cwnd should reset to expected/2 and there are 10 - // packets outstanding. - // - // NOTE: Technically netstack is incorrect in that we adjust the cwnd on - // the same segment that takes us out of recovery. But because of that - // the actual cwnd at exit of recovery will be expected/2 + 1 as we - // acked a cwnd worth of packets which will increase the cwnd further by - // 1 in congestion avoidance. - // - // Now in the first iteration since there are 10 packets outstanding. - // We would expect to get expected/2 +1 - 10 packets. But subsequent - // iterations will send us expected/2 + 1 + 1 (per iteration). - expected = expected/2 + 1 - 10 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd.", expected), 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // In cogestion avoidance, the packets trains increase by 1 in - // each iteration. - if i == 0 { - // After the first iteration we expect to get the full - // congestion window worth of packets in every - // iteration. - expected += 10 - } - expected++ - } -} - -func TestRetransmit(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - const iterations = 7 - data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in two shots. Packets will only be written at the - // MTU size though. - half := data[:len(data)/2] - if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - half = data[len(data)/2:] - if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - } - - // Wait for a timeout and retransmit. - rtxOffset := bytesRead - maxPayload*expected - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.Timeouts.Value = %v, want = %v", got, want) - } - - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) - } - - if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want) - } - - // Acknowledge half of the pending data. - rtxOffset = bytesRead - expected*maxPayload/2 - c.SendAck(790, rtxOffset) - - // Receive the remaining data, making sure that acknowledged data is not - // retransmitted. - for offset := rtxOffset; offset < len(data); offset += maxPayload { - c.ReceiveAndCheckPacket(data, offset, maxPayload) - c.SendAck(790, offset+maxPayload) - } - - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) -} - func TestUpdateListenBacklog(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -3008,8 +2535,8 @@ func TestReceivedSegmentQueuing(t *testing.T) { checker.TCPFlags(header.TCPFlagAck), ), ) - tcp := header.TCP(header.IPv4(b).Payload()) - ack := seqnum.Value(tcp.AckNumber()) + tcpHdr := header.TCP(header.IPv4(b).Payload()) + ack := seqnum.Value(tcpHdr.AckNumber()) if ack == last { break } @@ -3053,6 +2580,10 @@ func TestReadAfterClosedState(t *testing.T) { ), ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + // Send some data and acknowledge the FIN. data := []byte{1, 2, 3} c.SendPacket(data, &context.Headers{ @@ -3074,9 +2605,15 @@ func TestReadAfterClosedState(t *testing.T) { ), ) - // Give the stack the chance to transition to closed state. + // Give the stack the chance to transition to closed state. Note that since + // both the sender and receiver are now closed, we effectively skip the + // TIME-WAIT state. time.Sleep(1 * time.Second) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + // Wait for receive to be notified. select { case <-ch: @@ -3668,13 +3205,14 @@ func TestTCPEndpointProbe(t *testing.T) { } } -func TestSetCongestionControl(t *testing.T) { +func TestStackSetCongestionControl(t *testing.T) { testCases := []struct { - cc tcp.CongestionControlOption - mustPass bool + cc tcpip.CongestionControlOption + err *tcpip.Error }{ - {"reno", true}, - {"cubic", true}, + {"reno", nil}, + {"cubic", nil}, + {"blahblah", tcpip.ErrNoSuchFile}, } for _, tc := range testCases { @@ -3684,62 +3222,135 @@ func TestSetCongestionControl(t *testing.T) { s := c.Stack() - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != nil && tc.mustPass { - t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want not-nil", tcp.ProtocolNumber, tc.cc, err) + var oldCC tcpip.CongestionControlOption + if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil { + t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &oldCC, err) } - var cc tcp.CongestionControlOption + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err { + t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.cc, err, tc.err) + } + + var cc tcpip.CongestionControlOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) } - if got, want := cc, tc.cc; got != want { + + got, want := cc, oldCC + // If SetTransportProtocolOption is expected to succeed + // then the returned value for congestion control should + // match the one specified in the + // SetTransportProtocolOption call above, else it should + // be what it was before the call to + // SetTransportProtocolOption. + if tc.err == nil { + want = tc.cc + } + if got != want { t.Fatalf("got congestion control: %v, want: %v", got, want) } }) } } -func TestAvailableCongestionControl(t *testing.T) { +func TestStackAvailableCongestionControl(t *testing.T) { c := context.New(t, 1500) defer c.Cleanup() s := c.Stack() // Query permitted congestion control algorithms. - var aCC tcp.AvailableCongestionControlOption + var aCC tcpip.AvailableCongestionControlOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil { t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err) } - if got, want := aCC, tcp.AvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcp.AvailableCongestionControlOption: %v, want: %v", got, want) + if got, want := aCC, tcpip.AvailableCongestionControlOption("reno cubic"); got != want { + t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want) } } -func TestSetAvailableCongestionControl(t *testing.T) { +func TestStackSetAvailableCongestionControl(t *testing.T) { c := context.New(t, 1500) defer c.Cleanup() s := c.Stack() // Setting AvailableCongestionControlOption should fail. - aCC := tcp.AvailableCongestionControlOption("xyz") + aCC := tcpip.AvailableCongestionControlOption("xyz") if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil { t.Fatalf("s.TransportProtocolOption(%v, %v) = nil, want non-nil", tcp.ProtocolNumber, &aCC) } // Verify that we still get the expected list of congestion control options. - var cc tcp.AvailableCongestionControlOption + var cc tcpip.AvailableCongestionControlOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) } - if got, want := cc, tcp.AvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcp.AvailableCongestionControlOption: %v, want: %v", got, want) + if got, want := cc, tcpip.AvailableCongestionControlOption("reno cubic"); got != want { + t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want) + } +} + +func TestEndpointSetCongestionControl(t *testing.T) { + testCases := []struct { + cc tcpip.CongestionControlOption + err *tcpip.Error + }{ + {"reno", nil}, + {"cubic", nil}, + {"blahblah", tcpip.ErrNoSuchFile}, + } + + for _, connected := range []bool{false, true} { + for _, tc := range testCases { + t.Run(fmt.Sprintf("SetSockOpt(.., %v) w/ connected = %v", tc.cc, connected), func(t *testing.T) { + c := context.New(t, 1500) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + var oldCC tcpip.CongestionControlOption + if err := c.EP.GetSockOpt(&oldCC); err != nil { + t.Fatalf("c.EP.SockOpt(%v) = %v", &oldCC, err) + } + + if connected { + c.Connect(789 /* iss */, 32768 /* rcvWnd */, nil) + } + + if err := c.EP.SetSockOpt(tc.cc); err != tc.err { + t.Fatalf("c.EP.SetSockOpt(%v) = %v, want %v", tc.cc, err, tc.err) + } + + var cc tcpip.CongestionControlOption + if err := c.EP.GetSockOpt(&cc); err != nil { + t.Fatalf("c.EP.SockOpt(%v) = %v", &cc, err) + } + + got, want := cc, oldCC + // If SetSockOpt is expected to succeed then the + // returned value for congestion control should match + // the one specified in the SetSockOpt above, else it + // should be what it was before the call to SetSockOpt. + if tc.err == nil { + want = tc.cc + } + if got != want { + t.Fatalf("got congestion control: %v, want: %v", got, want) + } + }) + } } } func enableCUBIC(t *testing.T, c *context.Context) { t.Helper() - opt := tcp.CongestionControlOption("cubic") + opt := tcpip.CongestionControlOption("cubic") if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil { t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %v = %v", opt, err) } @@ -3868,7 +3479,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki RcvWnd: 30000, }) - // Receive the SYN-ACK reply. + // Receive the SYN-ACK reply.w b := c.GetPacket() tcp := header.TCP(header.IPv4(b).Payload()) iss = seqnum.Value(tcp.SequenceNumber()) @@ -3932,12 +3543,18 @@ func TestListenBacklogFull(t *testing.T) { time.Sleep(50 * time.Millisecond) - // Now execute one more handshake. This should not be completed and - // delivered on an Accept() call as the backlog is full at this point. - irs, iss := executeHandshake(t, c, context.TestPort+uint16(listenBacklog), false /* synCookieInUse */) + // Now execute send one more SYN. The stack should not respond as the backlog + // is full at this point. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort + 2, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: seqnum.Value(789), + RcvWnd: 30000, + }) + c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - time.Sleep(50 * time.Millisecond) - // Try to accept the connection. + // Try to accept the connections in the backlog. we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) @@ -3969,16 +3586,8 @@ func TestListenBacklogFull(t *testing.T) { } } - // Now craft the ACK again and verify that the connection is now ready - // to be accepted. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + uint16(listenBacklog), - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) + // Now a new handshake must succeed. + executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */) newEP, _, err := c.EP.Accept() if err == tcpip.ErrWouldBlock { @@ -3994,6 +3603,7 @@ func TestListenBacklogFull(t *testing.T) { t.Fatalf("Timed out waiting for accept") } } + // Now verify that the TCP socket is usable and in a connected state. data := "Don't panic" newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) @@ -4004,13 +3614,7 @@ func TestListenBacklogFull(t *testing.T) { } } -func TestListenBacklogFullSynCookieInUse(t *testing.T) { - saved := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = saved - }() - tcp.SynRcvdCountThreshold = 1 - +func TestListenSynRcvdQueueFull(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -4029,48 +3633,72 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { // Test acceptance. // Start listening. listenBacklog := 1 - portOffset := uint16(0) if err := c.EP.Listen(listenBacklog); err != nil { t.Fatalf("Listen failed: %v", err) } - executeHandshake(t, c, context.TestPort+portOffset, false) - portOffset++ - // Wait for this to be delivered to the accept queue. - time.Sleep(50 * time.Millisecond) + // Send two SYN's the first one should get a SYN-ACK, the + // second one should not get any response and is dropped as + // the synRcvd count will be equal to backlog. + irs := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: seqnum.Value(789), + RcvWnd: 30000, + }) - nonCookieIRS, nonCookieISS := executeHandshake(t, c, context.TestPort+portOffset, false) + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcp := header.TCP(header.IPv4(b).Payload()) + iss := seqnum.Value(tcp.SequenceNumber()) + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.AckNum(uint32(irs) + 1), + } + checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - // Since the backlog is full at this point this connection will not - // transition out of handshake and ignore the ACK. - // - // At this point there should be 1 completed connection in the backlog - // and one incomplete one pending for a final ACK and hence not ready to be - // delivered to the endpoint. - // - // Now execute one more handshake. This should not be completed and - // delivered on an Accept() call as the backlog is full at this point - // and there is already 1 pending endpoint. + // Now execute send one more SYN. The stack should not respond as the backlog + // is full at this point. // - // This one should use a SYN cookie as the synRcvdCount is equal to the - // SynRcvdCountThreshold. - time.Sleep(50 * time.Millisecond) - portOffset++ - irs, iss := executeHandshake(t, c, context.TestPort+portOffset, true) + // NOTE: we did not complete the handshake for the previous one so the + // accept backlog should be empty and there should be one connection in + // synRcvd state. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort + 1, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: seqnum.Value(889), + RcvWnd: 30000, + }) + c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - time.Sleep(50 * time.Millisecond) + // Now complete the previous connection and verify that there is a connection + // to accept. + // Send ACK. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) - // Verify that there is only one acceptable connection at this point. + // Try to accept the connections in the backlog. we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - _, _, err = c.EP.Accept() + newEP, _, err := c.EP.Accept() if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept() + newEP, _, err = c.EP.Accept() if err != nil { t.Fatalf("Accept failed: %v", err) } @@ -4080,27 +3708,68 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { } } - // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept() - if err != tcpip.ErrWouldBlock { - select { - case <-ch: - t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) - case <-time.After(1 * time.Second): - } + // Now verify that the TCP socket is usable and in a connected state. + data := "Don't panic" + newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) + pkt := c.GetPacket() + tcp = header.TCP(header.IPv4(pkt).Payload()) + if string(tcp.Payload()) != data { + t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) } +} - // Now send an ACK for the half completed connection +func TestListenBacklogFullSynCookieInUse(t *testing.T) { + saved := tcp.SynRcvdCountThreshold + defer func() { + tcp.SynRcvdCountThreshold = saved + }() + tcp.SynRcvdCountThreshold = 1 + + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + // Start listening. + listenBacklog := 1 + portOffset := uint16(0) + if err := c.EP.Listen(listenBacklog); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + executeHandshake(t, c, context.TestPort+portOffset, false) + portOffset++ + // Wait for this to be delivered to the accept queue. + time.Sleep(50 * time.Millisecond) + + // Send a SYN request. + irs := seqnum.Value(789) c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + portOffset - 1, + SrcPort: context.TestPort, DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: nonCookieIRS + 1, - AckNum: nonCookieISS + 1, + Flags: header.TCPFlagSyn, + SeqNum: irs, RcvWnd: 30000, }) + // The Syn should be dropped as the endpoint's backlog is full. + c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) + + // Verify that there is only one acceptable connection at this point. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) - // Verify that the connection is now delivered to the backlog. _, _, err = c.EP.Accept() if err == tcpip.ErrWouldBlock { // Wait for connection to be established. @@ -4116,41 +3785,15 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { } } - // Finally send an ACK for the connection that used a cookie and verify that - // it's also completed and delivered. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + portOffset, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs, - AckNum: iss, - RcvWnd: 30000, - }) - - time.Sleep(50 * time.Millisecond) - newEP, _, err := c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. + // Now verify that there are no more connections that can be accepted. + _, _, err = c.EP.Accept() + if err != tcpip.ErrWouldBlock { select { case <-ch: - newEP, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - + t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") } } - - // Now verify that the TCP socket is usable and in a connected state. - data := "Don't panic" - newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - if string(tcp.Payload()) != data { - t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) - } } func TestPassiveConnectionAttemptIncrement(t *testing.T) { @@ -4165,9 +3808,15 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) { if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %v", err) } + if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } if err := c.EP.Listen(1); err != nil { t.Fatalf("Listen failed: %v", err) } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } stats := c.Stack().Stats() want := stats.TCP.PassiveConnectionOpenings.Value() + 1 @@ -4218,18 +3867,12 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { } srcPort := uint16(context.TestPort) - // Now attempt 3 handshakes, the first two will fill up the accept and the SYN-RCVD - // queue for the endpoint. + // Now attempt a handshakes it will fill up the accept backlog. executeHandshake(t, c, srcPort, false) // Give time for the final ACK to be processed as otherwise the next handshake could // get accepted before the previous one based on goroutine scheduling. time.Sleep(50 * time.Millisecond) - irs, iss := executeHandshake(t, c, srcPort+1, false) - - // Wait for a short while for the accepted connection to be delivered to - // the channel before trying to send the 3rd SYN. - time.Sleep(40 * time.Millisecond) want := stats.TCP.ListenOverflowSynDrop.Value() + 1 @@ -4267,26 +3910,44 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { t.Fatalf("Timed out waiting for accept") } } +} - // Now complete the next connection in SYN-RCVD state as it should - // have dropped the final ACK to the handshake due to accept queue - // being full. - c.SendPacket(nil, &context.Headers{ - SrcPort: srcPort + 1, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) +func TestEndpointBindListenAcceptState(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } - // Now check that there is one more acceptable connections. - _, _, err = c.EP.Accept() + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS}) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + aep, _, err := ep.Accept() if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept() + aep, _, err = ep.Accept() if err != nil { t.Fatalf("Accept failed: %v", err) } @@ -4295,19 +3956,23 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { t.Fatalf("Timed out waiting for accept") } } + if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + // Listening endpoint remains in listen state. + if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } - // Try and accept a 3rd one this should fail. - _, _, err = c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - ep, _, err = c.EP.Accept() - if err == nil { - t.Fatalf("Accept succeeded when it should have failed got: %+v", ep) - } - - case <-time.After(1 * time.Second): - } + ep.Close() + // Give worker goroutines time to receive the close notification. + time.Sleep(1 * time.Second) + if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + // Accepted endpoint remains open when the listen endpoint is closed. + if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) } + } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 6e12413c6..a4d89e24d 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -520,32 +520,21 @@ func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil) } -// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends -// the specified option bytes as the Option field in the initial SYN packet. +// Connect performs the 3-way handshake for c.EP with the provided Initial +// Sequence Number (iss) and receive window(rcvWnd) and any options if +// specified. // // It also sets the receive buffer for the endpoint to the specified // value in epRcvBuf. -func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) { - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - if epRcvBuf != nil { - if err := c.EP.SetSockOpt(*epRcvBuf); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) - } - } - +// +// PreCondition: c.EP must already be created. +func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) { // Start connection attempt. waitEntry, notifyCh := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&waitEntry, waiter.EventOut) defer c.WQ.EventUnregister(&waitEntry) - err = c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}) - if err != tcpip.ErrConnectStarted { + if err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}); err != tcpip.ErrConnectStarted { c.t.Fatalf("Unexpected return value from Connect: %v", err) } @@ -557,13 +546,16 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. checker.TCPFlags(header.TCPFlagSyn), ), ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } - tcp := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) c.SendPacket(nil, &Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: tcp.SourcePort(), + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), Flags: header.TCPFlagSyn | header.TCPFlagAck, SeqNum: iss, AckNum: c.IRS.Add(1), @@ -584,15 +576,38 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. // Wait for connection to be established. select { case <-notifyCh: - err = c.EP.GetSockOpt(tcpip.ErrorOption{}) - if err != nil { + if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil { c.t.Fatalf("Unexpected error when connecting: %v", err) } case <-time.After(1 * time.Second): c.t.Fatalf("Timed out waiting for connection") } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } + + c.Port = tcpHdr.SourcePort() +} + +// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends +// the specified option bytes as the Option field in the initial SYN packet. +// +// It also sets the receive buffer for the endpoint to the specified +// value in epRcvBuf. +func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) { + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } - c.Port = tcp.SourcePort() + if epRcvBuf != nil { + if err := c.EP.SetSockOpt(*epRcvBuf); err != nil { + c.t.Fatalf("SetSockOpt failed failed: %v", err) + } + } + c.Connect(iss, rcvWnd, options) } // RawEndpoint is just a small wrapper around a TCP endpoint's state to make @@ -690,6 +705,9 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * if err != nil { c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err) } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } // Start connection attempt. waitEntry, notifyCh := waiter.NewChannelEntry(nil) @@ -719,6 +737,10 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * }), ), ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } + tcpSeg := header.TCP(header.IPv4(b).Payload()) synOptions := header.ParseSynOptions(tcpSeg.Options(), false) @@ -782,6 +804,9 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * case <-time.After(1 * time.Second): c.t.Fatalf("Timed out waiting for connection") } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } // Store the source port in use by the endpoint. c.Port = tcpSeg.SourcePort() @@ -821,10 +846,16 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption if err := ep.Bind(tcpip.FullAddress{Port: StackPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } + if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { + c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } if err := ep.Listen(10); err != nil { c.t.Fatalf("Listen failed: %v", err) } + if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { + c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } rep := c.PassiveConnectWithOptions(100, wndScale, synOptions) @@ -847,6 +878,10 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption c.t.Fatalf("Timed out waiting for accept") } } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { + c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + return rep } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 3d52a4f31..fa7278286 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1000,3 +1000,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { } + +// State implements socket.Socket.State. +func (e *endpoint) State() uint32 { + // TODO(b/112063468): Translate internal state to values returned by Linux. + return 0 +} |